Assistant model is not passed onto the custom_generate method

Hi all. So I’ve been working on an implementation of Speculative Cascades for a class project and for some reason I’ve noticed that if I passed the assistant model to .generate, my custom_generate function won’t receive that assistant model. Would you mind explaining why this is case?

1 Like

seems spec?


This is expected in Transformers. You are passing two generation-control features that live at different layers, and only one of them exposes the assistant model to your code.

The short version you can keep in your head:

  • assistant_model=... is a mode switch for Transformers’ built-in assisted / speculative decoding path. It is handled by generate() itself. It is not a “regular kwarg” that gets forwarded into arbitrary hooks. (Hugging Face)
  • custom_generate=<callable> is a custom decoding-loop hook with a fixed signature. That signature does not include assistant_model, so your function will never “receive” it unless you capture it yourself. (Hugging Face)

Below is the detailed “why,” with background and the practical patterns you want for Speculative Cascades.


Background: what assistant_model actually means in Transformers

Speculative decoding in one paragraph

Speculative (assisted) decoding uses two models:

  1. A small “draft” model proposes a chunk of tokens cheaply.
  2. The big “target” model verifies those proposed tokens in one forward pass, accepting as many as it can.

This can reduce the number of expensive target-model forward passes. Transformers explicitly frames speculative decoding as “adds a second smaller model” and “the main model verifies candidate tokens in a single forward pass.” (Hugging Face)

How Transformers exposes this

Transformers exposes speculative decoding primarily through the assistant_model parameter to generate(). The docs say speculative decoding is enabled with assistant_model. (Hugging Face)

Also, Transformers documents constraints that matter for real projects:

  • Only greedy and multinomial sampling are supported for speculative decoding. (Hugging Face)
  • Batched inputs are not supported (at least in the documented behavior). (Hugging Face)

So assistant_model is not “just an extra argument.” It is “please run a different decoding algorithm.”


Background: what custom_generate means (two different mechanisms)

Transformers uses the same name custom_generate, but there are two distinct ways to use it.

1) custom_generate=<callable> means “reuse prep, replace only the decoding loop”

This is the “custom loop” feature.

Transformers will still do all the normal generate() preparation (batch expansion, attention masks, logits processors, stopping criteria, etc.), then it calls your loop with this signature: (Hugging Face)

def custom_loop(model, input_ids, attention_mask, logits_processor, stopping_criteria, generation_config, **model_kwargs):
    ...

Key point: no assistant_model parameter exists in this callable interface. (Hugging Face)

So if your “custom_generate function” is this callable loop, Transformers has nowhere to pass the assistant model, and it will not show up in **model_kwargs either.

2) custom_generate="repo-or-path" means “override generate() entirely”

This is the “custom generation repo” feature: you ship a custom_generate/generate.py with a generate(model, ...) function.

In that mode, the docs explicitly state:

“All received arguments and model are forwarded to your custom generate method, with the exception of … trust_remote_code and custom_generate.” (Hugging Face)

So this mode can receive assistant_model normally, because it forwards essentially everything you pass to generate().


Why your custom loop does not receive assistant_model

Reason 1: assistant_model is “generate-level control,” not “forward-pass input”

The callable custom loop receives:

  • prepared tensors (input_ids, attention_mask)
  • generation machinery (logits_processor, stopping_criteria, generation_config)
  • **model_kwargs meant for model(..., **model_kwargs) inside your loop (Hugging Face)

But assistant_model is not an input to the main model’s forward pass. It is a second model object used by the generation algorithm. So it is not part of model_kwargs, and the callable interface does not include it. (Hugging Face)

Reason 2: conceptually, assistant_model and a custom loop conflict

When you pass assistant_model, you are telling Transformers: “run the built-in speculative decoding loop.” (Hugging Face)
When you pass a callable custom_generate, you are telling Transformers: “I am providing the decoding loop.”

Those are mutually exclusive at the “who owns the loop” level. The library does not attempt to merge them. It gives you the standard callable-loop contract. That contract has no assistant model. (Hugging Face)


What you should do for Speculative Cascades

Speculative Cascades usually means you want more control than “one assistant model proposes”:

  • multiple assistants (different sizes)
  • stage escalation logic
  • different “draft length” schedules per stage
  • careful cache management and rollback rules

