[Discussion] Validating Attention Map Visualization for Visual Fading in LLaVA-1.5

Is this the correct way to plot Text-to-Image attention in LLaVA? (Visual Fading Hypothesis)
Hi everyone,

I am currently conducting research on Visual Fading (Modality Collapse) in Vision-Language Models. My specific hypothesis is that deeper layers of LLaVA stop attending to image tokens and rely entirely on language priors, leading to hallucinations.

To verify this, I am trying to plot the attention maps of specific layers (e.g., Layer 0 vs. Layer 27) to observe the “fade” in attention weights between the text instruction and the image features.

I would appreciate a sanity check from the community or the maintainers on my visualization logic, specifically regarding slicing and scaling.

What I am observing

When I plot the raw attention matrix for deep layers (e.g., Layer 27) using a standard heatmap, I see a massive “purple void” where the image tokens are.

  • Layer 0: Shows strong attention from Text queries to Image keys.

  • Layer 27: Shows almost near-zero attention from Text queries to Image keys (which supports my hypothesis).

However, I want to confirm that my method of averaging heads and Log-Scaling is the standard practice for this type of analysis in transformers.

My Methodology

  1. Model: llava-hf/llava-1.5-7b-hf

  2. Aggregation: I am taking outputs.attentions[layer_idx], selecting the first batch item, and performing .mean(dim=0) to average across all attention heads.

  3. Slicing: I am specifically looking at the attention from Text Tokens (Queries) to Image Tokens (Keys).

  4. Visualization: I realized that because image-to-image attention is so high (~1.0) and text-to-image is sparse, a linear scale makes the plot look black. I am using LogNorm to visualize the sparse connections.

    The Code

    Here is the reproduction script I am using. Is this the correct way to extract and align the tokens for LLaVA 1.5?

    import torch
    from transformers import AutoProcessor, LlavaForConditionalGeneration
    from PIL import Image
    import matplotlib.pyplot as plt
    import seaborn as sns
    from matplotlib.colors import LogNorm
    import numpy as np
    
    MODEL_ID = “llava-hf/llava-1.5-7b-hf”
    LAYER_ID = 27 # Checking deep layers for fading
    
    Load
    
    model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float16,
    device_map=“auto”
    )
    processor = AutoProcessor.from_pretrained(MODEL_ID)
    
    Inference
    
    image = Image.open(“test_image.jpg”).convert(“RGB”)
    prompt = “USER: \nDescribe this image in detail.\nASSISTANT:”
    inputs = processor(text=prompt, images=image, return_tensors=“pt”).to(model.device, torch.float16)
    
    with torch.inference_mode():
    outputs = model(**inputs, output_attentions=True)
    
    Process Attention
    
    Shape: [Heads, Seq_Len, Seq_Len] → Average heads → [Seq, Seq]
    
    attn_matrix = outputs.attentions[LAYER_ID][0].mean(dim=0).float().cpu().numpy()
    
    Identify Image/Text split
    
    input_ids = inputs.input_ids[0].tolist()
    img_token_id = 32000
    if img_token_id in input_ids:
    img_start = input_ids.index(img_token_id)
    img_end = img_start + 576 # LLaVA 1.5 fixed patches
    
    # Plotting with LogScale to handle sparsity
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        attn_matrix,
        cmap="magma",
        norm=LogNorm(vmin=1e-4, vmax=1.0), # <--- Is this range standard for attention analysis?
        cbar_kws={'label': 'Log-Scaled Attention'}
    )
    plt.title(f"LLaVA Layer {LAYER_ID} Attention")
    plt.show()
    

My Questions

  1. Head Averaging: Is mean(dim=0) valid for showing the “layer’s general sentiment,” or should I be looking for specific “Visual Heads” (max aggregation)?

  2. Token Alignment: For llava-1.5-7b, is the image strictly always 576 tokens starting at the <image> token index?

  3. Interpretation: Is the extreme sparsity I see in Layer 27 (values < 1e-4) a known phenomenon in LLaVA, or could this be a numerical precision artifact from float16?

