Latent Batch¶
Documentation¶
- Class name:
Latent Batch
- Category:
WAS Suite/Latent
- Output node:
False
This node is designed to merge two sets of latent representations into a single batch, potentially resizing one set to match the dimensions of the other before concatenation. It plays a crucial role in operations that require combining or comparing different sets of latent data.
Input types¶
Required¶
Optional¶
latent_a
- The first set of latent samples to be batched. It is crucial as it sets the base dimensions for the output batch if resizing is needed.
- Comfy dtype:
LATENT
- Python dtype:
Dict[str, torch.Tensor]
latent_b
- The second set of latent samples to be batched. This set may be resized to match the dimensions of the first set before concatenation.
- Comfy dtype:
LATENT
- Python dtype:
Dict[str, torch.Tensor]
latent_c
- An optional third set of latent samples that can be included in the batch. Similar to the other inputs, it may be resized to ensure consistency in dimensions.
- Comfy dtype:
LATENT
- Python dtype:
Dict[str, torch.Tensor]
latent_d
- An optional fourth set of latent samples for inclusion in the batch. It may also be resized for dimensional consistency with the other sets.
- Comfy dtype:
LATENT
- Python dtype:
Dict[str, torch.Tensor]
Output types¶
latent
- Comfy dtype:
LATENT
- The combined set of latent samples from all input sets, potentially with some sets resized to match others' dimensions.
- Python dtype:
Dict[str, torch.Tensor]
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class WAS_Latent_Batch:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
},
"optional": {
"latent_a": ("LATENT",),
"latent_b": ("LATENT",),
"latent_c": ("LATENT",),
"latent_d": ("LATENT",),
},
}
RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("latent",)
FUNCTION = "latent_batch"
CATEGORY = "WAS Suite/Latent"
def _check_latent_dimensions(self, tensors, names):
dimensions = [(tensor["samples"].shape) for tensor in tensors]
if len(set(dimensions)) > 1:
mismatched_indices = [i for i, dim in enumerate(dimensions) if dim[1] != dimensions[0][1]]
mismatched_latents = [names[i] for i in mismatched_indices]
if mismatched_latents:
raise ValueError(f"WAS latent Batch Warning: Input latent dimensions do not match for latents: {mismatched_latents}")
def latent_batch(self, **kwargs):
batched_tensors = [kwargs[key] for key in kwargs if kwargs[key] is not None]
latent_names = [key for key in kwargs if kwargs[key] is not None]
if not batched_tensors:
raise ValueError("At least one input latent must be provided.")
self._check_latent_dimensions(batched_tensors, latent_names)
samples_out = {}
samples_out["samples"] = torch.cat([tensor["samples"] for tensor in batched_tensors], dim=0)
samples_out["batch_index"] = []
for tensor in batched_tensors:
cindex = tensor.get("batch_index", list(range(tensor["samples"].shape[0])))
samples_out["batch_index"] += cindex
return (samples_out,)