Is huggingface dataset suitable for ddp training?

Hi, I’m currently reproducing a project which uses torch lightning trainer to manage the DDP training, while its home-made dataset implementation blow up my CPU memory. Its implementation is doubly bad (the raw data is ~300G, while during training I have to apply more than 720G), so I just wonder if HF dataset can perform better than that.

I attach its dataset implementation below, I think the most suspicious part is the CachedDataset wrapper:

from typing import Any, TypeVar

from multiprocessing import Manager

import torch
from torch.utils.data import Dataset

__all__ = ["CachedDataset"]


class NumpiedTensor:
    def __init__(self, tensor: torch.Tensor) -> None:
        self.array = tensor.numpy()

    def to_tensor(self) -> torch.Tensor:
        return torch.tensor(self.array)


def numpize_sample(sample: Any) -> Any:
    if isinstance(sample, torch.Tensor):
        return NumpiedTensor(sample)
    elif isinstance(sample, tuple):
        return tuple(numpize_sample(s) for s in sample)
    elif isinstance(sample, list):
        return [numpize_sample(s) for s in sample]
    elif isinstance(sample, dict):
        return {k: numpize_sample(v) for k, v in sample.items()}
    else:
        return sample


def tensorize_sample(sample: Any) -> Any:
    if isinstance(sample, NumpiedTensor):
        return sample.to_tensor()
    elif isinstance(sample, tuple):
        return tuple(tensorize_sample(s) for s in sample)
    elif isinstance(sample, list):
        return [tensorize_sample(s) for s in sample]
    elif isinstance(sample, dict):
        return {k: tensorize_sample(v) for k, v in sample.items()}
    else:
        return sample


T_co = TypeVar("T_co", covariant=True)


class CachedDataset(Dataset[T_co]):
    def __init__(self, dataset: Dataset[T_co]) -> None:
        self.dataset = dataset

        self.manager = Manager()
        self.cache = self.manager.dict()

    def __len__(self) -> int:
        return len(self.dataset)  # type: ignore[arg-type]

    def __getitem__(self, index: int) -> Any:
        if index not in self.cache:
            self.cache[index] = numpize_sample(self.dataset[index])

        return tensorize_sample(self.cache[index])

Where it wraps a HDF5Dataset

from __future__ import annotations

from typing import Any

from pathlib import Path
import pickle as pkl

import torch
from torch.utils.data import Dataset

import h5py as h5

__all__ = [
    "RawHDF5Dataset",
    "HDF5Dataset",
]


class RawHDF5Dataset(Dataset[int]):
    def __init__(self, dataset_path: Path | str, grp_list: Path | str | list[str] | None = None) -> None:
        self.dataset_path = dataset_path

        if grp_list is None:
            with h5.File(self.dataset_path, "r") as f:
                self.grp_list = list(f.keys())
        elif isinstance(grp_list, (str, Path)):
            with open(grp_list, "rb") as f:
                self.grp_list = pkl.load(f)
        elif isinstance(grp_list, list):
            self.grp_list = grp_list
        else:
            raise NotImplementedError()
        self.grp_list.sort()

        self.f: h5.File | None = None

    def __len__(self) -> int:
        return len(self.grp_list)

    def __getitem__(self, index: int) -> dict[str, Any]:
        if self.f is None:
            self.f = h5.File(self.dataset_path, "r")

        return {k: v[:] for k, v in self.f[self.grp_list[index]].items()}

    def __del__(self) -> None:
        if self.f is not None:
            self.f.close()


class HDF5Dataset(RawHDF5Dataset):
    def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
        return {k: torch.as_tensor(v) for k, v in super().__getitem__(index).items()}

I ask chatgpt and claude, both of them told me every GPU will create their own CachedDataset instance, but I’m not sure. Will HFDataset handle this better?

1 Like

I don’t have a multi-GPU setup, so I can’t test it here, but I’ve gathered the resources for now.

1 Like

Hi John,

Sorry for replying late. Thank you for your kind response.

I tried your solution in section 5.2 with the following script:

from pathlib import Path
import h5py
from datasets import Dataset, DatasetDict

DATA_PATH = Path('...')

splits = {
    "test": "hdf5/chunk.derev-hop160-tt.hdf5",
    "val": "hdf5/chunk.derev-hop160-cv.hdf5", 
    "train": "hdf5/chunk.derev-hop160-tr.hdf5",
}

def h5_generator(h5_path):
    with h5py.File(h5_path, "r") as f:
        grp_names = sorted(f.keys())
        for g in grp_names:
            grp = f[g]
            yield {
                "act": grp["act"][:],
                "wav": grp["wav"][:],
            }

dataset_dict = {}
for split_name, h5_path in splits.items():
    ds = Dataset.from_generator(
        h5_generator,
        gen_kwargs={"h5_path": DATA_PATH/h5_path},
        writer_batch_size=500,
        cache_dir=str(DATA_PATH)
    )
    dataset_dict[split_name] = ds

dataset = DatasetDict(dataset_dict)
dataset.save_to_disk(DATA_PATH/"hf_datasets")

It generates perfectly multiple shards of the dataset in Arrow format.

However, there’s one thing I don’t understand, when I check the dataset(e.g. the test split), it is composed of shards with size of around 480MB each.
While the dataset_info for this split is

