AMT VFI¶
Documentation¶
- Class name:
AMT VFI
- Category:
ComfyUI-Frame-Interpolation/VFI
- Output node:
False
The AMT_VFI node specializes in video frame interpolation, utilizing deep learning techniques to generate intermediate frames that enhance the smoothness and frame rate of video sequences. It employs advanced models to predict and insert frames between existing ones, improving video playback quality.
Input types¶
Required¶
ckpt_name
- Specifies the checkpoint name for the model to be used in the interpolation process, determining the specific pre-trained model configuration.
- Comfy dtype:
COMBO[STRING]
- Python dtype:
str
frames
- The input video frames to be interpolated, provided as a tensor. This parameter is crucial for defining the sequence of frames the model will process.
- Comfy dtype:
IMAGE
- Python dtype:
torch.Tensor
clear_cache_after_n_frames
- Controls the frequency of cache clearing to manage memory usage during the interpolation process, affecting performance and resource utilization.
- Comfy dtype:
INT
- Python dtype:
int
multiplier
- Determines the number of intermediate frames to be generated between each pair of original frames, directly influencing the output video's frame rate.
- Comfy dtype:
INT
- Python dtype:
int
Optional¶
optional_interpolation_states
- An optional parameter that allows for the customization of interpolation states, offering flexibility in handling specific frames or conditions.
- Comfy dtype:
INTERPOLATION_STATES
- Python dtype:
InterpolationStateList
Output types¶
image
- Comfy dtype:
IMAGE
- The output video frames after interpolation, showcasing the enhanced fluidity and increased frame rate achieved through the process.
- Python dtype:
torch.Tensor
- Comfy dtype:
Usage tips¶
- Infra type:
GPU
- Common nodes: unknown
Source code¶
class AMT_VFI:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"ckpt_name": (list(CKPT_CONFIGS.keys()), ),
"frames": ("IMAGE", ),
"clear_cache_after_n_frames": ("INT", {"default": 1, "min": 1, "max": 100}),
"multiplier": ("INT", {"default": 2, "min": 2, "max": 1000})
},
"optional": {
"optional_interpolation_states": ("INTERPOLATION_STATES", )
}
}
RETURN_TYPES = ("IMAGE", )
FUNCTION = "vfi"
CATEGORY = "ComfyUI-Frame-Interpolation/VFI"
def vfi(
self,
ckpt_name: typing.AnyStr,
frames: torch.Tensor,
clear_cache_after_n_frames: typing.SupportsInt = 1,
multiplier: typing.SupportsInt = 2,
optional_interpolation_states: InterpolationStateList = None,
**kwargs
):
model_path = load_file_from_direct_url(MODEL_TYPE, f"https://huggingface.co/lalala125/AMT/resolve/main/{ckpt_name}")
ckpt_config = CKPT_CONFIGS[ckpt_name]
interpolation_model = ckpt_config["network"](**ckpt_config["params"])
interpolation_model.load_state_dict(torch.load(model_path)["state_dict"])
interpolation_model.eval().to(get_torch_device())
frames = preprocess_frames(frames)
padder = InputPadder(frames.shape, 16)
frames = padder.pad(frames)
def return_middle_frame(frame_0, frame_1, timestep, model):
return model(
frame_0,
frame_1,
embt=torch.FloatTensor([timestep] * frame_0.shape[0]).view(frame_0.shape[0], 1, 1, 1).to(get_torch_device()),
scale_factor=1.0,
eval=True
)["imgt_pred"]
args = [interpolation_model]
out = generic_frame_loop(frames, clear_cache_after_n_frames, multiplier, return_middle_frame, *args,
interpolation_states=optional_interpolation_states, dtype=torch.float32)
out = padder.unpad(out)
out = postprocess_frames(out)
return (out,)