🔧 Compile Model¶
Documentation¶
- Class name:
ModelCompile+
- Category:
essentials
- Output node:
False
The ModelCompile node is designed to optimize and compile PyTorch models for improved execution performance. It allows for dynamic compilation with options for full graph compilation and various optimization modes, aiming to enhance model efficiency and runtime.
Input types¶
Required¶
model
- The PyTorch model to be compiled. This parameter is crucial as it determines the model that will undergo optimization and compilation for enhanced performance.
- Comfy dtype:
MODEL
- Python dtype:
torch.nn.Module
fullgraph
- A boolean flag indicating whether to compile the entire computational graph of the model. Enabling this can potentially improve performance by optimizing the model's execution as a whole.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
dynamic
- A boolean flag that enables dynamic compilation of the model. This allows for more flexible optimizations based on the model's runtime behavior.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
mode
- Specifies the compilation mode, which can range from default settings to modes focused on reducing overhead or maximizing autotuning, with or without CUDA graphs.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
Output types¶
model
- Comfy dtype:
MODEL
- The optimized and compiled PyTorch model, ready for improved execution.
- Python dtype:
torch.nn.Module
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class ModelCompile():
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"fullgraph": ("BOOLEAN", { "default": False }),
"dynamic": ("BOOLEAN", { "default": False }),
"mode": (["default", "reduce-overhead", "max-autotune", "max-autotune-no-cudagraphs"],),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "execute"
CATEGORY = "essentials"
def execute(self, model, fullgraph, dynamic, mode):
work_model = model.clone()
torch._dynamo.config.suppress_errors = True
work_model.model.diffusion_model = torch.compile(work_model.model.diffusion_model, dynamic=dynamic, fullgraph=fullgraph, mode=mode)
return( work_model, )