{
  "builder_name": "generator",
  "citation": "",
  "config_name": "default",
  "dataset_name": "generator",
  "dataset_size": 33296374180,
  "description": "",
  "download_checksums": {},
  "download_size": 0,
  "features": {
    "act": {
      "feature": {
        "feature": {
          "dtype": "float32",
          "_type": "Value"
        },
        "_type": "List"
      },
      "_type": "List"
    },
    "wav": {
      "feature": {
        "feature": {
          "dtype": "float32",
          "_type": "Value"
        },
        "_type": "List"
      },
      "_type": "List"
    }
  },
  "homepage": "",
  "license": "",
  "size_in_bytes": 33296374180,
  "splits": {
    "train": {
      "name": "train",
      "num_bytes": 33296374180,
      "num_examples": 3239,
      "shard_lengths": [
        500,
        500,
        500,
        500,
        500,
        500,
        239
      ],
      "dataset_name": "generator"
    }
  },
  "version": {
    "version_str": "0.0.0",
    "major": 0,
    "minor": 0,
    "patch": 0
  }
}

which intuitively suggests there should be only 7 shards according to the shard_lengths, but I actually have 33G/480M = 67 shards, which doesn’t make much sense. On the other hand, I also check in the cache_dir, where stores the cached version and it does contain 7 shards, with each shard of size 5G (except the last one).

Would you explain such weird behavior?

1 Like

Hmm…?


You’re seeing two different “sharding layers” at once:

  • Logical shards: what dataset_info.splits.*.shard_lengths is describing.
  • Physical shards: the actual .arrow files created by save_to_disk, controlled by max_shard_size / num_shards.

They are not the same thing, and HF Datasets does not keep them in sync anymore. That is why “7 shards in shard_lengths” and “~67 files on disk” both exist and both are “correct” in their own sense.

I’ll walk through what your script is doing and how HF Datasets behaves internally.


1. What your script actually does

You run:

def h5_generator(h5_path):
    with h5py.File(h5_path, "r") as f:
        grp_names = sorted(f.keys())
        for g in grp_names:
            grp = f[g]
            yield {
                "act": grp["act"][:],
                "wav": grp["wav"][:],
            }

ds = Dataset.from_generator(
    h5_generator,
    gen_kwargs={"h5_path": DATA_PATH/h5_path},
    writer_batch_size=500,
    cache_dir=str(DATA_PATH)
)

Key points:

  1. Dataset.from_generator(...) builds a temporary Arrow dataset in cache_dir.

    • By default, HF Datasets writes to disk every 1000 rows; you changed that to 500 via writer_batch_size=500. (Hugging Face Forums)
    • Each “writer batch” is a logical shard of ~500 rows.
  2. You then do:

    dataset = DatasetDict(dataset_dict)
    dataset.save_to_disk(DATA_PATH/"hf_datasets")
    

    save_to_disk will re-write the dataset into a new directory, using its own sharding logic based on max_shard_size and/or num_shards. It does not preserve your original writer-batch boundaries. (Hugging Face)

So there are effectively two separate representations:

  • A cached representation in cache_dir (your 7×~5 GB files).
  • A final “saved to disk” representation in hf_datasets/ (your ~67×480 MB files).

2. Why shard_lengths says 7 but you see ~67 .arrow files

2.1 What shard_lengths actually means now

dataset_info.splits["train"].shard_lengths is part of the legacy SplitInfo metadata that comes from TensorFlow Datasets semantics (number of examples per original shard). HF Datasets keeps it for backward compatibility, but it is effectively deprecated now:

  • In the datasets.splits code, you can see both shard_lengths and original_shard_lengths fields, with comments marking them as deprecated and kept only so dataset_infos stay compatible. (lovelace.cluster.earlham.edu)

So:

  • shard_lengths = [500, 500, 500, 500, 500, 500, 239]
  • is describing the original generator write batches used when constructing the dataset, not the current on-disk .arrow file layout after save_to_disk.

Hugging Face’s own docs emphasize that saving creates “arrow files” plus a dataset_info.json, but it doesn’t say the shard_lengths field is updated to match the new file slicing. (Hugging Face)

In other words:

shard_lengths is logical/legacy metadata, not a faithful description of how many .arrow files are in your saved folder.

2.2 How save_to_disk decides how many .arrow files to create

Dataset.save_to_disk shards data based on size in bytes, not your earlier writer_batch_size.

Internally, it calls into a sharding routine that uses:

  • num_shards (if you pass it explicitly), or
  • max_shard_size (if you don’t), with a default of 500 MB (config.MAX_SHARD_SIZE = "500MB"). (GitHub)

The logic (simplified) is:

num_shards = floor(dataset_nbytes / max_shard_size) + 1  (if num_shards is None)

Given your num_bytes ≈ 33,296,374,180:

  • 500 MB = 500 * 1024^2 ≈ 524,288,000 bytes.
  • 33,296,374,180 / 524,288,000 ≈ 63.56
  • → num_shards ≈ 64–67 depending on overhead and internal counting.

You report:

  • ~33 GB total,
  • shard files around ~480 MB each,
  • total ~67 files.

That’s exactly what you expect for “dataset size divided by ~500 MB”. The fact they’re ~480 MB not exactly 500 MB is just internal padding / Arrow overhead.

So the picture is:

  • Generator build phase:

    • writer_batch_size=500 → HF creates 7 logical shards of 500 examples (last smaller).
    • These are stored under cache_dir as ~5 GB Arrow files (your 7 files).
  • save_to_disk phase:

    • Ignores those original shards.
    • Re-shards based on size, using ~500 MB target size.
    • Produces ~67 .arrow files of ~480 MB each under hf_datasets/....

