Skip to content

Mask Smooth Region

Documentation

  • Class name: Mask Smooth Region
  • Category: WAS Suite/Image/Masking
  • Output node: False

This node applies a smoothing operation to the specified regions within the input masks, utilizing a specified sigma value to control the smoothness level. It's designed to refine mask edges, making them less jagged and more visually appealing.

Input types

Required

  • masks
    • The input masks to be smoothed. This parameter is crucial for determining which areas of the image will undergo the smoothing process.
    • Comfy dtype: MASK
    • Python dtype: torch.Tensor
  • sigma
    • A parameter controlling the smoothness level of the mask's edges. Higher values result in smoother edges.
    • Comfy dtype: FLOAT
    • Python dtype: float

Output types

  • MASKS
    • Comfy dtype: MASK
    • The smoothed masks, with refined edges for improved visual quality.
    • Python dtype: torch.Tensor

Usage tips

  • Infra type: GPU
  • Common nodes: unknown

Source code

class WAS_Mask_Smooth_Region:

    def __init__(self):
        self.WT = WAS_Tools_Class()

    @classmethod
    def INPUT_TYPES(cls):
        return {
                    "required": {
                        "masks": ("MASK",),
                        "sigma": ("FLOAT", {"default":5.0, "min":0.0, "max":128.0, "step":0.1}),
                    }
                }

    CATEGORY = "WAS Suite/Image/Masking"

    RETURN_TYPES = ("MASK",)
    RETURN_NAMES = ("MASKS",)

    FUNCTION = "smooth_region"

    def smooth_region(self, masks, sigma=128):
        if masks.ndim > 3:
            regions = []
            for mask in masks:
                mask_np = np.clip(255. * mask.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
                pil_image = Image.fromarray(mask_np, mode="L")
                region_mask = self.WT.Masking.smooth_region(pil_image, sigma)
                region_tensor = pil2mask(region_mask).unsqueeze(0).unsqueeze(1)
                regions.append(region_tensor)
            regions_tensor = torch.cat(regions, dim=0)
            return (regions_tensor,)
        else:
            mask_np = np.clip(255. * masks.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)
            pil_image = Image.fromarray(mask_np, mode="L")
            region_mask = self.WT.Masking.smooth_region(pil_image, sigma)
            region_tensor = pil2mask(region_mask).unsqueeze(0).unsqueeze(1)
            return (region_tensor,)