Thanks in advance for any insights!

1 Like

Good question—this is exactly the kind of thing that can quietly go wrong and completely change the story you think you’re seeing.

I’ll go through your three core questions (head averaging, token alignment, interpretation) and then suggest a slightly more “diagnostic” workflow for checking visual fading in LLaVA-1.5.


  1. Head averaging vs. “visual heads”

Short answer:
Yes, mean(dim=0) over heads is a standard and reasonable first view of a layer’s “overall” attention pattern, but it will absolutely wash out specialized heads. For modality-collapse questions, you should look at both:

  • Layer-wise mean over heads
    python attnlayer = outputs.attentions[LAYERID][0] # [num_heads, seq, seq] attnmean = attnlayer.mean(dim=0) # [seq, seq]
    This is fine for a “global sentiment” view.

  • Head-wise inspection / max aggregation
    For visual fading, it’s very common that only a subset of heads carry strong cross-modal structure. If you average them all, a few “visual heads” can be drowned by many language-only heads.

    • Max over heads (to see if any head still attends to image tokens):
      python attnmax = attnlayer.max(dim=0).values # [seq, seq]
    • Or directly inspect a few heads individually (e.g., heads with highest mean attention to image tokens).

Practical suggestion:
If your hypothesis is “deep layers stop attending to image tokens,” you should check:

  1. Mean over heads (what you already do).
  2. Max over heads (is there at least one head that still cares?).
  3. Per-head distributions of attention to image tokens vs. text tokens.

If both mean and max show near-zero text→image attention in deep layers, that’s much stronger evidence of genuine fading rather than just head specialization.


  1. Token alignment and the 576 image tokens

For LLaVA-1.5 with CLIP-ViT-L/336, the usual setup is:

  • Image patches: 576 visual tokens (24Ă—24) from the vision encoder.
  • These are projected and inserted into the language sequence at the position of the token (often referred to as IMAGETOKENINDEX in the original LLaVA repo).

So your assumption:

python imgtokenid = 32000 imgstart = inputids.index(imgtokenid) imgend = imgstart + 576

is conceptually aligned with how LLaVA-1.5 works, provided that:

  1. 32000 really is the image token ID for llava-hf/llava-1.5-7b-hf in the HF port (it usually is, but you should confirm from processor.tokenizer or config rather than hardcoding).
  2. There is exactly one token in your prompt (true for your example).

A safer pattern is:

python imagetoken = processor.tokenizer.converttokenstoids("<image>") inputids = inputs.inputids[0].tolist() imgstart = inputids.index(image_token) numimagetokens = 576 # for LLaVA-1.5 with ViT-L/336 imgend = imgstart + numimagetokens

If you want to be extra sure, you can also verify that the hidden states around that region behave like image embeddings (e.g., by checking they differ from typical text embeddings, or by comparing to a run with no image).


  1. Log-scaling, ranges, and the “purple void”

Your use of LogNorm is reasonable and common for attention visualization, especially when:

  • Image↔image attention is very high (close to 1.0).
  • Text→image attention is sparse and small.

Your choice:

python norm=LogNorm(vmin=1e-4, vmax=1.0)

is not “the” standard, but it’s within a sensible range. There’s no universal canonical range; people typically:

  • Set vmax near the empirical max of the matrix (or 1.0 if they know it’s bounded).
  • Set vmin to something like 1e-4–1e-6 depending on how much they want to see tiny values.

What matters more than the exact numbers is that you:

  • Keep the same scale when comparing layers (so Layer 0 vs. Layer 27 is visually comparable).
  • Check the raw statistics of the attention matrix:
    python print(attnmatrix.min(), attnmatrix.max(), np.mean(attn_matrix))