dataset_info.splits["train"].shard_lengths still reflects the first phase (500-example shards), not the second.


3. Why you see “7×5 GB in cache_dir” vs “67×480 MB in hf_datasets

This comes from the two different uses of “shard” inside HF Datasets:

  1. Cache-level shards (generator / map / transforms)

    • Every call to from_generator or .map uses a cache writer.
    • This writer flushes data to disk every writer_batch_size examples. (Hugging Face Forums)
    • Those flushes create intermediate Arrow tables, tracked in dataset.cache_files.
    • That’s where your 7×~5 GB files in cache_dir come from (500 examples per writer batch, but examples are huge arrays so each file is big).
  2. Final dataset shards (save_to_disk / push_to_hub)

    • save_to_disk re-chunks the dataset based on overall byte size, using max_shard_size (default 500 MB) or num_shards. (Hugging Face)
    • Each new shard is written as a separate .arrow file.
    • That’s your ~67 files of ~480 MB.

The dataset_info you’re inspecting is just copying the DatasetInfo object from the generator stage, including the old shard_lengths. That metadata is not regenerated to match the final arrow files.

If you want to see the actual physical files the final dataset is using, you should check:

from datasets import load_from_disk

ds = load_from_disk(DATA_PATH/"hf_datasets")["train"]
print(ds.cache_files)

ds.cache_files will list the .arrow files and the number of rows they each contain. (Hugging Face)

That list should have ~67 entries for the saved dataset, even though dataset_info.splits["train"].shard_lengths still has 7 numbers.


4. How to control the number/size of saved shards

If you want the final saved dataset to have a particular sharding (not 67 × 480 MB), you can control it directly when saving:

  • Limit shard size by bytes:

    ds.save_to_disk("hf_datasets", max_shard_size="5GB")
    
  • Or specify exact number of shards:

    ds.save_to_disk("hf_datasets", num_shards=7)
    

These options are documented (or at least exemplified) in HF Datasets’ docs and helper code. (Hugging Face)

That way:

  • max_shard_size="5GB" would give you roughly the same 7×5 GB layout you see in your generator cache.
  • num_shards=7 forces exactly 7 files regardless of size, as long as they fit.

Note: the shard_lengths field in dataset_info is still informational / legacy and may not equal the physical file count even then; to inspect the real structure, always look at cache_files and the actual .arrow files.


5. Sanity check with your numbers

Let’s reconcile your numbers explicitly:

  • dataset_info.dataset_size = 33,296,374,180 bytes (~31 GiB).

  • shard_lengths = [500, 500, 500, 500, 500, 500, 239] → 3239 examples total.

  • cache_dir:

    • 7 cache files, each ~5 GB → matches the 7 generator “writer batches” of 500 rows (your writer_batch_size=500).
  • hf_datasets/:

    • 67 arrow files, each ~480 MB.
    • 33.3 GB / 0.48 GB ≈ 69 → consistent with max_shard_size ≈ 0.5 GB and overhead/rounding.

So the behavior is exactly:

  • Generator-level sharding controlled by writer_batch_size → 7 logical shards (what shard_lengths remembers).
  • Final save-level sharding controlled by max_shard_size (default 500 MB) → ~67 physical arrow files.

No data is lost or duplicated; it’s just two distinct sharding layers.


TL;DR

  • The shard_lengths in dataset_info comes from the generator phase and describes logical shard sizes (here: batches of 500 examples). It’s deprecated / legacy metadata and not kept in sync with final .arrow files. (lovelace.cluster.earlham.edu)
  • save_to_disk re-shards the dataset based on byte size, using a default max_shard_size of 500 MB, which is why you see ~67 arrow files of ~480 MB each. (GitHub)
  • The 7 big (~5 GB) files in cache_dir are from from_generator’s writer_batch_size=500, not from save_to_disk.
  • To understand the actual on-disk layout of the final dataset, ignore shard_lengths and inspect ds.cache_files and the .arrow files themselves.
1 Like

Hi John,

I have another question on the data preprocessing. Now I get a HF dataset and I need to do data-augmentation on-fly on each item before making them into one batch. Before using HF datasets, I just define a collate function and pass it to DataLoader(collate_fn=...), but now that datasets also provides a set_transform() method which seems play the same role as collate_fn, so I wonder which one is a better choice or they are actually equal.

1 Like

Seems similar but different ones?


They are not equal. They can look similar because both can “modify data on the fly”, but they run at different stages and receive different inputs.

Background: what changes when you switch to :hugs: Datasets

A :hugs: datasets.Dataset is still a “map-style dataset” from PyTorch’s point of view: DataLoader repeatedly calls dataset[i] to fetch samples, then it batches them. PyTorch describes automatic batching as “fetch a list of samples, then call collate_fn on that list”. (PyTorch Docs)

:hugs: Datasets additionally exposes a formatting layer (via set_format and set_transform/with_transform) that changes what dataset[i] returns. The docs explicitly call set_transform/with_transform a formatting transform that is applied “right before returning the objects in __getitem__”, takes a batch as a dict, returns a batch, and replaces the format defined by set_format(). (Hugging Face)

So you now have two hooks:

  1. a hook that changes what each item looks like when fetched (__getitem__ formatting), and
  2. a hook that changes how a list of items becomes one batch (collation).

