Adding the OAS Dataset: Customizing Dataset Object and Dataloader Functions#

This tutorial is the second part of a series focused on adding a new dataset to BioNeMo using the Observed Antibody Space (OAS) database. There are three steps to this task:

  1. Preprocessing includes download of the raw data and any additional preparation steps, such as extracting the files. It also includes dividing the data into train, validation, and test splits. The preprocessing step can make use of two BioNeMo base classes, RemoteResource and ResourcePreprocessor, from bionemo.utils.remote and bionemo.data.preprocess, respectively. Their use is optional but they provide some basic functionality which can accelerate development. This step is covered by this tutorial. This objective was accomplished by the first tutorial, Downloading and Preprocessing.

  2. Development of the new dataset class. Here, the NeMo dataset class CSVMemMapDataset will be used. This step was covered in the last tutorial, Modifying the Dataset Class.

  3. Modification of the dataloader classes. This tutorial will cover customizing DataLoader objects using the newly created OAS datasets. This will include specifics on instantiating actual Dataset classes, customizing the collate function, and instantiating a dataloader. We will also review how these steps are executed within the BioNeMo model classes.

Setup and Assumptions#

This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at /workspace/bionemo as described in the Code Development section of the Quickstart Guide. This path will be referred to with the variable BIONEMO_WORKSPACE in the tutorial.

All commands should be executed inside the BioNeMo docker container.

BIONEMO_WORKSPACE = '/workspace/bionemo'
TUTORIAL_FILE_VERSION = 'step_999_final'
stage_files(TUTORIAL_FILE_VERSION, source_directory=f'{BIONEMO_WORKSPACE}/examples/oas_dataset')

Customizing a collate function#

In the last tutorial we saw how you can modify your yaml file to use a different set of data with existing tooling, in some cases, this isn’t enough. The collate_fn parameter of pytorch DataLoaders if used for last minute adjustments to batches, including masking, shuffling, batching, padding, and other slight modifications to the input data. In BioNeMo, we build our collate function ontop of collators used for language modeling (bionemo/data/dataloader/collate.py).

The collate function is ultimately injected into the dataloader upon construction. To customize further, we can simply extend the existing ProteinCollate class with our own additional collation, followed by a call to the parents method.

# Copyright (c) 2022, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch

from typing import List
from nemo.collections.common.tokenizers import TokenizerSpec
from bionemo.data.dataloader.collate import (
    BertCollate,
    BertMasking,
    SentencePieceTokenizerAdapter,
)
from bionemo.data.dataloader.protein_collate import ProteinBertCollate


__all__ = ['CustomProteinBertCollate']


class CustomProteinBertCollate(ProteinBertCollate):
   def collate_fn(self, batch: List[str], label_pad: int = -1):
        '''
        Parent does things like add padding, mask, onehot transformations, etc.
        in this case, we do the same thing but we sort our batch on values (strings)
        does this do anything useful in pratice? maybe not, but thats okay.

        This method ultimately will get injected into a DataLoader. We do this as a 
        part of the standard dataloader setup method inside our ESM1nv model, by instantiating
        this class and then injecting the collate_fn.
        '''
        extra = [] # Handles odd cases
        if len(batch) % 2 == 1:
            batch, extra = batch[:-1]

        back, front = batch[:len(batch)], batch[len(batch):]
        batch = back + front + extra

        new_batch = [ 'A' * len(seq) for seq in batch ]

        return super().collate_fn(new_batch, label_pad)

Injecting a custom collate object into an existing model.#

The implemented collate function servers a single purpose, it replaces all characters with the character ‘A.’ This is both easy to implement and simple to check for correctness. Upon doing so, the batch is passed back into the parent collate function for padding and masking. Next, we will inject this into our esm1nv model to be applied to the dataset. You can see below that this occurs on the build_pretraining_data_loader method, which primarily operates on an already existing Dataset object.

# Copyright (c) 2022, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from typing import Dict, Optional
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer

from nemo.core.neural_types import NeuralType
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
from nemo.collections.nlp.modules.common.megatron.utils import (
    average_losses_across_data_parallel_group,
)
from nemo.utils import logging

from bionemo.model.protein.esm1nv.esm1nv_model import ESM1nvModel
from bionemo.data.molecule import megamolbart_build_train_valid_test_datasets
from bionemo.data.dataloader.custom_protein_collate import CustomProteinBertCollate
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer

try:
    from apex.transformer import tensor_parallel


    HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
    HAVE_APEX = False


__all__ = ["CustomESM1nvModel"]

class CustomESM1nvModel(ESM1nvModel):
    """ CustomESM1nv model that extends the dataloader function to use our custom collate function.
    Checkout the base class for more information on how it all fits together.
    """

    def __init__(self, cfg: DictConfig, trainer: Trainer):
        super().__init__(cfg, trainer=trainer)

    def build_pretraining_data_loader(self, dataset, consumed_samples):
        """Buld dataloader given an input dataset."""

        assert self._cfg.data.dataloader_type == 'single', AssertionError(
            f'Only the Megatron sequential ("single") sampler is currently supported. {self._cfg.data.dataloader_type} was chosen.'
            )

        dataloader = super().build_pretraining_data_loader(dataset=dataset, consumed_samples=consumed_samples)

        # Add collate function and unpin memory to avoid crash with CUDA misaligned address
        dataloader.pin_memory = False # must be False with CSV dataset TODO check with binary
        pad_size_divisible_by_8 = True if self._cfg.masked_softmax_fusion else False

        dataloader.collate_fn = CustomProteinBertCollate(tokenizer=self.tokenizer,
                                                    seq_length=self._cfg.seq_length,
                                                    pad_size_divisible_by_8=pad_size_divisible_by_8,
                                                    modify_percent=self._cfg.data.modify_percent,
                                                    perturb_percent=self._cfg.data.perturb_percent,
                                                    ).collate_fn

        return dataloader