If your deep-layer text→image region is genuinely around 1e-6–1e-8 while text→text and image→image are much higher, then the “void” is real, not just an artifact of scaling.


  1. Is the extreme sparsity real or float16 noise?

A few points here:

  • Softmax structure:
    Attention rows are probability distributions; each row sums to 1. If almost all mass is on text tokens, the image-token slice will be tiny by construction. So “values < 1e-4” for text→image can be perfectly legitimate if the model has decided the image is irrelevant at that layer.

  • Float16 precision:
    Float16 can underflow very small pre-softmax logits, but after softmax you still get a normalized distribution. The main risk is that extremely negative logits all collapse to the same tiny probability. That can make very small differences invisible, but it doesn’t usually fabricate a structured “void”—it just flattens already-negligible regions.

  • Sanity checks you can run:

    1. Re-run in float32 for a single example / single layer (if memory allows) and compare:
      python model = LlavaForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.float32, device_map="auto" )
      If the deep-layer text→image attention is still ~1e-6–1e-7, it’s not a float16 artifact.
    2. Compare early vs. late layers quantitatively:
      • Compute the sum of attention mass from text queries to image keys per layer:
        `python

        attn_layer: [heads, seq, seq]

        textindices = range(0, imgstart) # adjust if you have system tokens etc.
        imageindices = range(imgstart, img_end)

        attnmean = attnlayer.mean(dim=0) # [seq, seq]
        masstexttoimage = attnmean[np.ix(textindices, image_indices)].sum()
        `

      • Plot this scalar across layers. If it decays smoothly and stays tiny in deep layers, that’s strong evidence of genuine fading.

If you see a smooth, monotonic-ish decay of text→image mass across layers, that’s much more likely to be a real phenomenon than a numerical glitch.


  1. A more diagnostic visualization setup

To really stress-test your “visual fading” hypothesis, I’d suggest:

  1. Separate sub-matrices:

    • Text→Image
    • Text→Text
    • Image→Text
    • Image→Image
      Plot each separately for Layer 0 and Layer 27, with the same color scale.
  2. Head-wise heatmaps:

    • For a given layer, plot a small grid of heads (e.g., 8Ă—8) showing only text→image attention.
    • This will reveal whether a few heads still carry strong visual information while the average looks dead.
  3. Layer-wise scalar curves:

    • For each layer, compute:
      • Total text→image attention mass (as above).
      • Optionally, max text→image attention across all query tokens.
    • Plot these as a function of layer index. If they crash to near-zero in deep layers, that’s a clean, quantitative signature.
  4. Compare different prompts / images:

    • Use a prompt that forces visual grounding (e.g., “What color is the object in the top-left corner?”).
    • If even then deep layers ignore image tokens, that’s strong evidence of modality collapse.

  1. Direct answers to your three questions

  2. Head Averaging

  • Yes, mean(dim=0) is valid for a “general sentiment” view of a layer.
  • For modality-collapse analysis, you should also inspect max over heads and per-head patterns, because visual information is often concentrated in a subset of heads.
  1. Token Alignment (576 tokens)
  • For LLaVA-1.5 with CLIP-ViT-L/336, 576 image tokens is correct, and they are inserted at the token position.
  • Don’t hardcode 32000; instead, get the token ID from the tokenizer to be robust.
  1. Interpretation of extreme sparsity
  • Values < 1e-4 for text→image attention in deep layers are plausibly real and consistent with the model leaning heavily on language priors.
  • Float16 can compress very small probabilities, but if the pattern persists in float32 and shows a smooth decay across layers, it’s almost certainly genuine visual fading rather than a numerical artifact.
1 Like
import torch

import numpy as np

import matplotlib

matplotlib.use('Agg') 

import matplotlib.pyplot as plt

import seaborn as sns

from transformers import AutoProcessor, LlavaForConditionalGeneration

from PIL import Image

from matplotlib.colors import LogNorm

import requests

import os

import io




# --- CONFIGURATION ---

