Model Pruner (mtb)


  • Class name: Model Pruner (mtb)
  • Category: mtb/prune
  • Output node: True

The MTB_ModelPruner node is designed for optimizing and pruning machine learning models to enhance performance and efficiency. It supports operations such as precision conversion, removal of unnecessary components, and conditional execution based on model characteristics.

Input types


  • save_separately
    • Determines whether model components should be saved separately, affecting the organization of the output files.
    • Comfy dtype: BOOLEAN
    • Python dtype: bool
  • save_folder
    • Specifies the directory where the pruned model and its components will be saved.
    • Comfy dtype: STRING
    • Python dtype: str
  • fix_clip
    • Indicates whether to apply fixes to the CLIP model component, potentially improving compatibility or performance.
    • Comfy dtype: BOOLEAN
    • Python dtype: bool
  • remove_junk
    • Controls the removal of unnecessary or redundant parts of the model to streamline and optimize.
    • Comfy dtype: BOOLEAN
    • Python dtype: bool
  • ema_mode
    • Defines the mode for Exponential Moving Average (EMA) handling within the model, affecting model stability and performance.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • precision_unet
    • Sets the precision level for the U-Net model component, impacting memory usage and computational efficiency.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • operation_unet
    • Specifies the operation to be performed on the U-Net model component, such as pruning or optimization.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • precision_clip
    • Sets the precision level for the CLIP model component, impacting memory usage and computational efficiency.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • operation_clip
    • Specifies the operation to be performed on the CLIP model component, such as pruning or optimization.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • precision_vae
    • Sets the precision level for the VAE model component, impacting memory usage and computational efficiency.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str
  • operation_vae
    • Specifies the operation to be performed on the VAE model component, such as pruning or optimization.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str


  • unet
    • Optional U-Net model component to be pruned or optimized, provided as a dictionary of tensors.
    • Comfy dtype: MODEL
    • Python dtype: dict[str, torch.Tensor] | None
  • clip
    • Optional CLIP model component to be pruned or optimized, provided as a dictionary of tensors.
    • Comfy dtype: CLIP
    • Python dtype: dict[str, torch.Tensor] | None
  • vae
    • Optional VAE model component to be pruned or optimized, provided as a dictionary of tensors.
    • Comfy dtype: VAE
    • Python dtype: dict[str, torch.Tensor] | None

Output types

The node doesn't have output types

Usage tips

  • Infra type: CPU
  • Common nodes: unknown

Source code

