Skip to content

UNetCrossAttentionMultiply

Documentation

  • Class name: UNetCrossAttentionMultiply
  • Category: _for_testing/attention_experiments
  • Output node: False

The UNetCrossAttentionMultiply node is designed to modify the attention mechanism within a given model by applying specific multipliers to the query (q), key (k), value (v), and output (out) components of the attention mechanism. This node aims to experiment with and potentially enhance the model's attention-driven interactions by fine-tuning these fundamental elements.

Input types

Required

  • model
    • The model parameter represents the neural network model to which the attention modification will be applied. It is crucial for defining the base architecture that will undergo attention adjustments.
    • Comfy dtype: MODEL
    • Python dtype: torch.nn.Module
  • q
    • The q parameter serves as a multiplier for the query component of the attention mechanism, influencing how the model attends to different parts of the input.
    • Comfy dtype: FLOAT
    • Python dtype: float
  • k
    • The k parameter acts as a multiplier for the key component, affecting the model's ability to match queries with relevant information.
    • Comfy dtype: FLOAT
    • Python dtype: float
  • v
    • The v parameter modifies the value component, impacting the information that is ultimately retrieved by the attention mechanism.
    • Comfy dtype: FLOAT
    • Python dtype: float
  • out
    • The out parameter adjusts the output of the attention mechanism, potentially altering the model's focus and the significance of different inputs.
    • Comfy dtype: FLOAT
    • Python dtype: float

Output types

  • model
    • Comfy dtype: MODEL
    • Returns the modified model with adjusted attention mechanism parameters, reflecting the applied multipliers to the query, key, value, and output components.
    • Python dtype: torch.nn.Module

Usage tips

  • Infra type: GPU
  • Common nodes: unknown

Source code

class UNetCrossAttentionMultiply:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { "model": ("MODEL",),
                              "q": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                              "k": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                              "v": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                              "out": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}),
                              }}
    RETURN_TYPES = ("MODEL",)
    FUNCTION = "patch"

    CATEGORY = "_for_testing/attention_experiments"

    def patch(self, model, q, k, v, out):
        m = attention_multiply("attn2", model, q, k, v, out)
        return (m, )