That is the core difference.

What each hook actually does

1) dataset.set_transform() / dataset.with_transform()

What it is:

  • A formatting transform for __getitem__.
  • Applied on-the-fly when __getitem__ is called.
  • Signature is batch dict in → batch dict out (“takes a batch (as a dict) as input and returns a batch”).
  • It replaces set_format() formatting.
  • with_transform() returns a new dataset object; set_transform() mutates in place. (Hugging Face)

Implication:

  • This is best for per-item work (decode, per-example augmentation, tokenize one example, convert types).
  • You are changing what each sample looks like before the DataLoader sees it.

2) DataLoader(collate_fn=...)

What it is:

  • A batch assembly function.
  • In the default (automatic batching) mode, PyTorch fetches a list of samples then calls collate_fn(list_of_samples). PyTorch even gives the “roughly equivalent” pseudocode:
    yield collate_fn([dataset[i] for i in indices]). (PyTorch Docs)
  • collate_fn is where you handle padding, stacking, mask creation, and anything that needs to look across the batch. (PyTorch Docs)

Implication:

  • This is best for per-batch work (dynamic padding, stacking tensors, building batch-level augmentations like mixup or in-batch negatives).

Why they “seem” the same

Because both can implement augmentation, and both run during data loading. But they are different units of work:

  • set_transform/with_transform: “Given one requested index (or a slice), what do I return from dataset[...]?” (Hugging Face)
  • collate_fn: “Given a list of already-returned examples, how do I combine them into one batch?” (PyTorch Docs)

Also, the input structures differ:

  • HF transform: dict of columns (batch as dict). (Hugging Face)
  • PyTorch collate: list of samples (list of dicts/tuples). (PyTorch Docs)

That mismatch alone causes a lot of confusion.

Which is “better” for on-the-fly per-item augmentation?

If you truly mean: “augment each item independently before batching”, then:

Prefer with_transform (or set_transform) for the augmentation itself

Because it is literally designed to alter __getitem__ output on-the-fly. (Hugging Face)

Still keep collate_fn when you need batching logic

Most real pipelines still need collation for at least one of:

  • padding variable length sequences
  • stacking images into a tensor
  • building masks
  • handling optional fields

PyTorch expects that “after fetching a list of samples … collate_fn is used to collate lists of samples into batches.” (PyTorch Docs)

So the practical answer is usually: use both.

  • with_transform: per-example augmentation
  • collate_fn: batch assembly (and any batch-dependent augmentation)

Transformer ecosystem note: “data collator” is just a collate_fn

In Transformers docs, a “data collator” is explicitly an object/function that “will form a batch by using a list of dataset elements as input” and may apply padding; some even apply random augmentation “on the formed batch” (example: random masking for LM). (Hugging Face)

So if you are in NLP, it is common to:

  • do tokenization via map (cached, efficient), and
  • do padding or random masking via a data collator (collate_fn).

Important pitfalls people hit (so you do not)

1) Performance: set_transform can be much slower than expected

There is a real report where a trivial set_transform makes iteration ~10× slower in a benchmark. (GitHub)
If performance matters, benchmark:

  • augmentation in with_transform
    vs
  • augmentation inside collate_fn
    vs
  • preprocessing via map plus lightweight on-the-fly augmentation

2) Output formatting surprises: set_transform replaces set_format

HF docs are explicit that the formatting transform “replaces the format defined by set_format()”. (Hugging Face)
So if you expected “torch formatting + transform”, you can get type surprises.

3) Multi-worker randomness and reproducibility

With num_workers > 0, PyTorch passes the dataset and collate_fn to worker processes; dataset access and transforms run in workers. (PyTorch Docs)
If your augmentation is random, you need to think about seeding per worker.

4) Transform signature mismatch

HF transform takes a batch dict and returns a batch dict. (Hugging Face)
If you write it like a pure “single example in → single example out” function, you often get subtle bugs.

A simple “decision checklist”

Use with_transform/set_transform when:

  • augmentation is independent per example
  • you want dataset[i] to already return augmented tensors/features
  • you want the augmentation to be part of the dataset view itself (Hugging Face)

Use collate_fn (or Transformers data collator) when:

  • you need padding/stacking
  • augmentation needs batch context
  • you want to operate on a list of samples right before the model step (PyTorch Docs)

Use both when:

  • per-item augmentation exists and you still need padding/stacking (common case).

Minimal example patterns

Pattern A: per-item augmentation in with_transform, simple collate

# deps: datasets, torch
from torch.utils.data import DataLoader

def aug(batch):
    # batch is a dict-of-columns. HF calls this a "batch (as a dict)". :contentReference[oaicite:18]{index=18}
    # Do per-example random augmentation here.
    return batch

train_ds = train_ds.with_transform(aug)

loader = DataLoader(train_ds, batch_size=32)  # default collate stacks if shapes match

Pattern B: per-item tokenize/augment, per-batch pad in collate (NLP)

# deps: transformers, torch
from torch.utils.data import DataLoader

def per_item(batch):
    # tokenize or per-item augmentation here
    return batch

train_ds = train_ds.with_transform(per_item)

def collate(list_of_examples):
    # list_of_examples is what PyTorch passes to collate_fn under automatic batching. :contentReference[oaicite:19]{index=19}
    # pad/stack here
    return batch

loader = DataLoader(train_ds, batch_size=32, collate_fn=collate)