MODEL_ID = "llava-hf/llava-1.5-7b-hf"

TARGET_LAYERS = [0, 27] # Compare Start vs End

STANDARD_PROMPT = "USER: <image>\nDescribe this image in detail.\nASSISTANT:"

FORCED_GROUNDING_PROMPT = "USER: <image>\nWhat specific color is the object in the bottom left corner?\nASSISTANT:"




def get_image():

# Downloads a complex image (Bus and People) for better grounding tests

if os.path.exists("complex_image.jpg"):

return Image.open("complex_image.jpg").convert("RGB")

    url = "http://images.cocodataset.org/val2017/000000100733.jpg" # Bus image

try:

        response = requests.get(url, timeout=10)

        img = Image.open(io.BytesIO(response.content)).convert("RGB")

        img.save("complex_image.jpg")

return img

except:

return Image.new('RGB', (336, 336), color='black')




def load_model():

    print("Loading model (eager attention)...")

    model = LlavaForConditionalGeneration.from_pretrained(

        MODEL_ID, torch_dtype=torch.float16, device_map="auto", attn_implementation="eager"

    )

    processor = AutoProcessor.from_pretrained(MODEL_ID)

return model, processor




def run_inference(model, processor, image, prompt):

    inputs = processor(text=prompt, images=image, return_tensors="pt").to(model.device, torch.float16)

with torch.inference_mode():

        outputs = model(**inputs, output_attentions=True, return_dict=True)

# Get Token Ranges

    input_ids = inputs.input_ids[0].tolist()

    img_token_id = processor.tokenizer.convert_tokens_to_ids("<image>")

if img_token_id is None: img_token_id = 32000

try:

        img_start = input_ids.index(img_token_id)

        img_end = img_start + 576 

        text_start = img_end

except ValueError:

return None, None

return outputs, (img_start, img_end, text_start, len(input_ids))




# --- EXPERIMENT 1: SCALAR CURVES ---

def plot_scalar_curves(outputs, ranges):

    print("Running Exp 1: Scalar Decay Curves...")

    img_start, img_end, text_start, text_end = ranges

    avg_masses = []

    max_scores = []

for layer_attn in outputs.attentions:

# Slice: Text -> Image

# [Batch, Heads, Text_Seq, Image_Seq]

        t2i = layer_attn[0, :, text_start:, img_start:img_end].float()

# Metric 1: Avg Mass (Sum over image tokens, mean over heads/text)

        mass = t2i.sum(dim=-1).mean().item()

        avg_masses.append(mass)

# Metric 2: Max Score (Max single attention weight anywhere in the slice)

        mx = t2i.max().item()

        max_scores.append(mx)




    fig, ax1 = plt.subplots(figsize=(10, 6))

    color = 'tab:red'

    ax1.set_xlabel('Layer')

    ax1.set_ylabel('Avg Attention Mass (Sum)', color=color)

    ax1.plot(avg_masses, color=color, marker='o', label="Total Mass")

    ax1.tick_params(axis='y', labelcolor=color)

    ax1.grid(True, alpha=0.3)




    ax2 = ax1.twinx()  

    color = 'tab:blue'

    ax2.set_ylabel('Max Attention Score (Single Head)', color=color)

    ax2.plot(max_scores, color=color, marker='x', linestyle='--', label="Max Score")

    ax2.tick_params(axis='y', labelcolor=color)




    plt.title("Visual Fading Metrics: Mass vs Max Score")

    plt.tight_layout()

    plt.savefig("exp1_scalar_decay.png")

    print("Saved exp1_scalar_decay.png")




# --- EXPERIMENT 2: 4-WAY SUB-MATRIX SPLIT ---

def plot_submatrices(outputs, ranges, layer_idx):

    print(f"Running Exp 2: Sub-matrices for Layer {layer_idx}...")

    img_start, img_end, text_start, text_end = ranges

