SAG delayed activation¶
Documentation¶
- Class name:
SAG delayed activation
- Category:
model_patches
- Output node:
False
This node introduces a custom implementation of Self-Attention Guidance (SAG) for neural networks, specifically designed to enhance model performance by dynamically adjusting attention mechanisms based on conditional inputs. It focuses on modifying the attention layers within a model to incorporate additional guidance, thereby potentially improving the model's ability to generate or process data with greater precision and relevance.
Input types¶
Required¶
model
- The neural network model to which the Self-Attention Guidance modifications will be applied. This parameter is crucial as it determines the base structure that will be enhanced with the custom SAG implementation.
- Comfy dtype:
MODEL
- Python dtype:
torch.nn.Module
scale
- A scaling factor that adjusts the intensity of the SAG effect, influencing how strongly the guidance affects the model's attention mechanisms.
- Comfy dtype:
FLOAT
- Python dtype:
float
blur_sigma
- Specifies the standard deviation of the Gaussian blur applied for adversarial blurring, part of the SAG technique to manipulate attention scores.
- Comfy dtype:
FLOAT
- Python dtype:
float
sigma_start
- The starting value of sigma for which the SAG modifications are applied, setting an operational range for the attention guidance.
- Comfy dtype:
FLOAT
- Python dtype:
float
sigma_end
- The ending value of sigma, marking the lower bound of the operational range for applying the SAG modifications.
- Comfy dtype:
FLOAT
- Python dtype:
float
Output types¶
model
- Comfy dtype:
MODEL
- The modified neural network model, now enhanced with the custom Self-Attention Guidance implementation, aimed at improving its performance by dynamically adjusting its attention mechanisms.
- Python dtype:
torch.nn.Module
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class SelfAttentionGuidanceCustom:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"scale": ("FLOAT", {"default": 0.5, "min": -2.0, "max": 5.0, "step": 0.1}),
"blur_sigma": ("FLOAT", {"default": 2.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"sigma_start": ("FLOAT", {"default": 15.0, "min": 0.0, "max": 1000.0, "step": 0.1, "round": 0.1}),
"sigma_end": ("FLOAT", {"default": 0.0, "min": 0.0, "max": 1000.0, "step": 0.1, "round": 0.1}),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "patch"
CATEGORY = "model_patches"
def patch(self, model, scale, blur_sigma, sigma_start, sigma_end):
m = model.clone()
attn_scores = None
# TODO: make this work properly with chunked batches
# currently, we can only save the attn from one UNet call
def attn_and_record(q, k, v, extra_options):
nonlocal attn_scores
# if uncond, save the attention scores
heads = extra_options["n_heads"]
cond_or_uncond = extra_options["cond_or_uncond"]
b = q.shape[0] // len(cond_or_uncond)
if 1 in cond_or_uncond:
uncond_index = cond_or_uncond.index(1)
# do the entire attention operation, but save the attention scores to attn_scores
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
n_slices = heads * b
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
return out
else:
return optimized_attention(q, k, v, heads=heads)
def post_cfg_function(args):
nonlocal attn_scores
uncond_attn = attn_scores
sag_scale = scale
sag_sigma = blur_sigma
sag_threshold = 1.0
model = args["model"]
uncond_pred = args["uncond_denoised"]
uncond = args["uncond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
model_options = args["model_options"]
x = args["input"]
if not isinstance(uncond, torch.Tensor):
return cfg_result
if min(cfg_result.shape[2:]) <= 4: #skip when too small to add padding
return cfg_result
if sigma[0] > sigma_start or sigma[0] < sigma_end:
return cfg_result
# create the adversarially blurred image
degraded = create_blur_map(uncond_pred, uncond_attn, sag_sigma, sag_threshold)
degraded_noised = degraded + x - uncond_pred
# call into the UNet
(sag, _) = comfy.samplers.calc_cond_batch(model, [uncond, None], degraded_noised, sigma, model_options)
# comfy.samplers.calc_cond_uncond_batch(model, uncond, None, degraded_noised, sigma, model_options)
return cfg_result + (degraded - sag) * sag_scale
m.set_model_sampler_post_cfg_function(post_cfg_function, disable_cfg1_optimization=False)
# from diffusers:
# unet.mid_block.attentions[0].transformer_blocks[0].attn1.patch
m.set_model_attn1_replace(attn_and_record, "middle", 0, 0)
return (m, )