Batch Crop From Mask Advanced¶
Documentation¶
- Class name:
BatchCropFromMaskAdvanced
- Category:
KJNodes/masking
- Output node:
False
This node is designed for advanced cropping operations on batches of images based on their associated masks. It calculates the optimal bounding box for each mask, applies smoothing to the bounding box sizes, and adjusts the crop size dynamically to ensure consistency across the batch. This process is aimed at enhancing the focus on relevant image areas while maintaining important aspects of the image composition.
Input types¶
Required¶
original_images
- The batch of original images that will be cropped according to the calculated bounding boxes derived from their corresponding masks.
- Comfy dtype:
IMAGE
- Python dtype:
List[torch.Tensor]
masks
- A batch of masks used to determine the areas of interest within the corresponding images. These masks guide the cropping process by identifying non-zero regions that signify important content.
- Comfy dtype:
MASK
- Python dtype:
torch.Tensor
crop_size_mult
- A multiplier applied to the size of the bounding boxes to adjust the final crop size, allowing for flexibility in the cropping process.
- Comfy dtype:
FLOAT
- Python dtype:
float
bbox_smooth_alpha
- A smoothing factor applied to the bounding box sizes to mitigate abrupt changes and ensure a smoother transition between crop sizes.
- Comfy dtype:
FLOAT
- Python dtype:
float
Output types¶
original_images
- Comfy dtype:
IMAGE
- The original images provided as input, returned without modification.
- Python dtype:
torch.Tensor
- Comfy dtype:
cropped_images
- Comfy dtype:
IMAGE
- The images after being cropped according to the calculated bounding boxes and applied adjustments.
- Python dtype:
torch.Tensor
- Comfy dtype:
cropped_masks
- Comfy dtype:
MASK
- The masks after being cropped to match the dimensions of the cropped images.
- Python dtype:
torch.Tensor
- Comfy dtype:
combined_crop_image
- Comfy dtype:
IMAGE
- A single image created by combining the cropped images, optimized for certain use cases.
- Python dtype:
torch.Tensor
- Comfy dtype:
combined_crop_masks
- Comfy dtype:
MASK
- A single mask created by combining the cropped masks, corresponding to the combined cropped image.
- Python dtype:
torch.Tensor
- Comfy dtype:
bboxes
- Comfy dtype:
BBOX
- The bounding boxes calculated for each mask, used to determine the crop areas.
- Python dtype:
List[tuple]
- Comfy dtype:
combined_bounding_box
- Comfy dtype:
BBOX
- The combined bounding box calculated from all masks, used for creating the combined cropped image and mask.
- Python dtype:
tuple
- Comfy dtype:
bbox_width
- Comfy dtype:
INT
- The width of the largest bounding box calculated across all masks, adjusted for crop size and smoothing.
- Python dtype:
int
- Comfy dtype:
bbox_height
- Comfy dtype:
INT
- The height of the largest bounding box calculated across all masks, adjusted for crop size and smoothing.
- Python dtype:
int
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class BatchCropFromMaskAdvanced:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"original_images": ("IMAGE",),
"masks": ("MASK",),
"crop_size_mult": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
"bbox_smooth_alpha": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}),
},
}
RETURN_TYPES = (
"IMAGE",
"IMAGE",
"MASK",
"IMAGE",
"MASK",
"BBOX",
"BBOX",
"INT",
"INT",
)
RETURN_NAMES = (
"original_images",
"cropped_images",
"cropped_masks",
"combined_crop_image",
"combined_crop_masks",
"bboxes",
"combined_bounding_box",
"bbox_width",
"bbox_height",
)
FUNCTION = "crop"
CATEGORY = "KJNodes/masking"
def smooth_bbox_size(self, prev_bbox_size, curr_bbox_size, alpha):
return round(alpha * curr_bbox_size + (1 - alpha) * prev_bbox_size)
def smooth_center(self, prev_center, curr_center, alpha=0.5):
return (round(alpha * curr_center[0] + (1 - alpha) * prev_center[0]),
round(alpha * curr_center[1] + (1 - alpha) * prev_center[1]))
def crop(self, masks, original_images, crop_size_mult, bbox_smooth_alpha):
bounding_boxes = []
combined_bounding_box = []
cropped_images = []
cropped_masks = []
cropped_masks_out = []
combined_crop_out = []
combined_cropped_images = []
combined_cropped_masks = []
def calculate_bbox(mask):
non_zero_indices = np.nonzero(np.array(mask))
# handle empty masks
min_x, max_x, min_y, max_y = 0, 0, 0, 0
if len(non_zero_indices[1]) > 0 and len(non_zero_indices[0]) > 0:
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
width = max_x - min_x
height = max_y - min_y
bbox_size = max(width, height)
return min_x, max_x, min_y, max_y, bbox_size
combined_mask = torch.max(masks, dim=0)[0]
_mask = tensor2pil(combined_mask)[0]
new_min_x, new_max_x, new_min_y, new_max_y, combined_bbox_size = calculate_bbox(_mask)
center_x = (new_min_x + new_max_x) / 2
center_y = (new_min_y + new_max_y) / 2
half_box_size = round(combined_bbox_size // 2)
new_min_x = max(0, round(center_x - half_box_size))
new_max_x = min(original_images[0].shape[1], round(center_x + half_box_size))
new_min_y = max(0, round(center_y - half_box_size))
new_max_y = min(original_images[0].shape[0], round(center_y + half_box_size))
combined_bounding_box.append((new_min_x, new_min_y, new_max_x - new_min_x, new_max_y - new_min_y))
self.max_bbox_size = 0
# First, calculate the maximum bounding box size across all masks
curr_max_bbox_size = max(calculate_bbox(tensor2pil(mask)[0])[-1] for mask in masks)
# Smooth the changes in the bounding box size
self.max_bbox_size = self.smooth_bbox_size(self.max_bbox_size, curr_max_bbox_size, bbox_smooth_alpha)
# Apply the crop size multiplier
self.max_bbox_size = round(self.max_bbox_size * crop_size_mult)
# Make sure max_bbox_size is divisible by 16, if not, round it upwards so it is
self.max_bbox_size = math.ceil(self.max_bbox_size / 16) * 16
if self.max_bbox_size > original_images[0].shape[0] or self.max_bbox_size > original_images[0].shape[1]:
# max_bbox_size can only be as big as our input's width or height, and it has to be even
self.max_bbox_size = math.floor(min(original_images[0].shape[0], original_images[0].shape[1]) / 2) * 2
# Then, for each mask and corresponding image...
for i, (mask, img) in enumerate(zip(masks, original_images)):
_mask = tensor2pil(mask)[0]
non_zero_indices = np.nonzero(np.array(_mask))
# check for empty masks
if len(non_zero_indices[0]) > 0 and len(non_zero_indices[1]) > 0:
min_x, max_x = np.min(non_zero_indices[1]), np.max(non_zero_indices[1])
min_y, max_y = np.min(non_zero_indices[0]), np.max(non_zero_indices[0])
# Calculate center of bounding box
center_x = np.mean(non_zero_indices[1])
center_y = np.mean(non_zero_indices[0])
curr_center = (round(center_x), round(center_y))
# If this is the first frame, initialize prev_center with curr_center
if not hasattr(self, 'prev_center'):
self.prev_center = curr_center
# Smooth the changes in the center coordinates from the second frame onwards
if i > 0:
center = self.smooth_center(self.prev_center, curr_center, bbox_smooth_alpha)
else:
center = curr_center
# Update prev_center for the next frame
self.prev_center = center
# Create bounding box using max_bbox_size
half_box_size = self.max_bbox_size // 2
min_x = max(0, center[0] - half_box_size)
max_x = min(img.shape[1], center[0] + half_box_size)
min_y = max(0, center[1] - half_box_size)
max_y = min(img.shape[0], center[1] + half_box_size)
# Append bounding box coordinates
bounding_boxes.append((min_x, min_y, max_x - min_x, max_y - min_y))
# Crop the image from the bounding box
cropped_img = img[min_y:max_y, min_x:max_x, :]
cropped_mask = mask[min_y:max_y, min_x:max_x]
# Resize the cropped image to a fixed size
new_size = max(cropped_img.shape[0], cropped_img.shape[1])
resize_transform = Resize(new_size, interpolation=InterpolationMode.NEAREST, max_size=max(img.shape[0], img.shape[1]))
resized_mask = resize_transform(cropped_mask.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
resized_img = resize_transform(cropped_img.permute(2, 0, 1))
# Perform the center crop to the desired size
# Constrain the crop to the smaller of our bbox or our image so we don't expand past the image dimensions.
crop_transform = CenterCrop((min(self.max_bbox_size, resized_img.shape[1]), min(self.max_bbox_size, resized_img.shape[2])))
cropped_resized_img = crop_transform(resized_img)
cropped_images.append(cropped_resized_img.permute(1, 2, 0))
cropped_resized_mask = crop_transform(resized_mask)
cropped_masks.append(cropped_resized_mask)
combined_cropped_img = original_images[i][new_min_y:new_max_y, new_min_x:new_max_x, :]
combined_cropped_images.append(combined_cropped_img)
combined_cropped_mask = masks[i][new_min_y:new_max_y, new_min_x:new_max_x]
combined_cropped_masks.append(combined_cropped_mask)
else:
bounding_boxes.append((0, 0, img.shape[1], img.shape[0]))
cropped_images.append(img)
cropped_masks.append(mask)
combined_cropped_images.append(img)
combined_cropped_masks.append(mask)
cropped_out = torch.stack(cropped_images, dim=0)
combined_crop_out = torch.stack(combined_cropped_images, dim=0)
cropped_masks_out = torch.stack(cropped_masks, dim=0)
combined_crop_mask_out = torch.stack(combined_cropped_masks, dim=0)
return (original_images, cropped_out, cropped_masks_out, combined_crop_out, combined_crop_mask_out, bounding_boxes, combined_bounding_box, self.max_bbox_size, self.max_bbox_size)