Reading list (high-signal, primary sources)

  • :hugs: Datasets docs: set_transform / with_transform semantics, “batch dict in”, applied before returning from __getitem__, replaces set_format. (Hugging Face)
  • PyTorch docs: where collate_fn runs, list-of-samples input, and the “roughly equivalent” pseudocode. (PyTorch Docs)
  • Transformers docs: data collators form a batch from a list of dataset elements; can do padding and batch-level augmentation. (Hugging Face)
  • HF Datasets issue: “Super slow iteration with trivial custom transform” (performance pitfall). (GitHub)

Summary

  • set_transform/with_transform changes what dataset[i] returns. It is a formatting layer applied at __getitem__. (Hugging Face)
  • collate_fn changes how a list of samples becomes one batch. PyTorch calls it with a list under normal batching. (PyTorch Docs)
  • Best default: do per-item augmentation in with_transform, do padding/stacking and batch-dependent logic in collate_fn or a Transformers data collator. (Hugging Face)
1 Like

Hi John,
I implement the on-fly data augmentation by two means, but the result quite surprises me:

  1. implement by set_transform:
    import torch
    import numpy as np
    from datasets import load_from_disk
    
    class WavActTransform:
        def __init__(
            self,
            duration: int | None,
            sr: int,
            hop_length: int,
            randperm_mic: bool = True,
            randperm_spk: bool = True,
        ):
            self.duration = duration
            self.sr = sr
            self.hop_length = hop_length
            self.randperm_mic = randperm_mic
            self.randperm_spk = randperm_spk
            self.duration_frame = self.sr * self.duration // self.hop_length if self.duration is not None else None
        
        def __call__(self, batch):
            batch_wav = torch.tensor(batch["wav"])
            batch_act = torch.tensor(batch["act"])
            transformed_wav = []
            transformed_act = []
    
            for wav, act in zip(batch_wav, batch_act):
                if self.duration_frame is not None:
                    t_start_act = np.random.randint(0, act.shape[1] - self.duration_frame + 1)
                    t_end_act = t_start_act + self.duration_frame
                    act = act[:, t_start_act:t_end_act]
                    
                    t_start = self.hop_length * t_start_act
                    t_end = self.hop_length * t_end_act
                    wav = wav[:, t_start:t_end]
                
                if self.randperm_mic:
                    wav = wav[torch.randperm(wav.shape[0])]
                
                if self.randperm_spk:
                    act = act[torch.randperm(act.shape[0])]
    
                transformed_wav.append(wav)
                transformed_act.append(act)
    
            batch_wav = torch.stack(transformed_wav)
            batch_act = torch.stack(transformed_act)
    
            return {"wav": batch_wav, "act": batch_act}
    transform = WavActTransform(10, 16000, 160)
    dataset.set_transform(transform)
    
  2. implement by modifying the __getitem__ method of torch dataset, wrapping up the HF dataset
     class WavActDataset(torch.utils.data.Dataset):
         def __init__(
             self,
             dataset_path,
             duration: int | None,
             sr: int,
             hop_length: int,
             randperm_mic: bool = True,
             randperm_spk: bool = True,
             cached: bool = False,
             ):
     
             self._dataset = load_from_disk(dataset_path, cached)
             self._dataset.set_format('torch')
             self.duration = duration
             self.sr = sr
             self.hop_length = hop_length
     
             self.randperm_mic = randperm_mic
             self.randperm_spk = randperm_spk
             self.duration_frame = self.sr * self.duration // self.hop_length if self.duration is not None else None
         
         def __len__(self):
             return len(self._dataset)
     
         def __getitem__(self, index):
             item = self._dataset[index]
             wav: torch.Tensor = item['wav']
             act: torch.Tensor = item['act']
             if self.duration_frame is not None:
                 t_start_act = np.random.randint(0, act.shape[1] - self.duration_frame + 1)
                 t_end_act = t_start_act + self.duration_frame
                 act = act[:, t_start_act:t_end_act]
                 
                 t_start = self.hop_length * t_start_act
                 t_end = self.hop_length * t_end_act
                 wav = wav[:, t_start:t_end]
             
             if self.randperm_mic:
                 wav = wav[torch.randperm(wav.shape[0])]
             
             if self.randperm_spk:
                 act = act[torch.randperm(act.shape[0])]
     
             return {'wav': wav, 'act': act}
    
    

For both implementations, I timing the indexing operation dataset[0], the former takes 1.3s and the latter takes only 0.1s, and by changing the implementation, I accelerate my training 3x faster at least.

I just don’t see a fundamental difference between the two, maybe it’s because I spend some time on the for-loop and torch.stack in the first one?

1 Like

Yeah. The difference in processing speed between those two functions (strictly speaking, mainly the difference in “processing timing”) occasionally comes up in discussions.


The difference is real, and in your specific code it is expected.

Two big causes dominate:

  1. set_transform forces you onto the “formatting transform” path, which often hands you Python objects or NumPy arrays, not Torch tensors.
  2. Your set_transform implementation does extra full-tensor allocations and (likely) double-batching.

Below is the “why”, mapped to your exact code.


0) What set_transform is doing under the hood

HF documents set_transform/with_transform as a formatting transform that:

  • replaces set_format()
  • takes a batch as a dict and returns a batch
  • is applied “right before returning the objects in __getitem__
  • can be restricted to specific columns (columns=[...]) (Hugging Face)

