Model Pruner (mtb)¶
Documentation¶
- Class name:
Model Pruner (mtb)
- Category:
mtb/prune
- Output node:
True
The MTB_ModelPruner node is designed for optimizing and pruning machine learning models to enhance performance and efficiency. It supports operations such as precision conversion, removal of unnecessary components, and conditional execution based on model characteristics.
Input types¶
Required¶
save_separately
- Determines whether model components should be saved separately, affecting the organization of the output files.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
save_folder
- Specifies the directory where the pruned model and its components will be saved.
- Comfy dtype:
STRING
- Python dtype:
str
fix_clip
- Indicates whether to apply fixes to the CLIP model component, potentially improving compatibility or performance.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
remove_junk
- Controls the removal of unnecessary or redundant parts of the model to streamline and optimize.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
ema_mode
- Defines the mode for Exponential Moving Average (EMA) handling within the model, affecting model stability and performance.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
precision_unet
- Sets the precision level for the U-Net model component, impacting memory usage and computational efficiency.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
operation_unet
- Specifies the operation to be performed on the U-Net model component, such as pruning or optimization.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
precision_clip
- Sets the precision level for the CLIP model component, impacting memory usage and computational efficiency.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
operation_clip
- Specifies the operation to be performed on the CLIP model component, such as pruning or optimization.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
precision_vae
- Sets the precision level for the VAE model component, impacting memory usage and computational efficiency.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
operation_vae
- Specifies the operation to be performed on the VAE model component, such as pruning or optimization.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
Optional¶
unet
- Optional U-Net model component to be pruned or optimized, provided as a dictionary of tensors.
- Comfy dtype:
MODEL
- Python dtype:
dict[str, torch.Tensor] | None
clip
- Optional CLIP model component to be pruned or optimized, provided as a dictionary of tensors.
- Comfy dtype:
CLIP
- Python dtype:
dict[str, torch.Tensor] | None
vae
- Optional VAE model component to be pruned or optimized, provided as a dictionary of tensors.
- Comfy dtype:
VAE
- Python dtype:
dict[str, torch.Tensor] | None
Output types¶
The node doesn't have output types
Usage tips¶
- Infra type:
CPU
- Common nodes: unknown
Source code¶
class MTB_ModelPruner:
@classmethod
def INPUT_TYPES(cls):
return {
"optional": {
"unet": ("MODEL",),
"clip": ("CLIP",),
"vae": ("VAE",),
},
"required": {
"save_separately": ("BOOLEAN", {"default": False}),
"save_folder": ("STRING", {"default": "checkpoints/ComfyUI"}),
"fix_clip": ("BOOLEAN", {"default": True}),
"remove_junk": ("BOOLEAN", {"default": True}),
"ema_mode": (
("disabled", "remove_ema", "ema_only"),
{"default": "remove_ema"},
),
"precision_unet": (
Precision.list_members(),
{"default": Precision.FULL.value},
),
"operation_unet": (
Operation.list_members(),
{"default": Operation.CONVERT.value},
),
"precision_clip": (
Precision.list_members(),
{"default": Precision.FULL.value},
),
"operation_clip": (
Operation.list_members(),
{"default": Operation.CONVERT.value},
),
"precision_vae": (
Precision.list_members(),
{"default": Precision.FULL.value},
),
"operation_vae": (
Operation.list_members(),
{"default": Operation.CONVERT.value},
),
},
}
OUTPUT_NODE = True
RETURN_TYPES = ()
CATEGORY = "mtb/prune"
FUNCTION = "prune"
def convert_precision(self, tensor: torch.Tensor, precision: Precision):
precision = Precision.from_str(precision)
log.debug(f"Converting to {precision}")
match precision:
case Precision.FP8:
if tensor.dtype in dtypes_to_fp8:
return tensor.to(torch.float8_e4m3fn)
log.error(f"Cannot convert {tensor.dtype} to fp8")
return tensor
case Precision.FP16:
if tensor.dtype in dtypes_to_fp16:
return tensor.half()
log.error(f"Cannot convert {tensor.dtype} to f16")
return tensor
case Precision.BF16:
if tensor.dtype in dtypes_to_bf16:
return tensor.bfloat16()
log.error(f"Cannot convert {tensor.dtype} to bf16")
return tensor
case Precision.FULL | Precision.FP32:
return tensor
def is_sdxl_model(self, clip: dict[str, torch.Tensor] | None):
if clip:
return (any(k.startswith("conditioner.embedders") for k in clip),)
return False
def has_ema(self, unet: dict[str, torch.Tensor]):
return any(k.startswith("model_ema") for k in unet)
def fix_clip(self, clip: dict[str, torch.Tensor] | None):
if self.is_sdxl_model(clip):
log.warn("[fix clip] SDXL not supported")
return
if clip is None:
return
position_id_key = (
"cond_stage_model.transformer.text_model.embeddings.position_ids"
)
if position_id_key in clip:
correct = torch.Tensor([list(range(77))]).to(torch.int64)
now = clip[position_id_key].to(torch.int64)
broken = correct.ne(now)
broken = [i for i in range(77) if broken[0][i]]
if len(broken) != 0:
clip[position_id_key] = correct
log.info(f"[Converter] Fixed broken clip\n{broken}")
else:
log.info(
"[Converter] Clip in this model is fine, skip fixing..."
)
else:
log.info("[Converter] Missing position id in model, try fixing...")
clip[position_id_key] = torch.Tensor([list(range(77))]).to(
torch.int64
)
return clip
def get_dicts(self, unet, clip, vae):
clip_sd = clip.get_sd()
state_dict = unet.model.state_dict_for_saving(
clip_sd, vae.get_sd(), None
)
unet = {
k: v
for k, v in state_dict.items()
if k.startswith("model.diffusion_model")
}
clip = {
k: v
for k, v in state_dict.items()
if k.startswith("cond_stage_model")
or k.startswith("conditioner.embedders")
}
vae = {
k: v
for k, v in state_dict.items()
if k.startswith("first_stage_model")
}
other = {
k: v
for k, v in state_dict.items()
if k not in unet and k not in vae and k not in clip
}
return (unet, clip, vae, other)
def do_remove_junk(self, tensors: dict[str, dict[str, torch.Tensor]]):
need_delete: list[str] = []
for layer in tensors:
for key in layer:
for jk in PRUNE_DATA["known_junk_prefix"]:
if key.startswith(jk):
need_delete.append(".".join([layer, key]))
for k in need_delete:
log.info(f"Removing junk data: {k}")
del tensors[k]
return tensors
def prune(
self,
*,
save_separately: bool,
save_folder: str,
fix_clip: bool,
remove_junk: bool,
ema_mode: str,
precision_unet: Precision,
precision_clip: Precision,
precision_vae: Precision,
operation_unet: str,
operation_clip: str,
operation_vae: str,
unet: dict[str, torch.Tensor] | None = None,
clip: dict[str, torch.Tensor] | None = None,
vae: dict[str, torch.Tensor] | None = None,
):
operation = {
"unet": Operation.from_str(operation_unet),
"clip": Operation.from_str(operation_clip),
"vae": Operation.from_str(operation_vae),
}
precision = {
"unet": Precision.from_str(precision_unet),
"clip": Precision.from_str(precision_clip),
"vae": Precision.from_str(precision_vae),
}
unet, clip, vae, _other = self.get_dicts(unet, clip, vae)
out_dir = Path(save_folder)
folder = out_dir.parent
if not out_dir.is_absolute():
folder = (comfy_out_dir / save_folder).parent
if not folder.exists():
if folder.parent.exists():
folder.mkdir()
else:
raise FileNotFoundError(
f"Folder {folder.parent} does not exist"
)
name = out_dir.name
save_name = f"{name}-{precision_unet}"
if ema_mode != "disabled":
save_name += f"-{ema_mode}"
if fix_clip:
save_name += "-clip-fix"
if (
any(o == Operation.CONVERT for o in operation.values())
and any(p == Precision.FP8 for p in precision.values())
and torch.__version__ < "2.1.0"
):
raise NotImplementedError(
"PyTorch 2.1.0 or newer is required for fp8 conversion"
)
if not self.is_sdxl_model(clip):
for part in [unet, vae, clip]:
if part:
nai_keys = PRUNE_DATA["nai_keys"]
for k in list(part.keys()):
for r in nai_keys:
if isinstance(k, str) and k.startswith(r):
new_key = k.replace(r, nai_keys[r])
part[new_key] = part[k]
del part[k]
log.info(
f"[Converter] Fixed novelai error key {k}"
)
break
if fix_clip:
clip = self.fix_clip(clip)
ok: dict[str, dict[str, torch.Tensor]] = {
"unet": {},
"clip": {},
"vae": {},
}
def _hf(part: str, wk: str, t: torch.Tensor):
if not isinstance(t, torch.Tensor):
log.debug("Not a torch tensor, skipping key")
return
log.debug(f"Operation {operation[part]}")
if operation[part] == Operation.CONVERT:
ok[part][wk] = self.convert_precision(
t, precision[part]
) # conv_func(t)
elif operation[part] == Operation.COPY:
ok[part][wk] = t
elif operation[part] == Operation.DELETE:
return
log.info("[Converter] Converting model...")
for part_name, part in zip(
["unet", "vae", "clip", "other"],
[unet, vae, clip],
strict=False,
):
if part:
match ema_mode:
case "remove_ema":
for k, v in tqdm.tqdm(part.items()):
if "model_ema." not in k:
_hf(part_name, k, v)
case "ema_only":
if not self.has_ema(part):
log.warn("No EMA to extract")
return
for k in tqdm.tqdm(part):
ema_k = "___"
try:
ema_k = "model_ema." + k[6:].replace(".", "")
except Exception:
pass
if ema_k in part:
_hf(part_name, k, part[ema_k])
elif not k.startswith("model_ema.") or k in [
"model_ema.num_updates",
"model_ema.decay",
]:
_hf(part_name, k, part[k])
case "disabled" | _:
for k, v in tqdm.tqdm(part.items()):
_hf(part_name, k, v)
if save_separately:
if remove_junk:
ok = self.do_remove_junk(ok)
flat_ok = {
k: v
for _, subdict in ok.items()
for k, v in subdict.items()
}
save_path = (
folder / f"{part_name}-{save_name}.safetensors"
).as_posix()
safetensors.torch.save_file(flat_ok, save_path)
ok: dict[str, dict[str, torch.Tensor]] = {
"unet": {},
"clip": {},
"vae": {},
}
if save_separately:
return ()
if remove_junk:
ok = self.do_remove_junk(ok)
flat_ok = {
k: v for _, subdict in ok.items() for k, v in subdict.items()
}
try:
safetensors.torch.save_file(
flat_ok, (folder / f"{save_name}.safetensors").as_posix()
)
except Exception as e:
log.error(e)
return ()