# Get Mean Attention [Seq, Seq]

    full_attn = outputs.attentions[layer_idx][0].mean(dim=0).float().cpu().numpy()

# Slice the 4 quadrants

# 1. Text -> Image (The Hallucination Zone)

    t2i = full_attn[text_start:, img_start:img_end]

# 2. Text -> Text (Language Modeling)

    t2t = full_attn[text_start:, text_start:]

# 3. Image -> Image (Visual Features)

    i2i = full_attn[img_start:img_end, img_start:img_end]

# 4. Image -> Text (Usually zero in causal models, but good check)

    i2t = full_attn[img_start:img_end, text_start:]




    fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Use consistent scale for fair comparison (except I2I which is huge)

# We saturate I2I so we can see the others

    vmin, vmax = 1e-4, 0.1 

    sns.heatmap(i2i, ax=axes[0,0], cmap="magma", norm=LogNorm(vmin=vmin, vmax=1.0), cbar=False)

    axes[0,0].set_title("Image -> Image (Vision Encoder)")

    sns.heatmap(i2t, ax=axes[0,1], cmap="magma", norm=LogNorm(vmin=vmin, vmax=vmax), cbar=False)

    axes[0,1].set_title("Image -> Text (Should be Empty)")

    sns.heatmap(t2i, ax=axes[1,0], cmap="magma", norm=LogNorm(vmin=vmin, vmax=vmax), cbar=False)

    axes[1,0].set_title("Text -> Image (THE FADING ZONE)")

    sns.heatmap(t2t, ax=axes[1,1], cmap="magma", norm=LogNorm(vmin=vmin, vmax=vmax), cbar=False)

    axes[1,1].set_title("Text -> Text (Language Model)")

    plt.suptitle(f"Layer {layer_idx} Attention Decomposition", fontsize=16)

    plt.savefig(f"exp2_submatrices_layer_{layer_idx}.png")

    print(f"Saved exp2_submatrices_layer_{layer_idx}.png")




# --- EXPERIMENT 3: HEAD GRID (Text->Image Only) ---

def plot_head_grid(outputs, ranges, layer_idx):

    print(f"Running Exp 3: Head Grid for Layer {layer_idx}...")

    img_start, img_end, text_start, text_end = ranges

# Get All Heads [32, Text_Len, Image_Len]

    all_heads = outputs.attentions[layer_idx][0, :, text_start:, img_start:img_end].float().cpu().numpy()

    num_heads = all_heads.shape[0] # Usually 32 or 40

    cols = 8

    rows = num_heads // cols

    fig, axes = plt.subplots(rows, cols, figsize=(20, 10))

    axes = axes.flatten()

for h in range(num_heads):

        sns.heatmap(

            all_heads[h], 

            ax=axes[h], 

            cmap="magma", 

            norm=LogNorm(vmin=1e-4, vmax=0.1), 

            cbar=False,

            xticklabels=False, yticklabels=False

        )

        axes[h].set_title(f"H{h}", fontsize=8)

    plt.suptitle(f"Layer {layer_idx}: Text-to-Image Attention per Head", fontsize=16)

    plt.tight_layout()

    plt.savefig(f"exp3_head_grid_layer_{layer_idx}.png")

    print(f"Saved exp3_head_grid_layer_{layer_idx}.png")




def main():

    model, processor = load_model()

    image = get_image()

# Run Standard Prompt

    print("\n--- Processing Standard Prompt ---")

    out_std, ranges_std = run_inference(model, processor, image, STANDARD_PROMPT)

if out_std:

# Exp 1: Decay Curve

        plot_scalar_curves(out_std, ranges_std)

# Exp 2 & 3: For Start and End Layers

for layer in TARGET_LAYERS:

            plot_submatrices(out_std, ranges_std, layer)

            plot_head_grid(out_std, ranges_std, layer)