So if you call dataset.set_transform(transform) you are telling HF: “ignore the normal formatting pipeline, I will produce the final returned objects myself.” (Hugging Face)

That detail matters because in your wrapper approach you do:

self._dataset.set_format("torch")

and then you operate on already-created torch tensors.

With set_transform, you did not set torch formatting, and even if you did, the transform is documented as replacing that formatting layer. (Hugging Face)


1) The #1 time sink: torch.tensor(...) always copies

In your set_transform version you do:

batch_wav = torch.tensor(batch["wav"])
batch_act = torch.tensor(batch["act"])

torch.tensor(x) always makes a copy. PyTorch explicitly states torch.tensor(tensor) is equivalent to tensor.detach().clone(), and points you to torch.as_tensor() (avoid copies where possible) and torch.from_numpy() (shares storage with NumPy) as alternatives. (PyTorch Documentation)

So every dataset[0] in the set_transform path likely does:

  • Arrow / Python / NumPy object materialization (HF side)
  • then a full copy into a new torch tensor (torch.tensor)
  • then later another full allocation (see next section)

In the wrapper path you do:

item = self._dataset[index]   # already torch because set_format('torch')
wav = item["wav"]
act = item["act"]

So you skip at least one huge copy per sample.

Why it can be much slower than you expect

If batch["wav"] is a Python nested list (possible depending on your dataset feature types), torch.tensor(list_of_lists) is particularly slow because it must walk Python objects and infer dtype/shape. If it is a NumPy array, it is still a copy (by definition). (PyTorch Documentation)


2) The #2 time sink: you allocate again with torch.stack

Inside your transform you do:

transformed_wav.append(wav)
...
batch_wav = torch.stack(transformed_wav)

torch.stack allocates a brand new tensor and copies data into it.

So in the set_transform path you likely allocate at least:

  • Copy 1: torch.tensor(batch["wav"])
  • Copy 2: torch.stack(transformed_wav)

In the wrapper path, you allocate less:

  • You slice (often views) and permute (advanced indexing makes a copy, but you do that in both implementations)
  • Then the DataLoader collates and stacks once

So the wrapper path tends to have:

  • only the “necessary” copies (permutation + final collation)
    while the set_transform path can add extra copies on top.

3) You may be “batching twice” with set_transform + DataLoader

This is the most common hidden footgun with set_transform.

Your set_transform returns batched tensors:

  • dataset[0] returns {"wav": (1, …), "act": (1, …)} because you stack, even for one item.

Then PyTorch DataLoader with automatic batching does:

fetch a list of samples, then call collate_fn([dataset[i] for i in indices]) (PyTorch Documentation)

So if each dataset[i] already contains a leading batch dimension of 1, the default collate will typically produce something shaped like (batch_size, 1, …) and will stack/copy again.

Even if you later squeeze that dimension, you already paid the cost.

In contrast, your wrapper returns per-sample wav: (mic, time) and act: (spk, frames), and the DataLoader stacks once into (batch_size, mic, time).

This single structural mismatch can easily turn into a multi-x slowdown at training time.


4) HF set_transform has known perf issues even for trivial transforms

Independent of your code, there are real reports that set_transform can drastically slow iteration even when the transform itself is trivial. (GitHub)

So you have:

  • genuine framework overhead (HF formatting transform path) (GitHub)
  • plus your own extra allocations (torch.tensor, torch.stack)
  • plus possibly double batching (extra stacking in DataLoader) (PyTorch Documentation)

A 10× delta on dataset[0] and 3× on full training is plausible under that combination.


5) Why your wrapper is “fundamentally faster”

Because it lines up with the normal PyTorch data pipeline:

  • __getitem__ returns one sample
  • DataLoader batches once (collation) (PyTorch Documentation)
  • You avoid unnecessary tensor construction and stacking

Also, set_format("torch") keeps the conversion logic inside HF’s optimized formatting code rather than redoing it manually via torch.tensor on every call. (And set_transform is explicitly a replacement for that formatting layer.) (Hugging Face)


6) How to confirm exactly which factor dominates (quick checks)

Do these prints/timings inside your WavActTransform.__call__:

  1. Check input types
print(type(batch["wav"]), type(batch["act"]))

If you see Python lists, that explains a lot.

  1. Time just the conversion
    Time these separately:
  • torch.tensor(batch["wav"])
  • torch.as_tensor(batch["wav"]) (if it is NumPy)
    PyTorch explicitly says as_tensor tries to avoid copies and from_numpy shares storage. (PyTorch Documentation)
  1. Check output shapes
    Print batch_wav.shape returned by transform and the shape your training step receives from DataLoader. If you see (B, 1, …) anywhere, you are double batching.

7) What to do going forward (practical recommendations)

Best “clean + fast” option

Keep your wrapper dataset (your current fast solution). It is idiomatic PyTorch and avoids the set_transform formatting overhead.

If you want to stay purely in HF Dataset without a wrapper

Do augmentation in collate_fn instead:

  • Keep dataset.set_format("torch")
  • DataLoader gives you a list of torch tensors
  • You augment each item and then stack once

This matches PyTorch’s model: DataLoader fetches samples then collates to a batch. (PyTorch Documentation)

If you insist on set_transform

