1. Context & Goal
I am implementing âVisual KV Cache Steeringâ for LLaVA (based on llava-hf/llava-1.5-7b-hf). The goal is to:
-
Run a prefill step to populate the KV cache for the prompt (
USER: <image>\n{text}...). -
Intervene in the cache by adding steering vectors specifically to the visual token positions (the 576 tokens corresponding to the image).
-
Generate the rest of the response using this modified cache.
2. Implementation Strategy
My current logic follows this pattern:
-
Prefill: Call
model(**inputs, use_cache=True)to getpast_key_values. -
Modify: Convert
past_key_values(which is aDynamicCachein newertransformers) to a legacy list/tuple format (or modify it in place) to inject vectors into the visual token indices. -
Generate: Call
model.generate()passing the modifiedpast_key_valuesand the last token asinput_ids.
3. The Issue
When passing the modified cache back to model.generate, I encounter an IndexError. It appears that prepare_inputs_for_generation inside LlavaForConditionalGeneration is misinterpreting the cache length or structure, likely due to conflicts between the new DynamicCache format and the legacy tuple format I am trying to use.
Traceback (most recent call last):
âŚ
File âtransformers/generation/utils.pyâ, line 2781, in _sample
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
File âtransformers/models/llava/modeling_llava.pyâ, line 466, in prepare_inputs_for_generation
model_inputs = super().prepare_inputs_for_generation(
File âtransformers/generation/utils.pyâ, line 574, in prepare_inputs_for_generation
inputs_embeds, input_ids = self._cache_dependant_input_preparation(
File âtransformers/generation/utils.pyâ, line 476, in _cache_dependant_input_preparation
or (cache_position[-1] >= input_ids.shape[1]) # Exception 3
IndexError: index -1 is out of bounds for dimension 0 with size 0
4. Reproduction Code
Here is the simplified reproduction script. The failure occurs at the final model.generate call.
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration, DynamicCache
from PIL import Image
Setup
model_id = âllava-hf/llava-1.5-7b-hfâ
model = LlavaForConditionalGeneration.from_pretrained(
model_id, torch_dtype=torch.float16, device_map=âcudaâ
)
processor = AutoProcessor.from_pretrained(model_id)
Dummy Inputs
image = Image.new(âRGBâ, (336, 336), color=âredâ)
prompt_text = âDescribe this image.â
prompt = f"USER: \n{prompt_text}\nASSISTANT:"
inputs = processor(text=prompt, images=image, return_tensors=âptâ).to(âcudaâ, torch.float16)
1. Prefill
with torch.inference_mode():
out = model(**inputs, use_cache=True, return_dict=True)
cache = DynamicCache.from_legacy_cache(out.past_key_values)
# 2. Modify Cache (Simulated Steering)
# Converting to legacy format to iterate and modify
legacy = list(cache.to_legacy_cache())
for i, (k, v) in enumerate(legacy):
# Example modification: simply cloning for reproduction
# In real code, I add vectors to specific indices here
legacy[i] = (k.clone(), v.clone())
# Re-wrap as DynamicCache (Attempt 1) or Tuple (Attempt 2)
# Both approaches lead to issues in LLaVA's generate step
steered_cache = DynamicCache.from_legacy_cache(tuple(legacy))
# 3. Generate
# We provide the last token of the prompt as the seed
seed_ids = inputs["input_ids"][:, -1:]
# Calculate cache position
past_len = steered_cache.get_seq_length()
cache_pos = torch.arange(past_len, past_len + seed_ids.shape[1], device=seed_ids.device)
# Construct Attention Mask
attn_mask = torch.cat(
[inputs["attention_mask"], inputs["attention_mask"].new_ones((1, 1))],
dim=-1
)
# FAILURE HAPPENS HERE
generated_ids = model.generate(
input_ids=seed_ids,
past_key_values=steered_cache, # Passing the modified cache
cache_position=cache_pos,
attention_mask=attn_mask,
max_new_tokens=100
)
5. Questions
-
How should we correctly modify the KV cache in-place between the prefill and generate steps for LLaVA models in the latest Transformers version?
-
Does
LlavaForConditionalGenerationrequirepixel_valuesto be passed togenerateeven if the image tokens are already present in thepast_key_values? -
How do we resolve the
IndexErrorregardingcache_position? It seems the model thinks the cache is empty or misaligned with theinput_ids.