That aligns better with “I own the loop” than “use Transformers’ built-in assisted decoding loop.”

So the practical answer is: do not expect assistant_model to arrive via the callable custom loop. Pass assistants into your loop yourself.

Pattern A: capture the assistant model(s) via closure or functools.partial

This is the cleanest approach with custom_generate=<callable>.

from functools import partial

def cascade_loop(model, input_ids, attention_mask, logits_processor, stopping_criteria, generation_config, *, assistants, **model_kwargs):
    # assistants could be [small_draft, medium_draft, ...]
    # Your cascade logic here: propose with assistants[k], verify with model, accept/reject/escalate.
    ...
    return input_ids

loop = partial(cascade_loop, assistants=[assistant_small, assistant_medium])

out = target_model.generate(
    **inputs,
    custom_generate=loop,
    max_new_tokens=128,
)

You stop using assistant_model= entirely in this run. You treat the assistant models as your state.

Pattern B: use repo-based custom_generate so args are forwarded

If you really want to call generate(..., assistant_model=...) and have it arrive inside your custom generation function, then use the repo/path override.

Transformers documents that all args are forwarded to your custom generate (except the trigger args). (Hugging Face)

Your custom_generate/generate.py can then accept assistant_model (or your own assistant_models) explicitly.


Extra context: why assisted generation often behaves “different” in edge cases

Your observation matches a larger theme: assisted/speculative generation has distinct plumbing, and argument/config propagation can differ.

Two concrete examples from Hugging Face’s own issue tracker:

  • A bug report where assisted decoding did not forward model_kwargs the same way regular generation did. (GitHub)
  • A recent bug report that num_assistant_tokens and related assistant parameters were not properly passed down to the assistant model’s generation config in some versions. (GitHub)

You do not need these issues to explain your situation, but they are useful “project background”: assisted generation is a special path with its own implementation details and occasional sharp edges.


Where to read next (high-signal)

All are worth bookmarking while you implement cascades:

  • Transformers “Generation strategies” page

    • Speculative decoding overview, constraints, assistant_model usage, plus the exact callable custom_generate signature. (Hugging Face)
  • Transformers “Custom generation methods” section

    • Clear explanation of repo-based custom_generate and the “all arguments forwarded” rule. (Hugging Face)
  • Transformers generate() API docs

    • Shows assistant_model and custom_generate are top-level parameters on GenerationMixin.generate. (Hugging Face)

Summary

  • assistant_model triggers Transformers’ built-in speculative decoding mode. It is handled by generate() and not part of a generic hook payload. (Hugging Face)
  • If your custom_generate is a callable loop, it has a fixed signature that does not include assistant_model, so your function will not receive it. (Hugging Face)
  • For Speculative Cascades, pass assistants into your loop via closure/partial, or use repo-based custom_generate where arguments are forwarded. (Hugging Face)
1 Like

So I managed to put it on a hf-repo and do it that way, but I encountered this issue. Which is driving me insane because my hf-repo shows that the folder definitely exists. I have no idea what is going on in this instance.

1 Like

Hmm…?


What should happen with assistant_model and custom_generate

In Transformers, assistant_model is the knob that turns on assisted / speculative-style decoding in the built-in generation code path. The docs show it used directly in model.generate(..., assistant_model=assistant_model, ...). (Hugging Face)

Separately, custom_generate is a hook that lets you override the decoding loop with your own function (either a callable, or a repo that contains custom_generate/generate.py). (Hugging Face)

Key contract (from HF docs):

  • When you call base generate() with custom_generate=..., Transformers loads custom_generate/generate.py and then calls your custom generate. (Hugging Face)
  • All arguments are forwarded into your custom generate, except the two “trigger” args: trust_remote_code and custom_generate. (Hugging Face)

So in principle, assistant_model should be forwarded into your custom generation function.

If you observe it “not arriving”, the most common reality is: you are not actually executing the custom generation function you think you are (stale cache, wrong entrypoint, or you never reached the hook due to the load error you’re now seeing).


Why you get the OSError even though the folder exists on the Hub

Your screenshot error is thrown by GenerationMixin.load_custom_generate.

