StableCascade_SuperResolutionControlnet¶
Documentation¶
- Class name:
StableCascade_SuperResolutionControlnet
- Category:
_for_testing/stable_cascade
- Output node:
False
This node is designed for generating super-resolution control inputs for a cascading image generation process, utilizing a VAE to encode images into latent representations that are then adjusted for different stages of image resolution enhancement.
Input types¶
Required¶
image
- The input image to be processed and encoded into latent representations for super-resolution enhancement.
- Comfy dtype:
IMAGE
- Python dtype:
torch.Tensor
vae
- A variational autoencoder (VAE) model used to encode the input image into a latent space, facilitating the super-resolution process.
- Comfy dtype:
VAE
- Python dtype:
torch.nn.Module
Output types¶
controlnet_input
- Comfy dtype:
IMAGE
- The encoded latent representation of the input image, prepared for the control network of the super-resolution process.
- Python dtype:
torch.Tensor
- Comfy dtype:
stage_c
- Comfy dtype:
LATENT
- A latent representation tailored for the initial stage of the super-resolution process.
- Python dtype:
Dict[str, torch.Tensor]
- Comfy dtype:
stage_b
- Comfy dtype:
LATENT
- A latent representation tailored for the subsequent stage of the super-resolution process.
- Python dtype:
Dict[str, torch.Tensor]
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class StableCascade_SuperResolutionControlnet:
def __init__(self, device="cpu"):
self.device = device
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE",),
"vae": ("VAE", ),
}}
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
FUNCTION = "generate"
CATEGORY = "_for_testing/stable_cascade"
def generate(self, image, vae):
width = image.shape[-2]
height = image.shape[-3]
batch_size = image.shape[0]
controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1)
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
return (controlnet_input, {
"samples": c_latent,
}, {
"samples": b_latent,
})