std_out = ! cd {BIONEMO_WORKSPACE}/examples/protein/esm1nv && python pretrain_oas.py ++trainer.max_steps=101
print('\n'.join(std_out))
[NeMo W 2023-08-25 18:46:43 experimental:27] Module <class 'nemo.collections.nlp.models.text_normalization_as_tagging.thutmose_tagger.ThutmoseTaggerModel'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:46:44 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'pretrain_oas': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information
      warnings.warn(msg, UserWarning)
    
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/hydra/_internal/hydra.py:119: UserWarning: Future Hydra versions will no longer change working directory at job runtime by default.
    See https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/ for more information.
      ret = run_job(
    
[NeMo I 2023-08-25 18:46:44 pretrain_oas:12] 
    
    ************** Experiment configuration ***********
[NeMo I 2023-08-25 18:46:44 pretrain_oas:13] 
    name: esm1nv-oas
    do_training: true
    do_testing: false
    restore_from_path: null
    trainer:
      devices: 1
      num_nodes: 1
      accelerator: gpu
      precision: 16-mixed
      logger: false
      enable_checkpointing: false
      use_distributed_sampler: false
      max_epochs: null
      max_steps: 101
      log_every_n_steps: 10
      val_check_interval: 100
      limit_val_batches: 10
      limit_test_batches: 500
      accumulate_grad_batches: 1
      gradient_clip_val: 1.0
      benchmark: false
    exp_manager:
      name: ${name}
      exp_dir: /result/nemo_experiments/${.name}/${.wandb_logger_kwargs.name}
      explicit_log_dir: ${.exp_dir}
      create_wandb_logger: false
      create_tensorboard_logger: true
      wandb_logger_kwargs:
        project: ${name}_pretraining
        name: ${name}_pretraining
        group: ${name}
        job_type: Localhost_nodes_${trainer.num_nodes}_gpus_${trainer.devices}
        notes: 'date: ${now:%y%m%d-%H%M%S}'
        tags:
        - ${name}
        offline: false
      resume_if_exists: false
      resume_ignore_no_checkpoint: true
      create_checkpoint_callback: true
      checkpoint_callback_params:
        monitor: val_loss
        save_top_k: 10
        mode: min
        always_save_nemo: false
        filename: megatron_bert--{val_loss:.2f}-{step}-{consumed_samples}
        model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
    model:
      micro_batch_size: 8
      tensor_model_parallel_size: 1
      pipeline_model_parallel_size: 1
      seq_length: 512
      max_position_embeddings: ${.seq_length}
      encoder_seq_length: ${.seq_length}
      num_layers: 6
      hidden_size: 768
      ffn_hidden_size: 3072
      num_attention_heads: 12
      init_method_std: 0.02
      hidden_dropout: 0.1
      kv_channels: null
      apply_query_key_layer_scaling: true
      layernorm_epsilon: 1.0e-05
      make_vocab_size_divisible_by: 128
      pre_process: true
      post_process: true
      bert_binary_head: false
      resume_from_checkpoint: null
      masked_softmax_fusion: true
      tokenizer:
        library: sentencepiece
        type: null
        model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
        vocab_file: /tokenizers/vocab/protein_sequence_sentencepiece.vocab
        merge_file: null
      native_amp_init_scale: 4294967296
      native_amp_growth_interval: 1000
      fp32_residual_connection: false
      fp16_lm_cross_entropy: false
      seed: 1234
      use_cpu_initialization: false
      onnx_safe: false
      activations_checkpoint_method: null
      activations_checkpoint_num_layers: 1
      data:
        ngc_registry_target: uniref50_2022_05
        ngc_registry_version: v23.06
        data_prefix: ''
        num_workers: 10
        dataloader_type: single
        reset_position_ids: false
        reset_attention_mask: false
        eod_mask_loss: false
        masked_lm_prob: 0.15
        short_seq_prob: 0.1
        skip_lines: 0
        drop_last: false
        pin_memory: false
        data_impl: csv_mmap
        data_impl_kwargs:
          csv_mmap:
            header_lines: 1
            newline_int: 10
            workers: ${model.data.num_workers}
            sort_dataset_paths: true
            data_sep: ','
            data_col: 1
        use_upsampling: true
        seed: ${model.seed}
        max_seq_length: ${model.seq_length}
        dataset_path: /data/OASpaired/processed/heavy
        dataset:
          train: x[000..005]
          test: x[000..001]
          val: x[000..001]
        micro_batch_size: ${model.micro_batch_size}
        modify_percent: 0.1
        perturb_percent: 0.5
      optim:
        name: fused_adam
        lr: 0.0002
        weight_decay: 0.01
        betas:
        - 0.9
        - 0.98
        sched:
          name: CosineAnnealing
          warmup_steps: 500
          constant_steps: 50000
          min_lr: 2.0e-05
    do_dataloader: false
    
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/plugins/precision/native_amp.py:131: LightningDeprecationWarning: The `NativeMixedPrecisionPlugin` class has been renamed in v1.9.0 and will be removed in v2.0.0. Please use `pytorch_lightning.plugins.MixedPrecisionPlugin` instead.
      rank_zero_deprecation(
    
[NeMo I 2023-08-25 18:46:44 utils:168] Selected Callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[NeMo E 2023-08-25 18:46:44 exp_manager:646] exp_manager received explicit_log_dir: /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining and at least one of exp_dir: /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining, or version: None. Please note that exp_dir, name, and version will be ignored.
[NeMo W 2023-08-25 18:46:44 exp_manager:651] Exp_manager is logging to /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining, but it already exists.
[NeMo I 2023-08-25 18:46:44 exp_manager:374] Experiments will be logged at /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining
[NeMo I 2023-08-25 18:46:44 exp_manager:797] TensorboardLogger has been set up
[NeMo W 2023-08-25 18:46:44 exp_manager:893] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to 101. Please ensure that max_steps will run for at least 1 epochs to ensure that checkpointing will not error out.
[NeMo I 2023-08-25 18:46:44 utils:191] Resuming training from checkpoint: None
[NeMo I 2023-08-25 18:46:44 utils:234] 
    
    ************** Trainer configuration ***********
[NeMo I 2023-08-25 18:46:44 utils:235] 
    name: esm1nv-oas
    do_training: true
    do_testing: false
    restore_from_path: null
    trainer:
      devices: 1
      num_nodes: 1
      accelerator: gpu
      precision: 16-mixed
      logger: false
      enable_checkpointing: false
      use_distributed_sampler: false
      max_epochs: null
      max_steps: 101
      log_every_n_steps: 10
      val_check_interval: 100
      limit_val_batches: 10
      limit_test_batches: 500
      accumulate_grad_batches: 1
      gradient_clip_val: 1.0
      benchmark: false
    exp_manager:
      name: ${name}
      exp_dir: /result/nemo_experiments/${.name}/${.wandb_logger_kwargs.name}
      explicit_log_dir: ${.exp_dir}
      create_wandb_logger: false
      create_tensorboard_logger: true
      wandb_logger_kwargs:
        project: ${name}_pretraining
        name: ${name}_pretraining
        group: ${name}
        job_type: Localhost_nodes_${trainer.num_nodes}_gpus_${trainer.devices}
        notes: 'date: ${now:%y%m%d-%H%M%S}'
        tags:
        - ${name}
        offline: false
      resume_if_exists: false
      resume_ignore_no_checkpoint: true
      create_checkpoint_callback: true
      checkpoint_callback_params:
        monitor: val_loss
        save_top_k: 10
        mode: min
        always_save_nemo: false
        filename: megatron_bert--{val_loss:.2f}-{step}-{consumed_samples}
        model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
    model:
      micro_batch_size: 8
      tensor_model_parallel_size: 1
      pipeline_model_parallel_size: 1
      seq_length: 512
      max_position_embeddings: ${.seq_length}
      encoder_seq_length: ${.seq_length}
      num_layers: 6
      hidden_size: 768
      ffn_hidden_size: 3072
      num_attention_heads: 12
      init_method_std: 0.02
      hidden_dropout: 0.1
      kv_channels: null
      apply_query_key_layer_scaling: true
      layernorm_epsilon: 1.0e-05
      make_vocab_size_divisible_by: 128
      pre_process: true
      post_process: true
      bert_binary_head: false
      resume_from_checkpoint: null
      masked_softmax_fusion: true
      tokenizer:
        library: sentencepiece
        type: null
        model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
        vocab_file: /tokenizers/vocab/protein_sequence_sentencepiece.vocab
        merge_file: null
      native_amp_init_scale: 4294967296
      native_amp_growth_interval: 1000
      fp32_residual_connection: false
      fp16_lm_cross_entropy: false
      seed: 1234
      use_cpu_initialization: false
      onnx_safe: false
      activations_checkpoint_method: null
      activations_checkpoint_num_layers: 1
      data:
        ngc_registry_target: uniref50_2022_05
        ngc_registry_version: v23.06
        data_prefix: ''
        num_workers: 10
        dataloader_type: single
        reset_position_ids: false
        reset_attention_mask: false
        eod_mask_loss: false
        masked_lm_prob: 0.15
        short_seq_prob: 0.1
        skip_lines: 0
        drop_last: false
        pin_memory: false
        data_impl: csv_mmap
        data_impl_kwargs:
          csv_mmap:
            header_lines: 1
            newline_int: 10
            workers: ${model.data.num_workers}
            sort_dataset_paths: true
            data_sep: ','
            data_col: 1
        use_upsampling: true
        seed: ${model.seed}
        max_seq_length: ${model.seq_length}
        dataset_path: /data/OASpaired/processed/heavy
        dataset:
          train: x[000..005]
          test: x[000..001]
          val: x[000..001]
        micro_batch_size: ${model.micro_batch_size}
        modify_percent: 0.1
        perturb_percent: 0.5
      optim:
        name: fused_adam
        lr: 0.0002
        weight_decay: 0.01
        betas:
        - 0.9
        - 0.98
        sched:
          name: CosineAnnealing
          warmup_steps: 500
          constant_steps: 50000
          min_lr: 2.0e-05
      global_batch_size: 8
      precision: 16-mixed
    do_dataloader: false
    
[NeMo I 2023-08-25 18:46:44 pretrain_oas:19] ************** Starting Training ***********
[NeMo I 2023-08-25 18:46:44 megatron_init:231] Rank 0 has data parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:234] All data parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:235] Ranks 0 has data parallel rank: 0
[NeMo I 2023-08-25 18:46:44 megatron_init:243] Rank 0 has model parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:244] All model parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:254] Rank 0 has tensor model parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:258] All tensor model parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:259] Rank 0 has tensor model parallel rank: 0
[NeMo I 2023-08-25 18:46:44 megatron_init:273] Rank 0 has pipeline model parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:285] Rank 0 has embedding group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:291] All pipeline model parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:292] Rank 0 has pipeline model parallel rank 0
[NeMo I 2023-08-25 18:46:44 megatron_init:293] All embedding group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:294] Rank 0 has embedding rank: 0
23-08-25 18:46:44 - PID:2297 - rank:(0, 0, 0, 0) - microbatches.py:39 - INFO - setting number of micro-batches to constant 1
[NeMo I 2023-08-25 18:46:44 tokenizer_utils:191] Getting SentencePiece with model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
[NeMo I 2023-08-25 18:46:44 megatron_base_model:229] Padded vocab_size: 128, original vocab_size: 30, dummy tokens: 98.
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py:175: UserWarning: The `batch_idx` argument in `CustomESM1nvModel.on_train_batch_start` hook may not match with the actual batch index when using a `dataloader_iter` argument in your `training_step`.
      rank_zero_warn(
    
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py:175: UserWarning: The `batch_idx` argument in `CustomESM1nvModel.on_train_batch_end` hook may not match with the actual batch index when using a `dataloader_iter` argument in your `training_step`.
      rank_zero_warn(
    
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
Added key: store_based_barrier_key:1 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------

Added key: store_based_barrier_key:2 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
Added key: store_based_barrier_key:3 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:3 with 1 nodes.
Added key: store_based_barrier_key:4 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:4 with 1 nodes.
Added key: store_based_barrier_key:5 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:5 with 1 nodes.
Added key: store_based_barrier_key:6 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:6 with 1 nodes.
Added key: store_based_barrier_key:7 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:7 with 1 nodes.
[NeMo W 2023-08-25 18:46:45 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:613: UserWarning: Checkpoint directory /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining/checkpoints exists and is not empty.
      rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
    
[NeMo I 2023-08-25 18:46:45 megatron_bert_model:563] Pipeline model parallel rank: 0, Tensor model parallel rank: 0, Number of model parameters on device: 4.36e+07. Total number of model parameters: 4.36e+07.
[NeMo I 2023-08-25 18:46:45 esm1nv_model:96] Building Bert datasets.
train:808
Loading data from /data/OASpaired/processed/heavy/train/x000.csv, /data/OASpaired/processed/heavy/train/x001.csv, /data/OASpaired/processed/heavy/train/x002.csv, /data/OASpaired/processed/heavy/train/x003.csv, /data/OASpaired/processed/heavy/train/x004.csv, /data/OASpaired/processed/heavy/train/x005.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:343] Processing 6 data files using 10 workers
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:349] Time building 0 / 6 mem-mapped files: 0:00:00.235856
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x000.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x001.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x002.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x003.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x004.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x005.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:117] Time loading 6 mem-mapped files: 0:00:00.002752
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:121] Computing global indices
[NeMo I 2023-08-25 18:46:45 dataset_utils:1341]  > loading indexed mapping from /data/OASpaired/processed/heavy/train/__indexmap_808mns_512msl_0.00ssp_1234s.npy
[NeMo I 2023-08-25 18:46:45 dataset_utils:1344]     loaded indexed file in 0.001 seconds
[NeMo I 2023-08-25 18:46:45 dataset_utils:1345]     total number of samples: 21129
val:160
Loading data from /data/OASpaired/processed/heavy/val/x000.csv, /data/OASpaired/processed/heavy/val/x001.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:343] Processing 2 data files using 10 workers
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:349] Time building 0 / 2 mem-mapped files: 0:00:00.231405
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/val/x000.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/val/x001.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:117] Time loading 2 mem-mapped files: 0:00:00.001060
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:121] Computing global indices
[NeMo I 2023-08-25 18:46:46 dataset_utils:1341]  > loading indexed mapping from /data/OASpaired/processed/heavy/val/__indexmap_160mns_512msl_0.00ssp_1234s.npy
[NeMo I 2023-08-25 18:46:46 dataset_utils:1344]     loaded indexed file in 0.000 seconds
[NeMo I 2023-08-25 18:46:46 dataset_utils:1345]     total number of samples: 294
test:4000
Loading data from /data/OASpaired/processed/heavy/test/x000.csv, /data/OASpaired/processed/heavy/test/x001.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:343] Processing 2 data files using 10 workers
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:349] Time building 0 / 2 mem-mapped files: 0:00:00.246505
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/test/x000.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/test/x001.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:117] Time loading 2 mem-mapped files: 0:00:00.001123
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:121] Computing global indices
[NeMo I 2023-08-25 18:46:46 dataset_utils:1341]  > loading indexed mapping from /data/OASpaired/processed/heavy/test/__indexmap_4000mns_512msl_0.00ssp_1234s.npy
[NeMo I 2023-08-25 18:46:46 dataset_utils:1344]     loaded indexed file in 0.000 seconds
[NeMo I 2023-08-25 18:46:46 dataset_utils:1345]     total number of samples: 6499
[NeMo I 2023-08-25 18:46:46 esm1nv_model:114] Length of train dataset: 808
[NeMo I 2023-08-25 18:46:46 esm1nv_model:115] Length of val dataset: 160
[NeMo I 2023-08-25 18:46:46 esm1nv_model:116] Length of test dataset: 4000
[NeMo I 2023-08-25 18:46:46 esm1nv_model:117] Finished building Bert datasets.
[NeMo I 2023-08-25 18:46:46 megatron_bert_model:662] Setting up train dataloader with len(len(self._train_ds)): 808 and consumed samples: 0
[NeMo I 2023-08-25 18:46:46 data_samplers:76] Instantiating MegatronPretrainingSampler with total_samples: 808 and consumed_samples: 0
[NeMo I 2023-08-25 18:46:46 megatron_bert_model:670] Setting up validation dataloader with len(len(self._validation_ds)): 160 and consumed samples: 0
[NeMo I 2023-08-25 18:46:46 data_samplers:76] Instantiating MegatronPretrainingSampler with total_samples: 160 and consumed_samples: 0
[NeMo I 2023-08-25 18:46:46 megatron_bert_model:678] Setting up test dataloader with len(len(self._test_ds)): 4000 and consumed samples: 0
[NeMo I 2023-08-25 18:46:46 data_samplers:76] Instantiating MegatronPretrainingSampler with total_samples: 4000 and consumed_samples: 0
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
[NeMo I 2023-08-25 18:46:46 nlp_overrides:124] Configuring DDP for model parallelism.
[NeMo I 2023-08-25 18:46:46 modelPT:722] Optimizer config = FusedAdam (
    Parameter Group 0
        betas: [0.9, 0.98]
        bias_correction: True
        eps: 1e-08
        lr: 0.0002
        weight_decay: 0.01
    
    Parameter Group 1
        betas: [0.9, 0.98]
        bias_correction: True
        eps: 1e-08
        lr: 0.0002
        weight_decay: 0.0
    )
