Hmm…?
Your saturation is a predictable outcome of the specific fusion you chose: you compressed n discrete sequences into one sequence of averaged embeddings. That compression usually (1) throws away exactly the information you hoped to add, and (2) pushes the encoder into an input regime it was not pretrained on.
Below are the most likely causes in your exact setup, plus fixes that have strong precedent in multi-input seq2seq work.
1) What your “frequency × embedding” fusion really does
You built “soft tokens.”
At position t, you estimate a distribution over vocab IDs from the n candidates, then compute the expected embedding:
e_t = Σ_v p(t,v) * E[v]
This is a valid way to feed T5 because the model forward accepts inputs_embeds. (Hugging Face)
But it has two structural problems.
Problem A: it keeps only per-position marginals, not sequence structure
Your fusion preserves “which tokens appear at position t.”
It destroys:
- which tokens co-occur in the same candidate sequence
- cross-position dependencies inside each candidate
- “candidate identity” (which source said what)
So you reduce n full sequences to a much weaker signal. The model learns the easy part quickly (positions where candidates agree) then hits a ceiling.
Problem B: if candidates are even slightly misaligned, you are averaging unrelated tokens
Your method assumes token position t means the same thing across all candidates. If any candidate has an insertion, deletion, or shifted tokenization, then the histogram at position t mixes unrelated tokens. That behaves like structured noise and causes early plateau.
2) Out-of-distribution embeddings and “centroid collapse”
T5 was pretrained on embeddings that come from one discrete token ID per position. Your mixture embeddings often land in regions of embedding space that do not correspond to any real token embedding the model frequently saw.
Averaging embeddings also tends to shrink distinctions: many different distributions can produce similar mean vectors. So lots of examples become “more similar” to the encoder than they should be. That reduces separability and limits achievable loss reduction.
This is exactly why uncertain-input literature usually avoids naive averaging and instead preserves alternatives as lattices/confusion networks and consumes them with an architecture designed for uncertainty. (arXiv)
3) The most common silent bug: wrong scaling and PAD contamination
Even if the modeling idea were fine, two implementation details can kill learning.
Scaling
If you used raw counts (row sums ≈ n), then embedding magnitudes scale with n. Even with layer norms, you changed the distribution of activations the encoder sees, which can flatten attention or make gradients small.
Fix: per position, normalize to probabilities (row-sum = 1). Then optionally sharpen (below).
PAD contamination
If some candidates are shorter and you pad them, PAD gets counted. Then PAD embedding leaks into e_t. That is poison because you are injecting a strong “nothing here” vector into real positions.
Fix: exclude PAD tokens from the histogram entirely. Keep attention_mask correct for real vs padded time steps.
4) Why your loss curve has spikes and saturates (training dynamics explanation)
Two common patterns create a “down then stall with spikes” plot.
Warm restarts (scheduler)
If you use cosine annealing with warm restarts, the LR jumps back up at each restart. That can cause periodic loss spikes even if everything else is correct. (PyTorch Docs)
Debug choice: switch to linear warmup + decay while you diagnose modeling. Avoid restarts until stable.
Adafactor configuration traps (common with T5)
Adafactor settings are easy to misconfigure. For example, warmup_init=True requires relative_step=True, and that conflicts with setting a manual LR. This is a documented Transformers pitfall. (GitHub)
Debug choice: use AdamW first (simple), or use Adafactor with a known-good configuration from the docs.
5) The core issue: you want n:1 evidence fusion, but you implemented “early fusion by averaging”
For n:1, the strongest practical template is:
- keep each candidate as a real sequence
- encode each candidate separately
- fuse later (decoder attention or learned pooling)
This preserves structure and lets the model learn “which candidate to trust.”
The proven pattern: Fusion-in-Decoder (FiD)
FiD encodes each input separately and concatenates encoder states; the decoder cross-attends to all of them. It is explicitly designed for “many inputs to one output.” (arXiv)
FiD is usually used for retrieved passages, but your “n candidate sequences” is the same abstraction: multiple evidence streams.
Classic multi-source seq2seq precedent
Multi-source neural translation uses multiple encoders and a single decoder, exploring combination methods. Same structural problem, older but foundational. (arXiv)
Uncertain-input precedent
If your n sequences represent alternative hypotheses with implicit probabilities, lattice-to-seq models show why preserving posterior structure matters and how to incorporate it. (arXiv)
6) Solutions, ranked by “probability of fixing your saturation”
Solution 1 (fastest baseline): treat candidates as augmentation, not fusion
Create n training pairs per original example:
(candidate_i → target)for i in 1..n
This often beats embedding averaging because you stay fully in-distribution (discrete tokens) and you do not destroy structure.
Inference options:
- run all candidates and choose the best by log-likelihood
- ensemble outputs (logprob sum)
Solution 2 (simple single-pass fusion): concatenate candidates with separators and tags
Input text like:
cand1: ... </s> cand2: ... </s> ...
Pros: trivial to implement, preserves order.
Cons: context length grows; attention cost grows.
Solution 3 (best “correct fusion”): FiD-style encode-separately then fuse in decoder
High level:
- reshape batch to encode each candidate independently
- concatenate encoder hidden states
- feed concatenated states to decoder
If you want an existence proof, the FiD repo shows the exact reshape-and-concatenate trick. (GitHub)
Solution 4 (if you insist on your histogram idea): make it less lossy and learnable
If you keep “distribution over vocab per position,” do not map it to embeddings with a fixed linear average and stop.
Do this instead:
- Normalize counts to probabilities.
- Sharpen to reduce blur:
q ∝ p^αwith α > 1. - Add a confidence channel (entropy or max prob) so the model knows which positions are ambiguous.
- Add a trainable projection:
e'_t = W e_tor small MLP before feeding encoder.
This lets the model learn how to interpret your soft evidence.
7) “Plumbing” checks that you should run regardless
These do not solve the fundamental fusion issue, but they can mimic saturation.
Check A: label padding is ignored (-100)
For seq2seq, padded label positions must be -100 so loss ignores them. DataCollatorForSeq2Seq defaults label_pad_token_id=-100 and documents that -100 is ignored by PyTorch losses. (Hugging Face)
If you accidentally compute loss on padded labels, you will see an artificial floor.
Check B: generation with inputs_embeds is special
Forward can use inputs_embeds, but generate() historically does not accept it cleanly for seq2seq. People work around it by calling the encoder first and passing encoder_outputs to generate(). (GitHub)
This matters for evaluation: you can think you are evaluating the model you trained, but you are not actually using the same conditioning path unless you do this correctly. A recent discussion clarifies how generate() behaves when you pass encoder_outputs. (Hugging Face Forums)
8) A concrete debug plan that will tell you what is wrong in 1–2 experiments
Experiment 1: can the model overfit a tiny subset with normal discrete input?
Take 32 examples, pick one candidate per example, train until near-zero loss.
- If you cannot overfit, you have a training/pipeline bug (labels, masking, LR, optimizer).
- If you can, your saturation is caused by your fusion method.
Experiment 2: compare three n:1 strategies head-to-head
Hold everything constant and run:
- augmentation baseline (n separate pairs)
- concatenation baseline
- your histogram-embedding fusion
If (1) or (2) beats (3) quickly, your averaging destroyed signal.
Experiment 3: measure blur
On a batch, compute:
- mean per-position entropy of
p(t,·) - mean norm of fused
e_t - cosine similarity between fused sequences across examples
High entropy + high similarity between examples is the signature of “centroid collapse.”
Curated links (high-signal, directly relevant)
-
T5 supports
inputs_embedsin forward (signature documented): https://huggingface.co/transformers/v3.0.2/model_doc/t5.html (Hugging Face) -
inputs_embedswithgenerate()workaround discussions:- https://github.com/huggingface/transformers/issues/12844 (GitHub)
- https://github.com/huggingface/transformers/issues/6535 (GitHub)
- https://huggingface.co/static-proxy/discuss.huggingface.co/t/how-to-use-inputs-embeds-in-generate/713 (Hugging Face Forums)
- https://huggingface.co/static-proxy/discuss.huggingface.co/t/understanding-t5-with-custom-embedding/162325 (Hugging Face Forums)
-
FiD (n inputs to 1 output, proven fusion):
-
Multi-source seq2seq baseline paper: https://arxiv.org/abs/1601.00710 (arXiv)
-
Uncertain input, do not average away structure:
- Lattice-to-seq paper: https://arxiv.org/abs/1704.00559 (arXiv)
- Self-attentional models for lattice inputs: https://arxiv.org/pdf/1906.01617v1 (isl.iar.kit.edu)
-
Seq2seq label padding and
label_pad_token_id=-100: -
Adafactor pitfalls and recommended settings:
- Issue showing constraint: https://github.com/huggingface/transformers/issues/7789 (GitHub)
- Official optimizer schedule docs (includes Adafactor guidance): https://huggingface.co/docs/transformers/en/main_classes/optimizer_schedules (Hugging Face)
-
Warm restarts explanation (why spikes happen): https://docs.pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.CosineAnnealingWarmRestarts.html (PyTorch Docs)
Summary
- Your fusion averages away sequence structure and candidate identity, so the model learns the easy consensus then saturates.
- Misalignment across candidates and PAD leakage can turn your histogram into noise.
- Scaling issues (counts vs probabilities) can cause early plateaus.
- Strong fixes: augmentation, concatenation with separators, or FiD-style encode-separately and fuse in decoder.
- Also verify label padding
-100, avoid warm restarts while debugging, and confirm Adafactor settings.