Perturbed-Attention Guidance (Advanced)¶
Documentation¶
- Class name:
PerturbedAttention
- Category:
advanced/model
- Output node:
False
The PerturbedAttention node introduces a mechanism to modify the behavior of attention mechanisms within neural network models, specifically targeting the self-attention components of U-Net architectures. It achieves this by applying perturbations to the attention calculations, potentially enhancing the model's ability to focus on relevant features or introducing variability for generative tasks.
Input types¶
Required¶
model
- The neural network model to which perturbed attention will be applied. This parameter is crucial as it determines the base architecture that will be modified by the perturbation process.
- Comfy dtype:
MODEL
- Python dtype:
torch.nn.Module
scale
- Specifies the scale of perturbation to be applied, directly influencing the intensity and impact of the perturbations on the attention mechanism.
- Comfy dtype:
FLOAT
- Python dtype:
float
adaptive_scale
- Determines the adaptive scaling factor for perturbations, allowing for dynamic adjustment based on model performance or other criteria.
- Comfy dtype:
FLOAT
- Python dtype:
float
unet_block
- Identifies the specific U-Net block (input, middle, or output) where the perturbed attention mechanism will be applied, enabling precise targeting within the model.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
unet_block_id
- The numerical identifier for the U-Net block, further specifying the exact location for perturbation application within the chosen block.
- Comfy dtype:
INT
- Python dtype:
int
sigma_start
- Defines the starting value of sigma for the perturbation process, controlling the intensity of perturbations applied to the attention mechanism. It plays a key role in adjusting the model's focus and variability.
- Comfy dtype:
FLOAT
- Python dtype:
float
sigma_end
- Specifies the ending value of sigma, marking the point at which perturbation intensity ceases to increase, thereby defining the range of perturbation application.
- Comfy dtype:
FLOAT
- Python dtype:
float
rescale
- A factor that adjusts the scale of perturbations, offering an additional layer of control over their intensity.
- Comfy dtype:
FLOAT
- Python dtype:
float
rescale_mode
- Determines how the rescaling of perturbations is applied, with options including 'full' and 'partial', to influence the overall effect on the attention mechanism.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
Optional¶
unet_block_list
- A list specifying the U-Net blocks where the perturbed attention mechanism will be applied. This allows for targeted modifications within the model, enhancing flexibility and control over the perturbation effects.
- Comfy dtype:
STRING
- Python dtype:
List[str]
Output types¶
model
- Comfy dtype:
MODEL
- The modified neural network model with perturbed attention mechanisms applied to specified U-Net blocks. This output reflects the enhanced or altered capabilities of the model due to the perturbations.
- Python dtype:
torch.nn.Module
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class PerturbedAttention:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 100.0, "step": 0.1, "round": 0.01}),
"adaptive_scale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.001, "round": 0.0001}),
"unet_block": (["input", "middle", "output"], {"default": "middle"}),
"unet_block_id": ("INT", {"default": 0}),
"sigma_start": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"sigma_end": ("FLOAT", {"default": -1.0, "min": -1.0, "max": 10000.0, "step": 0.01, "round": False}),
"rescale": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"rescale_mode": (["full", "partial"], {"default": "full"}),
},
"optional": {
"unet_block_list": ("STRING", {"default": ""}),
},
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "advanced/model"
def parse_unet_blocks(self, model: ModelPatcher, unet_block_list: str):
output: list[tuple[str, int, int | None]] = []
# Get all Self-attention blocks
input_blocks, middle_blocks, output_blocks = [], [], []
for name, module in model.model.diffusion_model.named_modules():
if module.__class__.__name__ == "CrossAttention" and name.endswith("attn1"):
parts = name.split(".")
block_name = parts[0]
block_id = int(parts[1])
if block_name.startswith("input"):
input_blocks.append(block_id)
elif block_name.startswith("middle"):
middle_blocks.append(block_id - 1)
elif block_name.startswith("output"):
output_blocks.append(block_id)
def group_blocks(blocks: list[int]):
return [(i, len(list(gr))) for i, gr in groupby(blocks)]
input_blocks, middle_blocks, output_blocks = group_blocks(input_blocks), group_blocks(middle_blocks), group_blocks(output_blocks)
unet_blocks = [b.strip() for b in unet_block_list.split(",")]
for block in unet_blocks:
name, indices = block[0], block[1:].split(".")
match name:
case "d":
layer, cur_blocks = "input", input_blocks
case "m":
layer, cur_blocks = "middle", middle_blocks
case "u":
layer, cur_blocks = "output", output_blocks
if len(indices) >= 2:
number, index = cur_blocks[int(indices[0])][0], int(indices[1])
assert 0 <= index < cur_blocks[int(indices[0])][1]
else:
number, index = cur_blocks[int(indices[0])][0], None
output.append((layer, number, index))
return output
def patch(
self,
model: ModelPatcher,
scale: float = 3.0,
adaptive_scale: float = 0.0,
unet_block: str = "middle",
unet_block_id: int = 0,
sigma_start: float = -1.0,
sigma_end: float = -1.0,
rescale: float = 0.0,
rescale_mode: str = "full",
unet_block_list: str = "",
):
m = model.clone()
sigma_start = float("inf") if sigma_start < 0 else sigma_start
if unet_block_list:
blocks = self.parse_unet_blocks(model, unet_block_list)
else:
blocks = [(unet_block, unet_block_id, None)]
def perturbed_attention(q: Tensor, k: Tensor, v: Tensor, extra_options, mask=None):
"""Perturbed self-attention"""
return v
def post_cfg_function(args):
"""CFG+PAG"""
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"].copy()
x = args["input"]
signal_scale = scale
if adaptive_scale > 0:
t = model.model_sampling.timestep(sigma)[0].item()
signal_scale -= scale * (adaptive_scale**4) * (1000 - t)
if signal_scale < 0:
signal_scale = 0
if signal_scale == 0 or not (sigma_end < sigma[0] <= sigma_start):
return cfg_result
# Replace Self-attention with PAG
for block in blocks:
layer, number, index = block
model_options = set_model_options_patch_replace(model_options, perturbed_attention, "attn1", layer, number, index)
if BACKEND == "ComfyUI":
(pag_cond_pred,) = calc_cond_batch(model, [cond], x, sigma, model_options)
if BACKEND == "Forge":
(pag_cond_pred, _) = calc_cond_uncond_batch(model, cond, None, x, sigma, model_options)
pag = (cond_pred - pag_cond_pred) * signal_scale
return cfg_result + rescale_pag(pag, cond_pred, cfg_result, rescale, rescale_mode)
m.set_model_sampler_post_cfg_function(post_cfg_function)
return (m,)