Then rewrite the transform to avoid the biggest traps:

  • Avoid torch.tensor(...) on every call. It always copies. Prefer torch.as_tensor / torch.from_numpy when possible. (PyTorch Documentation)
  • Avoid producing an extra batch dimension (avoid torch.stack for the single-item __getitem__ case), or disable DataLoader automatic batching and let the dataset return batches (more complex).
  • Use columns=["wav","act"] in set_transform to minimize what HF passes into your transform. (Hugging Face)

Even then, be aware of the known set_transform slowdown reports. (GitHub)


Good references (directly relevant)

  • HF docs: set_transform replaces set_format, batch-dict signature, applied in __getitem__, columns option (Hugging Face)
  • PyTorch docs: DataLoader batches by fetching samples then calling collate_fn([dataset[i]...]) (PyTorch Documentation)
  • PyTorch docs: torch.tensor(t) is clone-like (copies); as_tensor avoids copies where possible; from_numpy shares storage (PyTorch Documentation)
  • HF issue: trivial set_transform causing ~10× slower iteration (GitHub)

Summary

  • set_transform is a formatting hook that replaces set_format, so you often lose the fast “return torch tensors” path. (Hugging Face)
  • Your set_transform code does extra full copies: torch.tensor (always copies) plus torch.stack (allocates again). (PyTorch Documentation)
  • You may also be batching twice when DataLoader stacks samples that already have a leading batch dimension. (PyTorch Documentation)
  • HF has known reports of set_transform iteration slowdowns even for trivial transforms. (GitHub)
1 Like

Thank you John!
Now I stick to the torch.Dataset wrapper implementation. And I use a lightning DataModule to manage dataloader in DDP training. Here is the code(where the WavActDataset is as above)

from typing import Any
from pathlib import Path
from torch.utils.data import DataLoader
import lightning as lt
from datasets import Dataset as HFDataset, load_from_disk
from neural_fcasa.datasets.hf_wavact_dataset import WavActDataset, WavActTransform


class DataModule(lt.LightningDataModule):
    def __init__(
        self,
        train_dataset_path: str | Path,
        val_dataset_path: str | Path,
        batch_size: int,
        duration: int | None,
        sr: int,
        hop_length: int,
        randperm_mic: bool = True,
        randperm_spk: bool = True,
        num_workers: int = 10,
        cached: bool | None = None,
    ):
        super().__init__()

        self.train_dataset_path = train_dataset_path
        self.val_dataset_path = val_dataset_path
        self.transform = WavActTransform(
            duration=duration,
            sr=sr,
            hop_length=hop_length,
            randperm_mic=randperm_mic,
            randperm_spk=randperm_spk
        )

        self.dataset_kwargs: dict[str, Any] = dict(
            duration=duration,
            sr=sr,
            hop_length=hop_length,
            randperm_mic=randperm_mic,
            randperm_spk=randperm_spk,
            cached=cached,
        )

        self.dataloader_configs: dict[str, Any] = dict(
            batch_size=batch_size,
            num_workers=num_workers,
            persistent_workers=True,
            shuffle=False, # let lightning handle the distributed sampler, see https://pytorch-lightning.readthedocs.io/en/0.9.0/multi_gpu.html
        )

    def setup(self, stage: str | None):
        if stage == "fit":
            self.train_dataset = WavActDataset(self.train_dataset_path, **self.dataset_kwargs)
            self.val_dataset = WavActDataset(self.val_dataset_path, **self.dataset_kwargs)

            print(f"Dataset size: {len(self.train_dataset)=},  {len(self.val_dataset)=}")
        else:
            raise ValueError("`stage` is not 'fit'.")

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            drop_last=True,
            **self.dataloader_configs,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            drop_last=False,
            **self.dataloader_configs,
        )

When I set cached=False(Hence keep_in_memory=False for hf dataset load_from_disk, the training is good, no memory issue. But when I set cached=True, trying to accelerate the training, there is OOM issue.
In my case, I have train set of 272GB, val set 33GB and I do my DDP training on 4 gpus. From what you taught me above, I know that this DDP training would launch 4 processes, each uses its independent memory, hence with cached=True, there will be 4 copies of the dataset and hence it requires at least 4*(272+33)=1220GB, which exceeds my 500GB quota. So I wonder if I can use some shared memory method to avoid the redundant data-caching, or is there a compromise that doesn’t need to cache the whole dataset but uses as much as the quota allows to accelerate the training?

1 Like

Seems some workarounds?


You are seeing OOM because cached=True in your code is effectively keep_in_memory=True for datasets.load_from_disk, and that means “copy the dataset into RAM.” HF documents keep_in_memory exactly that way. (Hugging Face)

In DDP, each rank is a separate process. Lightning’s DDP strategy is multi-process and (by default) starts processes via start_method='popen', meaning separate interpreter processes with no shared Python heap. (Lightning AI)
So you do not get one shared in-RAM dataset. You get N independent in-RAM copies.

On top of that, your PyTorch DataLoader(num_workers=10) creates worker processes and “the dataset, collate_fn, worker_init_fn are passed to each worker,” and each worker process has its own dataset object copy. (PyTorch Documentation)
This matters a lot if the dataset is truly in RAM (keep_in_memory). With memory-mapped datasets it matters much less.


1) Your memory math is actually worse than 4×

You already computed:

  • train + val = 272 GB + 33 GB = 305 GB
  • world_size = 4 ranks
  • naive duplication across ranks: 4 × 305 = 1220 GB

That alone exceeds 500 GB.

But you also have DataLoader workers per rank:

