StyleAligned Reference Sampler¶
Documentation¶
- Class name:
StyleAlignedReferenceSampler
- Category:
style_aligned
- Output node:
False
The StyleAlignedReferenceSampler node is designed to sample reference styles in a manner that aligns with specific style guidelines or constraints. It focuses on generating style references that adhere to predefined stylistic parameters, ensuring consistency and coherence in the style of the generated content.
Input types¶
Required¶
model
- Specifies the model to be used for sampling, ensuring that the generated styles align with the desired aesthetic or thematic guidelines.
- Comfy dtype:
MODEL
- Python dtype:
ModelPatcher
share_norm
- Determines how normalization layers are shared across the model, influencing the consistency of style application.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
SHARE_NORM_OPTIONS
share_attn
- Controls how attention layers are shared, affecting the focus and coherence of the generated styles.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
SHARE_ATTN_OPTIONS
scale
- Adjusts the scale of the applied style, allowing for finer control over the intensity of style effects.
- Comfy dtype:
FLOAT
- Python dtype:
float
batch_size
- Specifies the number of samples to be processed in parallel, affecting the efficiency of the sampling process.
- Comfy dtype:
INT
- Python dtype:
int
noise_seed
- Determines the seed for noise generation, allowing for reproducibility and consistency in the style sampling process.
- Comfy dtype:
INT
- Python dtype:
int
cfg
- Controls the conditioning factor, adjusting the influence of the conditioning on the style sampling.
- Comfy dtype:
FLOAT
- Python dtype:
float
positive
- Defines positive conditioning to guide the style sampling towards desired attributes.
- Comfy dtype:
CONDITIONING
- Python dtype:
Conditioning
negative
- Specifies negative conditioning to avoid certain attributes in the style sampling.
- Comfy dtype:
CONDITIONING
- Python dtype:
Conditioning
ref_positive
- Provides additional positive conditioning specifically for reference style generation, enhancing the alignment with desired attributes.
- Comfy dtype:
CONDITIONING
- Python dtype:
Conditioning
sampler
- Selects the sampling strategy to be used, affecting the diversity and quality of the generated styles.
- Comfy dtype:
SAMPLER
- Python dtype:
Sampler
sigmas
- Sets the noise levels for the sampling process, influencing the exploration of the style space.
- Comfy dtype:
SIGMAS
- Python dtype:
Sigmas
ref_latents
- Inputs pre-existing latents to be used as a reference in the style sampling process, enabling more precise style alignment.
- Comfy dtype:
STEP_LATENTS
- Python dtype:
Latent
Output types¶
output
- Comfy dtype:
LATENT
- Outputs the sampled latents according to the specified style guidelines.
- Python dtype:
Latent
- Comfy dtype:
denoised_output
- Comfy dtype:
LATENT
- Provides a denoised version of the sampled latents, potentially enhancing the clarity and quality of the generated styles.
- Python dtype:
Latent
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class StyleAlignedReferenceSampler:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"model": ("MODEL",),
"share_norm": (SHARE_NORM_OPTIONS,),
"share_attn": (SHARE_ATTN_OPTIONS,),
"scale": ("FLOAT", {"default": 1, "min": 0, "max": 2.0, "step": 0.01}),
"batch_size": ("INT", {"default": 2, "min": 1, "max": 8, "step": 1}),
"noise_seed": (
"INT",
{"default": 0, "min": 0, "max": 0xFFFFFFFFFFFFFFFF},
),
"cfg": (
"FLOAT",
{
"default": 8.0,
"min": 0.0,
"max": 100.0,
"step": 0.1,
"round": 0.01,
},
),
"positive": ("CONDITIONING",),
"negative": ("CONDITIONING",),
"ref_positive": ("CONDITIONING",),
"sampler": ("SAMPLER",),
"sigmas": ("SIGMAS",),
"ref_latents": ("STEP_LATENTS",),
},
}
RETURN_TYPES = ("LATENT", "LATENT")
RETURN_NAMES = ("output", "denoised_output")
FUNCTION = "patch"
CATEGORY = "style_aligned"
def patch(
self,
model: ModelPatcher,
share_norm: str,
share_attn: str,
scale: float,
batch_size: int,
noise_seed: int,
cfg: float,
positive: T,
negative: T,
ref_positive: T,
sampler: T,
sigmas: T,
ref_latents: T,
) -> "tuple[dict, dict]":
m = model.clone()
args = StyleAlignedArgs(share_attn)
# Concat batch with style latent
style_latent_tensor = ref_latents[0].unsqueeze(0)
height, width = style_latent_tensor.shape[-2:]
latent_t = torch.zeros(
[batch_size, 4, height, width], device=ref_latents.device
)
latent = {"samples": latent_t}
noise = comfy.sample.prepare_noise(latent_t, noise_seed)
latent_t = torch.cat((style_latent_tensor, latent_t), dim=0)
ref_noise = torch.zeros_like(noise[0]).unsqueeze(0)
noise = torch.cat((ref_noise, noise), dim=0)
x0_output = {}
preview_callback = latent_preview.prepare_callback(m, sigmas.shape[-1] - 1, x0_output)
# Replace first latent with the corresponding reference latent after each step
def callback(step: int, x0: T, x: T, steps: int):
preview_callback(step, x0, x, steps)
if (step + 1 < steps):
x[0] = ref_latents[step+1]
x0[0] = ref_latents[step+1]
# Register shared norms
share_group_norm = share_norm in ["group", "both"]
share_layer_norm = share_norm in ["layer", "both"]
register_shared_norm(m, share_group_norm, share_layer_norm)
# Patch cross attn
m.set_model_attn1_patch(SharedAttentionProcessor(args, scale))
# Add reference conditioning to batch
batched_condition = []
for i,condition in enumerate(positive):
additional = condition[1].copy()
batch_with_reference = torch.cat([ref_positive[i][0], condition[0].repeat([batch_size] + [1] * len(condition[0].shape[1:]))], dim=0)
if 'pooled_output' in additional and 'pooled_output' in ref_positive[i][1]:
# combine pooled output
pooled_output = torch.cat([ref_positive[i][1]['pooled_output'], additional['pooled_output'].repeat([batch_size]
+ [1] * len(additional['pooled_output'].shape[1:]))], dim=0)
additional['pooled_output'] = pooled_output
if 'control' in additional:
if 'control' in ref_positive[i][1]:
# combine control conditioning
control_hint = torch.cat([ref_positive[i][1]['control'].cond_hint_original, additional['control'].cond_hint_original.repeat([batch_size]
+ [1] * len(additional['control'].cond_hint_original.shape[1:]))], dim=0)
cloned_controlnet = additional['control'].copy()
cloned_controlnet.set_cond_hint(control_hint, strength=additional['control'].strength, timestep_percent_range=additional['control'].timestep_percent_range)
additional['control'] = cloned_controlnet
else:
# add zeros for first in batch
control_hint = torch.cat([torch.zeros_like(additional['control'].cond_hint_original), additional['control'].cond_hint_original.repeat([batch_size]
+ [1] * len(additional['control'].cond_hint_original.shape[1:]))], dim=0)
cloned_controlnet = additional['control'].copy()
cloned_controlnet.set_cond_hint(control_hint, strength=additional['control'].strength, timestep_percent_range=additional['control'].timestep_percent_range)
additional['control'] = cloned_controlnet
batched_condition.append([batch_with_reference, additional])
disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED
samples = comfy.sample.sample_custom(
m,
noise,
cfg,
sampler,
sigmas,
batched_condition,
negative,
latent_t,
callback=callback,
disable_pbar=disable_pbar,
seed=noise_seed,
)
# remove reference image
samples = samples[1:]
out = latent.copy()
out["samples"] = samples
if "x0" in x0_output:
out_denoised = latent.copy()
x0 = x0_output["x0"][1:]
out_denoised["samples"] = m.model.process_latent_out(x0.cpu())
else:
out_denoised = out
return (out, out_denoised)