Source code for nemo_automodel.components.datasets.reservoir_sampler
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import random
from typing import Any, Dict, Iterable, Iterator, Optional
[docs]
class ReservoirSampler:
"""Streaming shuffle with a fixed-size buffer.
This is a bounded-memory shuffling wrapper for streaming datasets/iterables.
It maintains a buffer of ``buffer_size`` items. Once the buffer is filled,
it repeatedly:
- samples a random buffer slot
- yields the evicted item
- replaces it with the next item from the underlying iterator
When the underlying iterator is exhausted, the remaining buffer items are
yielded.
"""
def __init__(self, iterator: Iterable[Dict[str, Any]], buffer_size: int, seed: Optional[int] = None):
"""
Reservoir sampler is a sampler that samples items from an iterator using a buffer.
It is used to sample items from an iterator in a way that is memory efficient.
Args:
iterator: Iterator to sample from.
buffer_size: Size of the buffer.
seed: Seed for the random number generator.
"""
if iterator is None:
raise ValueError("iterator must be provided")
if buffer_size <= 0:
raise ValueError(f"buffer_size must be > 0, got {buffer_size}")
self._buffer_size = int(buffer_size)
self._seed = seed
self._iterable = iterator
[docs]
def __iter__(self) -> Iterator[Dict[str, Any]]:
"""
Iterate over the iterator and sample items from the buffer.
"""
rng = random.Random(self._seed)
it = iter(self._iterable)
buffer: list[Optional[Dict[str, Any]]] = []
for item in it:
buffer.append(item)
if len(buffer) == self._buffer_size:
break
if not buffer:
return
rng.shuffle(buffer)
while True:
new_pos = rng.randrange(len(buffer))
evicted_item = buffer[new_pos]
try:
buffer[new_pos] = next(it)
except StopIteration:
yield evicted_item
buffer[new_pos] = None
break
else:
yield evicted_item
# handle tail
yield from filter(lambda x: x is not None, buffer)
[docs]
def __len__(self) -> int:
"""
No len methods is supported with ReservoirSampler.
"""
raise RuntimeError("__len__ is not supported with ReservoirSampler.")
[docs]
def __getitem__(self, idx: int) -> Dict[str, Any]:
"""
No getitem method is supported with ReservoirSampler.
"""
raise RuntimeError("__getitem__ is not supported with ReservoirSampler.")