Batch Transform (mtb)¶
Documentation¶
- Class name:
Batch Transform (mtb)
- Category:
mtb/batch
- Output node:
False
This node is designed to apply a series of 2D transformations to batches of images, including operations such as translation, rotation, and scaling. It abstracts complex image manipulation processes into a more accessible interface for batch processing, enhancing the efficiency and flexibility of image transformation tasks.
Input types¶
Required¶
image
- The input image tensor to be transformed, serving as the primary data upon which the batch transformations are applied.
- Comfy dtype:
IMAGE
- Python dtype:
torch.Tensor
border_handling
- Specifies the method for handling borders during the transformation process, affecting how the image is extended or cropped.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
constant_color
- Defines the color used to fill any new areas of the image created as a result of the transformation, ensuring visual consistency.
- Comfy dtype:
COLOR
- Python dtype:
str
Optional¶
x
- Optional list of floats representing the translation distances along the x-axis for each image in the batch.
- Comfy dtype:
FLOATS
- Python dtype:
list[float] | None
y
- Optional list of floats representing the translation distances along the y-axis for each image in the batch.
- Comfy dtype:
FLOATS
- Python dtype:
list[float] | None
zoom
- Optional list of floats indicating the zoom levels to be applied to each image in the batch.
- Comfy dtype:
FLOATS
- Python dtype:
list[float] | None
angle
- Optional list of floats specifying the rotation angles for each image in the batch, in degrees.
- Comfy dtype:
FLOATS
- Python dtype:
list[float] | None
shear
- Optional list of floats detailing the shear angles to be applied to each image in the batch, adding a skew effect.
- Comfy dtype:
FLOATS
- Python dtype:
list[float] | None
Output types¶
image
- Comfy dtype:
IMAGE
- The transformed image tensor, resulting from the applied batch of 2D transformations.
- Python dtype:
torch.Tensor
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class MTB_Batch2dTransform:
"""Transform a batch of images using a batch of keyframes"""
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
"border_handling": (
["edge", "constant", "reflect", "symmetric"],
{"default": "edge"},
),
"constant_color": ("COLOR", {"default": "#000000"}),
},
"optional": {
"x": ("FLOATS",),
"y": ("FLOATS",),
"zoom": ("FLOATS",),
"angle": ("FLOATS",),
"shear": ("FLOATS",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "transform_batch"
CATEGORY = "mtb/batch"
def get_num_elements(
self, param: None | torch.Tensor | list[torch.Tensor] | list[float]
) -> int:
if isinstance(param, torch.Tensor):
return torch.numel(param)
elif isinstance(param, list):
return len(param)
return 0
def transform_batch(
self,
image: torch.Tensor,
border_handling: str,
constant_color: str,
x: list[float] | None = None,
y: list[float] | None = None,
zoom: list[float] | None = None,
angle: list[float] | None = None,
shear: list[float] | None = None,
):
if all(
self.get_num_elements(param) <= 0
for param in [x, y, zoom, angle, shear]
):
raise ValueError(
"At least one transform parameter must be provided"
)
keyframes: dict[str, list[float]] = {
"x": [],
"y": [],
"zoom": [],
"angle": [],
"shear": [],
}
default_vals = {"x": 0, "y": 0, "zoom": 1.0, "angle": 0, "shear": 0}
if x and self.get_num_elements(x) > 0:
keyframes["x"] = x
if y and self.get_num_elements(y) > 0:
keyframes["y"] = y
if zoom and self.get_num_elements(zoom) > 0:
# some easing types like elastic can pull back... maybe it should abs the value?
keyframes["zoom"] = [max(x, 0.00001) for x in zoom]
if angle and self.get_num_elements(angle) > 0:
keyframes["angle"] = angle
if shear and self.get_num_elements(shear) > 0:
keyframes["shear"] = shear
for name, values in keyframes.items():
count = len(values)
if count > 0 and count != image.shape[0]:
raise ValueError(
f"Length of {name} values ({count}) must match number of images ({image.shape[0]})"
)
if count == 0:
keyframes[name] = [default_vals[name]] * image.shape[0]
transformer = MTB_TransformImage()
res = [
transformer.transform(
image[i].unsqueeze(0),
keyframes["x"][i],
keyframes["y"][i],
keyframes["zoom"][i],
keyframes["angle"][i],
keyframes["shear"][i],
border_handling,
constant_color,
)[0]
for i in range(image.shape[0])
]
return (torch.cat(res, dim=0),)