CCSR_Upscale¶
Documentation¶
- Class name:
CCSR_Upscale
- Category:
CCSR
- Output node:
False
The node CCSR_Upscale
is designed to enhance the resolution of images or latent representations through advanced upscaling techniques. It leverages custom algorithms and models to upscale images with improved quality and detail, aiming to achieve higher fidelity outputs compared to traditional upscaling methods.
Input types¶
Required¶
ccsr_model
- Specifies the model used for the upscaling process, central to determining the upscaling technique and its effectiveness.
- Comfy dtype:
CCSRMODEL
- Python dtype:
str
image
- The image to be upscaled, serving as the primary input for the upscaling process.
- Comfy dtype:
IMAGE
- Python dtype:
torch.Tensor
resize_method
- Defines the method used to resize the image, impacting the upscaling quality and characteristics.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
scale_by
- Determines the scaling factor for the upscaling process, affecting the final size of the output.
- Comfy dtype:
FLOAT
- Python dtype:
float
steps
- Specifies the number of steps to perform in the upscaling process, influencing the detail and quality of the upscaled image.
- Comfy dtype:
INT
- Python dtype:
int
t_max
- The maximum temperature for sampling, affecting the randomness and detail in the upscaled image.
- Comfy dtype:
FLOAT
- Python dtype:
float
t_min
- The minimum temperature for sampling, setting the lower bound for randomness and detail in the upscaled image.
- Comfy dtype:
FLOAT
- Python dtype:
float
sampling_method
- Determines the sampling strategy used during upscaling, affecting the texture and quality of the output.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
tile_size
- The size of tiles used in the upscaling process, impacting the processing efficiency and detail capture.
- Comfy dtype:
INT
- Python dtype:
int
tile_stride
- The stride of tiles during upscaling, affecting overlap and detail continuity between tiles.
- Comfy dtype:
INT
- Python dtype:
int
vae_tile_size_encode
- Tile size for the VAE encoding step, influencing the detail preservation during encoding.
- Comfy dtype:
INT
- Python dtype:
int
vae_tile_size_decode
- Tile size for the VAE decoding step, affecting the detail reconstruction during decoding.
- Comfy dtype:
INT
- Python dtype:
int
color_fix_type
- Specifies the method used for color correction, crucial for maintaining color accuracy in the upscaled image.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
keep_model_loaded
- Indicates whether the upscaling model should remain loaded between invocations, affecting processing speed and resource usage.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
seed
- The random seed for the upscaling process, ensuring reproducibility of the results.
- Comfy dtype:
INT
- Python dtype:
int
Output types¶
upscaled_image
- Comfy dtype:
IMAGE
- The output of the upscaling process, providing enhanced resolution images with improved quality and detail.
- Python dtype:
torch.Tensor
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class CCSR_Upscale:
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod
def INPUT_TYPES(s):
return {"required": {
"ccsr_model": ("CCSRMODEL", ),
"image": ("IMAGE", ),
"resize_method": (s.upscale_methods, {"default": "lanczos"}),
"scale_by": ("FLOAT", {"default": 1.0, "min": 0.01, "max": 20.0, "step": 0.01}),
"steps": ("INT", {"default": 45, "min": 3, "max": 4096, "step": 1}),
"t_max": ("FLOAT", {"default": 0.6667,"min": 0, "max": 1, "step": 0.01}),
"t_min": ("FLOAT", {"default": 0.3333,"min": 0, "max": 1, "step": 0.01}),
"sampling_method": (
[
'ccsr',
'ccsr_tiled_mixdiff',
'ccsr_tiled_vae_gaussian_weights',
], {
"default": 'ccsr_tiled_mixdiff'
}),
"tile_size": ("INT", {"default": 512, "min": 1, "max": 4096, "step": 1}),
"tile_stride": ("INT", {"default": 256, "min": 1, "max": 4096, "step": 1}),
"vae_tile_size_encode": ("INT", {"default": 1024, "min": 2, "max": 4096, "step": 8}),
"vae_tile_size_decode": ("INT", {"default": 1024, "min": 2, "max": 4096, "step": 8}),
"color_fix_type": (
[
'none',
'adain',
'wavelet',
], {
"default": 'adain'
}),
"keep_model_loaded": ("BOOLEAN", {"default": False}),
"seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
},
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES =("upscaled_image",)
FUNCTION = "process"
CATEGORY = "CCSR"
@torch.no_grad()
def process(self, ccsr_model, image, resize_method, scale_by, steps, t_max, t_min, tile_size, tile_stride, color_fix_type, keep_model_loaded, vae_tile_size_encode, vae_tile_size_decode, sampling_method, seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
comfy.model_management.unload_all_models()
device = comfy.model_management.get_torch_device()
config_path = os.path.join(script_directory, "configs/model/ccsr_stage2.yaml")
empty_text_embed = torch.load(os.path.join(script_directory, "empty_text_embed.pt"), map_location=device)
dtype = torch.float16 if comfy.model_management.should_use_fp16() and not comfy.model_management.is_device_mps(device) else torch.float32
if not hasattr(self, "model") or self.model is None:
config = OmegaConf.load(config_path)
self.model = instantiate_from_config(config)
load_state_dict(self.model, torch.load(ccsr_model, map_location="cpu"), strict=True)
# reload preprocess model if specified
self.model.freeze()
self.model.to(device, dtype=dtype)
sampler = SpacedSampler(self.model, var_type="fixed_small")
batch_size = image.shape[0]
image, = ImageScaleBy.upscale(self, image, resize_method, scale_by)
# Assuming 'image' is a PyTorch tensor with shape [B, H, W, C] and you want to resize it.
B, H, W, C = image.shape
# Calculate the new height and width, rounding down to the nearest multiple of 64.
new_height = H // 64 * 64
new_width = W // 64 * 64
# Reorder to [B, C, H, W] before using interpolate.
image = image.permute(0, 3, 1, 2).contiguous()
# Resize the image tensor.
resized_image = F.interpolate(image, size=(new_height, new_width), mode='bicubic', align_corners=False)
# Move the tensor to the GPU.
#resized_image = resized_image.to(device)
strength = 1.0
self.model.control_scales = [strength] * 13
height, width = resized_image.size(-2), resized_image.size(-1)
shape = (1, 4, height // 8, width // 8)
x_T = torch.randn(shape, device=self.model.device, dtype=torch.float32)
autocast_condition = dtype == torch.float16 and not comfy.model_management.is_device_mps(device)
out = []
pbar = comfy.utils.ProgressBar(batch_size)
with torch.autocast(comfy.model_management.get_autocast_device(device), dtype=dtype) if autocast_condition else nullcontext():
for i in range(batch_size):
img = resized_image[i].unsqueeze(0).to(device)
if sampling_method == 'ccsr_tiled_mixdiff':
self.model.reset_encoder_decoder()
print("Using tiled mixdiff")
samples = sampler.sample_with_mixdiff_ccsr(
empty_text_embed, tile_size=tile_size, tile_stride=tile_stride,
steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=img,
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0,
color_fix_type=color_fix_type
)
elif sampling_method == 'ccsr_tiled_vae_gaussian_weights':
self.model._init_tiled_vae(encoder_tile_size=vae_tile_size_encode // 8, decoder_tile_size=vae_tile_size_decode // 8)
print("Using gaussian weights")
samples = sampler.sample_with_tile_ccsr(
empty_text_embed, tile_size=tile_size, tile_stride=tile_stride,
steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=img,
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0,
color_fix_type=color_fix_type
)
else:
self.model.reset_encoder_decoder()
print("no tiling")
samples = sampler.sample_ccsr(
empty_text_embed, steps=steps, t_max=t_max, t_min=t_min, shape=shape, cond_img=img,
positive_prompt="", negative_prompt="", x_T=x_T,
cfg_scale=1.0,
color_fix_type=color_fix_type
)
out.append(samples.squeeze(0).cpu())
comfy.model_management.throw_exception_if_processing_interrupted()
pbar.update(1)
print("Sampled image ", i, " out of ", batch_size)
original_height, original_width = H, W
processed_height = samples.size(2)
target_width = int(processed_height * (original_width / original_height))
out_stacked = torch.stack(out, dim=0).cpu().to(torch.float32).permute(0, 2, 3, 1)
resized_back_image, = ImageScale.upscale(self, out_stacked, "lanczos", target_width, processed_height, crop="disabled")
if not keep_model_loaded:
self.model = None
comfy.model_management.soft_empty_cache()
return(resized_back_image,)