That loader has a very specific rule:

  1. It first checks whether the string you passed (e.g. "radia/speculative-cascades") is a local path using os.path.exists(pretrained_model_name_or_path).
  2. If it is a local path, it checks locally for custom_generate/generate.py.
  3. If it’s not a local path, it checks the Hub for custom_generate/generate.py.
  4. If the check fails, it raises the exact OSError you see. (matsolutions.in)

This “local first” behavior is the big gotcha. It means:

The #1 cause

You accidentally have a local directory named radia/speculative-cascades (relative to your current working directory), so Transformers treats your argument as a filesystem path, not a Hub repo id.

Then it looks for:

./radia/speculative-cascades/custom_generate/generate.py

If that local clone is older or missing the folder, you get the OSError, even if the Hub repo is correct. That is exactly how the loader is written. (matsolutions.in)


Confirm quickly (2 checks)

Run these in the same environment where it fails:

import os
from huggingface_hub import file_exists

repo = "radia/speculative-cascades"

print("Local path exists?", os.path.exists(repo))
print("Remote file exists?", file_exists(repo, "custom_generate/generate.py"))

Interpretation:

  • If Local path exists? True then you have a path collision. Rename/move that local directory (or run from a different working directory) so the repo id is not also a valid relative path.
  • If Remote file exists? False then your environment cannot see the file on the Hub (offline mode, networking, auth, wrong repo type, or you pushed to a different branch than the default).

(If you want a stronger remote check, download the exact file.)
HF Hub provides hf_hub_download() for this. (Hugging Face)


Your repo code looks structurally correct

From the file you linked, you do have a custom_generate/generate.py, and it defines a top-level generate(model, *args, **kwargs) wrapper that calls GenerationMixin.generate(..., custom_generate=_speculative_cascades, **kwargs). (Hugging Face)

Inside _speculative_cascades, assistant_model is explicitly used when constructing the candidate generator. (Hugging Face)

So if you can get past the loader error, assistant_model should be usable in your decoding loop.


Concrete fixes (most likely to work)

Fix A (most likely): remove the local-path collision

  1. Check if you have ./radia/speculative-cascades/ on disk (relative to where you run Python).
  2. Rename ./radia/ or run from a different directory so os.path.exists("radia/speculative-cascades") becomes False.
  3. Retry.

This aligns exactly with the loader’s branching logic. (matsolutions.in)

Fix B: you are in offline mode or blocked from Hub

Check env vars:

  • HF_HUB_OFFLINE=1
  • TRANSFORMERS_OFFLINE=1

If set, file_exists(...) will fail or behave unexpectedly and you can hit this error even for valid repos.

Fix C: stale cached dynamic module

Even after you push updates to the Hub, Transformers can keep using cached dynamic modules. If you initially loaded the repo before custom_generate/ existed, you might still be running old cached code.

Typical nuke option (Linux/macOS):

  • Remove Hugging Face cache directories under ~/.cache/huggingface/ (both hub and modules).

(Do this only if you understand it will force redownloads.)


Why you might think assistant_model is “not passed”

Given your current state, the simplest explanation is:

  • You never reach your custom generate code because load_custom_generate raises first.
  • Or you reach an older cached copy of your code that didn’t accept/propagate the argument yet.

Once the loader is fixed, add a trivial diagnostic at the top of your custom generate():

def generate(model, *args, **kwargs):
    print("assistant_model in kwargs?", "assistant_model" in kwargs, type(kwargs.get("assistant_model")))
    ...

If that prints True, forwarding is working, and any remaining issue is inside your logic (e.g., you overwrite assistant_model later).


References and good docs (the ones that actually matter here)

  • Transformers docs: Custom generation methods (required folder layout, hardcoded paths, argument forwarding rules). (Hugging Face)
  • Transformers docs: Universal assisted decoding / assistant_model usage (what assistant_model is supposed to do). (Hugging Face)
  • HF Hub docs: hf_hub_download() for verifying the file exists remotely from Python. (Hugging Face)

Summary

  • custom_generate should forward assistant_model (it forwards everything except custom_generate and trust_remote_code). (Hugging Face)
  • Your OSError is almost always a local path collision: radia/speculative-cascades exists on disk, so Transformers stops treating it as a Hub repo id. (matsolutions.in)
  • Verify with os.path.exists(repo) and file_exists(repo, "custom_generate/generate.py").
  • Fix the collision or offline/cache issues, then re-test argument passing with a print at the top of your custom generate.