Handling class imbalance when finetuning a decoder model on text generation

Iam trying to formulate a classification task into a topic generation task using a custom pretrained decoder model ~xM params

Now i have pretrained the model with a labelled pretraining dataset(~20M records) which has an error rate of X% in the gt labels for some records
I also have a labelled finetuning dataset(~10k) which has high quality error free labelelled data pairs the issue is this dataset is class imbalanced and when i finetune the full model on this dataset and run inference on a test set the inference distrubutions of the models predictions follow the same as the distrubution that the finetunnig dataset had

What are the techniques which can be used to stop this? from happening? thanks a ton..

1 Like

hmm…?


Why your predictions match the fine-tune label distribution

When you fine-tune a decoder LM with standard cross-entropy to generate a label/topic string, the training objective tends to bake in the label prior of the fine-tune set. On a small, imbalanced dataset, the easiest way to reduce loss is often “guess the common label more often,” especially when inputs are ambiguous or when most loss is spent modeling the prompt rather than the label.

This is a well-studied “long-tail” failure mode: naive training becomes biased toward frequent classes, and methods like logit adjustment are explicitly motivated as fixes for this phenomenon. (arXiv)


Step 0: Decide the target behavior (this changes the correct fix)

You need to choose which of these you actually want:

  1. Equal importance across classes (better macro-F1 / minority recall)
    You want training to not be dominated by head classes.

  2. Correct real-world frequency (calibration to deployment prior)
    You want predicted frequencies to match the true base rates at inference time (which may differ from your fine-tune set).

Most practical setups do both:

  • train to improve minority performance;
  • then calibrate/correct priors at inference.

Highest-impact change: stop free-form generation and turn inference into scoring among known labels

If the label space is fixed (K topics), do not let decoding roam. Instead score each label candidate and pick the best.

Candidate scoring (recommended)

For each label y represented as a short token sequence, compute:

s(y) = \log P(y \mid x)

Then predict:

\hat{y} = \arg\max_y s(y)

Why this helps:

  • avoids greedy/beam “mode seeking” that favors frequent labels,
  • makes prior correction (below) trivial,
  • removes label-string length and tokenization artifacts from decoding as much as possible.

If you must constrain generation with prefix_allowed_tokens_fn, note that there are real-world reports where constraints are not respected under some conditions; test thoroughly. (GitHub)


Training technique 1: compute loss only on the label tokens (completion-only / prompt masking)

A common hidden cause of majority collapse in “classification-as-generation” is that you compute loss over the entire prompt+label sequence. The prompt can dominate gradient updates, leaving the label decision under-trained.

Use “completion-only” training: mask prompt tokens so loss applies only to the label span. TRL documents this pattern via DataCollatorForCompletionOnlyLM, and explicitly notes it works when packing=False. (Hugging Face)

Practical effect:

  • more gradient budget on “choose correct label,”
  • less incentive to fall back to the fine-tune prior.

Training technique 2: rebalance the optimizer’s view of the data

You typically need sampling and/or loss shaping.

A) Balanced or temperature sampling

Instead of sampling examples proportional to dataset frequency, adjust class sampling so minority classes appear often enough to matter (balanced batches, or smoothed sampling).

Pitfall: do not resample the full dataset before train/test split; it creates leakage and inflated metrics. (imbalanced-learn.org)

B) Class-balanced loss (effective number of samples)

A robust default is class-balanced weighting using the “effective number of samples” idea:

E(n) = \frac{1-\beta^n}{1-\beta}

Use weights inversely proportional to E(n) rather than raw 1/n. This avoids extreme weights when a class is tiny. (arXiv)

C) Focal loss (focus on hard examples)

Focal loss down-weights easy (often majority) examples so they don’t dominate training and pushes learning onto harder cases. (arXiv)

D) LDAM + Deferred Re-Weighting (two-phase training)

LDAM introduces a minority-favoring margin, and DRW recommends a schedule: learn representations first, then apply reweighting. This is useful when reweighting early destabilizes training. (arXiv)

E) Balanced Softmax (prior-aware objective)

Balanced Softmax is designed to correct biased gradients under long-tail training and explicitly address mismatch between training and testing label distributions. (NeurIPS Proceedings)


Training technique 3: reduce “10k imbalanced fine-tune overwrites everything”

