hmm…?
Why your predictions match the fine-tune label distribution
When you fine-tune a decoder LM with standard cross-entropy to generate a label/topic string, the training objective tends to bake in the label prior of the fine-tune set. On a small, imbalanced dataset, the easiest way to reduce loss is often “guess the common label more often,” especially when inputs are ambiguous or when most loss is spent modeling the prompt rather than the label.
This is a well-studied “long-tail” failure mode: naive training becomes biased toward frequent classes, and methods like logit adjustment are explicitly motivated as fixes for this phenomenon. (arXiv)
Step 0: Decide the target behavior (this changes the correct fix)
You need to choose which of these you actually want:
-
Equal importance across classes (better macro-F1 / minority recall)
You want training to not be dominated by head classes. -
Correct real-world frequency (calibration to deployment prior)
You want predicted frequencies to match the true base rates at inference time (which may differ from your fine-tune set).
Most practical setups do both:
- train to improve minority performance;
- then calibrate/correct priors at inference.
Highest-impact change: stop free-form generation and turn inference into scoring among known labels
If the label space is fixed (K topics), do not let decoding roam. Instead score each label candidate and pick the best.
Candidate scoring (recommended)
For each label y represented as a short token sequence, compute:
Then predict:
Why this helps:
- avoids greedy/beam “mode seeking” that favors frequent labels,
- makes prior correction (below) trivial,
- removes label-string length and tokenization artifacts from decoding as much as possible.
If you must constrain generation with prefix_allowed_tokens_fn, note that there are real-world reports where constraints are not respected under some conditions; test thoroughly. (GitHub)
Training technique 1: compute loss only on the label tokens (completion-only / prompt masking)
A common hidden cause of majority collapse in “classification-as-generation” is that you compute loss over the entire prompt+label sequence. The prompt can dominate gradient updates, leaving the label decision under-trained.
Use “completion-only” training: mask prompt tokens so loss applies only to the label span. TRL documents this pattern via DataCollatorForCompletionOnlyLM, and explicitly notes it works when packing=False. (Hugging Face)
Practical effect:
- more gradient budget on “choose correct label,”
- less incentive to fall back to the fine-tune prior.
Training technique 2: rebalance the optimizer’s view of the data
You typically need sampling and/or loss shaping.
A) Balanced or temperature sampling
Instead of sampling examples proportional to dataset frequency, adjust class sampling so minority classes appear often enough to matter (balanced batches, or smoothed sampling).
Pitfall: do not resample the full dataset before train/test split; it creates leakage and inflated metrics. (imbalanced-learn.org)
B) Class-balanced loss (effective number of samples)
A robust default is class-balanced weighting using the “effective number of samples” idea:
Use weights inversely proportional to E(n) rather than raw 1/n. This avoids extreme weights when a class is tiny. (arXiv)
C) Focal loss (focus on hard examples)
Focal loss down-weights easy (often majority) examples so they don’t dominate training and pushes learning onto harder cases. (arXiv)
D) LDAM + Deferred Re-Weighting (two-phase training)
LDAM introduces a minority-favoring margin, and DRW recommends a schedule: learn representations first, then apply reweighting. This is useful when reweighting early destabilizes training. (arXiv)
E) Balanced Softmax (prior-aware objective)
Balanced Softmax is designed to correct biased gradients under long-tail training and explicitly address mismatch between training and testing label distributions. (NeurIPS Proceedings)
Training technique 3: reduce “10k imbalanced fine-tune overwrites everything”
Full fine-tuning on 10k skewed data commonly increases prior imprinting and forgetting. Two practical mitigations:
- freeze most layers (tune only top blocks),
- or use parameter-efficient approaches (adapters/LoRA).
Even without adding new methods, lowering learning rate and early-stopping on macro metrics helps.
(Implementation note: if you use Hugging Face Trainer, class-weighted objectives and custom losses are commonly implemented by subclassing Trainer and overriding compute_loss. (Hugging Face Forums))
Inference technique 1: logit adjustment / prior correction (directly targets your symptom)
If your complaint is specifically “predicted frequencies mirror fine-tune frequencies,” correct the prior explicitly at inference.
A standard form:
is the fine-tune label prior (from your 10k),
is the deployment prior you want (or uniform if you want equalized rates),
is a strength you tune on validation.
This is the central idea in “logit adjustment” for long-tail learning. (arXiv)
When this works best:
- the model’s conditional ranking signal is decent,
- but the prior dominates the final decision.
Inference technique 2: if deployment priors are unknown, estimate label shift
If you suspect test/deployment priors differ from fine-tune priors, you can estimate the test priors without labeling test data using BBSE (Black Box Shift Estimation), then plug the estimate into the prior correction term above. (arXiv)
Inference technique 3: calibration for label-token bias (useful with label strings)
If your labels are natural words (not special tokens), the LM can be biased toward some labels independent of input. Contextual calibration estimates this bias using content-free inputs and corrects it without additional labeled data. (Proceedings of Machine Learning Research)
Practical “best default” plan for your exact pipeline
1) Make labels easy to score
- represent each topic as a canonical short token sequence (ideally 1 token via special tokens),
- use candidate scoring instead of free generation.
2) Fine-tune with label-only loss
- mask prompt tokens (completion-only). (Hugging Face)
3) Add imbalance handling
Pick one to start:
- class-balanced loss via effective number of samples, (arXiv)
or - DRW schedule (warmup then reweight), (arXiv)
or - focal loss if head classes are overwhelmingly easy. (arXiv)
4) Stabilize updates
- partial freezing or PEFT-style tuning,
- early stop on macro-F1 / per-class recall.
5) Correct priors at inference
- apply logit adjustment using fine-tune priors and your intended deployment prior. (arXiv)
Common pitfalls that cause your exact behavior
- Multi-token labels with different lengths: comparing only first-token probability biases results; always score the full label sequence.
- Not masking prompts: label decision gets too little gradient. (Hugging Face)
- Resampling before splitting: leakage inflates validation and hides failures. (imbalanced-learn.org)
- Assuming constrained decoding always works: there are reported cases where
prefix_allowed_tokens_fnis not followed. (GitHub)