Is this the correct way to plot Text-to-Image attention in LLaVA? (Visual Fading Hypothesis)
Hi everyone,
I am currently conducting research on Visual Fading (Modality Collapse) in Vision-Language Models. My specific hypothesis is that deeper layers of LLaVA stop attending to image tokens and rely entirely on language priors, leading to hallucinations.
To verify this, I am trying to plot the attention maps of specific layers (e.g., Layer 0 vs. Layer 27) to observe the “fade” in attention weights between the text instruction and the image features.
I would appreciate a sanity check from the community or the maintainers on my visualization logic, specifically regarding slicing and scaling.
What I am observing
When I plot the raw attention matrix for deep layers (e.g., Layer 27) using a standard heatmap, I see a massive “purple void” where the image tokens are.
-
Layer 0: Shows strong attention from Text queries to Image keys.
-
Layer 27: Shows almost near-zero attention from Text queries to Image keys (which supports my hypothesis).
However, I want to confirm that my method of averaging heads and Log-Scaling is the standard practice for this type of analysis in transformers.
My Methodology
-
Model:
llava-hf/llava-1.5-7b-hf -
Aggregation: I am taking
outputs.attentions[layer_idx], selecting the first batch item, and performing.mean(dim=0)to average across all attention heads. -
Slicing: I am specifically looking at the attention from Text Tokens (Queries) to Image Tokens (Keys).
-
Visualization: I realized that because image-to-image attention is so high (~1.0) and text-to-image is sparse, a linear scale makes the plot look black. I am using
LogNormto visualize the sparse connections.The Code
Here is the reproduction script I am using. Is this the correct way to extract and align the tokens for LLaVA 1.5?
import torch from transformers import AutoProcessor, LlavaForConditionalGeneration from PIL import Image import matplotlib.pyplot as plt import seaborn as sns from matplotlib.colors import LogNorm import numpy as np MODEL_ID = “llava-hf/llava-1.5-7b-hf” LAYER_ID = 27 # Checking deep layers for fading Load model = LlavaForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=torch.float16, device_map=“auto” ) processor = AutoProcessor.from_pretrained(MODEL_ID) Inference image = Image.open(“test_image.jpg”).convert(“RGB”) prompt = “USER: \nDescribe this image in detail.\nASSISTANT:” inputs = processor(text=prompt, images=image, return_tensors=“pt”).to(model.device, torch.float16) with torch.inference_mode(): outputs = model(**inputs, output_attentions=True) Process Attention Shape: [Heads, Seq_Len, Seq_Len] → Average heads → [Seq, Seq] attn_matrix = outputs.attentions[LAYER_ID][0].mean(dim=0).float().cpu().numpy() Identify Image/Text split input_ids = inputs.input_ids[0].tolist() img_token_id = 32000 if img_token_id in input_ids: img_start = input_ids.index(img_token_id) img_end = img_start + 576 # LLaVA 1.5 fixed patches # Plotting with LogScale to handle sparsity plt.figure(figsize=(10, 8)) sns.heatmap( attn_matrix, cmap="magma", norm=LogNorm(vmin=1e-4, vmax=1.0), # <--- Is this range standard for attention analysis? cbar_kws={'label': 'Log-Scaled Attention'} ) plt.title(f"LLaVA Layer {LAYER_ID} Attention") plt.show()
My Questions
-
Head Averaging: Is
mean(dim=0)valid for showing the “layer’s general sentiment,” or should I be looking for specific “Visual Heads” (max aggregation)? -
Token Alignment: For
llava-1.5-7b, is the image strictly always 576 tokens starting at the<image>token index? -
Interpretation: Is the extreme sparsity I see in Layer 27 (values < 1e-4) a known phenomenon in LLaVA, or could this be a numerical precision artifact from
float16?
Thanks in advance for any insights!