So the “worst-case” number of dataset replicas is closer to:

  • world_size × (num_workers + 1) = 4 × 11 = 44 processes that may each hold large allocations (depending on start method and what gets materialized).

Even if copy-on-write reduces some duplication for workers on Linux, DDP ranks are still separate processes started via popen in Lightning by default, so rank-to-rank sharing does not happen. (Lightning AI)


2) The key conceptual distinction: “in-memory copy” vs “memory-mapped + OS page cache”

HF Datasets is built on Arrow and is designed to be memory-mapped from disk for large datasets. HF explicitly says datasets can be “backed by an on-disk cache, which is memory-mapped,” and that this is why large datasets can be used without loading them into RAM. (Hugging Face)

HF even states that loading Wikipedia only uses a few MB of RAM because “the Arrow data is actually memory-mapped from disk, and not loaded in memory,” and memory-mapping leverages virtual memory for fast lookup. (Hugging Face)

That architecture gives you a very important property:

OS page cache is shared across processes

If all ranks read the same underlying Arrow files, the kernel page cache stores recently used file pages once and shares them across all processes. That is the “shared memory” you want, and it is already the default path when keep_in_memory=False. (Hugging Face)

So the practical lesson is:

  • keep_in_memory=True → copies into each process heap. Not shared across DDP ranks. (Hugging Face)
  • keep_in_memory=False → memory-mapped Arrow files; the OS can cache pages globally and share them. (Hugging Face)

This is exactly why your run is stable with cached=False.


3) Answer to your two questions

Q1: “Can I use some shared memory method to avoid redundant data-caching?”

Not in the way keep_in_memory=True works.

  • Lightning DDP uses multiple independent processes (default start_method='popen'). (Lightning AI)
  • Python objects and allocations created by keep_in_memory=True live in each process heap, so they are not shared.

You can force “sharing” only by not copying into per-process RAM in the first place and instead using memory mapping, which is what HF recommends for large datasets. (Hugging Face)

Q2: “Is there a compromise that uses as much quota as allowed to accelerate training?”

Yes. Use the OS page cache as your “partial cache.”

If your dataset files are memory-mapped (keep_in_memory=False), the OS will naturally keep hot file pages in RAM, up to your available memory, and that cache is shared across all ranks and workers. (Hugging Face)

In your case, 305 GB total dataset size < 500 GB RAM quota, so in principle the OS page cache can hold almost the entire dataset after it has been read once. That gets you close to “in-memory speed” without 4× duplication.


4) What I would do (ranked by leverage)

Option A (recommended): keep keep_in_memory=False, rely on memory-mapping + page cache

This is the path HF designed for “big data.” (Hugging Face)

How to make it fast in practice:

  1. Put the dataset Arrow files on local NVMe if possible (network FS can bottleneck).
  2. Run normally. After part of the first epoch, the OS cache warms up. After 1 full epoch, it is often dramatically faster if RAM can hold most of the dataset.

You already observed training is stable in this mode, so you are most of the way there.

“Warm the cache” intentionally (optional)

If the first epoch speed matters, do a one-time sequential scan before training to populate the OS page cache. Example idea:

  • rank 0 only: iterate through indices sequentially with num_workers=0 and a big batch size, just to read pages.
  • then start training.

This is not a HF feature. It is just using the OS cache.


Option B: pre-shard the dataset per rank and only keep that shard in memory

If you insist on true in-RAM storage, the only sane way is: each rank loads only 1/4 of train.

  • shard0 size ≈ 272/4 ≈ 68 GB
  • total across 4 ranks ≈ 272 GB in RAM (plus whatever you do for val)

But that requires physical sharding (separate datasets on disk per shard) or a load path that does not materialize the full dataset before sharding.

If you load the full dataset and then call .shard, you already paid the RAM cost if keep_in_memory=True.

So this approach is viable only if you restructure storage:

  • create 4 HF datasets on disk (save_to_disk) each containing disjoint subsets
  • each rank reads only its own path

This also reduces contention because each rank reads mostly different files.


Option C: if you keep keep_in_memory=True, you must reduce process multiplication

If you really flip keep_in_memory=True, then you cannot also keep a large num_workers safely.

PyTorch is explicit that worker processes are created, and dataset access and transforms run in workers, and each worker has its own dataset object. (PyTorch Documentation)

So with keep_in_memory=True, you typically set:

  • num_workers=0 (or very small)
  • maybe persistent_workers=False
  • and accept that augmentation happens in the main rank process

This often loses the speed you were trying to gain, so it is usually not worth it.


5) Small but important note: your naming “cached” conflates two different “caches”

HF has:

  • an on-disk Arrow cache (normal, memory-mapped) (Hugging Face)
  • an explicit “copy in memory” mode (keep_in_memory=True) (Hugging Face)

What you call cached=True is really “copy to RAM,” not “use HF cache.” That naming makes it harder to reason about.


6) High-signal references


Summary

  • keep_in_memory=True copies the dataset into each process memory. HF says it “copies the dataset in-memory.” (Hugging Face)
  • Lightning DDP launches separate processes (default start_method='popen'), so those copies are not shared across ranks. (Lightning AI)
  • DataLoader workers also replicate the dataset object per worker process. (PyTorch Documentation)
  • The “shared memory” solution is to keep keep_in_memory=False and rely on Arrow memory-mapping plus the OS page cache, which is shared across processes and can use your RAM quota automatically. (Hugging Face)
1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.