class MTB_ModelPruner:
    def INPUT_TYPES(cls):
        return {
            "optional": {
                "unet": ("MODEL",),
                "clip": ("CLIP",),
                "vae": ("VAE",),
            "required": {
                "save_separately": ("BOOLEAN", {"default": False}),
                "save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}),
                "fix_clip": ("BOOLEAN", {"default": True}),
                "remove_junk": ("BOOLEAN", {"default": True}),
                "ema_mode": (
                    ("disabled", "remove_ema", "ema_only"),
                    {"default": "remove_ema"},
                "precision_unet": (
                    {"default": Precision.FULL.value},
                "operation_unet": (
                    {"default": Operation.CONVERT.value},
                "precision_clip": (
                    {"default": Precision.FULL.value},
                "operation_clip": (
                    {"default": Operation.CONVERT.value},
                "precision_vae": (
                    {"default": Precision.FULL.value},
                "operation_vae": (
                    {"default": Operation.CONVERT.value},

    OUTPUT_NODE = True
    CATEGORY = "mtb/prune"
    FUNCTION = "prune"

    def convert_precision(self, tensor: torch.Tensor, precision: Precision):
        precision = Precision.from_str(precision)
        log.debug(f"Converting to {precision}")
        match precision:
            case Precision.FP8:
                if tensor.dtype in dtypes_to_fp8:
                log.error(f"Cannot convert {tensor.dtype} to fp8")
                return tensor
            case Precision.FP16:
                if tensor.dtype in dtypes_to_fp16:
                    return tensor.half()
                log.error(f"Cannot convert {tensor.dtype} to f16")
                return tensor
            case Precision.BF16:
                if tensor.dtype in dtypes_to_bf16:
                    return tensor.bfloat16()
                log.error(f"Cannot convert {tensor.dtype} to bf16")
                return tensor
            case Precision.FULL | Precision.FP32:
                return tensor

    def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None):
        if clip:
            return (any(k.startswith("conditioner.embedders") for k in clip),)
        return False

    def has_ema(self, unet: dict[str, torch.Tensor]):
        return any(k.startswith("model_ema") for k in unet)

    def fix_clip(self, clip: dict[str, torch.Tensor] | None):
        if self.is_sdxl_model(clip):
            log.warn("[fix clip] SDXL not supported")

        if clip is None:

        position_id_key = (
        if position_id_key in clip:
            correct = torch.Tensor([list(range(77))]).to(torch.int64)
            now = clip[position_id_key].to(torch.int64)

            broken =
            broken = [i for i in range(77) if broken[0][i]]

            if len(broken) != 0:
                clip[position_id_key] = correct
      "[Converter] Fixed broken clip\n{broken}")
                    "[Converter] Clip in this model is fine, skip fixing..."

  "[Converter] Missing position id in model, try fixing...")
            clip[position_id_key] = torch.Tensor([list(range(77))]).to(
        return clip

    def get_dicts(self, unet, clip, vae):
        clip_sd = clip.get_sd()
        state_dict = unet.model.state_dict_for_saving(
            clip_sd, vae.get_sd(), None

        unet = {
            k: v
            for k, v in state_dict.items()
            if k.startswith("model.diffusion_model")
        clip = {
            k: v
            for k, v in state_dict.items()
            if k.startswith("cond_stage_model")
            or k.startswith("conditioner.embedders")
        vae = {
            k: v
            for k, v in state_dict.items()
            if k.startswith("first_stage_model")

        other = {
            k: v
            for k, v in state_dict.items()
            if k not in unet and k not in vae and k not in clip

        return (unet, clip, vae, other)

    def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]):
        need_delete: list[str] = []
        for layer in tensors:
            for key in layer:
                for jk in PRUNE_DATA["known_junk_prefix"]:
                    if key.startswith(jk):
                        need_delete.append(".".join([layer, key]))

        for k in need_delete:
  "Removing junk data: {k}")
            del tensors[k]

        return tensors

    def prune(
        save_separately: bool,
        save_folder: str,
        fix_clip: bool,
        remove_junk: bool,
        ema_mode: str,
        precision_unet: Precision,
        precision_clip: Precision,
        precision_vae: Precision,
        operation_unet: str,
        operation_clip: str,
        operation_vae: str,
        unet: dict[str, torch.Tensor] | None = None,
        clip: dict[str, torch.Tensor] | None = None,
        vae: dict[str, torch.Tensor] | None = None,
        operation = {
            "unet": Operation.from_str(operation_unet),
            "clip": Operation.from_str(operation_clip),
            "vae": Operation.from_str(operation_vae),
        precision = {
            "unet": Precision.from_str(precision_unet),
            "clip": Precision.from_str(precision_clip),
            "vae": Precision.from_str(precision_vae),

        unet, clip, vae, _other = self.get_dicts(unet, clip, vae)

        out_dir = Path(save_folder)
        folder = out_dir.parent
        if not out_dir.is_absolute():
            folder = (comfy_out_dir / save_folder).parent

        if not folder.exists():
            if folder.parent.exists():
                raise FileNotFoundError(
                    f"Folder {folder.parent} does not exist"

        name =
        save_name = f"{name}-{precision_unet}"
        if ema_mode != "disabled":
            save_name += f"-{ema_mode}"
        if fix_clip:
            save_name += "-clip-fix"

        if (
            any(o == Operation.CONVERT for o in operation.values())
            and any(p == Precision.FP8 for p in precision.values())
            and torch.__version__ < "2.1.0"
            raise NotImplementedError(
                "PyTorch 2.1.0 or newer is required for fp8 conversion"

        if not self.is_sdxl_model(clip):
            for part in [unet, vae, clip]:
                if part:
                    nai_keys = PRUNE_DATA["nai_keys"]
                    for k in list(part.keys()):
                        for r in nai_keys:
                            if isinstance(k, str) and k.startswith(r):
                                new_key = k.replace(r, nai_keys[r])
                                part[new_key] = part[k]
                                del part[k]
                                    f"[Converter] Fixed novelai error key {k}"

            if fix_clip:
                clip = self.fix_clip(clip)

        ok: dict[str, dict[str, torch.Tensor]] = {
            "unet": {},
            "clip": {},
            "vae": {},

        def _hf(part: str, wk: str, t: torch.Tensor):
            if not isinstance(t, torch.Tensor):
                log.debug("Not a torch tensor, skipping key")

            log.debug(f"Operation {operation[part]}")
            if operation[part] == Operation.CONVERT:
                ok[part][wk] = self.convert_precision(
                    t, precision[part]
                )  # conv_func(t)
            elif operation[part] == Operation.COPY:
                ok[part][wk] = t
            elif operation[part] == Operation.DELETE:
                return"[Converter] Converting model...")

        for part_name, part in zip(
            ["unet", "vae", "clip", "other"],
            [unet, vae, clip],
            if part:
                match ema_mode:
                    case "remove_ema":
                        for k, v in tqdm.tqdm(part.items()):
                            if "model_ema." not in k:
                                _hf(part_name, k, v)
                    case "ema_only":
                        if not self.has_ema(part):
                            log.warn("No EMA to extract")
                        for k in tqdm.tqdm(part):
                            ema_k = "___"
                                ema_k = "model_ema." + k[6:].replace(".", "")
                            except Exception:
                            if ema_k in part:
                                _hf(part_name, k, part[ema_k])
                            elif not k.startswith("model_ema.") or k in [
                                _hf(part_name, k, part[k])
                    case "disabled" | _:
                        for k, v in tqdm.tqdm(part.items()):
                            _hf(part_name, k, v)

                if save_separately:
                    if remove_junk:
                        ok = self.do_remove_junk(ok)

                    flat_ok = {
                        k: v
                        for _, subdict in ok.items()
                        for k, v in subdict.items()
                    save_path = (
                        folder / f"{part_name}-{save_name}.safetensors"
                    safetensors.torch.save_file(flat_ok, save_path)
                    ok: dict[str, dict[str, torch.Tensor]] = {
                        "unet": {},
                        "clip": {},
                        "vae": {},

        if save_separately:
            return ()

        if remove_junk:
            ok = self.do_remove_junk(ok)

        flat_ok = {
            k: v for _, subdict in ok.items() for k, v in subdict.items()

                flat_ok, (folder / f"{save_name}.safetensors").as_posix()
        except Exception as e:

        return ()