Skip to content

SUPIR Model Loader (Legacy)

Documentation

  • Class name: SUPIR_model_loader
  • Category: SUPIR
  • Output node: False

This node is responsible for loading the SUPIR model, a key component in the SUPIR framework for image processing and enhancement. It handles the initialization and configuration of the model, ensuring it is ready for subsequent image processing tasks.

Input types

Required

  • supir_model
    • Specifies the path to the SUPIR model's checkpoint files, crucial for loading the model's state for image processing tasks.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: List[str]
  • sdxl_model
    • Specifies the path to the SDXL model's checkpoint files, which are merged with the SUPIR model to enhance its capabilities.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: List[str]
  • fp8_unet
    • A flag to determine whether to cast the UNet weights to a lower precision format to save VRAM, with a slight impact on quality.
    • Comfy dtype: BOOLEAN
    • Python dtype: bool
  • diffusion_dtype
    • Specifies the data type for diffusion operations, with options to optimize for performance or compatibility.
    • Comfy dtype: COMBO[STRING]
    • Python dtype: str

Output types

  • SUPIR_model
    • Comfy dtype: SUPIRMODEL
    • The loaded and configured SUPIR model, ready for image processing tasks.
    • Python dtype: torch.nn.Module
  • SUPIR_VAE
    • Comfy dtype: SUPIRVAE
    • The loaded VAE component of the SUPIR model, essential for certain image processing operations.
    • Python dtype: torch.nn.Module

Usage tips

  • Infra type: CPU
  • Common nodes: unknown

Source code

class SUPIR_model_loader:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": {
            "supir_model": (folder_paths.get_filename_list("checkpoints"),),
            "sdxl_model": (folder_paths.get_filename_list("checkpoints"),),
            "fp8_unet": ("BOOLEAN", {"default": False}),
            "diffusion_dtype": (
                    [
                        'fp16',
                        'bf16',
                        'fp32',
                        'auto'
                    ], {
                        "default": 'auto'
                    }),
            },
        }

    RETURN_TYPES = ("SUPIRMODEL", "SUPIRVAE")
    RETURN_NAMES = ("SUPIR_model","SUPIR_VAE",)
    FUNCTION = "process"
    CATEGORY = "SUPIR"
    DESCRIPTION = """
Old loader, not recommended to be used.  
Loads the SUPIR model and the selected SDXL model and merges them.
"""

    def process(self, supir_model, sdxl_model, diffusion_dtype, fp8_unet):
        device = mm.get_torch_device()
        mm.unload_all_models()

        SUPIR_MODEL_PATH = folder_paths.get_full_path("checkpoints", supir_model)
        SDXL_MODEL_PATH = folder_paths.get_full_path("checkpoints", sdxl_model)

        config_path = os.path.join(script_directory, "options/SUPIR_v0.yaml")
        clip_config_path = os.path.join(script_directory, "configs/clip_vit_config.json")
        tokenizer_path = os.path.join(script_directory, "configs/tokenizer")

        custom_config = {
            'sdxl_model': sdxl_model,
            'diffusion_dtype': diffusion_dtype,
            'supir_model': supir_model,
            'fp8_unet': fp8_unet,
        }

        if diffusion_dtype == 'auto':
            try:
                if mm.should_use_fp16():
                    print("Diffusion using fp16")
                    dtype = torch.float16
                    model_dtype = 'fp16'
                elif mm.should_use_bf16():
                    print("Diffusion using bf16")
                    dtype = torch.bfloat16
                    model_dtype = 'bf16'
                else:
                    print("Diffusion using fp32")
                    dtype = torch.float32
                    model_dtype = 'fp32'
            except:
                raise AttributeError("ComfyUI version too old, can't autodetect properly. Set your dtypes manually.")
        else:
            print(f"Diffusion using {diffusion_dtype}")
            dtype = convert_dtype(diffusion_dtype)
            model_dtype = diffusion_dtype

        if not hasattr(self, "model") or self.model is None or self.current_config != custom_config:
            self.current_config = custom_config
            self.model = None

            mm.soft_empty_cache()

            config = OmegaConf.load(config_path)

            if mm.XFORMERS_IS_AVAILABLE:
                print("Using XFORMERS")
                config.model.params.control_stage_config.params.spatial_transformer_attn_type = "softmax-xformers"
                config.model.params.network_config.params.spatial_transformer_attn_type = "softmax-xformers"
                config.model.params.first_stage_config.params.ddconfig.attn_type = "vanilla-xformers" 

            config.model.params.diffusion_dtype = model_dtype
            config.model.target = ".SUPIR.models.SUPIR_model_v2.SUPIRModel"
            pbar = comfy.utils.ProgressBar(5)

            self.model = instantiate_from_config(config.model).cpu()
            self.model.model.dtype = dtype
            pbar.update(1)
            try:
                print(f"Attempting to load SDXL model: [{SDXL_MODEL_PATH}]")
                sdxl_state_dict = load_state_dict(SDXL_MODEL_PATH)
                self.model.load_state_dict(sdxl_state_dict, strict=False)
                if fp8_unet:
                    self.model.model.to(torch.float8_e4m3fn)
                else:
                    self.model.model.to(dtype)
                pbar.update(1)
            except:
                raise Exception("Failed to load SDXL model")

            #first clip model from SDXL checkpoint
            try:
                print("Loading first clip model from SDXL checkpoint")

                replace_prefix = {}
                replace_prefix["conditioner.embedders.0.transformer."] = ""

                sd = comfy.utils.state_dict_prefix_replace(sdxl_state_dict, replace_prefix, filter_keys=False)
                clip_text_config = CLIPTextConfig.from_pretrained(clip_config_path)
                self.model.conditioner.embedders[0].tokenizer = CLIPTokenizer.from_pretrained(tokenizer_path)
                self.model.conditioner.embedders[0].transformer = CLIPTextModel(clip_text_config)
                self.model.conditioner.embedders[0].transformer.load_state_dict(sd, strict=False)
                self.model.conditioner.embedders[0].eval()
                self.model.conditioner.embedders[0].to(dtype)
                for param in self.model.conditioner.embedders[0].parameters():
                    param.requires_grad = False
                pbar.update(1)
            except:
                raise Exception("Failed to load first clip model from SDXL checkpoint")

            del sdxl_state_dict

            #second clip model from SDXL checkpoint
            try:
                print("Loading second clip model from SDXL checkpoint")
                replace_prefix2 = {}
                replace_prefix2["conditioner.embedders.1.model."] = ""
                sd = comfy.utils.state_dict_prefix_replace(sd, replace_prefix2, filter_keys=True)                
                clip_g = build_text_model_from_openai_state_dict(sd, device, cast_dtype=dtype)
                self.model.conditioner.embedders[1].model = clip_g
                self.model.conditioner.embedders[1].to(dtype)
                pbar.update(1)
            except:
                raise Exception("Failed to load second clip model from SDXL checkpoint")

            del sd, clip_g

            try:
                print(f'Attempting to load SUPIR model: [{SUPIR_MODEL_PATH}]')
                supir_state_dict = load_state_dict(SUPIR_MODEL_PATH)
                self.model.load_state_dict(supir_state_dict, strict=False)
                if fp8_unet:
                    self.model.model.to(torch.float8_e4m3fn)
                else:
                    self.model.model.to(dtype)
                del supir_state_dict
                pbar.update(1)
            except:
                raise Exception("Failed to load SUPIR model")
            mm.soft_empty_cache()

        return (self.model, self.model.first_stage_model,)