Batch Normalize (Latent)¶
Documentation¶
- Class name:
BatchNormalizeLatent
- Category:
latent/filters
- Output node:
False
This node applies batch normalization to latent representations, adjusting each latent sample's distribution towards a standard distribution, thereby stabilizing and potentially improving the generative process.
Input types¶
Required¶
latents
- The latent representations to be normalized. This input is crucial for the normalization process as it directly modifies the distribution of these latents.
- Comfy dtype:
LATENT
- Python dtype:
Dict[str, torch.Tensor]
factor
- A scaling factor that interpolates between the original and normalized latent representations, allowing for controlled adjustment of the normalization effect.
- Comfy dtype:
FLOAT
- Python dtype:
float
Output types¶
latent
- Comfy dtype:
LATENT
- The normalized latent representations, adjusted according to the specified factor to potentially enhance generative model performance.
- Python dtype:
Dict[str, torch.Tensor]
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class BatchNormalizeLatent:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latents": ("LATENT", ),
"factor": ("FLOAT", {"default": 1.0, "min": -10.0, "max": 10.0, "step": 0.01, "round": 0.01}),
},
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "batch_normalize"
CATEGORY = "latent/filters"
def batch_normalize(self, latents, factor):
latents_copy = copy.deepcopy(latents)
t = latents_copy["samples"] # [B x C x H x W]
t = t.movedim(0,1) # [C x B x H x W]
for c in range(t.size(0)):
c_sd, c_mean = torch.std_mean(t[c], dim=None)
for i in range(t.size(1)):
i_sd, i_mean = torch.std_mean(t[c, i], dim=None)
t[c, i] = (t[c, i] - i_mean) / i_sd
t[c] = t[c] * c_sd + c_mean
latents_copy["samples"] = torch.lerp(latents["samples"], t.movedim(1,0), factor) # [B x C x H x W]
return (latents_copy,)