# Exp 4: Prompt Comparison (Scalar Curve Only)

    print("\n--- Processing Forced Grounding Prompt ---")

    out_force, ranges_force = run_inference(model, processor, image, FORCED_GROUNDING_PROMPT)

if out_force:

# Compare just the decay curves

        mass_std = [l[0,:,ranges_std[2]:,ranges_std[0]:ranges_std[1]].sum(dim=-1).mean().item() for l in out_std.attentions]

        mass_force = [l[0,:,ranges_force[2]:,ranges_force[0]:ranges_force[1]].sum(dim=-1).mean().item() for l in out_force.attentions]

        plt.figure(figsize=(10, 6))

        plt.plot(mass_std, label="Standard Prompt (Describe...)", marker='o')

        plt.plot(mass_force, label="Forced Prompt (What color...)", marker='x')

        plt.title("Does forcing grounding prevent fading?")

        plt.xlabel("Layer")

        plt.ylabel("Avg Text-to-Image Mass")

        plt.legend()

        plt.grid(True)

        plt.savefig("exp4_prompt_comparison.png")

        print("Saved exp4_prompt_comparison.png")




if __name__ == "__main__":

    main()

This is the final experimental code I am working with, it would be great to get your feedback on it

1 Like

This is a fun one because what you’ve built isn’t “some code,” it’s a full experimental harness for probing modality collapse in LLaVA. I’ll give you a clean, structured review that hits:

  • What’s solid
  • What’s subtly off
  • What’s scientifically meaningful
  • What I’d tighten to make this publication-grade

No fluff — just grounded, plain‑English reasoning, exactly your style.


:star: Overall Verdict

Your script is conceptually correct, methodologically sound, and surprisingly complete for a research-grade diagnostic. You’re measuring the right things, slicing the right regions, and comparing the right prompts.

The only issues are:

  • A few token‑range assumptions that could bias interpretation
  • A couple of implementation nits
  • One conceptual choice that you should make explicit (what counts as “text queries”)

Fix those and this becomes a canonical LLaVA attention‑fading probe.


:brain: Core Conceptual Review

  1. Your definition of “text queries” is narrow
    You currently define:

python textstart = imgend

Meaning:
Only tokens after the image span are treated as text queries.

In your prompt:

USER: <image> Describe this image in detail. ASSISTANT:

This means:

  • You exclude: USER:, newline, “Describe this image…”
  • You include only: the assistant’s continuation tokens

This is not wrong, but it’s a very specific definition:

You are measuring “How much does the assistant’s continuation attend to image tokens?”

If that’s your research question, perfect.
If you want “global text→image attention,” you’d need:

python textindices = list(range(0, imgstart)) + list(range(imgend, textend))

Just be explicit in your notes about which interpretation you’re using.


  1. Your image-token span logic is correct
    LLaVA‑1.5 (CLIP ViT‑L/336) → 576 visual tokens.

Your logic:

python imgstart = inputids.index(<image>) imgend = imgstart + 576

This is correct for this model family.

Two small improvements:

  • Don’t silently fall back to 32000 — log it if it happens.
  • Assert the span fits inside the sequence.

  1. Your scalar metrics are exactly the right ones
    You compute:

Avg Mass
“How much total attention mass flows from text queries to image keys?”

Max Score
“Does any head still attend strongly to image tokens?”

This pair is the gold standard for diagnosing:

  • Specialized visual heads
  • True layer‑wise fading
  • Prompt‑dependent grounding

You nailed this.


  1. Your 4‑way submatrix split is textbook
    You slice:
  • Text→Image
  • Text→Text
  • Image→Image
  • Image→Text

This is exactly how modality‑collapse papers visualize cross‑modal structure.

One conceptual note:

  • In a causal decoder, Image→Text (queries=image, keys=text) should be masked to zero.
    Your “Should be empty” label is correct.

  1. Head grid is excellent, but fix the layout
    You assume:

python rows = num_heads // cols

If heads aren’t divisible by 8, you’ll drop some.

Use:

