MarigoldDepthEstimation¶
Documentation¶
- Class name:
MarigoldDepthEstimation
- Category:
Marigold
- Output node:
False
The MarigoldDepthEstimation node is designed for diffusion-based monocular depth estimation, leveraging the Marigold model to generate depth maps from single images. This node focuses on enhancing the accuracy of depth perception in images by employing a series of denoising steps and iterations, allowing for the creation of ensembled depth maps that offer a more detailed and accurate representation of depth. It supports various configurations, including model selection and inversion of depth map colors, to cater to different use cases and preferences.
Input types¶
Required¶
image
- The input image for which the depth map is to be estimated. This is the primary input over which the depth estimation process is executed.
- Comfy dtype:
IMAGE
- Python dtype:
torch.Tensor
seed
- A seed value for random number generation, ensuring reproducibility of the depth estimation process.
- Comfy dtype:
INT
- Python dtype:
int
denoise_steps
- Specifies the number of steps per depth map to enhance the accuracy of the depth estimation. Increasing this value results in more detailed depth maps at the cost of longer processing times.
- Comfy dtype:
INT
- Python dtype:
int
n_repeat
- Determines the amount of iterations to be ensembled into a single depth map, affecting the overall depth accuracy and detail.
- Comfy dtype:
INT
- Python dtype:
int
regularizer_strength
- Adjusts the strength of the regularization during the ensembling process, typically left at default settings.
- Comfy dtype:
FLOAT
- Python dtype:
float
reduction_method
- Specifies the method used for reducing the ensemble of depth maps into a single map, impacting the final depth representation.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
max_iter
- Sets the maximum number of iterations for the ensembling process, controlling the depth map's refinement level.
- Comfy dtype:
INT
- Python dtype:
int
tol
- Defines the tolerance level for the ensembling process, determining when the process should terminate based on convergence criteria.
- Comfy dtype:
FLOAT
- Python dtype:
float
invert
- Inverts the default depth map colors from black representing the front to the opposite, catering to specific application requirements.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
keep_model_loaded
- Indicates whether the Marigold model should remain loaded between invocations, improving performance at the expense of memory usage.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
n_repeat_batch_size
- Controls how many of the n_repeat iterations are processed in a batch, optimizing VRAM usage and processing speed based on available resources.
- Comfy dtype:
INT
- Python dtype:
int
use_fp16
- Determines whether to use fp16 precision for processing, affecting VRAM usage and potentially the depth map's quality.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
scheduler
- Chooses the scheduler for the denoising steps, affecting the depth map's final quality and characteristics.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
normalize
- Specifies whether the depth map should be normalized, affecting the range and distribution of depth values.
- Comfy dtype:
BOOLEAN
- Python dtype:
bool
Optional¶
model
- Selects between the Marigold model and its LCM version for depth estimation, influencing the depth map's accuracy and appearance.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
Output types¶
ensembled_image
- Comfy dtype:
IMAGE
- The resulting depth map, ensembled from multiple iterations and denoising steps, providing a detailed representation of depth in the image.
- Python dtype:
torch.Tensor
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class MarigoldDepthEstimation:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"image": ("IMAGE", ),
"seed": ("INT", {"default": 123,"min": 0, "max": 0xffffffffffffffff, "step": 1}),
"denoise_steps": ("INT", {"default": 10, "min": 1, "max": 4096, "step": 1}),
"n_repeat": ("INT", {"default": 10, "min": 1, "max": 4096, "step": 1}),
"regularizer_strength": ("FLOAT", {"default": 0.02, "min": 0.001, "max": 4096, "step": 0.001}),
"reduction_method": (
[
'median',
'mean',
], {
"default": 'median'
}),
"max_iter": ("INT", {"default": 5, "min": 1, "max": 4096, "step": 1}),
"tol": ("FLOAT", {"default": 1e-3, "min": 1e-6, "max": 1e-1, "step": 1e-6}),
"invert": ("BOOLEAN", {"default": True}),
"keep_model_loaded": ("BOOLEAN", {"default": True}),
"n_repeat_batch_size": ("INT", {"default": 2, "min": 1, "max": 4096, "step": 1}),
"use_fp16": ("BOOLEAN", {"default": True}),
"scheduler": (
[
'DDIMScheduler',
'DDPMScheduler',
'PNDMScheduler',
'DEISMultistepScheduler',
'LCMScheduler',
], {
"default": 'DDIMScheduler'
}),
"normalize": ("BOOLEAN", {"default": True}),
},
"optional": {
"model": (
[
'Marigold',
'marigold-lcm-v1-0',
], {
"default": 'Marigold'
}),
}
}
RETURN_TYPES = ("IMAGE",)
RETURN_NAMES =("ensembled_image",)
FUNCTION = "process"
CATEGORY = "Marigold"
DESCRIPTION = """
Diffusion-based monocular depth estimation:
https://github.com/prs-eth/Marigold
- denoise_steps: steps per depth map, increase for accuracy in exchange of processing time
- n_repeat: amount of iterations to be ensembled into single depth map
- n_repeat_batch_size: how many of the n_repeats are processed as a batch,
if you have the VRAM this can match the n_repeats for faster processing
- model: Marigold or it's LCM version marigold-lcm-v1-0
For the LCM model use around 4 steps and the LCMScheduler
- scheduler: Different schedulers give bit different results
- invert: marigold by default produces depth map where black is front,
for controlnets etc. we want the opposite.
- regularizer_strength, reduction_method, max_iter, tol (tolerance) are settings
for the ensembling process, generally do not touch.
- use_fp16: if true, use fp16, if false use fp32
fp16 uses much less VRAM, but in some cases can lead to loss of quality.
"""
def process(self, image, seed, denoise_steps, n_repeat, regularizer_strength, reduction_method, max_iter, tol,invert, keep_model_loaded, n_repeat_batch_size, use_fp16, scheduler, normalize, model="Marigold"):
batch_size = image.shape[0]
precision = torch.float16 if use_fp16 else torch.float32
device = model_management.get_torch_device()
torch.manual_seed(seed)
image = image.permute(0, 3, 1, 2).to(device).to(dtype=precision)
if normalize:
image = image * 2.0 - 1.0
diffusers_model_path = os.path.join(folder_paths.models_dir,'diffusers')
#load the diffusers model
if model == "Marigold":
folders_to_check = [
os.path.join(script_directory,"checkpoints","Marigold_v1_merged",),
os.path.join(script_directory,"checkpoints","Marigold",),
os.path.join(diffusers_model_path,"Marigold_v1_merged"),
os.path.join(diffusers_model_path,"Marigold")
]
elif model == "marigold-lcm-v1-0":
folders_to_check = [
os.path.join(diffusers_model_path,"marigold-lcm-v1-0"),
os.path.join(diffusers_model_path,"checkpoints","marigold-lcm-v1-0")
]
self.custom_config = {
"model": model,
"use_fp16": use_fp16,
"scheduler": scheduler,
}
if not hasattr(self, 'marigold_pipeline') or self.marigold_pipeline is None or self.current_config != self.custom_config:
self.current_config = self.custom_config
# Load the model only if it hasn't been loaded before
checkpoint_path = None
for folder in folders_to_check:
if os.path.exists(folder):
checkpoint_path = folder
break
to_ignore = ["*.bin", "*fp16*"]
if checkpoint_path is None:
if model == "Marigold":
try:
from huggingface_hub import snapshot_download
checkpoint_path = os.path.join(diffusers_model_path, "Marigold")
snapshot_download(repo_id="Bingxin/Marigold", ignore_patterns=to_ignore, local_dir=checkpoint_path, local_dir_use_symlinks=False)
except:
raise FileNotFoundError(f"No checkpoint directory found at {checkpoint_path}")
if model == "marigold-lcm-v1-0":
try:
from huggingface_hub import snapshot_download
checkpoint_path = os.path.join(diffusers_model_path, "marigold-lcm-v1-0")
snapshot_download(repo_id="prs-eth/marigold-lcm-v1-0", ignore_patterns=to_ignore, local_dir=checkpoint_path, local_dir_use_symlinks=False)
except:
raise FileNotFoundError(f"No checkpoint directory found at {checkpoint_path}")
self.marigold_pipeline = MarigoldPipeline.from_pretrained(checkpoint_path, enable_xformers=False, empty_text_embed=empty_text_embed, noise_scheduler_type=scheduler)
self.marigold_pipeline = self.marigold_pipeline.to(device).half() if use_fp16 else self.marigold_pipeline.to(device)
self.marigold_pipeline.unet.eval() # Set the model to evaluation mode
pbar = comfy.utils.ProgressBar(batch_size * n_repeat)
out = []
with torch.no_grad():
for i in range(batch_size):
# Duplicate the current image n_repeat times
duplicated_batch = image[i].unsqueeze(0).repeat(n_repeat, 1, 1, 1)
# Process the duplicated batch in sub-batches
depth_maps = []
for j in range(0, n_repeat, n_repeat_batch_size):
# Get the current sub-batch
sub_batch = duplicated_batch[j:j + n_repeat_batch_size]
# Process the sub-batch
depth_maps_sub_batch = self.marigold_pipeline(sub_batch, num_inference_steps=denoise_steps, show_pbar=False)
# Process each depth map in the sub-batch if necessary
for depth_map in depth_maps_sub_batch:
depth_map = torch.clip(depth_map, -1.0, 1.0)
depth_map = (depth_map + 1.0) / 2.0
depth_maps.append(depth_map)
pbar.update(1)
depth_predictions = torch.cat(depth_maps, dim=0).squeeze()
del duplicated_batch, depth_maps_sub_batch
torch.cuda.empty_cache() # clear vram cache for ensembling
# Test-time ensembling
if n_repeat > 1:
depth_map, pred_uncert = ensemble_depths(
depth_predictions,
regularizer_strength=regularizer_strength,
max_iter=max_iter,
tol=tol,
reduction=reduction_method,
max_res=None,
device=device,
)
print(depth_map.shape)
depth_map = depth_map.unsqueeze(2).repeat(1, 1, 3)
print(depth_map.shape)
else:
depth_map = depth_map.permute(1, 2, 0)
depth_map = depth_map.repeat(1, 1, 3)
print(depth_map.shape)
out.append(depth_map)
del depth_map, depth_predictions
if invert:
outstack = 1.0 - torch.stack(out, dim=0).cpu().to(torch.float32)
else:
outstack = torch.stack(out, dim=0).cpu().to(torch.float32)
if not keep_model_loaded:
self.marigold_pipeline = None
model_management.soft_empty_cache()
return (outstack,)