🔧 Image Expand Batch¶
Documentation¶
- Class name:
ImageExpandBatch+
- Category:
essentials
- Output node:
False
The ImageExpandBatch+ node is designed to expand the batch size of images, allowing for the inclusion of additional images into an existing batch. This functionality is crucial for operations that require batch processing of images, such as batch image transformations, augmentations, or processing in machine learning models.
Input types¶
Required¶
image
- The 'image' parameter represents the input image or images to be expanded into a larger batch. This parameter is crucial for defining the set of images that will undergo batch expansion.
- Comfy dtype:
IMAGE
- Python dtype:
torch.Tensor
size
- The 'size' parameter specifies the desired size of the expanded batch, determining how many times the input images are replicated or how additional images are included.
- Comfy dtype:
INT
- Python dtype:
int
method
- The 'method' parameter defines the technique used for expanding the batch, such as replication of existing images or inclusion of new images, affecting the approach to batch expansion.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
Output types¶
image
- Comfy dtype:
IMAGE
- The output 'image' parameter represents the expanded batch of images, ready for further processing or analysis.
- Python dtype:
torch.Tensor
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class ImageExpandBatch:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"image": ("IMAGE",),
"size": ("INT", { "default": 16, "min": 1, "step": 1, }),
"method": (["expand", "repeat all", "repeat first", "repeat last"],)
}
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "execute"
CATEGORY = "essentials"
def execute(self, image, size, method):
orig_size = image.shape[0]
if orig_size == size:
return (image,)
if size <= 1:
return (image[:size],)
if 'expand' in method:
out = torch.empty([size] + list(image.shape)[1:], dtype=image.dtype, device=image.device)
if size < orig_size:
scale = (orig_size - 1) / (size - 1)
for i in range(size):
out[i] = image[min(round(i * scale), orig_size - 1)]
else:
scale = orig_size / size
for i in range(size):
out[i] = image[min(math.floor((i + 0.5) * scale), orig_size - 1)]
elif 'all' in method:
out = image.repeat([math.ceil(size / image.shape[0])] + [1] * (len(image.shape) - 1))[:size]
elif 'first' in method:
if size < image.shape[0]:
out = image[:size]
else:
out = torch.cat([image[:1].repeat(size-image.shape[0], 1, 1, 1), image], dim=0)
elif 'last' in method:
if size < image.shape[0]:
out = image[:size]
else:
out = torch.cat((image, image[-1:].repeat((size-image.shape[0], 1, 1, 1))), dim=0)
return (out,)