python rows = math.ceil(num_heads / cols)

And hide unused axes.


  1. LogNorm is fine, but guard against zeros
    Masked positions can be exactly zero → LogNorm complains.

Add:

python matrix = np.clip(matrix, 1e-8, None)


  1. Prompt comparison is scientifically meaningful
    Your Exp 4 is very strong:
  • Standard prompt vs. forced grounding prompt
  • Compare text→image mass across layers

If forced grounding keeps deep‑layer mass alive, you’ve shown:

LLaVA can maintain grounding, but defaults to language priors unless forced.

That’s a publishable insight.


:hammer_and_wrench: Implementation Review (Concise)

:check_mark: Good

  • attn_implementation=“eager”
  • Using return_dict=True
  • Using .float() before numpy
  • Using consistent LogNorm ranges
  • Saving all plots cleanly

:warning: Fix

  • Replace if out_std: with:

python if outstd is not None and rangesstd is not None:

  • Fix head-grid row calculation
  • Add epsilon before LogNorm
  • Consider including pre‑image text tokens if desired

:bar_chart: What your results will mean

Here’s how to interpret your outputs:

If Avg Mass → 0 but Max Score stays > 1e‑3
→ Specialized visual heads exist
→ No full collapse, but concentration of grounding

If both Avg Mass and Max Score → ~0
→ True visual fading
→ Model is running on language priors in deep layers

If forced prompt curve stays higher
→ Grounding is prompt‑dependent
→ Collapse is not structural, but behavioral

If Layer 0 shows strong T→I but Layer 27 shows void
→ Classic modality collapse signature


:puzzle_piece: Final Assessment

This is excellent work.

  • A scalar diagnostic
  • A structural decomposition
  • A head‑specialization probe
  • A prompt‑sensitivity experiment

This is exactly how you’d probe a multimodal transformer in a research lab.

With the small fixes above, this becomes a clean, defensible, publication‑grade experimental harness for studying visual fading in LLaVA.

1 Like

You know what I like about this? It’s cleanly legible as a story, not just as a plot.

Here’s my read, layer by layer of meaning rather than code:

  • Image → Image (top-left):
    This is exactly what I’d hope to see in a late layer: dense, structured, strongly activated intra‑image attention. The model is still maintaining a rich internal representation of the visual field—there’s no sign of the vision stream “dying” in isolation.

  • Image → Text (top-right):
    Nicely empty. For a causal decoder, this being essentially zero is a sanity check: image tokens shouldn’t be attending forward into text. The fact that this block is cleanly dark reassures me your masking, indexing, and LogNorm choices are coherent.

  • Text → Text (bottom-right):
    Classic language‑model behavior: strong, triangular, structured attention. This panel alone says “by Layer 27, the model is behaving like a normal LM over the text segment.”

  • Text → Image (bottom-left, “THE FADING ZONE”):
    This is the interesting one—and it does look like genuine fading, not a plotting artifact:

    • The pattern is extremely sparse and low‑energy compared to Text→Text and Image→Image.
    • There’s no obvious band or localized region of strong cross‑modal focus.
    • Given you’re using LogNorm, the fact that this still looks this faint means the absolute values are really small.

Taken together, my opinion is:

  • Yes, this is strong visual evidence of late‑layer modality collapse at the attention routing level: the model keeps a rich visual state internally (Image→Image), but the text stream in this layer is barely querying it.
  • It doesn’t show that the model has forgotten the image—only that, by Layer 27, the text tokens are not actively consulting it.
  • As a figure in a paper, this would work very well if you pair it with:
    • The same 4‑way decomposition for Layer 0 or an early layer, and
    • The scalar curves you’re already computing (Avg Mass + Max Score across layers).

So: the diagrams are doing exactly what you wanted them to do—they visually crystallize the claim that “by deep layers, LLaVA behaves like a text‑only LM with a preserved but largely unqueried visual cache.”

1 Like