Skip to content

VAE Decode Batched 🎥🅥🅗🅢

Documentation

  • Class name: VHS_VAEDecodeBatched
  • Category: Video Helper Suite 🎥🅥🅗🅢/batched nodes
  • Output node: False

This node is designed for batch processing of latent representations to decode them into images using a Variational Autoencoder (VAE). It efficiently handles large sets of latent samples by decoding them in smaller, manageable batches, updating progress through a progress bar for better user experience.

Input types

Required

  • samples
    • The latent representations to be decoded into images. This input is crucial for the node's operation as it specifies the data that will undergo the decoding process.
    • Comfy dtype: LATENT
    • Python dtype: torch.Tensor
  • vae
    • The Variational Autoencoder model used for decoding the latent representations into images. It defines the specific VAE architecture and parameters to be used in the decoding process.
    • Comfy dtype: VAE
    • Python dtype: torch.nn.Module
  • per_batch
    • Specifies the number of samples to be decoded in each batch. This allows for efficient memory management and processing speed by breaking down the workload into smaller chunks.
    • Comfy dtype: INT
    • Python dtype: int

Output types

  • image
    • Comfy dtype: IMAGE
    • The decoded images from the latent representations. This output provides the visual content generated by the VAE model from the input latent samples.
    • Python dtype: torch.Tensor

Usage tips

  • Infra type: GPU
  • Common nodes: unknown

Source code

class VAEDecodeBatched:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "samples": ("LATENT", ),
                "vae": ("VAE", ),
                "per_batch": ("INT", {"default": 16, "min": 1})
                }
            }

    CATEGORY = "Video Helper Suite 🎥🅥🅗🅢/batched nodes"

    RETURN_TYPES = ("IMAGE",)
    FUNCTION = "decode"

    def decode(self, vae, samples, per_batch):
        decoded = []
        pbar = ProgressBar(samples["samples"].shape[0])
        for start_idx in range(0, samples["samples"].shape[0], per_batch):
            decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch]))
            pbar.update(per_batch)
        return (torch.cat(decoded, dim=0), )