[NeMo I 2023-08-25 18:46:46 lr_scheduler:910] Scheduler "<nemo.core.optim.lr_scheduler.CosineAnnealing object at 0x7f39a83d3400>" 
    will be used during training (effective maximum steps = 101) - 
    Parameters : 
    (warmup_steps: 500
    constant_steps: 50000
    min_lr: 2.0e-05
    max_steps: 101
    )

  | Name                           | Type                     | Params
----------------------------------------------------------------------------
0 | model                          | BertModel                | 43.6 M
1 | model.language_model           | TransformerLanguageModel | 43.0 M
2 | model.language_model.embedding | Embedding                | 491 K 
3 | model.language_model.encoder   | ParallelTransformer      | 42.5 M
4 | model.lm_head                  | BertLMHead               | 592 K 
5 | model.lm_head.dense            | Linear                   | 590 K 
6 | model.lm_head.layernorm        | MixedFusedLayerNorm      | 1.5 K 
----------------------------------------------------------------------------
43.6 M    Trainable params
0         Non-trainable params
43.6 M    Total params
87.225    Total estimated model params size (MB)

Sanity Checking: 0it [00:00, ?it/s][NeMo W 2023-08-25 18:46:46 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:401: UserWarning: Found `dataloader_iter` argument in the `validation_step`. Note that the support for this signature is experimental and the behavior is subject to change.
      rank_zero_warn(
    

Sanity Checking:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0:  50%|█████     | 1/2 [00:01<00:01,  1.56s/it]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:01<00:00,  1.26it/s][NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('consumed_samples', ...)` in your `validation_epoch_end` but the value needs to be floating point. Converting it to torch.float32.
      warning_cache.warn(
    
[NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:536: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
      warning_cache.warn(
    
[NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:536: PossibleUserWarning: It is recommended to use `self.log('val_loss_ECE', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
      warning_cache.warn(
    
[NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:536: PossibleUserWarning: It is recommended to use `self.log('consumed_samples', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
      warning_cache.warn(
    

                                                                           [NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/fit_loop.py:344: UserWarning: Found `dataloader_iter` argument in the `training_step`. Note that the support for this signature is experimental and the behavior is subject to change.
      rank_zero_warn(
    


Training: 0it [00:00, ?it/s]
Training:   0%|          | 0/111 [00:00<?, ?it/s]
Epoch 0:   0%|          | 0/111 [00:00<?, ?it/s] [NeMo W 2023-08-25 18:46:50 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('global_step', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
      warning_cache.warn(
    
[NeMo W 2023-08-25 18:46:50 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('consumed_samples', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
      warning_cache.warn(
    
[NeMo W 2023-08-25 18:46:50 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/torch/optim/lr_scheduler.py:139: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
      warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
    

Epoch 0:   1%|          | 1/111 [00:01<03:36,  1.97s/it]
Epoch 0:   1%|          | 1/111 [00:01<03:36,  1.97s/it, loss=4.82, v_num=, reduced_train_loss=4.820, global_step=0.000, consumed_samples=0.000]
Epoch 0:   2%|▏         | 2/111 [00:02<01:49,  1.01s/it, loss=4.82, v_num=, reduced_train_loss=4.820, global_step=0.000, consumed_samples=0.000]
Epoch 0:   2%|▏         | 2/111 [00:02<01:49,  1.01s/it, loss=4.83, v_num=, reduced_train_loss=4.840, global_step=1.000, consumed_samples=8.000]
Epoch 0:   3%|▎         | 3/111 [00:02<01:15,  1.42it/s, loss=4.83, v_num=, reduced_train_loss=4.840, global_step=1.000, consumed_samples=8.000]
Epoch 0:   3%|▎         | 3/111 [00:02<01:15,  1.42it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=2.000, consumed_samples=16.00]
Epoch 0:   4%|▎         | 4/111 [00:02<00:57,  1.86it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=2.000, consumed_samples=16.00]
Epoch 0:   4%|▎         | 4/111 [00:02<00:57,  1.86it/s, loss=4.84, v_num=, reduced_train_loss=4.860, global_step=3.000, consumed_samples=24.00]
Epoch 0:   5%|▍         | 5/111 [00:02<00:46,  2.28it/s, loss=4.84, v_num=, reduced_train_loss=4.860, global_step=3.000, consumed_samples=24.00]
Epoch 0:   5%|▍         | 5/111 [00:02<00:46,  2.28it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=4.000, consumed_samples=32.00]
Epoch 0:   5%|▌         | 6/111 [00:02<00:39,  2.68it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=4.000, consumed_samples=32.00]
Epoch 0:   5%|▌         | 6/111 [00:02<00:39,  2.68it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=5.000, consumed_samples=40.00]
Epoch 0:   6%|▋         | 7/111 [00:02<00:33,  3.06it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=5.000, consumed_samples=40.00]
Epoch 0:   6%|▋         | 7/111 [00:02<00:33,  3.06it/s, loss=4.84, v_num=, reduced_train_loss=4.810, global_step=6.000, consumed_samples=48.00]
Epoch 0:   7%|▋         | 8/111 [00:02<00:29,  3.44it/s, loss=4.84, v_num=, reduced_train_loss=4.810, global_step=6.000, consumed_samples=48.00]
Epoch 0:   7%|▋         | 8/111 [00:02<00:29,  3.44it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=7.000, consumed_samples=56.00]
Epoch 0:   8%|▊         | 9/111 [00:02<00:26,  3.80it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=7.000, consumed_samples=56.00]
Epoch 0:   8%|▊         | 9/111 [00:02<00:26,  3.80it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=8.000, consumed_samples=64.00]
Epoch 0:   9%|▉         | 10/111 [00:02<00:24,  4.14it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=8.000, consumed_samples=64.00]
Epoch 0:   9%|▉         | 10/111 [00:02<00:24,  4.14it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=9.000, consumed_samples=72.00]
Epoch 0:  10%|▉         | 11/111 [00:02<00:22,  4.46it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=9.000, consumed_samples=72.00]
Epoch 0:  10%|▉         | 11/111 [00:02<00:22,  4.46it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=10.00, consumed_samples=80.00]
Epoch 0:  11%|█         | 12/111 [00:02<00:20,  4.78it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=10.00, consumed_samples=80.00]
Epoch 0:  11%|█         | 12/111 [00:02<00:20,  4.78it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=11.00, consumed_samples=88.00]
Epoch 0:  12%|█▏        | 13/111 [00:02<00:19,  5.07it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=11.00, consumed_samples=88.00]
Epoch 0:  12%|█▏        | 13/111 [00:02<00:19,  5.07it/s, loss=4.84, v_num=, reduced_train_loss=4.820, global_step=12.00, consumed_samples=96.00]
Epoch 0:  13%|█▎        | 14/111 [00:02<00:18,  5.37it/s, loss=4.84, v_num=, reduced_train_loss=4.820, global_step=12.00, consumed_samples=96.00]
Epoch 0:  13%|█▎        | 14/111 [00:02<00:18,  5.37it/s, loss=4.84, v_num=, reduced_train_loss=4.800, global_step=13.00, consumed_samples=104.0]
Epoch 0:  14%|█▎        | 15/111 [00:02<00:16,  5.66it/s, loss=4.84, v_num=, reduced_train_loss=4.800, global_step=13.00, consumed_samples=104.0]
Epoch 0:  14%|█▎        | 15/111 [00:02<00:16,  5.66it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=14.00, consumed_samples=112.0]
Epoch 0:  14%|█▍        | 16/111 [00:02<00:16,  5.93it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=14.00, consumed_samples=112.0]
Epoch 0:  14%|█▍        | 16/111 [00:02<00:16,  5.93it/s, loss=4.83, v_num=, reduced_train_loss=4.810, global_step=15.00, consumed_samples=120.0]
Epoch 0:  15%|█▌        | 17/111 [00:02<00:15,  6.20it/s, loss=4.83, v_num=, reduced_train_loss=4.810, global_step=15.00, consumed_samples=120.0]
Epoch 0:  15%|█▌        | 17/111 [00:02<00:15,  6.20it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=16.00, consumed_samples=128.0]
Epoch 0:  16%|█▌        | 18/111 [00:02<00:14,  6.46it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=16.00, consumed_samples=128.0]
Epoch 0:  16%|█▌        | 18/111 [00:02<00:14,  6.46it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=17.00, consumed_samples=136.0]
Epoch 0:  17%|█▋        | 19/111 [00:02<00:13,  6.71it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=17.00, consumed_samples=136.0]
Epoch 0:  17%|█▋        | 19/111 [00:02<00:13,  6.71it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=18.00, consumed_samples=144.0]
Epoch 0:  18%|█▊        | 20/111 [00:02<00:13,  6.96it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=18.00, consumed_samples=144.0]
Epoch 0:  18%|█▊        | 20/111 [00:02<00:13,  6.96it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=19.00, consumed_samples=152.0]
Epoch 0:  19%|█▉        | 21/111 [00:02<00:12,  7.19it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=19.00, consumed_samples=152.0]
Epoch 0:  19%|█▉        | 21/111 [00:02<00:12,  7.18it/s, loss=4.83, v_num=, reduced_train_loss=4.850, global_step=20.00, consumed_samples=160.0]
Epoch 0:  20%|█▉        | 22/111 [00:02<00:12,  7.41it/s, loss=4.83, v_num=, reduced_train_loss=4.850, global_step=20.00, consumed_samples=160.0]
Epoch 0:  20%|█▉        | 22/111 [00:02<00:12,  7.41it/s, loss=4.83, v_num=, reduced_train_loss=4.860, global_step=21.00, consumed_samples=168.0]
Epoch 0:  21%|██        | 23/111 [00:03<00:11,  7.63it/s, loss=4.83, v_num=, reduced_train_loss=4.860, global_step=21.00, consumed_samples=168.0]
Epoch 0:  21%|██        | 23/111 [00:03<00:11,  7.63it/s, loss=4.83, v_num=, reduced_train_loss=4.650, global_step=22.00, consumed_samples=176.0]
Epoch 0:  22%|██▏       | 24/111 [00:03<00:11,  7.83it/s, loss=4.83, v_num=, reduced_train_loss=4.650, global_step=22.00, consumed_samples=176.0]
Epoch 0:  22%|██▏       | 24/111 [00:03<00:11,  7.83it/s, loss=4.8, v_num=, reduced_train_loss=4.350, global_step=23.00, consumed_samples=184.0] 
Epoch 0:  23%|██▎       | 25/111 [00:03<00:10,  8.02it/s, loss=4.8, v_num=, reduced_train_loss=4.350, global_step=23.00, consumed_samples=184.0]
Epoch 0:  23%|██▎       | 25/111 [00:03<00:10,  8.02it/s, loss=4.75, v_num=, reduced_train_loss=3.900, global_step=24.00, consumed_samples=192.0]
Epoch 0:  23%|██▎       | 26/111 [00:03<00:10,  8.21it/s, loss=4.75, v_num=, reduced_train_loss=3.900, global_step=24.00, consumed_samples=192.0]
Epoch 0:  23%|██▎       | 26/111 [00:03<00:10,  8.21it/s, loss=4.68, v_num=, reduced_train_loss=3.310, global_step=25.00, consumed_samples=200.0]
Epoch 0:  24%|██▍       | 27/111 [00:03<00:10,  8.39it/s, loss=4.68, v_num=, reduced_train_loss=3.310, global_step=25.00, consumed_samples=200.0]
Epoch 0:  24%|██▍       | 27/111 [00:03<00:10,  8.39it/s, loss=4.57, v_num=, reduced_train_loss=2.680, global_step=26.00, consumed_samples=208.0]
Epoch 0:  25%|██▌       | 28/111 [00:03<00:09,  8.56it/s, loss=4.57, v_num=, reduced_train_loss=2.680, global_step=26.00, consumed_samples=208.0]
Epoch 0:  25%|██▌       | 28/111 [00:03<00:09,  8.56it/s, loss=4.43, v_num=, reduced_train_loss=2.000, global_step=27.00, consumed_samples=216.0]
Epoch 0:  26%|██▌       | 29/111 [00:03<00:09,  8.75it/s, loss=4.43, v_num=, reduced_train_loss=2.000, global_step=27.00, consumed_samples=216.0]
Epoch 0:  26%|██▌       | 29/111 [00:03<00:09,  8.74it/s, loss=4.26, v_num=, reduced_train_loss=1.400, global_step=28.00, consumed_samples=224.0]
Epoch 0:  27%|██▋       | 30/111 [00:03<00:09,  8.92it/s, loss=4.26, v_num=, reduced_train_loss=1.400, global_step=28.00, consumed_samples=224.0]
Epoch 0:  27%|██▋       | 30/111 [00:03<00:09,  8.92it/s, loss=4.06, v_num=, reduced_train_loss=0.983, global_step=29.00, consumed_samples=232.0]
Epoch 0:  28%|██▊       | 31/111 [00:03<00:08,  9.09it/s, loss=4.06, v_num=, reduced_train_loss=0.983, global_step=29.00, consumed_samples=232.0]
Epoch 0:  28%|██▊       | 31/111 [00:03<00:08,  9.09it/s, loss=3.85, v_num=, reduced_train_loss=0.624, global_step=30.00, consumed_samples=240.0]
Epoch 0:  29%|██▉       | 32/111 [00:03<00:08,  9.26it/s, loss=3.85, v_num=, reduced_train_loss=0.624, global_step=30.00, consumed_samples=240.0]
Epoch 0:  29%|██▉       | 32/111 [00:03<00:08,  9.26it/s, loss=3.63, v_num=, reduced_train_loss=0.420, global_step=31.00, consumed_samples=248.0]
Epoch 0:  30%|██▉       | 33/111 [00:03<00:08,  9.43it/s, loss=3.63, v_num=, reduced_train_loss=0.420, global_step=31.00, consumed_samples=248.0]
Epoch 0:  30%|██▉       | 33/111 [00:03<00:08,  9.43it/s, loss=3.4, v_num=, reduced_train_loss=0.256, global_step=32.00, consumed_samples=256.0] 
Epoch 0:  31%|███       | 34/111 [00:03<00:08,  9.57it/s, loss=3.4, v_num=, reduced_train_loss=0.256, global_step=32.00, consumed_samples=256.0]
Epoch 0:  31%|███       | 34/111 [00:03<00:08,  9.57it/s, loss=3.17, v_num=, reduced_train_loss=0.180, global_step=33.00, consumed_samples=264.0]
Epoch 0:  32%|███▏      | 35/111 [00:03<00:07,  9.72it/s, loss=3.17, v_num=, reduced_train_loss=0.180, global_step=33.00, consumed_samples=264.0]
Epoch 0:  32%|███▏      | 35/111 [00:03<00:07,  9.72it/s, loss=2.93, v_num=, reduced_train_loss=0.121, global_step=34.00, consumed_samples=272.0]
Epoch 0:  32%|███▏      | 36/111 [00:03<00:07,  9.86it/s, loss=2.93, v_num=, reduced_train_loss=0.121, global_step=34.00, consumed_samples=272.0]
Epoch 0:  32%|███▏      | 36/111 [00:03<00:07,  9.86it/s, loss=2.7, v_num=, reduced_train_loss=0.0845, global_step=35.00, consumed_samples=280.0]
Epoch 0:  33%|███▎      | 37/111 [00:03<00:07,  9.99it/s, loss=2.7, v_num=, reduced_train_loss=0.0845, global_step=35.00, consumed_samples=280.0]
Epoch 0:  33%|███▎      | 37/111 [00:03<00:07,  9.99it/s, loss=2.46, v_num=, reduced_train_loss=0.081, global_step=36.00, consumed_samples=288.0]
Epoch 0:  34%|███▍      | 38/111 [00:03<00:07, 10.14it/s, loss=2.46, v_num=, reduced_train_loss=0.081, global_step=36.00, consumed_samples=288.0]
Epoch 0:  34%|███▍      | 38/111 [00:03<00:07, 10.14it/s, loss=2.22, v_num=, reduced_train_loss=0.0499, global_step=37.00, consumed_samples=296.0]
Epoch 0:  35%|███▌      | 39/111 [00:03<00:06, 10.29it/s, loss=2.22, v_num=, reduced_train_loss=0.0499, global_step=37.00, consumed_samples=296.0]
Epoch 0:  35%|███▌      | 39/111 [00:03<00:06, 10.29it/s, loss=1.98, v_num=, reduced_train_loss=0.042, global_step=38.00, consumed_samples=304.0] 
Epoch 0:  36%|███▌      | 40/111 [00:03<00:06, 10.42it/s, loss=1.98, v_num=, reduced_train_loss=0.042, global_step=38.00, consumed_samples=304.0]
Epoch 0:  36%|███▌      | 40/111 [00:03<00:06, 10.42it/s, loss=1.74, v_num=, reduced_train_loss=0.0138, global_step=39.00, consumed_samples=312.0]
Epoch 0:  37%|███▋      | 41/111 [00:03<00:06, 10.54it/s, loss=1.74, v_num=, reduced_train_loss=0.0138, global_step=39.00, consumed_samples=312.0]
Epoch 0:  37%|███▋      | 41/111 [00:03<00:06, 10.54it/s, loss=1.5, v_num=, reduced_train_loss=0.0586, global_step=40.00, consumed_samples=320.0] 
Epoch 0:  38%|███▊      | 42/111 [00:03<00:06, 10.67it/s, loss=1.5, v_num=, reduced_train_loss=0.0586, global_step=40.00, consumed_samples=320.0]
Epoch 0:  38%|███▊      | 42/111 [00:03<00:06, 10.67it/s, loss=1.26, v_num=, reduced_train_loss=0.00855, global_step=41.00, consumed_samples=328.0]
Epoch 0:  39%|███▊      | 43/111 [00:03<00:06, 10.79it/s, loss=1.26, v_num=, reduced_train_loss=0.00855, global_step=41.00, consumed_samples=328.0]
Epoch 0:  39%|███▊      | 43/111 [00:03<00:06, 10.79it/s, loss=1.03, v_num=, reduced_train_loss=0.0298, global_step=42.00, consumed_samples=336.0] 
Epoch 0:  40%|███▉      | 44/111 [00:04<00:06, 10.91it/s, loss=1.03, v_num=, reduced_train_loss=0.0298, global_step=42.00, consumed_samples=336.0]
Epoch 0:  40%|███▉      | 44/111 [00:04<00:06, 10.91it/s, loss=0.813, v_num=, reduced_train_loss=0.0272, global_step=43.00, consumed_samples=344.0]
Epoch 0:  41%|████      | 45/111 [00:04<00:05, 11.04it/s, loss=0.813, v_num=, reduced_train_loss=0.0272, global_step=43.00, consumed_samples=344.0]
Epoch 0:  41%|████      | 45/111 [00:04<00:05, 11.03it/s, loss=0.619, v_num=, reduced_train_loss=0.00627, global_step=44.00, consumed_samples=352.0]
Epoch 0:  41%|████▏     | 46/111 [00:04<00:05, 11.15it/s, loss=0.619, v_num=, reduced_train_loss=0.00627, global_step=44.00, consumed_samples=352.0]
Epoch 0:  41%|████▏     | 46/111 [00:04<00:05, 11.15it/s, loss=0.454, v_num=, reduced_train_loss=0.0058, global_step=45.00, consumed_samples=360.0] 
Epoch 0:  42%|████▏     | 47/111 [00:04<00:05, 11.26it/s, loss=0.454, v_num=, reduced_train_loss=0.0058, global_step=45.00, consumed_samples=360.0]
Epoch 0:  42%|████▏     | 47/111 [00:04<00:05, 11.26it/s, loss=0.322, v_num=, reduced_train_loss=0.0466, global_step=46.00, consumed_samples=368.0]
Epoch 0:  43%|████▎     | 48/111 [00:04<00:05, 11.38it/s, loss=0.322, v_num=, reduced_train_loss=0.0466, global_step=46.00, consumed_samples=368.0]
Epoch 0:  43%|████▎     | 48/111 [00:04<00:05, 11.38it/s, loss=0.222, v_num=, reduced_train_loss=0.00506, global_step=47.00, consumed_samples=376.0]
Epoch 0:  44%|████▍     | 49/111 [00:04<00:05, 11.49it/s, loss=0.222, v_num=, reduced_train_loss=0.00506, global_step=47.00, consumed_samples=376.0]
Epoch 0:  44%|████▍     | 49/111 [00:04<00:05, 11.49it/s, loss=0.152, v_num=, reduced_train_loss=0.00475, global_step=48.00, consumed_samples=384.0]
Epoch 0:  45%|████▌     | 50/111 [00:04<00:05, 11.57it/s, loss=0.152, v_num=, reduced_train_loss=0.00475, global_step=48.00, consumed_samples=384.0]
Epoch 0:  45%|████▌     | 50/111 [00:04<00:05, 11.57it/s, loss=0.104, v_num=, reduced_train_loss=0.00427, global_step=49.00, consumed_samples=392.0]
Epoch 0:  46%|████▌     | 51/111 [00:04<00:05, 11.65it/s, loss=0.104, v_num=, reduced_train_loss=0.00427, global_step=49.00, consumed_samples=392.0]
Epoch 0:  46%|████▌     | 51/111 [00:04<00:05, 11.65it/s, loss=0.0735, v_num=, reduced_train_loss=0.0238, global_step=50.00, consumed_samples=400.0]
Epoch 0:  47%|████▋     | 52/111 [00:04<00:05, 11.72it/s, loss=0.0735, v_num=, reduced_train_loss=0.0238, global_step=50.00, consumed_samples=400.0]
Epoch 0:  47%|████▋     | 52/111 [00:04<00:05, 11.72it/s, loss=0.0536, v_num=, reduced_train_loss=0.0226, global_step=51.00, consumed_samples=408.0]
Epoch 0:  48%|████▊     | 53/111 [00:04<00:04, 11.82it/s, loss=0.0536, v_num=, reduced_train_loss=0.0226, global_step=51.00, consumed_samples=408.0]
Epoch 0:  48%|████▊     | 53/111 [00:04<00:04, 11.82it/s, loss=0.0419, v_num=, reduced_train_loss=0.0215, global_step=52.00, consumed_samples=416.0]
Epoch 0:  49%|████▊     | 54/111 [00:04<00:04, 11.93it/s, loss=0.0419, v_num=, reduced_train_loss=0.0215, global_step=52.00, consumed_samples=416.0]
Epoch 0:  49%|████▊     | 54/111 [00:04<00:04, 11.93it/s, loss=0.0339, v_num=, reduced_train_loss=0.022, global_step=53.00, consumed_samples=424.0] 
Epoch 0:  50%|████▉     | 55/111 [00:04<00:04, 12.04it/s, loss=0.0339, v_num=, reduced_train_loss=0.022, global_step=53.00, consumed_samples=424.0]
Epoch 0:  50%|████▉     | 55/111 [00:04<00:04, 12.04it/s, loss=0.0304, v_num=, reduced_train_loss=0.0491, global_step=54.00, consumed_samples=432.0]
Epoch 0:  50%|█████     | 56/111 [00:04<00:04, 12.07it/s, loss=0.0304, v_num=, reduced_train_loss=0.0491, global_step=54.00, consumed_samples=432.0]
Epoch 0:  50%|█████     | 56/111 [00:04<00:04, 12.07it/s, loss=0.0263, v_num=, reduced_train_loss=0.00395, global_step=55.00, consumed_samples=440.0]
Epoch 0:  51%|█████▏    | 57/111 [00:04<00:04, 12.16it/s, loss=0.0263, v_num=, reduced_train_loss=0.00395, global_step=55.00, consumed_samples=440.0]
Epoch 0:  51%|█████▏    | 57/111 [00:04<00:04, 12.16it/s, loss=0.0225, v_num=, reduced_train_loss=0.00413, global_step=56.00, consumed_samples=448.0]
Epoch 0:  52%|█████▏    | 58/111 [00:04<00:04, 12.25it/s, loss=0.0225, v_num=, reduced_train_loss=0.00413, global_step=56.00, consumed_samples=448.0]
Epoch 0:  52%|█████▏    | 58/111 [00:04<00:04, 12.25it/s, loss=0.0217, v_num=, reduced_train_loss=0.0349, global_step=57.00, consumed_samples=456.0] 
Epoch 0:  53%|█████▎    | 59/111 [00:04<00:04, 12.32it/s, loss=0.0217, v_num=, reduced_train_loss=0.0349, global_step=57.00, consumed_samples=456.0]
Epoch 0:  53%|█████▎    | 59/111 [00:04<00:04, 12.32it/s, loss=0.0214, v_num=, reduced_train_loss=0.0356, global_step=58.00, consumed_samples=464.0]
Epoch 0:  54%|█████▍    | 60/111 [00:04<00:04, 12.40it/s, loss=0.0214, v_num=, reduced_train_loss=0.0356, global_step=58.00, consumed_samples=464.0]
Epoch 0:  54%|█████▍    | 60/111 [00:04<00:04, 12.40it/s, loss=0.0217, v_num=, reduced_train_loss=0.0194, global_step=59.00, consumed_samples=472.0]
Epoch 0:  55%|█████▍    | 61/111 [00:04<00:04, 12.46it/s, loss=0.0217, v_num=, reduced_train_loss=0.0194, global_step=59.00, consumed_samples=472.0]
Epoch 0:  55%|█████▍    | 61/111 [00:04<00:04, 12.46it/s, loss=0.0198, v_num=, reduced_train_loss=0.0199, global_step=60.00, consumed_samples=480.0]
Epoch 0:  56%|█████▌    | 62/111 [00:04<00:03, 12.52it/s, loss=0.0198, v_num=, reduced_train_loss=0.0199, global_step=60.00, consumed_samples=480.0]
Epoch 0:  56%|█████▌    | 62/111 [00:04<00:03, 12.52it/s, loss=0.0203, v_num=, reduced_train_loss=0.0202, global_step=61.00, consumed_samples=488.0]
Epoch 0:  57%|█████▋    | 63/111 [00:05<00:03, 12.56it/s, loss=0.0203, v_num=, reduced_train_loss=0.0202, global_step=61.00, consumed_samples=488.0]
Epoch 0:  57%|█████▋    | 63/111 [00:05<00:03, 12.56it/s, loss=0.0204, v_num=, reduced_train_loss=0.0319, global_step=62.00, consumed_samples=496.0]
Epoch 0:  58%|█████▊    | 64/111 [00:05<00:03, 12.64it/s, loss=0.0204, v_num=, reduced_train_loss=0.0319, global_step=62.00, consumed_samples=496.0]
Epoch 0:  58%|█████▊    | 64/111 [00:05<00:03, 12.64it/s, loss=0.0194, v_num=, reduced_train_loss=0.00546, global_step=63.00, consumed_samples=504.0]
Epoch 0:  59%|█████▊    | 65/111 [00:05<00:03, 12.72it/s, loss=0.0194, v_num=, reduced_train_loss=0.00546, global_step=63.00, consumed_samples=504.0]
Epoch 0:  59%|█████▊    | 65/111 [00:05<00:03, 12.72it/s, loss=0.0199, v_num=, reduced_train_loss=0.0178, global_step=64.00, consumed_samples=512.0] 
Epoch 0:  59%|█████▉    | 66/111 [00:05<00:03, 12.77it/s, loss=0.0199, v_num=, reduced_train_loss=0.0178, global_step=64.00, consumed_samples=512.0]
Epoch 0:  59%|█████▉    | 66/111 [00:05<00:03, 12.76it/s, loss=0.0205, v_num=, reduced_train_loss=0.018, global_step=65.00, consumed_samples=520.0] 
Epoch 0:  60%|██████    | 67/111 [00:05<00:03, 12.85it/s, loss=0.0205, v_num=, reduced_train_loss=0.018, global_step=65.00, consumed_samples=520.0]
Epoch 0:  60%|██████    | 67/111 [00:05<00:03, 12.85it/s, loss=0.0183, v_num=, reduced_train_loss=0.00246, global_step=66.00, consumed_samples=528.0]
Epoch 0:  61%|██████▏   | 68/111 [00:05<00:03, 12.94it/s, loss=0.0183, v_num=, reduced_train_loss=0.00246, global_step=66.00, consumed_samples=528.0]
Epoch 0:  61%|██████▏   | 68/111 [00:05<00:03, 12.94it/s, loss=0.0195, v_num=, reduced_train_loss=0.0292, global_step=67.00, consumed_samples=536.0] 
Epoch 0:  62%|██████▏   | 69/111 [00:05<00:03, 13.02it/s, loss=0.0195, v_num=, reduced_train_loss=0.0292, global_step=67.00, consumed_samples=536.0]
Epoch 0:  62%|██████▏   | 69/111 [00:05<00:03, 13.01it/s, loss=0.0211, v_num=, reduced_train_loss=0.0361, global_step=68.00, consumed_samples=544.0]
Epoch 0:  63%|██████▎   | 70/111 [00:05<00:03, 13.08it/s, loss=0.0211, v_num=, reduced_train_loss=0.0361, global_step=68.00, consumed_samples=544.0]
Epoch 0:  63%|██████▎   | 70/111 [00:05<00:03, 13.08it/s, loss=0.0218, v_num=, reduced_train_loss=0.0176, global_step=69.00, consumed_samples=552.0]
Epoch 0:  64%|██████▍   | 71/111 [00:05<00:03, 13.15it/s, loss=0.0218, v_num=, reduced_train_loss=0.0176, global_step=69.00, consumed_samples=552.0]
Epoch 0:  64%|██████▍   | 71/111 [00:05<00:03, 13.15it/s, loss=0.0207, v_num=, reduced_train_loss=0.00187, global_step=70.00, consumed_samples=560.0]
Epoch 0:  65%|██████▍   | 72/111 [00:05<00:02, 13.21it/s, loss=0.0207, v_num=, reduced_train_loss=0.00187, global_step=70.00, consumed_samples=560.0]
Epoch 0:  65%|██████▍   | 72/111 [00:05<00:02, 13.21it/s, loss=0.0212, v_num=, reduced_train_loss=0.0326, global_step=71.00, consumed_samples=568.0] 
Epoch 0:  66%|██████▌   | 73/111 [00:05<00:02, 13.29it/s, loss=0.0212, v_num=, reduced_train_loss=0.0326, global_step=71.00, consumed_samples=568.0]
Epoch 0:  66%|██████▌   | 73/111 [00:05<00:02, 13.29it/s, loss=0.0209, v_num=, reduced_train_loss=0.0162, global_step=72.00, consumed_samples=576.0]
Epoch 0:  67%|██████▋   | 74/111 [00:05<00:02, 13.36it/s, loss=0.0209, v_num=, reduced_train_loss=0.0162, global_step=72.00, consumed_samples=576.0]
Epoch 0:  67%|██████▋   | 74/111 [00:05<00:02, 13.35it/s, loss=0.0206, v_num=, reduced_train_loss=0.0163, global_step=73.00, consumed_samples=584.0]
Epoch 0:  68%|██████▊   | 75/111 [00:05<00:02, 13.42it/s, loss=0.0206, v_num=, reduced_train_loss=0.0163, global_step=73.00, consumed_samples=584.0]
Epoch 0:  68%|██████▊   | 75/111 [00:05<00:02, 13.42it/s, loss=0.019, v_num=, reduced_train_loss=0.0169, global_step=74.00, consumed_samples=592.0] 
Epoch 0:  68%|██████▊   | 76/111 [00:05<00:02, 13.49it/s, loss=0.019, v_num=, reduced_train_loss=0.0169, global_step=74.00, consumed_samples=592.0]
Epoch 0:  68%|██████▊   | 76/111 [00:05<00:02, 13.48it/s, loss=0.0192, v_num=, reduced_train_loss=0.00711, global_step=75.00, consumed_samples=600.0]
Epoch 0:  69%|██████▉   | 77/111 [00:05<00:02, 13.52it/s, loss=0.0192, v_num=, reduced_train_loss=0.00711, global_step=75.00, consumed_samples=600.0]
Epoch 0:  69%|██████▉   | 77/111 [00:05<00:02, 13.51it/s, loss=0.0197, v_num=, reduced_train_loss=0.015, global_step=76.00, consumed_samples=608.0]  
Epoch 0:  70%|███████   | 78/111 [00:05<00:02, 13.58it/s, loss=0.0197, v_num=, reduced_train_loss=0.015, global_step=76.00, consumed_samples=608.0]
Epoch 0:  70%|███████   | 78/111 [00:05<00:02, 13.58it/s, loss=0.0187, v_num=, reduced_train_loss=0.0147, global_step=77.00, consumed_samples=616.0]
Epoch 0:  71%|███████   | 79/111 [00:05<00:02, 13.65it/s, loss=0.0187, v_num=, reduced_train_loss=0.0147, global_step=77.00, consumed_samples=616.0]
Epoch 0:  71%|███████   | 79/111 [00:05<00:02, 13.65it/s, loss=0.017, v_num=, reduced_train_loss=0.00149, global_step=78.00, consumed_samples=624.0]
Epoch 0:  72%|███████▏  | 80/111 [00:05<00:02, 13.71it/s, loss=0.017, v_num=, reduced_train_loss=0.00149, global_step=78.00, consumed_samples=624.0]
Epoch 0:  72%|███████▏  | 80/111 [00:05<00:02, 13.71it/s, loss=0.0169, v_num=, reduced_train_loss=0.0169, global_step=79.00, consumed_samples=632.0]
Epoch 0:  73%|███████▎  | 81/111 [00:05<00:02, 13.77it/s, loss=0.0169, v_num=, reduced_train_loss=0.0169, global_step=79.00, consumed_samples=632.0]
Epoch 0:  73%|███████▎  | 81/111 [00:05<00:02, 13.77it/s, loss=0.0183, v_num=, reduced_train_loss=0.0474, global_step=80.00, consumed_samples=640.0]
Epoch 0:  74%|███████▍  | 82/111 [00:05<00:02, 13.83it/s, loss=0.0183, v_num=, reduced_train_loss=0.0474, global_step=80.00, consumed_samples=640.0]
Epoch 0:  74%|███████▍  | 82/111 [00:05<00:02, 13.83it/s, loss=0.0173, v_num=, reduced_train_loss=0.00144, global_step=81.00, consumed_samples=648.0]
Epoch 0:  75%|███████▍  | 83/111 [00:05<00:02, 13.85it/s, loss=0.0173, v_num=, reduced_train_loss=0.00144, global_step=81.00, consumed_samples=648.0]
Epoch 0:  75%|███████▍  | 83/111 [00:05<00:02, 13.85it/s, loss=0.0168, v_num=, reduced_train_loss=0.0219, global_step=82.00, consumed_samples=656.0] 
Epoch 0:  76%|███████▌  | 84/111 [00:06<00:01, 13.89it/s, loss=0.0168, v_num=, reduced_train_loss=0.0219, global_step=82.00, consumed_samples=656.0]
Epoch 0:  76%|███████▌  | 84/111 [00:06<00:01, 13.89it/s, loss=0.0177, v_num=, reduced_train_loss=0.0229, global_step=83.00, consumed_samples=664.0]
Epoch 0:  77%|███████▋  | 85/111 [00:06<00:01, 13.94it/s, loss=0.0177, v_num=, reduced_train_loss=0.0229, global_step=83.00, consumed_samples=664.0]
Epoch 0:  77%|███████▋  | 85/111 [00:06<00:01, 13.94it/s, loss=0.0184, v_num=, reduced_train_loss=0.031, global_step=84.00, consumed_samples=672.0] 
Epoch 0:  77%|███████▋  | 86/111 [00:06<00:01, 13.99it/s, loss=0.0184, v_num=, reduced_train_loss=0.031, global_step=84.00, consumed_samples=672.0]
Epoch 0:  77%|███████▋  | 86/111 [00:06<00:01, 13.99it/s, loss=0.0186, v_num=, reduced_train_loss=0.022, global_step=85.00, consumed_samples=680.0]
Epoch 0:  78%|███████▊  | 87/111 [00:06<00:01, 14.02it/s, loss=0.0186, v_num=, reduced_train_loss=0.022, global_step=85.00, consumed_samples=680.0]
Epoch 0:  78%|███████▊  | 87/111 [00:06<00:01, 14.02it/s, loss=0.0186, v_num=, reduced_train_loss=0.00275, global_step=86.00, consumed_samples=688.0]
Epoch 0:  79%|███████▉  | 88/111 [00:06<00:01, 14.06it/s, loss=0.0186, v_num=, reduced_train_loss=0.00275, global_step=86.00, consumed_samples=688.0]
Epoch 0:  79%|███████▉  | 88/111 [00:06<00:01, 14.05it/s, loss=0.0174, v_num=, reduced_train_loss=0.00636, global_step=87.00, consumed_samples=696.0]
Epoch 0:  80%|████████  | 89/111 [00:06<00:01, 14.10it/s, loss=0.0174, v_num=, reduced_train_loss=0.00636, global_step=87.00, consumed_samples=696.0]
Epoch 0:  80%|████████  | 89/111 [00:06<00:01, 14.10it/s, loss=0.0168, v_num=, reduced_train_loss=0.0227, global_step=88.00, consumed_samples=704.0] 
Epoch 0:  81%|████████  | 90/111 [00:06<00:01, 14.14it/s, loss=0.0168, v_num=, reduced_train_loss=0.0227, global_step=88.00, consumed_samples=704.0]
Epoch 0:  81%|████████  | 90/111 [00:06<00:01, 14.14it/s, loss=0.0159, v_num=, reduced_train_loss=0.000591, global_step=89.00, consumed_samples=712.0]
Epoch 0:  82%|████████▏ | 91/111 [00:06<00:01, 14.19it/s, loss=0.0159, v_num=, reduced_train_loss=0.000591, global_step=89.00, consumed_samples=712.0]
Epoch 0:  82%|████████▏ | 91/111 [00:06<00:01, 14.18it/s, loss=0.0158, v_num=, reduced_train_loss=0.000612, global_step=90.00, consumed_samples=720.0]
Epoch 0:  83%|████████▎ | 92/111 [00:06<00:01, 14.23it/s, loss=0.0158, v_num=, reduced_train_loss=0.000612, global_step=90.00, consumed_samples=720.0]
Epoch 0:  83%|████████▎ | 92/111 [00:06<00:01, 14.23it/s, loss=0.0147, v_num=, reduced_train_loss=0.0107, global_step=91.00, consumed_samples=728.0]  
Epoch 0:  84%|████████▍ | 93/111 [00:06<00:01, 14.28it/s, loss=0.0147, v_num=, reduced_train_loss=0.0107, global_step=91.00, consumed_samples=728.0]
Epoch 0:  84%|████████▍ | 93/111 [00:06<00:01, 14.27it/s, loss=0.014, v_num=, reduced_train_loss=0.000628, global_step=92.00, consumed_samples=736.0]
Epoch 0:  85%|████████▍ | 94/111 [00:06<00:01, 14.32it/s, loss=0.014, v_num=, reduced_train_loss=0.000628, global_step=92.00, consumed_samples=736.0]
Epoch 0:  85%|████████▍ | 94/111 [00:06<00:01, 14.32it/s, loss=0.0132, v_num=, reduced_train_loss=0.000621, global_step=93.00, consumed_samples=744.0]
Epoch 0:  86%|████████▌ | 95/111 [00:06<00:01, 14.36it/s, loss=0.0132, v_num=, reduced_train_loss=0.000621, global_step=93.00, consumed_samples=744.0]
Epoch 0:  86%|████████▌ | 95/111 [00:06<00:01, 14.36it/s, loss=0.0124, v_num=, reduced_train_loss=0.000607, global_step=94.00, consumed_samples=752.0]
Epoch 0:  86%|████████▋ | 96/111 [00:06<00:01, 14.40it/s, loss=0.0124, v_num=, reduced_train_loss=0.000607, global_step=94.00, consumed_samples=752.0]
Epoch 0:  86%|████████▋ | 96/111 [00:06<00:01, 14.40it/s, loss=0.012, v_num=, reduced_train_loss=0.000597, global_step=95.00, consumed_samples=760.0] 
Epoch 0:  87%|████████▋ | 97/111 [00:06<00:00, 14.44it/s, loss=0.012, v_num=, reduced_train_loss=0.000597, global_step=95.00, consumed_samples=760.0]
Epoch 0:  87%|████████▋ | 97/111 [00:06<00:00, 14.44it/s, loss=0.0113, v_num=, reduced_train_loss=0.000654, global_step=96.00, consumed_samples=768.0]
Epoch 0:  88%|████████▊ | 98/111 [00:06<00:00, 14.48it/s, loss=0.0113, v_num=, reduced_train_loss=0.000654, global_step=96.00, consumed_samples=768.0]
Epoch 0:  88%|████████▊ | 98/111 [00:06<00:00, 14.48it/s, loss=0.0106, v_num=, reduced_train_loss=0.000605, global_step=97.00, consumed_samples=776.0]
Epoch 0:  89%|████████▉ | 99/111 [00:06<00:00, 14.52it/s, loss=0.0106, v_num=, reduced_train_loss=0.000605, global_step=97.00, consumed_samples=776.0]
Epoch 0:  89%|████████▉ | 99/111 [00:06<00:00, 14.52it/s, loss=0.0106, v_num=, reduced_train_loss=0.000655, global_step=98.00, consumed_samples=784.0]
Epoch 0:  90%|█████████ | 100/111 [00:06<00:00, 14.56it/s, loss=0.0106, v_num=, reduced_train_loss=0.000655, global_step=98.00, consumed_samples=784.0]
Epoch 0:  90%|█████████ | 100/111 [00:06<00:00, 14.56it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation: 0it [00:00, ?it/s]

Validation:   0%|          | 0/10 [00:00<?, ?it/s]

Validation DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]

Validation DataLoader 0:  10%|█         | 1/10 [00:00<00:00, 11.68it/s]
Epoch 0:  91%|█████████ | 101/111 [00:06<00:00, 14.48it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  20%|██        | 2/10 [00:00<00:00, 17.98it/s]
Epoch 0:  92%|█████████▏| 102/111 [00:06<00:00, 14.57it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  30%|███       | 3/10 [00:00<00:00, 24.02it/s]
Epoch 0:  93%|█████████▎| 103/111 [00:07<00:00, 14.69it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  40%|████      | 4/10 [00:00<00:00, 28.93it/s]
Epoch 0:  94%|█████████▎| 104/111 [00:07<00:00, 14.80it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  50%|█████     | 5/10 [00:00<00:00, 33.06it/s]
Epoch 0:  95%|█████████▍| 105/111 [00:07<00:00, 14.92it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  60%|██████    | 6/10 [00:00<00:00, 32.88it/s]
Epoch 0:  95%|█████████▌| 106/111 [00:07<00:00, 14.99it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  70%|███████   | 7/10 [00:00<00:00, 34.25it/s]
Epoch 0:  96%|█████████▋| 107/111 [00:07<00:00, 15.09it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  80%|████████  | 8/10 [00:00<00:00, 36.78it/s]
Epoch 0:  97%|█████████▋| 108/111 [00:07<00:00, 15.20it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0:  90%|█████████ | 9/10 [00:00<00:00, 39.04it/s]
Epoch 0:  98%|█████████▊| 109/111 [00:07<00:00, 15.31it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]

Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 41.05it/s]
Epoch 0:  99%|█████████▉| 110/111 [00:07<00:00, 15.43it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Epoch 0:  99%|█████████▉| 110/111 [00:07<00:00, 15.42it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0, val_loss=0.000501]

                                                                        Epoch 0, global step 100: 'val_loss' reached 0.00050 (best 0.00050), saving model to '/result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining/checkpoints/megatron_bert--val_loss=0.00-step=100-consumed_samples=800.0-v11.ckpt' as top 10

Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0, val_loss=0.000501]
Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00745, v_num=, reduced_train_loss=0.000713, global_step=100.0, consumed_samples=800.0, val_loss=0.000501]
Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00745, v_num=, reduced_train_loss=0.000713, global_step=100.0, consumed_samples=800.0, val_loss=0.000501]`Trainer.fit` stopped: `max_steps=101` reached.

Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00745, v_num=, reduced_train_loss=0.000713, global_step=100.0, consumed_samples=800.0, val_loss=0.000501]
[NeMo I 2023-08-25 18:46:57 nlp_overrides:226] Removing checkpoint: /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining/checkpoints/megatron_bert--val_loss=0.00-step=100-consumed_samples=800.0-last.ckpt
[NeMo I 2023-08-25 18:46:58 pretrain_oas:24] ************** Finished Training ***********

Creating the Dataset object#

Underneath the abstractions we provide, ultimately the dataset is constructed by invoking the relevant NeMo object, specified with model.data.data_impl in the config file. Additionally we provide the requisite keyword arguments, specified with model.data.data_impl_kwargs field. Look around in NeMo for additional dataset types, or implement your own!

We can do this manually as well!

dataset_paths = [ 
    '/data/OASpaired/processed/heavy/train/x000.csv' ,
    '/data/OASpaired/processed/heavy/train/x001.csv' ,
    '/data/OASpaired/processed/heavy/train/x002.csv' ,
]
# Checkout nemo for examples of other dataset types, or add your own!
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import CSVMemMapDataset
# The kwargs here are taken from our yaml file.
dataset = CSVMemMapDataset(dataset_paths=dataset_paths, header_lines=1, newline_int=10, workers=1, sort_dataset_paths=True, data_sep=',', data_col=1)

for i, item in enumerate(iter(dataset)):
    if i > 10: break
    print(item)
[NeMo W 2023-08-25 18:47:09 experimental:27] Module <class 'nemo.collections.nlp.models.text_normalization_as_tagging.thutmose_tagger.ThutmoseTaggerModel'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:47:10 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:343] Processing 3 data files using 1 workers
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:349] Time building 0 / 3 mem-mapped files: 0:00:00.051294
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x000.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x001.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x002.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:117] Time loading 3 mem-mapped files: 0:00:00.005260
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:121] Computing global indices
GGGAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGGCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGACGGGGGAATGATGGTTACTACGAAGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GAGCTCTGACAGAGGAGGCCAGTCCTGGAATTGATTCCCAGTTCCTCACGTTCAGTGATGAGCACTGAACACAGACACCTCACCATGAACTTTGGGCTCAGATTGATTTTCCTTGTCCTTACTTTAAAAGGTGTGAAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCGCTTTCAGTAGCTATGACATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCATACATTAGTAGTGGTGGTGGTATCACCTACTATCCAGACACTGTGAAGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAAGTCTGAGGACACAGCCATGTATTACTGTGCAAGGCCCCCGGGACGGGGCTACTGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGGTGCCTGGTCAAGGATTATT
AACATATGTCCAATGTCCTCTCCACAGACACTGAACACACTGACTCTAACCATGGGATGGAGCTGGATCTTTCTCTTCCTCCTGTCAGGAACTGCAGGCGTCCACTCTGAGGTCCAGCTTCAGCAGTCAGGACCTGAGCTGGTGAAACCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTTCTGGATACACATTCACTGACTACAACATGCACTGGGTGAAGCAGAGCCATGGAAAGAGCCTTGAGTGGATTGGATATATTTATCCTTACAATGGTGGTACTGGCTACAACCAGAAGTTCAAGAGCAAGGCCACATTGACTGTAGACAATTCCTCCAGCACAGCCTACATGGAGCTCCGCAGCCTGACATCTGAGGACTCTGCAGTCTATTACTGTGCAAGATGGGGGCTAACTGGTGATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GACATAACAGCAAGAGAGTGTCCGGTTAGTCTCAAGGAAGACTGAGACACAGTCTTAGATATCATGGAATGGCTGTGGAACTTGCTATTTCTCATGGCAGCAGCTCAAAGTATCCAAGCACAGATCCAGTTGGTGCAGTCTGGACCTGAGCTGAAGAAGCCTGGAGAGACAGTCAGGATCTCCTGCAAGGCTTCTGGGTATACCTTCACAACTGCTGGAATGCAGTGGGTGCAAAAGATGCCAGGAAAGGGTTTGAAGTGGATTGGCTGGATAAACACCCACTCTGGAGTGCCAAAATATGCAGAAGACTTCAAGGGACGGTTTGCCTTCTCTTTGGAAACCTCTGCCAGCACTGCATATTTACAGATAAGCAACCTCAAAAATGAGGACACGGCTACGTATTTCTGTGCGAGATCAGGTTACGACGCCTTTGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GGGGAGCATATGATCAGTGTCCTCTCCAAAGTCCTTGAACATAGACTCTAACCATGGAATGGACCTGGGTCTTTCTCTTCCTCCTGTCAGTAACTGCAGGTGTCCACTCCCAGGTTCAGCTGCAGCAGTCTGGAGTTGAGCTGATGAAGCCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTACTGGCTACACACTCAGTAACTACTGGATAGAGTGGGTAAAGCAGAGGCCTGGACATGGCCTTGAGTGGATTGGAGAGATTTTACCTGGAGATGTTATTACTAACTACAATGAGAGGTTCAAGGACAAGGCCACATTCACTGCAGATACATCCTCCAACACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGATTCTGCCGTCTATTACTGTGCAAGAAGGGTTATTAAGGGGGGGTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGATGCCTGGTCAAGG
ACATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGAAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
ATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGGAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GGGGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGAACTCACCATGGGATGGAGCTATATCATCCTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTGCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTATATGTACTGGGTGAAGCAGAGGCCTGGACAAGGCCTTGAGTGGATTGGGGGGATTAATCCTAGCAATGGTGGTACTAACTTCAATGAGAAGTTCAAGAGCAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAAGATACGGCCTCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
ACTAGTGTGCAGATATGGACAGGCTTACTTCCTCATTGCTGCTGCTGATTGTCCCTGCATATGTCCTGTCCCAGGTTACTCTGAAAGAGTCTGGCCCTGGGATATTGCAGCCCTCCCAGACCCTCAGTCTGACTTGTTCTTTCTCTGGGTTTTCACTGACCACTTCTGGTATGGGTGTGACCTGGATTCGTCAGCCTTCAGGAAAGGGTCTGGAGTGGCTGGCACACATTTACTGGGATAATGACAAGCGCTATAATACATCCCTGAAGAGCCGGCTCACAATCTCCAAGGATACCTCCAGCAACCGGGTATTCCTCAAGATAACCAGTGTGGACACTGCAGATACTGCCACATACTACTGTACTCGAGTTACTACGTTGGTGGGCTACTTTGACCAATGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGCCAAAACGACACCCCCATCTGTCTATCCACTGGCCCCTGGATCTGCTGCCCAAACTAACTCCATGGTGACCCTGGGATGCCTGGTCAAGGG
AACACCACCAACAACGACATCGACAATCATTCCCTACACAAAGCTCTTCCGATCTAAACGGGAGAATAGAGTCAATGATTTATTCTTATATGAGGAGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGACCTCACCATGGGATGGAGCTATATCATTTTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTCCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTGGATGCACTGGGTGAAGCTGAGGCCTGGACAAGGCTTTGAGTGGATTGGAGAGATTAATCCTAGCAATGGTGGTACTAACTACAATGAGAAGTTCAAGAGAAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAATACGGAATTACTACGGTAGTAGCTACGAGGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GTCTATGGCAGTTCCTATCTCTCTCACTGGAGGCTGATTTTTGAAGAAAGGGGTTGTAGCCTAAAAGATGATGGTGTTAAGTCTTCTGTACCTGTTGACAGCCCTTCCGGGTATCCTGTCAGAGGTGCAGCTTCAGGAGTCAGGACCTAGCCTCGTGAAACCTTCTCAGACTCTGTCCCTCACCTGTTCTGTCACTGGCGACTCCATCACCAGTGGTTACTGGAACTGGATCCGGAAATTCCCAGGGAATAAACTTGAGTACATGGGGTACATAAGCTACAGTGGTAGCACTTACTACAATCCATCTCTCAAAAGTCGAATCTCCATCACTCGAGACACATCCAAGAACCAGTACTACCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGATGGGACTATGACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG

Testing our new collate function#

Before we inject our collate function into a dataloader, lets first take a look at what it actually does. As we saw previously, it simply replaces every character with ‘A’, this should be visually obvious! To do this, we must also include a tokenizer. This is required by the default language modeling collate function, which tokenizes the input, applies padding, and aligns it for distributed training. We call collate_fn with some dummy data and then watch the output transformed to AAAA..

from bionemo.data.dataloader.custom_protein_collate import CustomProteinBertCollate

# Some magic to get our NeMo tokenizer, filled with arguments from our config file.
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
tokenizer = get_nmt_tokenizer(
            library='sentencepiece',
            tokenizer_model= '/tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model',
            vocab_file='/tokenizers/vocab/protein_sequence_sentencepiece.vocab',
            legacy=False,
)

# Extra kwargs are again taken from our config file.
collate_fn = CustomProteinBertCollate(tokenizer=tokenizer,
                                                    seq_length=512,
                                                    pad_size_divisible_by_8=True,
                                                    modify_percent=.1, # Fraction of tokens to mask or perturb
                                                    perturb_percent=.5, # Fraction of modified tokens to perturb, 1-perturb_percent is masking probability
                                                    ).collate_fn
collate_fn(['ACTGT', 'ADFASDFA'])
[NeMo I 2023-08-25 18:47:10 tokenizer_utils:191] Getting SentencePiece with model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
{'text': tensor([[1, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3],
         [1, 6, 6, 4, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3]]),
 'types': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'is_random': tensor([0, 1]),
 'loss_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
 'labels': tensor([[1, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3],
         [1, 6, 6, 6, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3]]),
 'padding_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
         [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]),
 'batch': ['AAAAA', 'AAAAAAAA']}

DataLoader!#

Lastly, we must construct a dataloader composed of our collate function and our dataset object. From here, we can iterate over the reuslt and ensure it changed the data in the same way as manually calling the collate function.

from torch.utils.data import DataLoader
print("Before:")
dl = DataLoader(dataset, batch_size=2, shuffle=False)
for i, item in enumerate(dl):
    if i > 10:
        break
    print(item)
print("\n\n\nAfter:")
dl = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
for i, item in enumerate(dl):
    if i > 10:
        break
    print(item)
Before:
['GGGAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGGCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGACGGGGGAATGATGGTTACTACGAAGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GAGCTCTGACAGAGGAGGCCAGTCCTGGAATTGATTCCCAGTTCCTCACGTTCAGTGATGAGCACTGAACACAGACACCTCACCATGAACTTTGGGCTCAGATTGATTTTCCTTGTCCTTACTTTAAAAGGTGTGAAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCGCTTTCAGTAGCTATGACATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCATACATTAGTAGTGGTGGTGGTATCACCTACTATCCAGACACTGTGAAGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAAGTCTGAGGACACAGCCATGTATTACTGTGCAAGGCCCCCGGGACGGGGCTACTGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGGTGCCTGGTCAAGGATTATT']
['AACATATGTCCAATGTCCTCTCCACAGACACTGAACACACTGACTCTAACCATGGGATGGAGCTGGATCTTTCTCTTCCTCCTGTCAGGAACTGCAGGCGTCCACTCTGAGGTCCAGCTTCAGCAGTCAGGACCTGAGCTGGTGAAACCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTTCTGGATACACATTCACTGACTACAACATGCACTGGGTGAAGCAGAGCCATGGAAAGAGCCTTGAGTGGATTGGATATATTTATCCTTACAATGGTGGTACTGGCTACAACCAGAAGTTCAAGAGCAAGGCCACATTGACTGTAGACAATTCCTCCAGCACAGCCTACATGGAGCTCCGCAGCCTGACATCTGAGGACTCTGCAGTCTATTACTGTGCAAGATGGGGGCTAACTGGTGATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GACATAACAGCAAGAGAGTGTCCGGTTAGTCTCAAGGAAGACTGAGACACAGTCTTAGATATCATGGAATGGCTGTGGAACTTGCTATTTCTCATGGCAGCAGCTCAAAGTATCCAAGCACAGATCCAGTTGGTGCAGTCTGGACCTGAGCTGAAGAAGCCTGGAGAGACAGTCAGGATCTCCTGCAAGGCTTCTGGGTATACCTTCACAACTGCTGGAATGCAGTGGGTGCAAAAGATGCCAGGAAAGGGTTTGAAGTGGATTGGCTGGATAAACACCCACTCTGGAGTGCCAAAATATGCAGAAGACTTCAAGGGACGGTTTGCCTTCTCTTTGGAAACCTCTGCCAGCACTGCATATTTACAGATAAGCAACCTCAAAAATGAGGACACGGCTACGTATTTCTGTGCGAGATCAGGTTACGACGCCTTTGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GGGGAGCATATGATCAGTGTCCTCTCCAAAGTCCTTGAACATAGACTCTAACCATGGAATGGACCTGGGTCTTTCTCTTCCTCCTGTCAGTAACTGCAGGTGTCCACTCCCAGGTTCAGCTGCAGCAGTCTGGAGTTGAGCTGATGAAGCCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTACTGGCTACACACTCAGTAACTACTGGATAGAGTGGGTAAAGCAGAGGCCTGGACATGGCCTTGAGTGGATTGGAGAGATTTTACCTGGAGATGTTATTACTAACTACAATGAGAGGTTCAAGGACAAGGCCACATTCACTGCAGATACATCCTCCAACACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGATTCTGCCGTCTATTACTGTGCAAGAAGGGTTATTAAGGGGGGGTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGATGCCTGGTCAAGG', 'ACATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGAAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['ATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGGAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GGGGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGAACTCACCATGGGATGGAGCTATATCATCCTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTGCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTATATGTACTGGGTGAAGCAGAGGCCTGGACAAGGCCTTGAGTGGATTGGGGGGATTAATCCTAGCAATGGTGGTACTAACTTCAATGAGAAGTTCAAGAGCAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAAGATACGGCCTCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['ACTAGTGTGCAGATATGGACAGGCTTACTTCCTCATTGCTGCTGCTGATTGTCCCTGCATATGTCCTGTCCCAGGTTACTCTGAAAGAGTCTGGCCCTGGGATATTGCAGCCCTCCCAGACCCTCAGTCTGACTTGTTCTTTCTCTGGGTTTTCACTGACCACTTCTGGTATGGGTGTGACCTGGATTCGTCAGCCTTCAGGAAAGGGTCTGGAGTGGCTGGCACACATTTACTGGGATAATGACAAGCGCTATAATACATCCCTGAAGAGCCGGCTCACAATCTCCAAGGATACCTCCAGCAACCGGGTATTCCTCAAGATAACCAGTGTGGACACTGCAGATACTGCCACATACTACTGTACTCGAGTTACTACGTTGGTGGGCTACTTTGACCAATGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGCCAAAACGACACCCCCATCTGTCTATCCACTGGCCCCTGGATCTGCTGCCCAAACTAACTCCATGGTGACCCTGGGATGCCTGGTCAAGGG', 'AACACCACCAACAACGACATCGACAATCATTCCCTACACAAAGCTCTTCCGATCTAAACGGGAGAATAGAGTCAATGATTTATTCTTATATGAGGAGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGACCTCACCATGGGATGGAGCTATATCATTTTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTCCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTGGATGCACTGGGTGAAGCTGAGGCCTGGACAAGGCTTTGAGTGGATTGGAGAGATTAATCCTAGCAATGGTGGTACTAACTACAATGAGAAGTTCAAGAGAAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAATACGGAATTACTACGGTAGTAGCTACGAGGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GTCTATGGCAGTTCCTATCTCTCTCACTGGAGGCTGATTTTTGAAGAAAGGGGTTGTAGCCTAAAAGATGATGGTGTTAAGTCTTCTGTACCTGTTGACAGCCCTTCCGGGTATCCTGTCAGAGGTGCAGCTTCAGGAGTCAGGACCTAGCCTCGTGAAACCTTCTCAGACTCTGTCCCTCACCTGTTCTGTCACTGGCGACTCCATCACCAGTGGTTACTGGAACTGGATCCGGAAATTCCCAGGGAATAAACTTGAGTACATGGGGTACATAAGCTACAGTGGTAGCACTTACTACAATCCATCTCTCAAAAGTCGAATCTCCATCACTCGAGACACATCCAAGAACCAGTACTACCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGATGGGACTATGACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'AGTTGCGTCTTTTCTTATATGGGATCCTCTTCTCATAGAGCCTCCATCAGAGCATGGCTGTCTTGGGGCTGCTCTTCTGCCTGGTGACATTCCCAAGCTGTGTCCTATCCCAGGTGCAGCTGAAGCAGTCAGGACCTGGCCTAGTGCAGCCCTCACAGAGCCTGTCCATCACCTGCACAGTCTCTGGTTTCTCATTAACTAGCTATGGTGTACACTGGGTTCGCCAGTCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTGATATGGAGTGGTGGAAGCACAGACTATAATGCAGCTTTCATATCCAGACTGAGCATCAGCAAGGACAATTCCAAGAGCCAAGTTTTCTTTAAAATGAACAGTCTGCAAGCTAATGACACAGCCATATATTACTGTGCCAGAAATTCGGGGGGGTATGGTAACTACGCCCTTTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GGGGAACAGACACACAAACCTGGACTCACAAGTTTTTCTCTTCAGTGACAGACACAGACATAGAACATTCACGATGTACTTGGGACTGAACTATGTATTCATAGTTTTTCTCTTAAATGGTGTCCAGAGTGAAGTGAAGCTTGAGGAGTCTGGAGGAGGCTTGGTGCAACCTGGAGGATCCATGAAACTCTCTTGTGCTGCCTCTGGATTCACTTTTAGTGACGCCTGGATGGACTGGGTCCGCCAGTCTCCAGAGAAGGGGCTTGAGTGGGTTGCTGAAATTAGAAGCAAAGCTAATAATCATGCAACATACTATGCTGAGTCTGTGAAAGGGAGGTTCACCATCTCAAGAGATGATTCCAAAAGTAGTGTCTACCTGCAAATGAACAGCTTAAGAGCTGAAGACACTGGCATTTATTACTGTACCCGGTATGGTAACTGGCGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GGAGCTCTGACAGAGGAGGCAGGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTTGTGCTCAGCTTGATTTTCCTTGCCCTCATTTTAAAAGGTGTCCAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGTCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGACTAGACCCAACTGGGAACTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['ATCTCCTCACTAGAGCCCCCATCAGAGCATGGCTGTCCTGGTGCTGTTCCTCTGCCTGGTTGCATTTCCAAGCTGTGTCCTGTCCCAGGTGCAACTGAAGGAGTCAGGACCTGGCCTGGTGGCGCCCTCACAGAGCCTGTCCATCACTTGCACTGTCTCTGGGTTTTCCTTAACCAGCTATGGTGTACACTGGGTTCGCCAGCCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTAATATGGGCTGGTGGAATCACAAATTATAATTCGGCTCTCATGTCCAGACTGAGCATCAGCAAAGACAACTCCAAGAGCCAAGTTTTCTTAAAAATGAACAGTCTGCAAACTGTTGACACAGCCATGTACTACTGTGCCAGAGATAGGGCCGGCTACTATGGTAACTACTTTGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGCCAAAACGACACCCCCATCTGTCTATCCACTGGCCCCTGGATCTGCTGCCCAAACTAACTCCATGGTGACCCTGGGATGCCTGGTCAAGGG', 'GACATACCAGCAAGGGAGTGACCAGTTTGTCTTAAGGCACCACTGAGCCCAAGTCTTAGACATCATGGATTGGCTGTGGAACTTGCTATTCCTGATGGCAGCTGCCCAAAGTGCCCAAGCACAGATCCAGTTGGTGCAGTCTGGACCTGAGCTGAAGAAGCCTGGAGAGACAGTCAAGATCTCCTGCAAGGCTTCTGGGTATACCTTCACAAACTATGGAATGAACTGGGTGAAGCAGGCTCCAGGAAAGGGTTTAAAGTGGATGGGCTGGATAAACACCTACACTGGAGAGCCAACATATGCTGATGACTTCAAGGGACGGTTTGCCTTCTCTTTGGAAACCTCTGCCAGCACTGCCTATTTGCAGATCAACAACCTCAAAAATGAGGACATGGCTACATATTTCTGTGCAAGAGGGGGTAGTAGCTACAGGGACTGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GAGCTCTGACAGAGGAGGCCAGTCCTGGAATTGATTCCCAGTTCCTCACGTTCAGTGATGAGCAGTGAACACAGACCCCTCACCATGAACTTCGGGCTCAGATTGATTTTCCTTGTCCTTACTTTAAAAGGTGTCCAGTGTGACGTGAAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATACCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAAGTCTGAGGACACAGCCATGTATTACTGTACAAGCTCCCACCCTGATTGGGGAGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'TGGGGAGCTCTGACAGAGGAGGCCGGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACAGACACCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTATTTTAAAAGGTGTCCAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGTCTCCAGAGAAGAGGCTGGAGTGGGTCGCAGAAATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACACTGTGACGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGGAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGGGATCAGTCTACTATGATTACGTCGTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['TCATCTCCTCACTAGAGCCCCCATCAGAGCATGGCTGTCCTGGTGCTGTTCCTCTGCCTGGTTGCATTTCCAAGCTGTGTCCTGTCCCAGGTGCAGCTGAAGGAGTCAGGACCTGGCCTGGTGGCGCCCTCACAGAGCCTGTCCATCACTTGCACTGTCTCTGGGTTTTCATTAACCAGCTATGGTGTACACTGGGTTCGCCAGCCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTAATATGGGCTGGTGGAAGCACAAATTATAATTCGGCTCTCATGTCCAGACTGAGCATCAGCAAAGACAACTCCAAGAGCCAAGTTTTCTTAAAAATGAACAGTCTGCAAACTGATGACACAGCCATGTACTACTGTGCCAGAGATCTATTCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'TATTTTTCTTATATGGGGATCCTCTTCTCATAGAGCCTCCATCAGAGCATGGCTGTCCTGGTGCTGCTCTTCTGCCTGGTGACATTCCCAAGCTGTGTCCTATCCCAGGTGCAGCTGAAGCAGTCAGGACCTGGCCTAGTGCAGCCCTCACAGAGCCTGTCCATCACCTGCACAGTCTCTGGTTTCTCATTAACTAGCTATGGTGTACACTGGGTTCGCCAGCCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTGATATGGAGTGGTGGAAGCACAGACTATAATGCTGCTTTCATATCCAGACTGAGCATCAGCAAGGACAACTCCAAGAGCCAAGTTTTCTTTAAAATGAACAGTCTGCAAGCTGATGACACAGCCATATACTACTGTGCCAGAAAGGCCCCCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['CTGAGCTTTCTTATATGGGGAGCTCTGACAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGGCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGAGGAGGGGACTATGGTAACGGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'ACAGTCATTGAAAACACTGACTCTAATCATGGAATGTAACTGGATACTTCCTTTTATTCTGTCGGTAATTTCAGGGGTCTACTCAGAGGTTCAGCTCCAGCAGTCTGGGACTGTGCTGGCAAGGCCTGGGGCTTCCGTGAAGATGTCCTGCAAGGCTTCTGGCTACAGCTTTACCAGCTACTGGATGCACTGGGTAAAACAGAGGCCTGGACAGGGTCTAGAATGGATTGGTGCTATTTATCCTGGAAATAGTGATACTAGCTACAACCAGAAGTTCAAGGGCAAGGCCAAACTGACTGCAGTCACATCCGCCAGCACTGCCTACATGGAGCTCAGCAGCCTGACAAATGAGGACTCTGCGGTCTATTACTGTACCCTTATGATTACGACGACGGTTTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']



After:
{'text': tensor([[ 1,  6, 13,  ...,  6,  6,  6],
        [16,  6,  6,  ...,  4,  9,  6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 1,  ..., 0, 0, 0],
        [1, 0, 0,  ..., 1, 1, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[ 1, 25,  6,  ...,  6,  6,  6],
        [ 1,  6,  6,  ...,  6,  6,  6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 1, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[1, 6, 6,  ..., 6, 4, 6],
        [1, 6, 4,  ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0,  ..., 0, 1, 0],
        [0, 0, 1,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[ 1,  6, 26,  ...,  6,  6,  6],
        [ 1,  6,  4,  ...,  6,  6,  6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 1,  ..., 0, 0, 0],
        [0, 0, 1,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[ 1,  6,  6,  ..., 22,  6,  6],
        [ 4,  6,  6,  ...,  6,  6,  6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0,  ..., 1, 0, 0],
        [1, 0, 0,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[ 1,  4,  6,  ...,  6,  6,  6],
        [ 1, 26,  6,  ...,  6,  6,  6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 1, 0,  ..., 0, 0, 0],
        [0, 1, 0,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[ 1,  4,  6,  ...,  6,  6,  6],
        [ 1,  6,  6,  ...,  6, 19,  6]]), 'types': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 1, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 1, 0]]), 'labels': tensor([[1, 6, 6,  ..., 6, 6, 6],
        [1, 6, 6,  ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1,  ..., 1, 1, 1],
        [1, 1, 1,  ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}

Conclusion and further reading.#

This concludes our tutorial on including custom data in the BioNeMo framework. Throughout these tutorials we described how to manually update a model with a new dataset, and how those changes propagate throughout the framework. Checkout other Dataset classes and tokenizers in NeMo to learn about further customization.