Full fine-tuning on 10k skewed data commonly increases prior imprinting and forgetting. Two practical mitigations:

  • freeze most layers (tune only top blocks),
  • or use parameter-efficient approaches (adapters/LoRA).

Even without adding new methods, lowering learning rate and early-stopping on macro metrics helps.

(Implementation note: if you use Hugging Face Trainer, class-weighted objectives and custom losses are commonly implemented by subclassing Trainer and overriding compute_loss. (Hugging Face Forums))


Inference technique 1: logit adjustment / prior correction (directly targets your symptom)

If your complaint is specifically “predicted frequencies mirror fine-tune frequencies,” correct the prior explicitly at inference.

A standard form:

\mathrm{score}(y) = s(y) - \tau \log p_{\mathrm{ft}}(y) + \tau \log p_{\mathrm{target}}(y)
p_{\mathrm{ft}}(y)

is the fine-tune label prior (from your 10k),

p_{\mathrm{target}}(y)

is the deployment prior you want (or uniform if you want equalized rates),

\tau

is a strength you tune on validation.

This is the central idea in “logit adjustment” for long-tail learning. (arXiv)

When this works best:

  • the model’s conditional ranking signal is decent,
  • but the prior dominates the final decision.

Inference technique 2: if deployment priors are unknown, estimate label shift

If you suspect test/deployment priors differ from fine-tune priors, you can estimate the test priors without labeling test data using BBSE (Black Box Shift Estimation), then plug the estimate into the prior correction term above. (arXiv)


Inference technique 3: calibration for label-token bias (useful with label strings)

If your labels are natural words (not special tokens), the LM can be biased toward some labels independent of input. Contextual calibration estimates this bias using content-free inputs and corrects it without additional labeled data. (Proceedings of Machine Learning Research)


Practical “best default” plan for your exact pipeline

1) Make labels easy to score

  • represent each topic as a canonical short token sequence (ideally 1 token via special tokens),
  • use candidate scoring instead of free generation.

2) Fine-tune with label-only loss

3) Add imbalance handling

Pick one to start:

  • class-balanced loss via effective number of samples, (arXiv)
    or
  • DRW schedule (warmup then reweight), (arXiv)
    or
  • focal loss if head classes are overwhelmingly easy. (arXiv)

4) Stabilize updates

  • partial freezing or PEFT-style tuning,
  • early stop on macro-F1 / per-class recall.

5) Correct priors at inference

  • apply logit adjustment using fine-tune priors and your intended deployment prior. (arXiv)

Common pitfalls that cause your exact behavior

  • Multi-token labels with different lengths: comparing only first-token probability biases results; always score the full label sequence.
  • Not masking prompts: label decision gets too little gradient. (Hugging Face)
  • Resampling before splitting: leakage inflates validation and hides failures. (imbalanced-learn.org)
  • Assuming constrained decoding always works: there are reported cases where prefix_allowed_tokens_fn is not followed. (GitHub)

What you’re seeing is a very common and very predictable behaviour of decoder‑only models when you turn a classification problem into a generation problem: the model learns the label distribution of the fine‑tuning set as a prior, and unless you actively counteract it, that prior dominates inference.

Let’s break this down cleanly and then walk through the techniques that actually work.


:brain: Why this happens
A decoder model trained to generate labels is essentially learning:

P(\text{label} \mid \text{input})

But during fine‑tuning, because the dataset is small and imbalanced, the model also implicitly learns:

P(\text{label}) \approx \text{empirical distribution of fine‑tuning set}

When the fine‑tuning set is small (10k) and the pretraining set is noisy, the model overfits to the clean but skewed distribution.

This is especially strong in decoder‑only architectures because they are trained autoregressively and treat the label token(s) as part of the language distribution.


:hammer_and_wrench: Techniques that actually work

Below are the methods that reliably break the “distribution copying” behaviour.


  1. Loss Re‑weighting / Class‑Balanced Loss
    This is the most direct fix.

You assign higher loss weight to minority classes and lower weight to majority classes.

Common strategies:

  • Inverse frequency weighting
  • Effective number of samples (Cui et al., 2019)
  • Focal loss (helps with imbalance + noisy labels)

This prevents the model from learning the skewed prior.


  1. Oversampling / Undersampling (but carefully)
    For generation tasks, oversampling minority classes works surprisingly well.

But you must avoid:

  • duplicating identical samples (causes overfitting)
  • oversampling too aggressively (destroys natural priors)

