Skip to content

Model Pruner (mtb)

Documentation

  • Class name: Model Pruner (mtb)
  • Category: mtb/prune
  • Output node: True

The MTB_ModelPruner node is designed for optimizing and pruning machine learning models to enhance performance and efficiency. It supports operations such as precision conversion, removal of unnecessary components, and conditional execution based on model characteristics.

Input types

Required

  • save_separately
    • Determines whether model components should be saved separately, affecting the organization of the output files.
    • Comfy dtype: BOOLEAN
    • Python dtype: bool
  • save_folder
    • Specifies the directory where the pruned model and its components will be saved.
    • Comfy dtype: STRING
    • Python dtype: str
  • fix_clip
    • Indicates whether to apply fixes to the CLIP model component, potentially improving compatibility or performance.
    • Comfy dtype: BOOLEAN
    • Python dtype: bool
  • remove_junk
    • Controls the removal of unnecessary or redundant parts of the model to streamline and optimize.
    • Comfy dtype: BOOLEAN
    • Python dtype: bool
  • ema_mode
    • Defines the mode for Exponential Moving Average (EMA) handling within the model, affecting model stability and performance.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • precision_unet
    • Sets the precision level for the U-Net model component, impacting memory usage and computational efficiency.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • operation_unet
    • Specifies the operation to be performed on the U-Net model component, such as pruning or optimization.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • precision_clip
    • Sets the precision level for the CLIP model component, impacting memory usage and computational efficiency.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • operation_clip
    • Specifies the operation to be performed on the CLIP model component, such as pruning or optimization.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • precision_vae
    • Sets the precision level for the VAE model component, impacting memory usage and computational efficiency.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • operation_vae
    • Specifies the operation to be performed on the VAE model component, such as pruning or optimization.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str

Optional

  • unet
    • Optional U-Net model component to be pruned or optimized, provided as a dictionary of tensors.
    • Comfy dtype: MODEL
    • Python dtype: dict[str, torch.Tensor] | None
  • clip
    • Optional CLIP model component to be pruned or optimized, provided as a dictionary of tensors.
    • Comfy dtype: CLIP
    • Python dtype: dict[str, torch.Tensor] | None
  • vae
    • Optional VAE model component to be pruned or optimized, provided as a dictionary of tensors.
    • Comfy dtype: VAE
    • Python dtype: dict[str, torch.Tensor] | None

Output types

The node doesn't have output types

Usage tips

  • Infra type: CPU
  • Common nodes: unknown

Source code

