Batch Image CLIPSeg Masking¶
Documentation¶
- Class name:
SaltCLIPSegMasking
- Category:
SALT/Masking
- Output node:
False
This node is designed for batch processing of images to generate segmentation masks based on textual descriptions using the CLIPSeg model. It leverages a combination of image and text inputs to produce detailed segmentation masks that align with the given text descriptions, facilitating advanced image manipulation and analysis tasks.
Input types¶
Required¶
images
- A batch of images to be processed for segmentation. These images are transformed and resized to match the master image size for consistent mask generation.
- Comfy dtype:
IMAGE
- Python dtype:
List[torch.Tensor]
text
- An optional textual description that guides the segmentation process, allowing for targeted mask generation based on textual cues.
- Comfy dtype:
STRING
- Python dtype:
Optional[str]
Optional¶
clipseg_model
- An optional pre-loaded CLIPSeg model and processor to be used for segmentation. If not provided, the node will load a default model.
- Comfy dtype:
CLIPSEG_MODEL
- Python dtype:
Optional[Tuple[CLIPSegProcessor, CLIPSegForImageSegmentation]]
Output types¶
masks
- Comfy dtype:
MASK
- Segmentation masks corresponding to the input images, adjusted to the master image size.
- Python dtype:
torch.Tensor
- Comfy dtype:
mask_images
- Comfy dtype:
IMAGE
- RGB representations of the segmentation masks for visualization purposes.
- Python dtype:
torch.Tensor
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class SaltCLIPSegMasking:
def __init__(self):
pass
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"text": ("STRING", {"default":"", "multiline": False}),
},
"optional": {
"clipseg_model": ("CLIPSEG_MODEL",),
}
}
RETURN_TYPES = ("MASK", "IMAGE")
RETURN_NAMES = ("masks", "mask_images")
FUNCTION = "CLIPSeg_image"
CATEGORY = f"{NAME}/Masking"
def CLIPSeg_image(self, images, text=None, clipseg_model=None):
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
masks = []
image_masks = []
master_size = None
for image in images:
image = tensor2pil(image)
cache = os.path.join(models_dir, 'clipseg')
if not master_size:
master_size = image.size
if clipseg_model:
inputs = clipseg_model[0]
model = clipseg_model[1]
else:
inputs = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined", cache_dir=cache)
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined", cache_dir=cache)
with torch.no_grad():
result = model(**inputs(text=text, images=image, padding=True, return_tensors="pt"))
tensor = torch.sigmoid(result[0])
mask = (tensor - tensor.min()) / tensor.max()
mask = mask.unsqueeze(0)
mask = tensor2pil(mask)
mask = mask.resize(master_size)
mask_image_tensor = pil2tensor(mask.convert("RGB"))
mask_tensor = image2mask(mask_image_tensor)
masks.append(mask_tensor)
image_masks.append(mask_image_tensor)
masks = torch.cat(masks, dim=0)
image_masks = torch.cat(image_masks, dim=0)
return (masks, image_masks)