FilterZeroMasksAndCorrespondingImages¶
Documentation¶
- Class name:
FilterZeroMasksAndCorrespondingImages
- Category:
KJNodes/masking
- Output node:
False
This node is designed to filter out all zero-value masks from a batch of masks and, optionally, filter out corresponding images based on the presence of non-zero masks. It aims to streamline the preprocessing of image and mask data by ensuring that only relevant, non-empty masks and their associated images are passed forward for further processing.
Input types¶
Required¶
masks
- A list of masks to be filtered, removing those that are entirely zero-valued. This parameter is essential for identifying relevant data for further processing.
- Comfy dtype:
MASK
- Python dtype:
List[torch.Tensor]
Optional¶
original_images
- An optional list of images corresponding to the masks. If provided, images associated with non-zero masks are retained, aligning image data with filtered mask data.
- Comfy dtype:
IMAGE
- Python dtype:
Optional[List[torch.Tensor]]
Output types¶
non_zero_masks_out
- Comfy dtype:
MASK
- The filtered list of non-zero masks.
- Python dtype:
torch.Tensor
- Comfy dtype:
non_zero_mask_images_out
- Comfy dtype:
IMAGE
- The list of images corresponding to the non-zero masks, if original images were provided.
- Python dtype:
Optional[torch.Tensor]
- Comfy dtype:
zero_mask_images_out
- Comfy dtype:
IMAGE
- The list of images corresponding to the zero masks, if original images were provided.
- Python dtype:
Optional[torch.Tensor]
- Comfy dtype:
zero_mask_images_out_indexes
- Comfy dtype:
INDEXES
- The indexes of images corresponding to the zero masks, useful for tracking which images were filtered out.
- Python dtype:
Optional[List[int]]
- Comfy dtype:
Usage tips¶
- Infra type:
CPU
- Common nodes: unknown
Source code¶
class FilterZeroMasksAndCorrespondingImages:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"masks": ("MASK",),
},
"optional": {
"original_images": ("IMAGE",),
},
}
RETURN_TYPES = ("MASK", "IMAGE", "IMAGE", "INDEXES",)
RETURN_NAMES = ("non_zero_masks_out", "non_zero_mask_images_out", "zero_mask_images_out", "zero_mask_images_out_indexes",)
FUNCTION = "filter"
CATEGORY = "KJNodes/masking"
DESCRIPTION = """
Filter out all the empty (i.e. all zero) mask in masks
Also filter out all the corresponding images in original_images by indexes if provide
original_images (optional): If provided, need have same length as masks.
"""
def filter(self, masks, original_images=None):
non_zero_masks = []
non_zero_mask_images = []
zero_mask_images = []
zero_mask_images_indexes = []
masks_num = len(masks)
also_process_images = False
if original_images is not None:
imgs_num = len(original_images)
if len(original_images) == masks_num:
also_process_images = True
else:
print(f"[WARNING] ignore input: original_images, due to number of original_images ({imgs_num}) is not equal to number of masks ({masks_num})")
for i in range(masks_num):
non_zero_num = np.count_nonzero(np.array(masks[i]))
if non_zero_num > 0:
non_zero_masks.append(masks[i])
if also_process_images:
non_zero_mask_images.append(original_images[i])
else:
zero_mask_images.append(original_images[i])
zero_mask_images_indexes.append(i)
non_zero_masks_out = torch.stack(non_zero_masks, dim=0)
non_zero_mask_images_out = zero_mask_images_out = zero_mask_images_out_indexes = None
if also_process_images:
non_zero_mask_images_out = torch.stack(non_zero_mask_images, dim=0)
if len(zero_mask_images) > 0:
zero_mask_images_out = torch.stack(zero_mask_images, dim=0)
zero_mask_images_out_indexes = zero_mask_images_indexes
return (non_zero_masks_out, non_zero_mask_images_out, zero_mask_images_out, zero_mask_images_out_indexes)