> For clean Markdown of any page, append .md to the page URL.
> For a complete documentation index, see https://docs.nvidia.com/nemo/curator/llms.txt.
> For full documentation content, see https://docs.nvidia.com/nemo/curator/llms-full.txt.

> Tokenize and export text datasets to Megatron-LM binary format for large-scale language model pretraining

# Save and Export Text Data

After processing your text datasets with NeMo Curator, use writer stages to export curated data for downstream use. Curator provides writers for common formats (JSONL, Parquet) as well as specialized writers for training frameworks.

## Megatron Tokenization

`MegatronTokenizerWriter` tokenizes text documents and writes the `.bin` and `.idx` files required by [Megatron-LM](https://github.com/NVIDIA/Megatron-LM) for data loading during pretraining. This replaces the need to run Megatron's `preprocess_data.py` script separately and integrates tokenization directly into your curation pipeline.

### How It Works

1. **Tokenizer loading**: Downloads and loads a Hugging Face tokenizer specified by `model_identifier`. The tokenizer is downloaded once per node and loaded once per worker.
2. **Batched tokenization**: Documents are tokenized in batches (controlled by `tokenization_batch_size`) to avoid out-of-memory issues on large datasets.
3. **Binary output**: Tokenized data is written to a `.bin` file containing packed token IDs. Vocabulary sizes above 65,536 use 4 bytes per token (`int32`); smaller vocabularies use 2 bytes (`uint16`).
4. **Index output**: A `.idx` file stores metadata including sequence lengths, byte offsets, and document boundaries for efficient random access during training.

### Quick Start

```python
from nemo_curator.core.client import RayClient
from nemo_curator.pipeline import Pipeline
from nemo_curator.stages.text.io.reader import JsonlReader
from nemo_curator.stages.text.io.writer.megatron_tokenizer import MegatronTokenizerWriter

# Initialize Ray client
ray_client = RayClient()
ray_client.start()

# Define pipeline stages
stages = [
    JsonlReader(
        file_paths="/path/to/data",
        fields=["text"],
    ),
    MegatronTokenizerWriter(
        path="/path/to/output",
        model_identifier="nvidia/NVIDIA-Nemotron-Nano-12B-v2",
        append_eod=True,
    ),
]

# Create and run the pipeline
pipeline = Pipeline(
    name="megatron-tokenize",
    description="Tokenize dataset for Megatron-LM.",
    stages=stages,
)

results = pipeline.run()

ray_client.stop()
```

### Configuration

| Parameter                 | Type        | Default  | Description                                                      |
| ------------------------- | ----------- | -------- | ---------------------------------------------------------------- |
| `path`                    | str         | Required | Output directory for `.bin` and `.idx` files                     |
| `model_identifier`        | str         | Required | Hugging Face model identifier or local path for the tokenizer    |
| `text_field`              | str         | `"text"` | Name of the column containing text to tokenize                   |
| `append_eod`              | bool        | `False`  | Append the tokenizer's EOS token at the end of each document     |
| `tokenization_batch_size` | int         | `1000`   | Number of documents to tokenize per batch before writing to disk |
| `cache_dir`               | str \| None | `None`   | Local cache directory for the downloaded tokenizer               |
| `hf_token`                | str \| None | `None`   | Hugging Face API token for accessing gated models                |

### Output Format

The writer produces paired files for each input partition:

```bash
output/
├── {hash_1}.bin    # Packed token IDs (binary)
├── {hash_1}.idx    # Sequence metadata (lengths, offsets, document boundaries)
├── {hash_2}.bin
├── {hash_2}.idx
```

<Accordion title="File format details">
  **`.bin` file**: Contains concatenated token IDs for all documents in the partition. Token IDs are stored as `int32` (4 bytes) when the tokenizer vocabulary exceeds 65,536 tokens, or `uint16` (2 bytes) for smaller vocabularies such as GPT-2.

  **`.idx` file**: Contains a fixed header followed by per-sequence metadata:

  * 9-byte magic header (`MMIDIDX\x00\x00`)
  * 8-byte version number
  * 1-byte dtype code
  * 8-byte sequence count
  * 8-byte document count
  * Per-sequence lengths: 4-byte `int32` array (one entry per sequence)
  * Per-sequence byte offsets: 8-byte `int64` array (one entry per sequence)
  * Document boundary indices: 8-byte `int64` array (sequence count + 1 entries)

  These files are directly compatible with Megatron-LM's `MMapIndexedDataset` data loader.
</Accordion>

### End-of-Document Tokens

When `append_eod=True`, the tokenizer's EOS token is appended to the end of each document's token sequence. This is consistent with the behavior of Megatron's `preprocess_data.py` and is required for some training configurations that use document boundaries for attention masking.

If the tokenizer does not define an EOS token, `append_eod` is automatically disabled with a warning.

### Using Different Tokenizers

`MegatronTokenizerWriter` supports any tokenizer available through Hugging Face's `AutoTokenizer`:

<Tabs>
  <Tab title="Standard Model">
    ```python
    MegatronTokenizerWriter(
        path="output/",
        model_identifier="nvidia/NVIDIA-Nemotron-Nano-12B-v2",
        append_eod=True,
    )
    ```
  </Tab>

  <Tab title="Gated Model">
    ```python
    MegatronTokenizerWriter(
        path="output/",
        model_identifier="meta-llama/Llama-3.1-8B",
        hf_token="hf_...",
        append_eod=True,
    )
    ```
  </Tab>

  <Tab title="Local Tokenizer">
    ```python
    MegatronTokenizerWriter(
        path="output/",
        model_identifier="/local/path/to/tokenizer",
        append_eod=True,
    )
    ```
  </Tab>
</Tabs>

### Complete Pipeline Example

This example reads the TinyStories dataset from Parquet files and tokenizes it for Megatron-LM:

```python
from nemo_curator.core.client import RayClient
from nemo_curator.pipeline import Pipeline
from nemo_curator.stages.text.io.reader import ParquetReader
from nemo_curator.stages.text.io.writer.megatron_tokenizer import MegatronTokenizerWriter

ray_client = RayClient()
ray_client.start()

stages = [
    ParquetReader(
        file_paths="datasets/tinystories/",
    ),
    MegatronTokenizerWriter(
        path="datasets/tinystories-tokens/",
        model_identifier="nvidia/NVIDIA-Nemotron-Nano-12B-v2",
        append_eod=True,
        tokenization_batch_size=2000,
    ),
]

pipeline = Pipeline(
    name="megatron-tokenize",
    description="Tokenize TinyStories for Megatron-LM.",
    stages=stages,
)

results = pipeline.run()

ray_client.stop()
```

A runnable version of this example is available in the [tutorials directory](https://github.com/NVIDIA-NeMo/Curator/blob/main/tutorials/text/megatron-tokenizer/main.py).

***

For more information on using tokenized data with Megatron-LM, see the [Related Tools](/reference/related-tools) page.