Skip to content

SEGSPreview

Documentation

  • Class name: SEGSPreview
  • Category: ImpactPack/Util
  • Output node: True

The SEGSPreview node is designed to generate previews of segmentation images, specifically focusing on the control network aspect of these images. It processes a batch of segmentation data, extracting and saving control network images for each segment, and returns a list of these images along with their metadata.

Input types

Required

  • segs
    • The 'segs' parameter represents the segmentation data to be processed. It is crucial for generating the control network images and their previews, as it contains the necessary information for each segment.
    • Comfy dtype: SEGS
    • Python dtype: Tuple[Tuple[str, str], List[Segment]]
  • alpha_mode
    • This parameter determines the mode of alpha blending used in the preview generation, affecting the transparency and overlay of images.
    • Comfy dtype: BOOLEAN
    • Python dtype: str
  • min_alpha
    • Specifies the minimum alpha value for transparency in the preview images, controlling the visibility of underlying layers.
    • Comfy dtype: FLOAT
    • Python dtype: float

Optional

  • fallback_image_opt
    • An optional fallback image to use when segmentation data is incomplete or missing, ensuring a default visual is available.
    • Comfy dtype: IMAGE
    • Python dtype: Optional[Image]

Output types

  • image
    • Comfy dtype: IMAGE
    • The 'image' output type represents the final image or images generated by the node, encapsulating the visual result of the processing.
    • Python dtype: List[Image]
  • ui
    • The 'ui' parameter contains the results of the operation, including a list of images generated from the segmentation data, each with its filename, subfolder, and type.

Usage tips

  • Infra type: CPU
  • Common nodes: unknown

Source code

class SEGSPreview:
    def __init__(self):
        self.output_dir = folder_paths.get_temp_directory()
        self.type = "temp"

    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
                     "segs": ("SEGS", ),
                     "alpha_mode": ("BOOLEAN", {"default": True, "label_on": "enable", "label_off": "disable"}),
                     "min_alpha": ("FLOAT", {"default": 0.2, "min": 0.0, "max": 1.0, "step": 0.01}),
                    },
                "optional": {
                     "fallback_image_opt": ("IMAGE", ),
                    }
                }

    RETURN_TYPES = ("IMAGE", )
    OUTPUT_IS_LIST = (True, )
    FUNCTION = "doit"

    CATEGORY = "ImpactPack/Util"

    OUTPUT_NODE = True

    def doit(self, segs, alpha_mode=True, min_alpha=0.0, fallback_image_opt=None):
        full_output_folder, filename, counter, subfolder, filename_prefix = \
            folder_paths.get_save_image_path("impact_seg_preview", self.output_dir, segs[0][1], segs[0][0])

        results = list()
        result_image_list = []

        if fallback_image_opt is not None:
            segs = core.segs_scale_match(segs, fallback_image_opt.shape)

        if min_alpha != 0:
            min_alpha = int(255 * min_alpha)

        if len(segs[1]) > 0:
            if segs[1][0].cropped_image is not None:
                batch_count = len(segs[1][0].cropped_image)
            elif fallback_image_opt is not None:
                batch_count = len(fallback_image_opt)
            else:
                return {"ui": {"images": results}}

            for seg in segs[1]:
                result_image_batch = None
                cached_mask = None

                def get_combined_mask():
                    nonlocal cached_mask

                    if cached_mask is not None:
                        return cached_mask
                    else:
                        if isinstance(seg.cropped_mask, np.ndarray):
                            masks = torch.tensor(seg.cropped_mask)
                        else:
                            masks = seg.cropped_mask

                        cached_mask = (masks[0] * 255).to(torch.uint8)
                        for x in masks[1:]:
                            cached_mask |= (x * 255).to(torch.uint8)
                        cached_mask = (cached_mask/255.0).to(torch.float32)
                        cached_mask = utils.to_binary_mask(cached_mask, 0.1)
                        cached_mask = cached_mask.numpy()

                        return cached_mask

                def stack_image(image, mask=None):
                    nonlocal result_image_batch

                    if isinstance(image, np.ndarray):
                        image = torch.from_numpy(image)

                    if mask is not None:
                        image *= torch.tensor(mask)[None, ..., None]

                    if result_image_batch is None:
                        result_image_batch = image
                    else:
                        result_image_batch = torch.concat((result_image_batch, image), dim=0)

                for i in range(batch_count):
                    cropped_image = None

                    if seg.cropped_image is not None:
                        cropped_image = seg.cropped_image[i, None]
                    elif fallback_image_opt is not None:
                        # take from original image
                        ref_image = fallback_image_opt[i].unsqueeze(0)
                        cropped_image = crop_image(ref_image, seg.crop_region)

                    if cropped_image is not None:
                        if isinstance(cropped_image, np.ndarray):
                            cropped_image = torch.from_numpy(cropped_image)

                        cropped_image = cropped_image.clone()
                        cropped_pil = to_pil(cropped_image)

                        if alpha_mode:
                            if isinstance(seg.cropped_mask, np.ndarray):
                                cropped_mask = seg.cropped_mask
                            else:
                                if seg.cropped_image is not None and len(seg.cropped_image) != len(seg.cropped_mask):
                                    cropped_mask = get_combined_mask()
                                else:
                                    cropped_mask = seg.cropped_mask[i].numpy()

                            mask_array = (cropped_mask * 255).astype(np.uint8)

                            if min_alpha != 0:
                                mask_array[mask_array < min_alpha] = min_alpha

                            mask_pil = Image.fromarray(mask_array, mode='L').resize(cropped_pil.size)
                            cropped_pil.putalpha(mask_pil)
                            stack_image(cropped_image, cropped_mask)
                        else:
                            stack_image(cropped_image)

                        file = f"{filename}_{counter:05}_.webp"
                        cropped_pil.save(os.path.join(full_output_folder, file))
                        results.append({
                            "filename": file,
                            "subfolder": subfolder,
                            "type": self.type
                        })

                        counter += 1

                if result_image_batch is not None:
                    result_image_list.append(result_image_batch)

        return {"ui": {"images": results}, "result": (result_image_list,) }