class MTB_ModelPruner:
    @classmethod
    def INPUT_TYPES(cls):
        return {
            "optional": {
                "unet": ("MODEL",),
                "clip": ("CLIP",),
                "vae": ("VAE",),
            },
            "required": {
                "save_separately": ("BOOLEAN", {"default": False}),
                "save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}),
                "fix_clip": ("BOOLEAN", {"default": True}),
                "remove_junk": ("BOOLEAN", {"default": True}),
                "ema_mode": (
                    ("disabled", "remove_ema", "ema_only"),
                    {"default": "remove_ema"},
                ),
                "precision_unet": (
                    Precision.list_members(),
                    {"default": Precision.FULL.value},
                ),
                "operation_unet": (
                    Operation.list_members(),
                    {"default": Operation.CONVERT.value},
                ),
                "precision_clip": (
                    Precision.list_members(),
                    {"default": Precision.FULL.value},
                ),
                "operation_clip": (
                    Operation.list_members(),
                    {"default": Operation.CONVERT.value},
                ),
                "precision_vae": (
                    Precision.list_members(),
                    {"default": Precision.FULL.value},
                ),
                "operation_vae": (
                    Operation.list_members(),
                    {"default": Operation.CONVERT.value},
                ),
            },
        }

    OUTPUT_NODE = True
    RETURN_TYPES = ()
    CATEGORY = "mtb/prune"
    FUNCTION = "prune"

    def convert_precision(self, tensor: torch.Tensor, precision: Precision):
        precision = Precision.from_str(precision)
        log.debug(f"Converting to {precision}")
        match precision:
            case Precision.FP8:
                if tensor.dtype in dtypes_to_fp8:
                    return tensor.to(torch.float8_e4m3fn)
                log.error(f"Cannot convert {tensor.dtype} to fp8")
                return tensor
            case Precision.FP16:
                if tensor.dtype in dtypes_to_fp16:
                    return tensor.half()
                log.error(f"Cannot convert {tensor.dtype} to f16")
                return tensor
            case Precision.BF16:
                if tensor.dtype in dtypes_to_bf16:
                    return tensor.bfloat16()
                log.error(f"Cannot convert {tensor.dtype} to bf16")
                return tensor
            case Precision.FULL | Precision.FP32:
                return tensor

    def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None):
        if clip:
            return (any(k.startswith("conditioner.embedders") for k in clip),)
        return False

    def has_ema(self, unet: dict[str, torch.Tensor]):
        return any(k.startswith("model_ema") for k in unet)

    def fix_clip(self, clip: dict[str, torch.Tensor] | None):
        if self.is_sdxl_model(clip):
            log.warn("[fix clip] SDXL not supported")
            return

        if clip is None:
            return

        position_id_key = (
            "cond_stage_model.transformer.text_model.embeddings.position_ids"
        )
        if position_id_key in clip:
            correct = torch.Tensor([list(range(77))]).to(torch.int64)
            now = clip[position_id_key].to(torch.int64)

            broken = correct.ne(now)
            broken = [i for i in range(77) if broken[0][i]]

            if len(broken) != 0:
                clip[position_id_key] = correct
                log.info(f"[Converter] Fixed broken clip\n{broken}")
            else:
                log.info(
                    "[Converter] Clip in this model is fine, skip fixing..."
                )

        else:
            log.info("[Converter] Missing position id in model, try fixing...")
            clip[position_id_key] = torch.Tensor([list(range(77))]).to(
                torch.int64
            )
        return clip

    def get_dicts(self, unet, clip, vae):
        clip_sd = clip.get_sd()
        state_dict = unet.model.state_dict_for_saving(
            clip_sd, vae.get_sd(), None
        )

        unet = {
            k: v
            for k, v in state_dict.items()
            if k.startswith("model.diffusion_model")
        }
        clip = {
            k: v
            for k, v in state_dict.items()
            if k.startswith("cond_stage_model")
            or k.startswith("conditioner.embedders")
        }
        vae = {
            k: v
            for k, v in state_dict.items()
            if k.startswith("first_stage_model")
        }

        other = {
            k: v
            for k, v in state_dict.items()
            if k not in unet and k not in vae and k not in clip
        }

        return (unet, clip, vae, other)

    def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]):
        need_delete: list[str] = []
        for layer in tensors:
            for key in layer:
                for jk in PRUNE_DATA["known_junk_prefix"]:
                    if key.startswith(jk):
                        need_delete.append(".".join([layer, key]))

        for k in need_delete:
            log.info(f"Removing junk data: {k}")
            del tensors[k]

        return tensors

    def prune(
        self,
        *,
        save_separately: bool,
        save_folder: str,
        fix_clip: bool,
        remove_junk: bool,
        ema_mode: str,
        precision_unet: Precision,
        precision_clip: Precision,
        precision_vae: Precision,
        operation_unet: str,
        operation_clip: str,
        operation_vae: str,
        unet: dict[str, torch.Tensor] | None = None,
        clip: dict[str, torch.Tensor] | None = None,
        vae: dict[str, torch.Tensor] | None = None,
    ):
        operation = {
            "unet": Operation.from_str(operation_unet),
            "clip": Operation.from_str(operation_clip),
            "vae": Operation.from_str(operation_vae),
        }
        precision = {
            "unet": Precision.from_str(precision_unet),
            "clip": Precision.from_str(precision_clip),
            "vae": Precision.from_str(precision_vae),
        }

        unet, clip, vae, _other = self.get_dicts(unet, clip, vae)

        out_dir = Path(save_folder)
        folder = out_dir.parent
        if not out_dir.is_absolute():
            folder = (comfy_out_dir / save_folder).parent

        if not folder.exists():
            if folder.parent.exists():
                folder.mkdir()
            else:
                raise FileNotFoundError(
                    f"Folder {folder.parent} does not exist"
                )

        name = out_dir.name
        save_name = f"{name}-{precision_unet}"
        if ema_mode != "disabled":
            save_name += f"-{ema_mode}"
        if fix_clip:
            save_name += "-clip-fix"

        if (
            any(o == Operation.CONVERT for o in operation.values())
            and any(p == Precision.FP8 for p in precision.values())
            and torch.__version__ < "2.1.0"
        ):
            raise NotImplementedError(
                "PyTorch 2.1.0 or newer is required for fp8 conversion"
            )

        if not self.is_sdxl_model(clip):
            for part in [unet, vae, clip]:
                if part:
                    nai_keys = PRUNE_DATA["nai_keys"]
                    for k in list(part.keys()):
                        for r in nai_keys:
                            if isinstance(k, str) and k.startswith(r):
                                new_key = k.replace(r, nai_keys[r])
                                part[new_key] = part[k]
                                del part[k]
                                log.info(
                                    f"[Converter] Fixed novelai error key {k}"
                                )
                                break

            if fix_clip:
                clip = self.fix_clip(clip)

        ok: dict[str, dict[str, torch.Tensor]] = {
            "unet": {},
            "clip": {},
            "vae": {},
        }

        def _hf(part: str, wk: str, t: torch.Tensor):
            if not isinstance(t, torch.Tensor):
                log.debug("Not a torch tensor, skipping key")
                return

            log.debug(f"Operation {operation[part]}")
            if operation[part] == Operation.CONVERT:
                ok[part][wk] = self.convert_precision(
                    t, precision[part]
                )  # conv_func(t)
            elif operation[part] == Operation.COPY:
                ok[part][wk] = t
            elif operation[part] == Operation.DELETE:
                return

        log.info("[Converter] Converting model...")

        for part_name, part in zip(
            ["unet", "vae", "clip", "other"],
            [unet, vae, clip],
            strict=False,
        ):
            if part:
                match ema_mode:
                    case "remove_ema":
                        for k, v in tqdm.tqdm(part.items()):
                            if "model_ema." not in k:
                                _hf(part_name, k, v)
                    case "ema_only":
                        if not self.has_ema(part):
                            log.warn("No EMA to extract")
                            return
                        for k in tqdm.tqdm(part):
                            ema_k = "___"
                            try:
                                ema_k = "model_ema." + k[6:].replace(".", "")
                            except Exception:
                                pass
                            if ema_k in part:
                                _hf(part_name, k, part[ema_k])
                            elif not k.startswith("model_ema.") or k in [
                                "model_ema.num_updates",
                                "model_ema.decay",
                            ]:
                                _hf(part_name, k, part[k])
                    case "disabled" | _:
                        for k, v in tqdm.tqdm(part.items()):
                            _hf(part_name, k, v)

                if save_separately:
                    if remove_junk:
                        ok = self.do_remove_junk(ok)

                    flat_ok = {
                        k: v
                        for _, subdict in ok.items()
                        for k, v in subdict.items()
                    }
                    save_path = (
                        folder / f"{part_name}-{save_name}.safetensors"
                    ).as_posix()
                    safetensors.torch.save_file(flat_ok, save_path)
                    ok: dict[str, dict[str, torch.Tensor]] = {
                        "unet": {},
                        "clip": {},
                        "vae": {},
                    }

        if save_separately:
            return ()

        if remove_junk:
            ok = self.do_remove_junk(ok)

        flat_ok = {
            k: v for _, subdict in ok.items() for k, v in subdict.items()
        }

        try:
            safetensors.torch.save_file(
                flat_ok, (folder / f"{save_name}.safetensors").as_posix()
            )
        except Exception as e:
            log.error(e)

        return ()