A good rule of thumb:

  • oversample until each class is within 2–3Ă— of each other, not fully balanced.

  1. Use instruction‑style prompts that force the model to ignore priors
    This is underrated.

Instead of training the model to output just the label, train it to follow a task instruction:

Input: <text> Task: Identify the correct topic regardless of frequency. Output: <topic>

This reduces the model’s reliance on unconditional priors.


  1. Add a calibration layer at inference
    Even if the model is biased, you can correct it post‑hoc.

Two strong methods:

a) Temperature scaling per class
You adjust logits so that rare classes are not suppressed.

b) Prior correction
If you know the true expected class distribution, you can apply Bayes correction:

P{\text{corrected}}(y|x) \propto \frac{P{\text{model}}(y|x)}{P_{\text{model}}(y)}

This is extremely effective for generative classifiers.


  1. Use a two‑stage model: encoder classifier + decoder generator
    This is a hybrid approach:

  2. Use a small encoder classifier trained with class‑balanced loss.

  3. Feed its predicted class into the decoder as a conditioning token.

This gives you:

  • balanced classification
  • rich generative output

  1. Freeze most of the decoder during fine‑tuning
    If you fine‑tune the entire decoder on a small imbalanced dataset, it will absolutely learn the skew.

Instead:

  • freeze 90–99% of layers
  • fine‑tune only the top 1–3 transformer blocks
  • or use of the decoder during LoRA adapters

This preserves the fine‑tuning
If pretrained distribution you fine‑tune the entire decoderal knowledge and reduces overfitting on a small imbalanced dataset, it will to the fine‑tuning distribution.

absolutely learn the skew.

Instead—

  1. Use synthetic balancing:
  • freeze 90–99% of layers
  • fine
    If you can generate‑tune only the top synthetic examples for minority classes 1–3 transformer (even using your blocks
  • or use LoRA adapters

own model), you can rebalance theThis preserves the dataset without pretrained distributional knowledge and oversampling.

This works well for reduces overfitting to the fine‑tuning topic generation because the model distribution.


  1. Use can generate paraphrases.

synthetic balancing
If you can generate 8. Add a “label prior suppression synthetic examples for minority classes” token
This is (even using your a trick used in some LLM classification own model), you can rebalance the setups.

You add a special token dataset without like:

`
<force_fairness> oversampling.

This works well for
`

And train the topic generation model that when this token is present because the model can generate parap, it should ignorehrases.


prior label frequencies 8. Add a “label.

This works because decoder models are prior suppression extremely sensitive” token
This is a trick used in fine‑tuning
If some LLM classification setups.

You add you fine‑tune the entire decoder a special token like:

<force_fairness> on a small imbalanced

And train the dataset, it will absolutely learn model that when the skew.

Instead this token is present, it should ignore:

  • freeze 90–99% of layers
  • fine prior label frequencies.

This works because‑tune only the top 1–3 transformer decoder models are extremely sensitive blocks

  • or use LoRA adapters

to conditioning tokens.


This preserves the pretrained distribution :test_tube: What I’d recommendal knowledge and for your exact setup
Given your reduces overfitting to the fine‑tuning description:

  • large noisy pre distribution.

  1. Usetraining set
  • small clean but synthetic balancing
    If you can generate imbalanced fine synthetic examples‑tuning set
  • decoder‑only architecture for minority classes (even using your own model), you can rebalance the dataset without oversampling.

This works well for topic generation because the model can generate paraphrases.


  1. Add a “label prior suppression” token
    This is a trick used in some LLM classification setups.

You add a special token like:

<force_fairness>

And train the model that when this token is present, it should ignore prior label frequencies.

This works because decoder models are extremely sensitive to conditioning tokens.


:test_tube: What I’d recommend for your exact setup
Given your description:

  • large noisy pretraining set
  • small clean but imbalanced fine‑tuning set
  • decoder‑only architecture
  • generation‑based classification

The most effective combination is:

  1. Freeze most of the model + LoRA adapters
    Prevents overfitting to skew.

  2. Class‑balanced loss (effective number of samples)
    Prevents the model from copying the distribution.

  3. Instruction‑style prompting
    Reduces reliance on unconditional priors.

  4. Optional: post‑hoc calibration
    If you need perfect distribution alignment.

This combo usually fixes the issue without needing to modify the dataset too much.


1 Like