Countour Mask Regions Extraction¶
Documentation¶
- Class name:
SaltMaskContourExtraction
- Category:
SALT/Masking/Filter
- Output node:
False
This node is designed for extracting contours from mask regions using multiple threshold values. It applies a series of thresholds to identify and draw contours around regions of interest within the masks, effectively highlighting their boundaries.
Input types¶
Required¶
masks
- The input masks for which contours are to be extracted. These masks serve as the primary data on which the contour extraction process is applied.
- Comfy dtype:
MASK
- Python dtype:
torch.Tensor
Optional¶
thresholds
- A list of threshold values used to identify contours within the masks. Each threshold value defines a specific level of intensity for which contours are detected and drawn, allowing for a multi-faceted analysis of the mask regions.
- Comfy dtype:
INT
- Python dtype:
List[int]
Output types¶
MASKS
- Comfy dtype:
MASK
- The output consists of tensors representing the extracted contours from the input masks. These tensors highlight the boundaries of regions within the masks, based on the specified threshold values.
- Python dtype:
torch.Tensor
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class SaltMaskContourExtraction:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"masks": ("MASK",),
},
"optional": {
"thresholds": ("INT", {"default": 128, "min": 0, "max": 255, "step": 1}),
}
}
CATEGORY = f"{NAME}/Masking/Filter"
RETURN_TYPES = ("MASK",)
RETURN_NAMES = ("MASKS",)
FUNCTION = "contour_extraction"
def contour_extraction(self, masks, thresholds=[50, 100, 150, 200]):
regions = []
if not isinstance(thresholds, list):
thresholds = [thresholds]
for mask in masks:
pil_image = ImageOps.invert(mask2pil(mask.unsqueeze(0)))
image_array = np.array(pil_image.convert('L'))
combined_contours = np.zeros_like(image_array)
for threshold in thresholds:
_, thresh_image = cv2.threshold(image_array, threshold, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(combined_contours, contours, -1, (255, 255, 255), 1)
contour_pil = Image.fromarray(combined_contours)
region_tensor = pil2mask(contour_pil)
regions.append(region_tensor)
regions_tensor = torch.cat(regions, dim=0)
return (regions_tensor,)