Adding the OAS Dataset: Customizing Dataset Object and Dataloader Functions
Contents
Adding the OAS Dataset: Customizing Dataset Object and Dataloader Functions#
This tutorial is the second part of a series focused on adding a new dataset to BioNeMo using the Observed Antibody Space (OAS) database. There are three steps to this task:
Preprocessing includes download of the raw data and any additional preparation steps, such as extracting the files. It also includes dividing the data into train, validation, and test splits. The preprocessing step can make use of two BioNeMo base classes, RemoteResource and ResourcePreprocessor, from bionemo.utils.remote and bionemo.data.preprocess, respectively. Their use is optional but they provide some basic functionality which can accelerate development. This step is covered by this tutorial. This objective was accomplished by the first tutorial, Downloading and Preprocessing.
Development of the new dataset class. Here, the NeMo dataset class CSVMemMapDataset will be used. This step was covered in the last tutorial, Modifying the Dataset Class.
Modification of the dataloader classes. This tutorial will cover customizing DataLoader objects using the newly created OAS datasets. This will include specifics on instantiating actual Dataset classes, customizing the collate function, and instantiating a dataloader. We will also review how these steps are executed within the BioNeMo model classes.
Setup and Assumptions#
This tutorial assumes that a copy of the BioNeMo framework repo exists on workstation or server and has been mounted inside the container at /workspace/bionemo as described in the Code Development section of the Quickstart Guide. This path will be referred to with the variable BIONEMO_WORKSPACE in the tutorial.
All commands should be executed inside the BioNeMo docker container.
BIONEMO_WORKSPACE = '/workspace/bionemo'
TUTORIAL_FILE_VERSION = 'step_999_final'
stage_files(TUTORIAL_FILE_VERSION, source_directory=f'{BIONEMO_WORKSPACE}/examples/oas_dataset')
Customizing a collate function#
In the last tutorial we saw how you can modify your yaml file to use a different set of data with existing tooling, in some cases, this isn’t enough. The collate_fn
parameter of pytorch DataLoaders if used for last minute adjustments to batches, including masking, shuffling, batching, padding, and other slight modifications to the input data. In BioNeMo, we build our collate function ontop of collators used for language modeling (bionemo/data/dataloader/collate.py
).
The collate function is ultimately injected into the dataloader upon construction. To customize further, we can simply extend the existing ProteinCollate
class with our own additional collation, followed by a call to the parents method.
# Copyright (c) 2022, NVIDIA CORPORATION.
# SPDX-License-Identifier: Apache-2.0
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import List
from nemo.collections.common.tokenizers import TokenizerSpec
from bionemo.data.dataloader.collate import (
BertCollate,
BertMasking,
SentencePieceTokenizerAdapter,
)
from bionemo.data.dataloader.protein_collate import ProteinBertCollate
__all__ = ['CustomProteinBertCollate']
class CustomProteinBertCollate(ProteinBertCollate):
def collate_fn(self, batch: List[str], label_pad: int = -1):
'''
Parent does things like add padding, mask, onehot transformations, etc.
in this case, we do the same thing but we sort our batch on values (strings)
does this do anything useful in pratice? maybe not, but thats okay.
This method ultimately will get injected into a DataLoader. We do this as a
part of the standard dataloader setup method inside our ESM1nv model, by instantiating
this class and then injecting the collate_fn.
'''
extra = [] # Handles odd cases
if len(batch) % 2 == 1:
batch, extra = batch[:-1]
back, front = batch[:len(batch)], batch[len(batch):]
batch = back + front + extra
new_batch = [ 'A' * len(seq) for seq in batch ]
return super().collate_fn(new_batch, label_pad)
Injecting a custom collate object into an existing model.#
The implemented collate function servers a single purpose, it replaces all characters with the character ‘A.’ This is both easy to implement and simple to check for correctness. Upon doing so, the batch is passed back into the parent collate function for padding and masking. Next, we will inject this into our esm1nv model to be applied to the dataset. You can see below that this occurs on the build_pretraining_data_loader
method, which primarily operates on an already existing Dataset object.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from typing import Dict, Optional
from omegaconf.dictconfig import DictConfig
from pytorch_lightning.trainer.trainer import Trainer
from nemo.core.neural_types import NeuralType
from nemo.collections.nlp.models.language_modeling.megatron_bert_model import MegatronBertModel
from nemo.collections.nlp.modules.common.megatron.utils import (
average_losses_across_data_parallel_group,
)
from nemo.utils import logging
from bionemo.model.protein.esm1nv.esm1nv_model import ESM1nvModel
from bionemo.data.molecule import megamolbart_build_train_valid_test_datasets
from bionemo.data.dataloader.custom_protein_collate import CustomProteinBertCollate
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
try:
from apex.transformer import tensor_parallel
HAVE_APEX = True
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False
__all__ = ["CustomESM1nvModel"]
class CustomESM1nvModel(ESM1nvModel):
""" CustomESM1nv model that extends the dataloader function to use our custom collate function.
Checkout the base class for more information on how it all fits together.
"""
def __init__(self, cfg: DictConfig, trainer: Trainer):
super().__init__(cfg, trainer=trainer)
def build_pretraining_data_loader(self, dataset, consumed_samples):
"""Buld dataloader given an input dataset."""
assert self._cfg.data.dataloader_type == 'single', AssertionError(
f'Only the Megatron sequential ("single") sampler is currently supported. {self._cfg.data.dataloader_type} was chosen.'
)
dataloader = super().build_pretraining_data_loader(dataset=dataset, consumed_samples=consumed_samples)
# Add collate function and unpin memory to avoid crash with CUDA misaligned address
dataloader.pin_memory = False # must be False with CSV dataset TODO check with binary
pad_size_divisible_by_8 = True if self._cfg.masked_softmax_fusion else False
dataloader.collate_fn = CustomProteinBertCollate(tokenizer=self.tokenizer,
seq_length=self._cfg.seq_length,
pad_size_divisible_by_8=pad_size_divisible_by_8,
modify_percent=self._cfg.data.modify_percent,
perturb_percent=self._cfg.data.perturb_percent,
).collate_fn
return dataloader
std_out = ! cd {BIONEMO_WORKSPACE}/examples/protein/esm1nv && python pretrain_oas.py ++trainer.max_steps=101
print('\n'.join(std_out))
[NeMo W 2023-08-25 18:46:43 experimental:27] Module <class 'nemo.collections.nlp.models.text_normalization_as_tagging.thutmose_tagger.ThutmoseTaggerModel'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:46:44 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/hydra/_internal/defaults_list.py:251: UserWarning: In 'pretrain_oas': Defaults list is missing `_self_`. See https://hydra.cc/docs/upgrades/1.0_to_1.1/default_composition_order for more information
warnings.warn(msg, UserWarning)
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/hydra/_internal/hydra.py:119: UserWarning: Future Hydra versions will no longer change working directory at job runtime by default.
See https://hydra.cc/docs/next/upgrades/1.1_to_1.2/changes_to_job_working_dir/ for more information.
ret = run_job(
[NeMo I 2023-08-25 18:46:44 pretrain_oas:12]
************** Experiment configuration ***********
[NeMo I 2023-08-25 18:46:44 pretrain_oas:13]
name: esm1nv-oas
do_training: true
do_testing: false
restore_from_path: null
trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: 16
logger: false
enable_checkpointing: false
replace_sampler_ddp: false
max_epochs: null
max_steps: 101
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 10
limit_test_batches: 500
accumulate_grad_batches: 1
gradient_clip_val: 1.0
benchmark: false
exp_manager:
name: ${name}
exp_dir: /result/nemo_experiments/${.name}/${.wandb_logger_kwargs.name}
explicit_log_dir: ${.exp_dir}
create_wandb_logger: false
create_tensorboard_logger: true
wandb_logger_kwargs:
project: ${name}_pretraining
name: ${name}_pretraining
group: ${name}
job_type: Localhost_nodes_${trainer.num_nodes}_gpus_${trainer.devices}
notes: 'date: ${now:%y%m%d-%H%M%S}'
tags:
- ${name}
offline: false
resume_if_exists: false
resume_ignore_no_checkpoint: true
create_checkpoint_callback: true
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: false
filename: megatron_bert--{val_loss:.2f}-{step}-{consumed_samples}
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
model:
micro_batch_size: 8
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
seq_length: 512
max_position_embeddings: ${.seq_length}
encoder_seq_length: ${.seq_length}
num_layers: 6
hidden_size: 768
ffn_hidden_size: 3072
num_attention_heads: 12
init_method_std: 0.02
hidden_dropout: 0.1
kv_channels: null
apply_query_key_layer_scaling: true
layernorm_epsilon: 1.0e-05
make_vocab_size_divisible_by: 128
pre_process: true
post_process: true
bert_binary_head: false
resume_from_checkpoint: null
masked_softmax_fusion: true
tokenizer:
library: sentencepiece
type: null
model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
vocab_file: /tokenizers/vocab/protein_sequence_sentencepiece.vocab
merge_file: null
native_amp_init_scale: 4294967296
native_amp_growth_interval: 1000
fp32_residual_connection: false
fp16_lm_cross_entropy: false
seed: 1234
use_cpu_initialization: false
onnx_safe: false
activations_checkpoint_method: null
activations_checkpoint_num_layers: 1
data:
ngc_registry_target: uniref50_2022_05
ngc_registry_version: v23.06
data_prefix: ''
num_workers: 10
dataloader_type: single
reset_position_ids: false
reset_attention_mask: false
eod_mask_loss: false
masked_lm_prob: 0.15
short_seq_prob: 0.1
skip_lines: 0
drop_last: false
pin_memory: false
data_impl: csv_mmap
data_impl_kwargs:
csv_mmap:
header_lines: 1
newline_int: 10
workers: ${model.data.num_workers}
sort_dataset_paths: true
data_sep: ','
data_col: 1
use_upsampling: true
seed: ${model.seed}
max_seq_length: ${model.seq_length}
dataset_path: /data/OASpaired/processed/heavy
dataset:
train: x[000..005]
test: x[000..001]
val: x[000..001]
micro_batch_size: ${model.micro_batch_size}
modify_percent: 0.1
perturb_percent: 0.5
optim:
name: fused_adam
lr: 0.0002
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2.0e-05
do_dataloader: false
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/plugins/precision/native_amp.py:131: LightningDeprecationWarning: The `NativeMixedPrecisionPlugin` class has been renamed in v1.9.0 and will be removed in v2.0.0. Please use `pytorch_lightning.plugins.MixedPrecisionPlugin` instead.
rank_zero_deprecation(
[NeMo I 2023-08-25 18:46:44 utils:168] Selected Callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]
Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[NeMo E 2023-08-25 18:46:44 exp_manager:646] exp_manager received explicit_log_dir: /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining and at least one of exp_dir: /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining, or version: None. Please note that exp_dir, name, and version will be ignored.
[NeMo W 2023-08-25 18:46:44 exp_manager:651] Exp_manager is logging to /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining, but it already exists.
[NeMo I 2023-08-25 18:46:44 exp_manager:374] Experiments will be logged at /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining
[NeMo I 2023-08-25 18:46:44 exp_manager:797] TensorboardLogger has been set up
[NeMo W 2023-08-25 18:46:44 exp_manager:893] The checkpoint callback was told to monitor a validation value and trainer's max_steps was set to 101. Please ensure that max_steps will run for at least 1 epochs to ensure that checkpointing will not error out.
[NeMo I 2023-08-25 18:46:44 utils:191] Resuming training from checkpoint: None
[NeMo I 2023-08-25 18:46:44 utils:234]
************** Trainer configuration ***********
[NeMo I 2023-08-25 18:46:44 utils:235]
name: esm1nv-oas
do_training: true
do_testing: false
restore_from_path: null
trainer:
devices: 1
num_nodes: 1
accelerator: gpu
precision: 16
logger: false
enable_checkpointing: false
replace_sampler_ddp: false
max_epochs: null
max_steps: 101
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: 10
limit_test_batches: 500
accumulate_grad_batches: 1
gradient_clip_val: 1.0
benchmark: false
exp_manager:
name: ${name}
exp_dir: /result/nemo_experiments/${.name}/${.wandb_logger_kwargs.name}
explicit_log_dir: ${.exp_dir}
create_wandb_logger: false
create_tensorboard_logger: true
wandb_logger_kwargs:
project: ${name}_pretraining
name: ${name}_pretraining
group: ${name}
job_type: Localhost_nodes_${trainer.num_nodes}_gpus_${trainer.devices}
notes: 'date: ${now:%y%m%d-%H%M%S}'
tags:
- ${name}
offline: false
resume_if_exists: false
resume_ignore_no_checkpoint: true
create_checkpoint_callback: true
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: false
filename: megatron_bert--{val_loss:.2f}-{step}-{consumed_samples}
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}
model:
micro_batch_size: 8
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1
seq_length: 512
max_position_embeddings: ${.seq_length}
encoder_seq_length: ${.seq_length}
num_layers: 6
hidden_size: 768
ffn_hidden_size: 3072
num_attention_heads: 12
init_method_std: 0.02
hidden_dropout: 0.1
kv_channels: null
apply_query_key_layer_scaling: true
layernorm_epsilon: 1.0e-05
make_vocab_size_divisible_by: 128
pre_process: true
post_process: true
bert_binary_head: false
resume_from_checkpoint: null
masked_softmax_fusion: true
tokenizer:
library: sentencepiece
type: null
model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
vocab_file: /tokenizers/vocab/protein_sequence_sentencepiece.vocab
merge_file: null
native_amp_init_scale: 4294967296
native_amp_growth_interval: 1000
fp32_residual_connection: false
fp16_lm_cross_entropy: false
seed: 1234
use_cpu_initialization: false
onnx_safe: false
activations_checkpoint_method: null
activations_checkpoint_num_layers: 1
data:
ngc_registry_target: uniref50_2022_05
ngc_registry_version: v23.06
data_prefix: ''
num_workers: 10
dataloader_type: single
reset_position_ids: false
reset_attention_mask: false
eod_mask_loss: false
masked_lm_prob: 0.15
short_seq_prob: 0.1
skip_lines: 0
drop_last: false
pin_memory: false
data_impl: csv_mmap
data_impl_kwargs:
csv_mmap:
header_lines: 1
newline_int: 10
workers: ${model.data.num_workers}
sort_dataset_paths: true
data_sep: ','
data_col: 1
use_upsampling: true
seed: ${model.seed}
max_seq_length: ${model.seq_length}
dataset_path: /data/OASpaired/processed/heavy
dataset:
train: x[000..005]
test: x[000..001]
val: x[000..001]
micro_batch_size: ${model.micro_batch_size}
modify_percent: 0.1
perturb_percent: 0.5
optim:
name: fused_adam
lr: 0.0002
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 2.0e-05
global_batch_size: 8
precision: 16
do_dataloader: false
[NeMo I 2023-08-25 18:46:44 pretrain_oas:19] ************** Starting Training ***********
[NeMo I 2023-08-25 18:46:44 megatron_init:231] Rank 0 has data parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:234] All data parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:235] Ranks 0 has data parallel rank: 0
[NeMo I 2023-08-25 18:46:44 megatron_init:243] Rank 0 has model parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:244] All model parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:254] Rank 0 has tensor model parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:258] All tensor model parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:259] Rank 0 has tensor model parallel rank: 0
[NeMo I 2023-08-25 18:46:44 megatron_init:273] Rank 0 has pipeline model parallel group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:285] Rank 0 has embedding group: [0]
[NeMo I 2023-08-25 18:46:44 megatron_init:291] All pipeline model parallel group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:292] Rank 0 has pipeline model parallel rank 0
[NeMo I 2023-08-25 18:46:44 megatron_init:293] All embedding group ranks: [[0]]
[NeMo I 2023-08-25 18:46:44 megatron_init:294] Rank 0 has embedding rank: 0
23-08-25 18:46:44 - PID:2297 - rank:(0, 0, 0, 0) - microbatches.py:39 - INFO - setting number of micro-batches to constant 1
[NeMo I 2023-08-25 18:46:44 tokenizer_utils:191] Getting SentencePiece with model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
[NeMo I 2023-08-25 18:46:44 megatron_base_model:229] Padded vocab_size: 128, original vocab_size: 30, dummy tokens: 98.
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py:175: UserWarning: The `batch_idx` argument in `CustomESM1nvModel.on_train_batch_start` hook may not match with the actual batch index when using a `dataloader_iter` argument in your `training_step`.
rank_zero_warn(
[NeMo W 2023-08-25 18:46:44 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/configuration_validator.py:175: UserWarning: The `batch_idx` argument in `CustomESM1nvModel.on_train_batch_end` hook may not match with the actual batch index when using a `dataloader_iter` argument in your `training_step`.
rank_zero_warn(
Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
Added key: store_based_barrier_key:1 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
----------------------------------------------------------------------------------------------------
distributed_backend=nccl
All distributed processes registered. Starting with 1 processes
----------------------------------------------------------------------------------------------------
Added key: store_based_barrier_key:2 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
Added key: store_based_barrier_key:3 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:3 with 1 nodes.
Added key: store_based_barrier_key:4 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:4 with 1 nodes.
Added key: store_based_barrier_key:5 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:5 with 1 nodes.
Added key: store_based_barrier_key:6 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:6 with 1 nodes.
Added key: store_based_barrier_key:7 to store for rank: 0
Rank 0: Completed store-based barrier for key:store_based_barrier_key:7 with 1 nodes.
[NeMo W 2023-08-25 18:46:45 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:613: UserWarning: Checkpoint directory /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining/checkpoints exists and is not empty.
rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
[NeMo I 2023-08-25 18:46:45 megatron_bert_model:563] Pipeline model parallel rank: 0, Tensor model parallel rank: 0, Number of model parameters on device: 4.36e+07. Total number of model parameters: 4.36e+07.
[NeMo I 2023-08-25 18:46:45 esm1nv_model:96] Building Bert datasets.
train:808
Loading data from /data/OASpaired/processed/heavy/train/x000.csv, /data/OASpaired/processed/heavy/train/x001.csv, /data/OASpaired/processed/heavy/train/x002.csv, /data/OASpaired/processed/heavy/train/x003.csv, /data/OASpaired/processed/heavy/train/x004.csv, /data/OASpaired/processed/heavy/train/x005.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:343] Processing 6 data files using 10 workers
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:349] Time building 0 / 6 mem-mapped files: 0:00:00.235856
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x000.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x001.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x002.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x003.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x004.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x005.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:117] Time loading 6 mem-mapped files: 0:00:00.002752
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:121] Computing global indices
[NeMo I 2023-08-25 18:46:45 dataset_utils:1341] > loading indexed mapping from /data/OASpaired/processed/heavy/train/__indexmap_808mns_512msl_0.00ssp_1234s.npy
[NeMo I 2023-08-25 18:46:45 dataset_utils:1344] loaded indexed file in 0.001 seconds
[NeMo I 2023-08-25 18:46:45 dataset_utils:1345] total number of samples: 21129
val:160
Loading data from /data/OASpaired/processed/heavy/val/x000.csv, /data/OASpaired/processed/heavy/val/x001.csv
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:46:45 text_memmap_dataset:343] Processing 2 data files using 10 workers
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:349] Time building 0 / 2 mem-mapped files: 0:00:00.231405
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/val/x000.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/val/x001.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:117] Time loading 2 mem-mapped files: 0:00:00.001060
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:121] Computing global indices
[NeMo I 2023-08-25 18:46:46 dataset_utils:1341] > loading indexed mapping from /data/OASpaired/processed/heavy/val/__indexmap_160mns_512msl_0.00ssp_1234s.npy
[NeMo I 2023-08-25 18:46:46 dataset_utils:1344] loaded indexed file in 0.000 seconds
[NeMo I 2023-08-25 18:46:46 dataset_utils:1345] total number of samples: 294
test:4000
Loading data from /data/OASpaired/processed/heavy/test/x000.csv, /data/OASpaired/processed/heavy/test/x001.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:343] Processing 2 data files using 10 workers
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:349] Time building 0 / 2 mem-mapped files: 0:00:00.246505
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/test/x000.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/test/x001.csv
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:117] Time loading 2 mem-mapped files: 0:00:00.001123
[NeMo I 2023-08-25 18:46:46 text_memmap_dataset:121] Computing global indices
[NeMo I 2023-08-25 18:46:46 dataset_utils:1341] > loading indexed mapping from /data/OASpaired/processed/heavy/test/__indexmap_4000mns_512msl_0.00ssp_1234s.npy
[NeMo I 2023-08-25 18:46:46 dataset_utils:1344] loaded indexed file in 0.000 seconds
[NeMo I 2023-08-25 18:46:46 dataset_utils:1345] total number of samples: 6499
[NeMo I 2023-08-25 18:46:46 esm1nv_model:114] Length of train dataset: 808
[NeMo I 2023-08-25 18:46:46 esm1nv_model:115] Length of val dataset: 160
[NeMo I 2023-08-25 18:46:46 esm1nv_model:116] Length of test dataset: 4000
[NeMo I 2023-08-25 18:46:46 esm1nv_model:117] Finished building Bert datasets.
[NeMo I 2023-08-25 18:46:46 megatron_bert_model:662] Setting up train dataloader with len(len(self._train_ds)): 808 and consumed samples: 0
[NeMo I 2023-08-25 18:46:46 data_samplers:76] Instantiating MegatronPretrainingSampler with total_samples: 808 and consumed_samples: 0
[NeMo I 2023-08-25 18:46:46 megatron_bert_model:670] Setting up validation dataloader with len(len(self._validation_ds)): 160 and consumed samples: 0
[NeMo I 2023-08-25 18:46:46 data_samplers:76] Instantiating MegatronPretrainingSampler with total_samples: 160 and consumed_samples: 0
[NeMo I 2023-08-25 18:46:46 megatron_bert_model:678] Setting up test dataloader with len(len(self._test_ds)): 4000 and consumed samples: 0
[NeMo I 2023-08-25 18:46:46 data_samplers:76] Instantiating MegatronPretrainingSampler with total_samples: 4000 and consumed_samples: 0
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
[NeMo I 2023-08-25 18:46:46 nlp_overrides:124] Configuring DDP for model parallelism.
[NeMo I 2023-08-25 18:46:46 modelPT:722] Optimizer config = FusedAdam (
Parameter Group 0
betas: [0.9, 0.98]
bias_correction: True
eps: 1e-08
lr: 0.0002
weight_decay: 0.01
Parameter Group 1
betas: [0.9, 0.98]
bias_correction: True
eps: 1e-08
lr: 0.0002
weight_decay: 0.0
)
[NeMo I 2023-08-25 18:46:46 lr_scheduler:910] Scheduler "<nemo.core.optim.lr_scheduler.CosineAnnealing object at 0x7f39a83d3400>"
will be used during training (effective maximum steps = 101) -
Parameters :
(warmup_steps: 500
constant_steps: 50000
min_lr: 2.0e-05
max_steps: 101
)
| Name | Type | Params
----------------------------------------------------------------------------
0 | model | BertModel | 43.6 M
1 | model.language_model | TransformerLanguageModel | 43.0 M
2 | model.language_model.embedding | Embedding | 491 K
3 | model.language_model.encoder | ParallelTransformer | 42.5 M
4 | model.lm_head | BertLMHead | 592 K
5 | model.lm_head.dense | Linear | 590 K
6 | model.lm_head.layernorm | MixedFusedLayerNorm | 1.5 K
----------------------------------------------------------------------------
43.6 M Trainable params
0 Non-trainable params
43.6 M Total params
87.225 Total estimated model params size (MB)
Sanity Checking: 0it [00:00, ?it/s][NeMo W 2023-08-25 18:46:46 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py:401: UserWarning: Found `dataloader_iter` argument in the `validation_step`. Note that the support for this signature is experimental and the behavior is subject to change.
rank_zero_warn(
Sanity Checking: 0%| | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
Sanity Checking DataLoader 0: 50%|█████ | 1/2 [00:01<00:01, 1.56s/it]
Sanity Checking DataLoader 0: 100%|██████████| 2/2 [00:01<00:00, 1.26it/s][NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('consumed_samples', ...)` in your `validation_epoch_end` but the value needs to be floating point. Converting it to torch.float32.
warning_cache.warn(
[NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:536: PossibleUserWarning: It is recommended to use `self.log('val_loss', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
warning_cache.warn(
[NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:536: PossibleUserWarning: It is recommended to use `self.log('val_loss_ECE', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
warning_cache.warn(
[NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:536: PossibleUserWarning: It is recommended to use `self.log('consumed_samples', ..., sync_dist=True)` when logging on epoch level in distributed setting to accumulate the metric across devices.
warning_cache.warn(
[NeMo W 2023-08-25 18:46:48 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/loops/fit_loop.py:344: UserWarning: Found `dataloader_iter` argument in the `training_step`. Note that the support for this signature is experimental and the behavior is subject to change.
rank_zero_warn(
Training: 0it [00:00, ?it/s]
Training: 0%| | 0/111 [00:00<?, ?it/s]
Epoch 0: 0%| | 0/111 [00:00<?, ?it/s] [NeMo W 2023-08-25 18:46:50 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('global_step', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
warning_cache.warn(
[NeMo W 2023-08-25 18:46:50 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/logger_connector/result.py:232: UserWarning: You called `self.log('consumed_samples', ...)` in your `training_step` but the value needs to be floating point. Converting it to torch.float32.
warning_cache.warn(
[NeMo W 2023-08-25 18:46:50 nemo_logging:349] /usr/local/lib/python3.8/dist-packages/torch/optim/lr_scheduler.py:139: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
warnings.warn("Detected call of `lr_scheduler.step()` before `optimizer.step()`. "
Epoch 0: 1%| | 1/111 [00:01<03:36, 1.97s/it]
Epoch 0: 1%| | 1/111 [00:01<03:36, 1.97s/it, loss=4.82, v_num=, reduced_train_loss=4.820, global_step=0.000, consumed_samples=0.000]
Epoch 0: 2%|▏ | 2/111 [00:02<01:49, 1.01s/it, loss=4.82, v_num=, reduced_train_loss=4.820, global_step=0.000, consumed_samples=0.000]
Epoch 0: 2%|▏ | 2/111 [00:02<01:49, 1.01s/it, loss=4.83, v_num=, reduced_train_loss=4.840, global_step=1.000, consumed_samples=8.000]
Epoch 0: 3%|▎ | 3/111 [00:02<01:15, 1.42it/s, loss=4.83, v_num=, reduced_train_loss=4.840, global_step=1.000, consumed_samples=8.000]
Epoch 0: 3%|▎ | 3/111 [00:02<01:15, 1.42it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=2.000, consumed_samples=16.00]
Epoch 0: 4%|▎ | 4/111 [00:02<00:57, 1.86it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=2.000, consumed_samples=16.00]
Epoch 0: 4%|▎ | 4/111 [00:02<00:57, 1.86it/s, loss=4.84, v_num=, reduced_train_loss=4.860, global_step=3.000, consumed_samples=24.00]
Epoch 0: 5%|▍ | 5/111 [00:02<00:46, 2.28it/s, loss=4.84, v_num=, reduced_train_loss=4.860, global_step=3.000, consumed_samples=24.00]
Epoch 0: 5%|▍ | 5/111 [00:02<00:46, 2.28it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=4.000, consumed_samples=32.00]
Epoch 0: 5%|▌ | 6/111 [00:02<00:39, 2.68it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=4.000, consumed_samples=32.00]
Epoch 0: 5%|▌ | 6/111 [00:02<00:39, 2.68it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=5.000, consumed_samples=40.00]
Epoch 0: 6%|▋ | 7/111 [00:02<00:33, 3.06it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=5.000, consumed_samples=40.00]
Epoch 0: 6%|▋ | 7/111 [00:02<00:33, 3.06it/s, loss=4.84, v_num=, reduced_train_loss=4.810, global_step=6.000, consumed_samples=48.00]
Epoch 0: 7%|▋ | 8/111 [00:02<00:29, 3.44it/s, loss=4.84, v_num=, reduced_train_loss=4.810, global_step=6.000, consumed_samples=48.00]
Epoch 0: 7%|▋ | 8/111 [00:02<00:29, 3.44it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=7.000, consumed_samples=56.00]
Epoch 0: 8%|▊ | 9/111 [00:02<00:26, 3.80it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=7.000, consumed_samples=56.00]
Epoch 0: 8%|▊ | 9/111 [00:02<00:26, 3.80it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=8.000, consumed_samples=64.00]
Epoch 0: 9%|▉ | 10/111 [00:02<00:24, 4.14it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=8.000, consumed_samples=64.00]
Epoch 0: 9%|▉ | 10/111 [00:02<00:24, 4.14it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=9.000, consumed_samples=72.00]
Epoch 0: 10%|▉ | 11/111 [00:02<00:22, 4.46it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=9.000, consumed_samples=72.00]
Epoch 0: 10%|▉ | 11/111 [00:02<00:22, 4.46it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=10.00, consumed_samples=80.00]
Epoch 0: 11%|█ | 12/111 [00:02<00:20, 4.78it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=10.00, consumed_samples=80.00]
Epoch 0: 11%|█ | 12/111 [00:02<00:20, 4.78it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=11.00, consumed_samples=88.00]
Epoch 0: 12%|█▏ | 13/111 [00:02<00:19, 5.07it/s, loss=4.84, v_num=, reduced_train_loss=4.850, global_step=11.00, consumed_samples=88.00]
Epoch 0: 12%|█▏ | 13/111 [00:02<00:19, 5.07it/s, loss=4.84, v_num=, reduced_train_loss=4.820, global_step=12.00, consumed_samples=96.00]
Epoch 0: 13%|█▎ | 14/111 [00:02<00:18, 5.37it/s, loss=4.84, v_num=, reduced_train_loss=4.820, global_step=12.00, consumed_samples=96.00]
Epoch 0: 13%|█▎ | 14/111 [00:02<00:18, 5.37it/s, loss=4.84, v_num=, reduced_train_loss=4.800, global_step=13.00, consumed_samples=104.0]
Epoch 0: 14%|█▎ | 15/111 [00:02<00:16, 5.66it/s, loss=4.84, v_num=, reduced_train_loss=4.800, global_step=13.00, consumed_samples=104.0]
Epoch 0: 14%|█▎ | 15/111 [00:02<00:16, 5.66it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=14.00, consumed_samples=112.0]
Epoch 0: 14%|█▍ | 16/111 [00:02<00:16, 5.93it/s, loss=4.84, v_num=, reduced_train_loss=4.840, global_step=14.00, consumed_samples=112.0]
Epoch 0: 14%|█▍ | 16/111 [00:02<00:16, 5.93it/s, loss=4.83, v_num=, reduced_train_loss=4.810, global_step=15.00, consumed_samples=120.0]
Epoch 0: 15%|█▌ | 17/111 [00:02<00:15, 6.20it/s, loss=4.83, v_num=, reduced_train_loss=4.810, global_step=15.00, consumed_samples=120.0]
Epoch 0: 15%|█▌ | 17/111 [00:02<00:15, 6.20it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=16.00, consumed_samples=128.0]
Epoch 0: 16%|█▌ | 18/111 [00:02<00:14, 6.46it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=16.00, consumed_samples=128.0]
Epoch 0: 16%|█▌ | 18/111 [00:02<00:14, 6.46it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=17.00, consumed_samples=136.0]
Epoch 0: 17%|█▋ | 19/111 [00:02<00:13, 6.71it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=17.00, consumed_samples=136.0]
Epoch 0: 17%|█▋ | 19/111 [00:02<00:13, 6.71it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=18.00, consumed_samples=144.0]
Epoch 0: 18%|█▊ | 20/111 [00:02<00:13, 6.96it/s, loss=4.83, v_num=, reduced_train_loss=4.830, global_step=18.00, consumed_samples=144.0]
Epoch 0: 18%|█▊ | 20/111 [00:02<00:13, 6.96it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=19.00, consumed_samples=152.0]
Epoch 0: 19%|█▉ | 21/111 [00:02<00:12, 7.19it/s, loss=4.83, v_num=, reduced_train_loss=4.820, global_step=19.00, consumed_samples=152.0]
Epoch 0: 19%|█▉ | 21/111 [00:02<00:12, 7.18it/s, loss=4.83, v_num=, reduced_train_loss=4.850, global_step=20.00, consumed_samples=160.0]
Epoch 0: 20%|█▉ | 22/111 [00:02<00:12, 7.41it/s, loss=4.83, v_num=, reduced_train_loss=4.850, global_step=20.00, consumed_samples=160.0]
Epoch 0: 20%|█▉ | 22/111 [00:02<00:12, 7.41it/s, loss=4.83, v_num=, reduced_train_loss=4.860, global_step=21.00, consumed_samples=168.0]
Epoch 0: 21%|██ | 23/111 [00:03<00:11, 7.63it/s, loss=4.83, v_num=, reduced_train_loss=4.860, global_step=21.00, consumed_samples=168.0]
Epoch 0: 21%|██ | 23/111 [00:03<00:11, 7.63it/s, loss=4.83, v_num=, reduced_train_loss=4.650, global_step=22.00, consumed_samples=176.0]
Epoch 0: 22%|██▏ | 24/111 [00:03<00:11, 7.83it/s, loss=4.83, v_num=, reduced_train_loss=4.650, global_step=22.00, consumed_samples=176.0]
Epoch 0: 22%|██▏ | 24/111 [00:03<00:11, 7.83it/s, loss=4.8, v_num=, reduced_train_loss=4.350, global_step=23.00, consumed_samples=184.0]
Epoch 0: 23%|██▎ | 25/111 [00:03<00:10, 8.02it/s, loss=4.8, v_num=, reduced_train_loss=4.350, global_step=23.00, consumed_samples=184.0]
Epoch 0: 23%|██▎ | 25/111 [00:03<00:10, 8.02it/s, loss=4.75, v_num=, reduced_train_loss=3.900, global_step=24.00, consumed_samples=192.0]
Epoch 0: 23%|██▎ | 26/111 [00:03<00:10, 8.21it/s, loss=4.75, v_num=, reduced_train_loss=3.900, global_step=24.00, consumed_samples=192.0]
Epoch 0: 23%|██▎ | 26/111 [00:03<00:10, 8.21it/s, loss=4.68, v_num=, reduced_train_loss=3.310, global_step=25.00, consumed_samples=200.0]
Epoch 0: 24%|██▍ | 27/111 [00:03<00:10, 8.39it/s, loss=4.68, v_num=, reduced_train_loss=3.310, global_step=25.00, consumed_samples=200.0]
Epoch 0: 24%|██▍ | 27/111 [00:03<00:10, 8.39it/s, loss=4.57, v_num=, reduced_train_loss=2.680, global_step=26.00, consumed_samples=208.0]
Epoch 0: 25%|██▌ | 28/111 [00:03<00:09, 8.56it/s, loss=4.57, v_num=, reduced_train_loss=2.680, global_step=26.00, consumed_samples=208.0]
Epoch 0: 25%|██▌ | 28/111 [00:03<00:09, 8.56it/s, loss=4.43, v_num=, reduced_train_loss=2.000, global_step=27.00, consumed_samples=216.0]
Epoch 0: 26%|██▌ | 29/111 [00:03<00:09, 8.75it/s, loss=4.43, v_num=, reduced_train_loss=2.000, global_step=27.00, consumed_samples=216.0]
Epoch 0: 26%|██▌ | 29/111 [00:03<00:09, 8.74it/s, loss=4.26, v_num=, reduced_train_loss=1.400, global_step=28.00, consumed_samples=224.0]
Epoch 0: 27%|██▋ | 30/111 [00:03<00:09, 8.92it/s, loss=4.26, v_num=, reduced_train_loss=1.400, global_step=28.00, consumed_samples=224.0]
Epoch 0: 27%|██▋ | 30/111 [00:03<00:09, 8.92it/s, loss=4.06, v_num=, reduced_train_loss=0.983, global_step=29.00, consumed_samples=232.0]
Epoch 0: 28%|██▊ | 31/111 [00:03<00:08, 9.09it/s, loss=4.06, v_num=, reduced_train_loss=0.983, global_step=29.00, consumed_samples=232.0]
Epoch 0: 28%|██▊ | 31/111 [00:03<00:08, 9.09it/s, loss=3.85, v_num=, reduced_train_loss=0.624, global_step=30.00, consumed_samples=240.0]
Epoch 0: 29%|██▉ | 32/111 [00:03<00:08, 9.26it/s, loss=3.85, v_num=, reduced_train_loss=0.624, global_step=30.00, consumed_samples=240.0]
Epoch 0: 29%|██▉ | 32/111 [00:03<00:08, 9.26it/s, loss=3.63, v_num=, reduced_train_loss=0.420, global_step=31.00, consumed_samples=248.0]
Epoch 0: 30%|██▉ | 33/111 [00:03<00:08, 9.43it/s, loss=3.63, v_num=, reduced_train_loss=0.420, global_step=31.00, consumed_samples=248.0]
Epoch 0: 30%|██▉ | 33/111 [00:03<00:08, 9.43it/s, loss=3.4, v_num=, reduced_train_loss=0.256, global_step=32.00, consumed_samples=256.0]
Epoch 0: 31%|███ | 34/111 [00:03<00:08, 9.57it/s, loss=3.4, v_num=, reduced_train_loss=0.256, global_step=32.00, consumed_samples=256.0]
Epoch 0: 31%|███ | 34/111 [00:03<00:08, 9.57it/s, loss=3.17, v_num=, reduced_train_loss=0.180, global_step=33.00, consumed_samples=264.0]
Epoch 0: 32%|███▏ | 35/111 [00:03<00:07, 9.72it/s, loss=3.17, v_num=, reduced_train_loss=0.180, global_step=33.00, consumed_samples=264.0]
Epoch 0: 32%|███▏ | 35/111 [00:03<00:07, 9.72it/s, loss=2.93, v_num=, reduced_train_loss=0.121, global_step=34.00, consumed_samples=272.0]
Epoch 0: 32%|███▏ | 36/111 [00:03<00:07, 9.86it/s, loss=2.93, v_num=, reduced_train_loss=0.121, global_step=34.00, consumed_samples=272.0]
Epoch 0: 32%|███▏ | 36/111 [00:03<00:07, 9.86it/s, loss=2.7, v_num=, reduced_train_loss=0.0845, global_step=35.00, consumed_samples=280.0]
Epoch 0: 33%|███▎ | 37/111 [00:03<00:07, 9.99it/s, loss=2.7, v_num=, reduced_train_loss=0.0845, global_step=35.00, consumed_samples=280.0]
Epoch 0: 33%|███▎ | 37/111 [00:03<00:07, 9.99it/s, loss=2.46, v_num=, reduced_train_loss=0.081, global_step=36.00, consumed_samples=288.0]
Epoch 0: 34%|███▍ | 38/111 [00:03<00:07, 10.14it/s, loss=2.46, v_num=, reduced_train_loss=0.081, global_step=36.00, consumed_samples=288.0]
Epoch 0: 34%|███▍ | 38/111 [00:03<00:07, 10.14it/s, loss=2.22, v_num=, reduced_train_loss=0.0499, global_step=37.00, consumed_samples=296.0]
Epoch 0: 35%|███▌ | 39/111 [00:03<00:06, 10.29it/s, loss=2.22, v_num=, reduced_train_loss=0.0499, global_step=37.00, consumed_samples=296.0]
Epoch 0: 35%|███▌ | 39/111 [00:03<00:06, 10.29it/s, loss=1.98, v_num=, reduced_train_loss=0.042, global_step=38.00, consumed_samples=304.0]
Epoch 0: 36%|███▌ | 40/111 [00:03<00:06, 10.42it/s, loss=1.98, v_num=, reduced_train_loss=0.042, global_step=38.00, consumed_samples=304.0]
Epoch 0: 36%|███▌ | 40/111 [00:03<00:06, 10.42it/s, loss=1.74, v_num=, reduced_train_loss=0.0138, global_step=39.00, consumed_samples=312.0]
Epoch 0: 37%|███▋ | 41/111 [00:03<00:06, 10.54it/s, loss=1.74, v_num=, reduced_train_loss=0.0138, global_step=39.00, consumed_samples=312.0]
Epoch 0: 37%|███▋ | 41/111 [00:03<00:06, 10.54it/s, loss=1.5, v_num=, reduced_train_loss=0.0586, global_step=40.00, consumed_samples=320.0]
Epoch 0: 38%|███▊ | 42/111 [00:03<00:06, 10.67it/s, loss=1.5, v_num=, reduced_train_loss=0.0586, global_step=40.00, consumed_samples=320.0]
Epoch 0: 38%|███▊ | 42/111 [00:03<00:06, 10.67it/s, loss=1.26, v_num=, reduced_train_loss=0.00855, global_step=41.00, consumed_samples=328.0]
Epoch 0: 39%|███▊ | 43/111 [00:03<00:06, 10.79it/s, loss=1.26, v_num=, reduced_train_loss=0.00855, global_step=41.00, consumed_samples=328.0]
Epoch 0: 39%|███▊ | 43/111 [00:03<00:06, 10.79it/s, loss=1.03, v_num=, reduced_train_loss=0.0298, global_step=42.00, consumed_samples=336.0]
Epoch 0: 40%|███▉ | 44/111 [00:04<00:06, 10.91it/s, loss=1.03, v_num=, reduced_train_loss=0.0298, global_step=42.00, consumed_samples=336.0]
Epoch 0: 40%|███▉ | 44/111 [00:04<00:06, 10.91it/s, loss=0.813, v_num=, reduced_train_loss=0.0272, global_step=43.00, consumed_samples=344.0]
Epoch 0: 41%|████ | 45/111 [00:04<00:05, 11.04it/s, loss=0.813, v_num=, reduced_train_loss=0.0272, global_step=43.00, consumed_samples=344.0]
Epoch 0: 41%|████ | 45/111 [00:04<00:05, 11.03it/s, loss=0.619, v_num=, reduced_train_loss=0.00627, global_step=44.00, consumed_samples=352.0]
Epoch 0: 41%|████▏ | 46/111 [00:04<00:05, 11.15it/s, loss=0.619, v_num=, reduced_train_loss=0.00627, global_step=44.00, consumed_samples=352.0]
Epoch 0: 41%|████▏ | 46/111 [00:04<00:05, 11.15it/s, loss=0.454, v_num=, reduced_train_loss=0.0058, global_step=45.00, consumed_samples=360.0]
Epoch 0: 42%|████▏ | 47/111 [00:04<00:05, 11.26it/s, loss=0.454, v_num=, reduced_train_loss=0.0058, global_step=45.00, consumed_samples=360.0]
Epoch 0: 42%|████▏ | 47/111 [00:04<00:05, 11.26it/s, loss=0.322, v_num=, reduced_train_loss=0.0466, global_step=46.00, consumed_samples=368.0]
Epoch 0: 43%|████▎ | 48/111 [00:04<00:05, 11.38it/s, loss=0.322, v_num=, reduced_train_loss=0.0466, global_step=46.00, consumed_samples=368.0]
Epoch 0: 43%|████▎ | 48/111 [00:04<00:05, 11.38it/s, loss=0.222, v_num=, reduced_train_loss=0.00506, global_step=47.00, consumed_samples=376.0]
Epoch 0: 44%|████▍ | 49/111 [00:04<00:05, 11.49it/s, loss=0.222, v_num=, reduced_train_loss=0.00506, global_step=47.00, consumed_samples=376.0]
Epoch 0: 44%|████▍ | 49/111 [00:04<00:05, 11.49it/s, loss=0.152, v_num=, reduced_train_loss=0.00475, global_step=48.00, consumed_samples=384.0]
Epoch 0: 45%|████▌ | 50/111 [00:04<00:05, 11.57it/s, loss=0.152, v_num=, reduced_train_loss=0.00475, global_step=48.00, consumed_samples=384.0]
Epoch 0: 45%|████▌ | 50/111 [00:04<00:05, 11.57it/s, loss=0.104, v_num=, reduced_train_loss=0.00427, global_step=49.00, consumed_samples=392.0]
Epoch 0: 46%|████▌ | 51/111 [00:04<00:05, 11.65it/s, loss=0.104, v_num=, reduced_train_loss=0.00427, global_step=49.00, consumed_samples=392.0]
Epoch 0: 46%|████▌ | 51/111 [00:04<00:05, 11.65it/s, loss=0.0735, v_num=, reduced_train_loss=0.0238, global_step=50.00, consumed_samples=400.0]
Epoch 0: 47%|████▋ | 52/111 [00:04<00:05, 11.72it/s, loss=0.0735, v_num=, reduced_train_loss=0.0238, global_step=50.00, consumed_samples=400.0]
Epoch 0: 47%|████▋ | 52/111 [00:04<00:05, 11.72it/s, loss=0.0536, v_num=, reduced_train_loss=0.0226, global_step=51.00, consumed_samples=408.0]
Epoch 0: 48%|████▊ | 53/111 [00:04<00:04, 11.82it/s, loss=0.0536, v_num=, reduced_train_loss=0.0226, global_step=51.00, consumed_samples=408.0]
Epoch 0: 48%|████▊ | 53/111 [00:04<00:04, 11.82it/s, loss=0.0419, v_num=, reduced_train_loss=0.0215, global_step=52.00, consumed_samples=416.0]
Epoch 0: 49%|████▊ | 54/111 [00:04<00:04, 11.93it/s, loss=0.0419, v_num=, reduced_train_loss=0.0215, global_step=52.00, consumed_samples=416.0]
Epoch 0: 49%|████▊ | 54/111 [00:04<00:04, 11.93it/s, loss=0.0339, v_num=, reduced_train_loss=0.022, global_step=53.00, consumed_samples=424.0]
Epoch 0: 50%|████▉ | 55/111 [00:04<00:04, 12.04it/s, loss=0.0339, v_num=, reduced_train_loss=0.022, global_step=53.00, consumed_samples=424.0]
Epoch 0: 50%|████▉ | 55/111 [00:04<00:04, 12.04it/s, loss=0.0304, v_num=, reduced_train_loss=0.0491, global_step=54.00, consumed_samples=432.0]
Epoch 0: 50%|█████ | 56/111 [00:04<00:04, 12.07it/s, loss=0.0304, v_num=, reduced_train_loss=0.0491, global_step=54.00, consumed_samples=432.0]
Epoch 0: 50%|█████ | 56/111 [00:04<00:04, 12.07it/s, loss=0.0263, v_num=, reduced_train_loss=0.00395, global_step=55.00, consumed_samples=440.0]
Epoch 0: 51%|█████▏ | 57/111 [00:04<00:04, 12.16it/s, loss=0.0263, v_num=, reduced_train_loss=0.00395, global_step=55.00, consumed_samples=440.0]
Epoch 0: 51%|█████▏ | 57/111 [00:04<00:04, 12.16it/s, loss=0.0225, v_num=, reduced_train_loss=0.00413, global_step=56.00, consumed_samples=448.0]
Epoch 0: 52%|█████▏ | 58/111 [00:04<00:04, 12.25it/s, loss=0.0225, v_num=, reduced_train_loss=0.00413, global_step=56.00, consumed_samples=448.0]
Epoch 0: 52%|█████▏ | 58/111 [00:04<00:04, 12.25it/s, loss=0.0217, v_num=, reduced_train_loss=0.0349, global_step=57.00, consumed_samples=456.0]
Epoch 0: 53%|█████▎ | 59/111 [00:04<00:04, 12.32it/s, loss=0.0217, v_num=, reduced_train_loss=0.0349, global_step=57.00, consumed_samples=456.0]
Epoch 0: 53%|█████▎ | 59/111 [00:04<00:04, 12.32it/s, loss=0.0214, v_num=, reduced_train_loss=0.0356, global_step=58.00, consumed_samples=464.0]
Epoch 0: 54%|█████▍ | 60/111 [00:04<00:04, 12.40it/s, loss=0.0214, v_num=, reduced_train_loss=0.0356, global_step=58.00, consumed_samples=464.0]
Epoch 0: 54%|█████▍ | 60/111 [00:04<00:04, 12.40it/s, loss=0.0217, v_num=, reduced_train_loss=0.0194, global_step=59.00, consumed_samples=472.0]
Epoch 0: 55%|█████▍ | 61/111 [00:04<00:04, 12.46it/s, loss=0.0217, v_num=, reduced_train_loss=0.0194, global_step=59.00, consumed_samples=472.0]
Epoch 0: 55%|█████▍ | 61/111 [00:04<00:04, 12.46it/s, loss=0.0198, v_num=, reduced_train_loss=0.0199, global_step=60.00, consumed_samples=480.0]
Epoch 0: 56%|█████▌ | 62/111 [00:04<00:03, 12.52it/s, loss=0.0198, v_num=, reduced_train_loss=0.0199, global_step=60.00, consumed_samples=480.0]
Epoch 0: 56%|█████▌ | 62/111 [00:04<00:03, 12.52it/s, loss=0.0203, v_num=, reduced_train_loss=0.0202, global_step=61.00, consumed_samples=488.0]
Epoch 0: 57%|█████▋ | 63/111 [00:05<00:03, 12.56it/s, loss=0.0203, v_num=, reduced_train_loss=0.0202, global_step=61.00, consumed_samples=488.0]
Epoch 0: 57%|█████▋ | 63/111 [00:05<00:03, 12.56it/s, loss=0.0204, v_num=, reduced_train_loss=0.0319, global_step=62.00, consumed_samples=496.0]
Epoch 0: 58%|█████▊ | 64/111 [00:05<00:03, 12.64it/s, loss=0.0204, v_num=, reduced_train_loss=0.0319, global_step=62.00, consumed_samples=496.0]
Epoch 0: 58%|█████▊ | 64/111 [00:05<00:03, 12.64it/s, loss=0.0194, v_num=, reduced_train_loss=0.00546, global_step=63.00, consumed_samples=504.0]
Epoch 0: 59%|█████▊ | 65/111 [00:05<00:03, 12.72it/s, loss=0.0194, v_num=, reduced_train_loss=0.00546, global_step=63.00, consumed_samples=504.0]
Epoch 0: 59%|█████▊ | 65/111 [00:05<00:03, 12.72it/s, loss=0.0199, v_num=, reduced_train_loss=0.0178, global_step=64.00, consumed_samples=512.0]
Epoch 0: 59%|█████▉ | 66/111 [00:05<00:03, 12.77it/s, loss=0.0199, v_num=, reduced_train_loss=0.0178, global_step=64.00, consumed_samples=512.0]
Epoch 0: 59%|█████▉ | 66/111 [00:05<00:03, 12.76it/s, loss=0.0205, v_num=, reduced_train_loss=0.018, global_step=65.00, consumed_samples=520.0]
Epoch 0: 60%|██████ | 67/111 [00:05<00:03, 12.85it/s, loss=0.0205, v_num=, reduced_train_loss=0.018, global_step=65.00, consumed_samples=520.0]
Epoch 0: 60%|██████ | 67/111 [00:05<00:03, 12.85it/s, loss=0.0183, v_num=, reduced_train_loss=0.00246, global_step=66.00, consumed_samples=528.0]
Epoch 0: 61%|██████▏ | 68/111 [00:05<00:03, 12.94it/s, loss=0.0183, v_num=, reduced_train_loss=0.00246, global_step=66.00, consumed_samples=528.0]
Epoch 0: 61%|██████▏ | 68/111 [00:05<00:03, 12.94it/s, loss=0.0195, v_num=, reduced_train_loss=0.0292, global_step=67.00, consumed_samples=536.0]
Epoch 0: 62%|██████▏ | 69/111 [00:05<00:03, 13.02it/s, loss=0.0195, v_num=, reduced_train_loss=0.0292, global_step=67.00, consumed_samples=536.0]
Epoch 0: 62%|██████▏ | 69/111 [00:05<00:03, 13.01it/s, loss=0.0211, v_num=, reduced_train_loss=0.0361, global_step=68.00, consumed_samples=544.0]
Epoch 0: 63%|██████▎ | 70/111 [00:05<00:03, 13.08it/s, loss=0.0211, v_num=, reduced_train_loss=0.0361, global_step=68.00, consumed_samples=544.0]
Epoch 0: 63%|██████▎ | 70/111 [00:05<00:03, 13.08it/s, loss=0.0218, v_num=, reduced_train_loss=0.0176, global_step=69.00, consumed_samples=552.0]
Epoch 0: 64%|██████▍ | 71/111 [00:05<00:03, 13.15it/s, loss=0.0218, v_num=, reduced_train_loss=0.0176, global_step=69.00, consumed_samples=552.0]
Epoch 0: 64%|██████▍ | 71/111 [00:05<00:03, 13.15it/s, loss=0.0207, v_num=, reduced_train_loss=0.00187, global_step=70.00, consumed_samples=560.0]
Epoch 0: 65%|██████▍ | 72/111 [00:05<00:02, 13.21it/s, loss=0.0207, v_num=, reduced_train_loss=0.00187, global_step=70.00, consumed_samples=560.0]
Epoch 0: 65%|██████▍ | 72/111 [00:05<00:02, 13.21it/s, loss=0.0212, v_num=, reduced_train_loss=0.0326, global_step=71.00, consumed_samples=568.0]
Epoch 0: 66%|██████▌ | 73/111 [00:05<00:02, 13.29it/s, loss=0.0212, v_num=, reduced_train_loss=0.0326, global_step=71.00, consumed_samples=568.0]
Epoch 0: 66%|██████▌ | 73/111 [00:05<00:02, 13.29it/s, loss=0.0209, v_num=, reduced_train_loss=0.0162, global_step=72.00, consumed_samples=576.0]
Epoch 0: 67%|██████▋ | 74/111 [00:05<00:02, 13.36it/s, loss=0.0209, v_num=, reduced_train_loss=0.0162, global_step=72.00, consumed_samples=576.0]
Epoch 0: 67%|██████▋ | 74/111 [00:05<00:02, 13.35it/s, loss=0.0206, v_num=, reduced_train_loss=0.0163, global_step=73.00, consumed_samples=584.0]
Epoch 0: 68%|██████▊ | 75/111 [00:05<00:02, 13.42it/s, loss=0.0206, v_num=, reduced_train_loss=0.0163, global_step=73.00, consumed_samples=584.0]
Epoch 0: 68%|██████▊ | 75/111 [00:05<00:02, 13.42it/s, loss=0.019, v_num=, reduced_train_loss=0.0169, global_step=74.00, consumed_samples=592.0]
Epoch 0: 68%|██████▊ | 76/111 [00:05<00:02, 13.49it/s, loss=0.019, v_num=, reduced_train_loss=0.0169, global_step=74.00, consumed_samples=592.0]
Epoch 0: 68%|██████▊ | 76/111 [00:05<00:02, 13.48it/s, loss=0.0192, v_num=, reduced_train_loss=0.00711, global_step=75.00, consumed_samples=600.0]
Epoch 0: 69%|██████▉ | 77/111 [00:05<00:02, 13.52it/s, loss=0.0192, v_num=, reduced_train_loss=0.00711, global_step=75.00, consumed_samples=600.0]
Epoch 0: 69%|██████▉ | 77/111 [00:05<00:02, 13.51it/s, loss=0.0197, v_num=, reduced_train_loss=0.015, global_step=76.00, consumed_samples=608.0]
Epoch 0: 70%|███████ | 78/111 [00:05<00:02, 13.58it/s, loss=0.0197, v_num=, reduced_train_loss=0.015, global_step=76.00, consumed_samples=608.0]
Epoch 0: 70%|███████ | 78/111 [00:05<00:02, 13.58it/s, loss=0.0187, v_num=, reduced_train_loss=0.0147, global_step=77.00, consumed_samples=616.0]
Epoch 0: 71%|███████ | 79/111 [00:05<00:02, 13.65it/s, loss=0.0187, v_num=, reduced_train_loss=0.0147, global_step=77.00, consumed_samples=616.0]
Epoch 0: 71%|███████ | 79/111 [00:05<00:02, 13.65it/s, loss=0.017, v_num=, reduced_train_loss=0.00149, global_step=78.00, consumed_samples=624.0]
Epoch 0: 72%|███████▏ | 80/111 [00:05<00:02, 13.71it/s, loss=0.017, v_num=, reduced_train_loss=0.00149, global_step=78.00, consumed_samples=624.0]
Epoch 0: 72%|███████▏ | 80/111 [00:05<00:02, 13.71it/s, loss=0.0169, v_num=, reduced_train_loss=0.0169, global_step=79.00, consumed_samples=632.0]
Epoch 0: 73%|███████▎ | 81/111 [00:05<00:02, 13.77it/s, loss=0.0169, v_num=, reduced_train_loss=0.0169, global_step=79.00, consumed_samples=632.0]
Epoch 0: 73%|███████▎ | 81/111 [00:05<00:02, 13.77it/s, loss=0.0183, v_num=, reduced_train_loss=0.0474, global_step=80.00, consumed_samples=640.0]
Epoch 0: 74%|███████▍ | 82/111 [00:05<00:02, 13.83it/s, loss=0.0183, v_num=, reduced_train_loss=0.0474, global_step=80.00, consumed_samples=640.0]
Epoch 0: 74%|███████▍ | 82/111 [00:05<00:02, 13.83it/s, loss=0.0173, v_num=, reduced_train_loss=0.00144, global_step=81.00, consumed_samples=648.0]
Epoch 0: 75%|███████▍ | 83/111 [00:05<00:02, 13.85it/s, loss=0.0173, v_num=, reduced_train_loss=0.00144, global_step=81.00, consumed_samples=648.0]
Epoch 0: 75%|███████▍ | 83/111 [00:05<00:02, 13.85it/s, loss=0.0168, v_num=, reduced_train_loss=0.0219, global_step=82.00, consumed_samples=656.0]
Epoch 0: 76%|███████▌ | 84/111 [00:06<00:01, 13.89it/s, loss=0.0168, v_num=, reduced_train_loss=0.0219, global_step=82.00, consumed_samples=656.0]
Epoch 0: 76%|███████▌ | 84/111 [00:06<00:01, 13.89it/s, loss=0.0177, v_num=, reduced_train_loss=0.0229, global_step=83.00, consumed_samples=664.0]
Epoch 0: 77%|███████▋ | 85/111 [00:06<00:01, 13.94it/s, loss=0.0177, v_num=, reduced_train_loss=0.0229, global_step=83.00, consumed_samples=664.0]
Epoch 0: 77%|███████▋ | 85/111 [00:06<00:01, 13.94it/s, loss=0.0184, v_num=, reduced_train_loss=0.031, global_step=84.00, consumed_samples=672.0]
Epoch 0: 77%|███████▋ | 86/111 [00:06<00:01, 13.99it/s, loss=0.0184, v_num=, reduced_train_loss=0.031, global_step=84.00, consumed_samples=672.0]
Epoch 0: 77%|███████▋ | 86/111 [00:06<00:01, 13.99it/s, loss=0.0186, v_num=, reduced_train_loss=0.022, global_step=85.00, consumed_samples=680.0]
Epoch 0: 78%|███████▊ | 87/111 [00:06<00:01, 14.02it/s, loss=0.0186, v_num=, reduced_train_loss=0.022, global_step=85.00, consumed_samples=680.0]
Epoch 0: 78%|███████▊ | 87/111 [00:06<00:01, 14.02it/s, loss=0.0186, v_num=, reduced_train_loss=0.00275, global_step=86.00, consumed_samples=688.0]
Epoch 0: 79%|███████▉ | 88/111 [00:06<00:01, 14.06it/s, loss=0.0186, v_num=, reduced_train_loss=0.00275, global_step=86.00, consumed_samples=688.0]
Epoch 0: 79%|███████▉ | 88/111 [00:06<00:01, 14.05it/s, loss=0.0174, v_num=, reduced_train_loss=0.00636, global_step=87.00, consumed_samples=696.0]
Epoch 0: 80%|████████ | 89/111 [00:06<00:01, 14.10it/s, loss=0.0174, v_num=, reduced_train_loss=0.00636, global_step=87.00, consumed_samples=696.0]
Epoch 0: 80%|████████ | 89/111 [00:06<00:01, 14.10it/s, loss=0.0168, v_num=, reduced_train_loss=0.0227, global_step=88.00, consumed_samples=704.0]
Epoch 0: 81%|████████ | 90/111 [00:06<00:01, 14.14it/s, loss=0.0168, v_num=, reduced_train_loss=0.0227, global_step=88.00, consumed_samples=704.0]
Epoch 0: 81%|████████ | 90/111 [00:06<00:01, 14.14it/s, loss=0.0159, v_num=, reduced_train_loss=0.000591, global_step=89.00, consumed_samples=712.0]
Epoch 0: 82%|████████▏ | 91/111 [00:06<00:01, 14.19it/s, loss=0.0159, v_num=, reduced_train_loss=0.000591, global_step=89.00, consumed_samples=712.0]
Epoch 0: 82%|████████▏ | 91/111 [00:06<00:01, 14.18it/s, loss=0.0158, v_num=, reduced_train_loss=0.000612, global_step=90.00, consumed_samples=720.0]
Epoch 0: 83%|████████▎ | 92/111 [00:06<00:01, 14.23it/s, loss=0.0158, v_num=, reduced_train_loss=0.000612, global_step=90.00, consumed_samples=720.0]
Epoch 0: 83%|████████▎ | 92/111 [00:06<00:01, 14.23it/s, loss=0.0147, v_num=, reduced_train_loss=0.0107, global_step=91.00, consumed_samples=728.0]
Epoch 0: 84%|████████▍ | 93/111 [00:06<00:01, 14.28it/s, loss=0.0147, v_num=, reduced_train_loss=0.0107, global_step=91.00, consumed_samples=728.0]
Epoch 0: 84%|████████▍ | 93/111 [00:06<00:01, 14.27it/s, loss=0.014, v_num=, reduced_train_loss=0.000628, global_step=92.00, consumed_samples=736.0]
Epoch 0: 85%|████████▍ | 94/111 [00:06<00:01, 14.32it/s, loss=0.014, v_num=, reduced_train_loss=0.000628, global_step=92.00, consumed_samples=736.0]
Epoch 0: 85%|████████▍ | 94/111 [00:06<00:01, 14.32it/s, loss=0.0132, v_num=, reduced_train_loss=0.000621, global_step=93.00, consumed_samples=744.0]
Epoch 0: 86%|████████▌ | 95/111 [00:06<00:01, 14.36it/s, loss=0.0132, v_num=, reduced_train_loss=0.000621, global_step=93.00, consumed_samples=744.0]
Epoch 0: 86%|████████▌ | 95/111 [00:06<00:01, 14.36it/s, loss=0.0124, v_num=, reduced_train_loss=0.000607, global_step=94.00, consumed_samples=752.0]
Epoch 0: 86%|████████▋ | 96/111 [00:06<00:01, 14.40it/s, loss=0.0124, v_num=, reduced_train_loss=0.000607, global_step=94.00, consumed_samples=752.0]
Epoch 0: 86%|████████▋ | 96/111 [00:06<00:01, 14.40it/s, loss=0.012, v_num=, reduced_train_loss=0.000597, global_step=95.00, consumed_samples=760.0]
Epoch 0: 87%|████████▋ | 97/111 [00:06<00:00, 14.44it/s, loss=0.012, v_num=, reduced_train_loss=0.000597, global_step=95.00, consumed_samples=760.0]
Epoch 0: 87%|████████▋ | 97/111 [00:06<00:00, 14.44it/s, loss=0.0113, v_num=, reduced_train_loss=0.000654, global_step=96.00, consumed_samples=768.0]
Epoch 0: 88%|████████▊ | 98/111 [00:06<00:00, 14.48it/s, loss=0.0113, v_num=, reduced_train_loss=0.000654, global_step=96.00, consumed_samples=768.0]
Epoch 0: 88%|████████▊ | 98/111 [00:06<00:00, 14.48it/s, loss=0.0106, v_num=, reduced_train_loss=0.000605, global_step=97.00, consumed_samples=776.0]
Epoch 0: 89%|████████▉ | 99/111 [00:06<00:00, 14.52it/s, loss=0.0106, v_num=, reduced_train_loss=0.000605, global_step=97.00, consumed_samples=776.0]
Epoch 0: 89%|████████▉ | 99/111 [00:06<00:00, 14.52it/s, loss=0.0106, v_num=, reduced_train_loss=0.000655, global_step=98.00, consumed_samples=784.0]
Epoch 0: 90%|█████████ | 100/111 [00:06<00:00, 14.56it/s, loss=0.0106, v_num=, reduced_train_loss=0.000655, global_step=98.00, consumed_samples=784.0]
Epoch 0: 90%|█████████ | 100/111 [00:06<00:00, 14.56it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation: 0it [00:00, ?it/s]
Validation: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 0%| | 0/10 [00:00<?, ?it/s]
Validation DataLoader 0: 10%|█ | 1/10 [00:00<00:00, 11.68it/s]
Epoch 0: 91%|█████████ | 101/111 [00:06<00:00, 14.48it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 20%|██ | 2/10 [00:00<00:00, 17.98it/s]
Epoch 0: 92%|█████████▏| 102/111 [00:06<00:00, 14.57it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 30%|███ | 3/10 [00:00<00:00, 24.02it/s]
Epoch 0: 93%|█████████▎| 103/111 [00:07<00:00, 14.69it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 40%|████ | 4/10 [00:00<00:00, 28.93it/s]
Epoch 0: 94%|█████████▎| 104/111 [00:07<00:00, 14.80it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 50%|█████ | 5/10 [00:00<00:00, 33.06it/s]
Epoch 0: 95%|█████████▍| 105/111 [00:07<00:00, 14.92it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 60%|██████ | 6/10 [00:00<00:00, 32.88it/s]
Epoch 0: 95%|█████████▌| 106/111 [00:07<00:00, 14.99it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 70%|███████ | 7/10 [00:00<00:00, 34.25it/s]
Epoch 0: 96%|█████████▋| 107/111 [00:07<00:00, 15.09it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 80%|████████ | 8/10 [00:00<00:00, 36.78it/s]
Epoch 0: 97%|█████████▋| 108/111 [00:07<00:00, 15.20it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 90%|█████████ | 9/10 [00:00<00:00, 39.04it/s]
Epoch 0: 98%|█████████▊| 109/111 [00:07<00:00, 15.31it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Validation DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 41.05it/s]
Epoch 0: 99%|█████████▉| 110/111 [00:07<00:00, 15.43it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0]
Epoch 0: 99%|█████████▉| 110/111 [00:07<00:00, 15.42it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0, val_loss=0.000501]
Epoch 0, global step 100: 'val_loss' reached 0.00050 (best 0.00050), saving model to '/result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining/checkpoints/megatron_bert--val_loss=0.00-step=100-consumed_samples=800.0-v11.ckpt' as top 10
Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00978, v_num=, reduced_train_loss=0.000861, global_step=99.00, consumed_samples=792.0, val_loss=0.000501]
Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00745, v_num=, reduced_train_loss=0.000713, global_step=100.0, consumed_samples=800.0, val_loss=0.000501]
Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00745, v_num=, reduced_train_loss=0.000713, global_step=100.0, consumed_samples=800.0, val_loss=0.000501]`Trainer.fit` stopped: `max_steps=101` reached.
Epoch 0: 100%|██████████| 111/111 [00:08<00:00, 13.07it/s, loss=0.00745, v_num=, reduced_train_loss=0.000713, global_step=100.0, consumed_samples=800.0, val_loss=0.000501]
[NeMo I 2023-08-25 18:46:57 nlp_overrides:226] Removing checkpoint: /result/nemo_experiments/esm1nv-oas/esm1nv-oas_pretraining/checkpoints/megatron_bert--val_loss=0.00-step=100-consumed_samples=800.0-last.ckpt
[NeMo I 2023-08-25 18:46:58 pretrain_oas:24] ************** Finished Training ***********
Creating the Dataset object#
Underneath the abstractions we provide, ultimately the dataset is constructed by invoking the relevant NeMo object, specified with model.data.data_impl
in the config file. Additionally we provide the requisite keyword arguments, specified with model.data.data_impl_kwargs
field. Look around in NeMo for additional dataset types, or implement your own!
We can do this manually as well!
dataset_paths = [
'/data/OASpaired/processed/heavy/train/x000.csv' ,
'/data/OASpaired/processed/heavy/train/x001.csv' ,
'/data/OASpaired/processed/heavy/train/x002.csv' ,
]
# Checkout nemo for examples of other dataset types, or add your own!
from nemo.collections.nlp.data.language_modeling.text_memmap_dataset import CSVMemMapDataset
# The kwargs here are taken from our yaml file.
dataset = CSVMemMapDataset(dataset_paths=dataset_paths, header_lines=1, newline_int=10, workers=1, sort_dataset_paths=True, data_sep=',', data_col=1)
for i, item in enumerate(iter(dataset)):
if i > 10: break
print(item)
[NeMo W 2023-08-25 18:47:09 experimental:27] Module <class 'nemo.collections.nlp.models.text_normalization_as_tagging.thutmose_tagger.ThutmoseTaggerModel'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo W 2023-08-25 18:47:10 experimental:27] Module <class 'nemo.collections.asr.modules.audio_modules.SpectrogramToMultichannelFeatures'> is experimental, not ready for production and is not fully supported. Use at your own risk.
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:104] Building data files
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:343] Processing 3 data files using 1 workers
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:349] Time building 0 / 3 mem-mapped files: 0:00:00.051294
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:114] Loading data files
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x000.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x001.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:205] Loading /data/OASpaired/processed/heavy/train/x002.csv
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:117] Time loading 3 mem-mapped files: 0:00:00.005260
[NeMo I 2023-08-25 18:47:10 text_memmap_dataset:121] Computing global indices
GGGAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGGCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGACGGGGGAATGATGGTTACTACGAAGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GAGCTCTGACAGAGGAGGCCAGTCCTGGAATTGATTCCCAGTTCCTCACGTTCAGTGATGAGCACTGAACACAGACACCTCACCATGAACTTTGGGCTCAGATTGATTTTCCTTGTCCTTACTTTAAAAGGTGTGAAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCGCTTTCAGTAGCTATGACATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCATACATTAGTAGTGGTGGTGGTATCACCTACTATCCAGACACTGTGAAGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAAGTCTGAGGACACAGCCATGTATTACTGTGCAAGGCCCCCGGGACGGGGCTACTGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGGTGCCTGGTCAAGGATTATT
AACATATGTCCAATGTCCTCTCCACAGACACTGAACACACTGACTCTAACCATGGGATGGAGCTGGATCTTTCTCTTCCTCCTGTCAGGAACTGCAGGCGTCCACTCTGAGGTCCAGCTTCAGCAGTCAGGACCTGAGCTGGTGAAACCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTTCTGGATACACATTCACTGACTACAACATGCACTGGGTGAAGCAGAGCCATGGAAAGAGCCTTGAGTGGATTGGATATATTTATCCTTACAATGGTGGTACTGGCTACAACCAGAAGTTCAAGAGCAAGGCCACATTGACTGTAGACAATTCCTCCAGCACAGCCTACATGGAGCTCCGCAGCCTGACATCTGAGGACTCTGCAGTCTATTACTGTGCAAGATGGGGGCTAACTGGTGATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GACATAACAGCAAGAGAGTGTCCGGTTAGTCTCAAGGAAGACTGAGACACAGTCTTAGATATCATGGAATGGCTGTGGAACTTGCTATTTCTCATGGCAGCAGCTCAAAGTATCCAAGCACAGATCCAGTTGGTGCAGTCTGGACCTGAGCTGAAGAAGCCTGGAGAGACAGTCAGGATCTCCTGCAAGGCTTCTGGGTATACCTTCACAACTGCTGGAATGCAGTGGGTGCAAAAGATGCCAGGAAAGGGTTTGAAGTGGATTGGCTGGATAAACACCCACTCTGGAGTGCCAAAATATGCAGAAGACTTCAAGGGACGGTTTGCCTTCTCTTTGGAAACCTCTGCCAGCACTGCATATTTACAGATAAGCAACCTCAAAAATGAGGACACGGCTACGTATTTCTGTGCGAGATCAGGTTACGACGCCTTTGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GGGGAGCATATGATCAGTGTCCTCTCCAAAGTCCTTGAACATAGACTCTAACCATGGAATGGACCTGGGTCTTTCTCTTCCTCCTGTCAGTAACTGCAGGTGTCCACTCCCAGGTTCAGCTGCAGCAGTCTGGAGTTGAGCTGATGAAGCCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTACTGGCTACACACTCAGTAACTACTGGATAGAGTGGGTAAAGCAGAGGCCTGGACATGGCCTTGAGTGGATTGGAGAGATTTTACCTGGAGATGTTATTACTAACTACAATGAGAGGTTCAAGGACAAGGCCACATTCACTGCAGATACATCCTCCAACACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGATTCTGCCGTCTATTACTGTGCAAGAAGGGTTATTAAGGGGGGGTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGATGCCTGGTCAAGG
ACATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGAAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
ATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGGAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GGGGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGAACTCACCATGGGATGGAGCTATATCATCCTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTGCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTATATGTACTGGGTGAAGCAGAGGCCTGGACAAGGCCTTGAGTGGATTGGGGGGATTAATCCTAGCAATGGTGGTACTAACTTCAATGAGAAGTTCAAGAGCAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAAGATACGGCCTCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
ACTAGTGTGCAGATATGGACAGGCTTACTTCCTCATTGCTGCTGCTGATTGTCCCTGCATATGTCCTGTCCCAGGTTACTCTGAAAGAGTCTGGCCCTGGGATATTGCAGCCCTCCCAGACCCTCAGTCTGACTTGTTCTTTCTCTGGGTTTTCACTGACCACTTCTGGTATGGGTGTGACCTGGATTCGTCAGCCTTCAGGAAAGGGTCTGGAGTGGCTGGCACACATTTACTGGGATAATGACAAGCGCTATAATACATCCCTGAAGAGCCGGCTCACAATCTCCAAGGATACCTCCAGCAACCGGGTATTCCTCAAGATAACCAGTGTGGACACTGCAGATACTGCCACATACTACTGTACTCGAGTTACTACGTTGGTGGGCTACTTTGACCAATGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGCCAAAACGACACCCCCATCTGTCTATCCACTGGCCCCTGGATCTGCTGCCCAAACTAACTCCATGGTGACCCTGGGATGCCTGGTCAAGGG
AACACCACCAACAACGACATCGACAATCATTCCCTACACAAAGCTCTTCCGATCTAAACGGGAGAATAGAGTCAATGATTTATTCTTATATGAGGAGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGACCTCACCATGGGATGGAGCTATATCATTTTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTCCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTGGATGCACTGGGTGAAGCTGAGGCCTGGACAAGGCTTTGAGTGGATTGGAGAGATTAATCCTAGCAATGGTGGTACTAACTACAATGAGAAGTTCAAGAGAAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAATACGGAATTACTACGGTAGTAGCTACGAGGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
GTCTATGGCAGTTCCTATCTCTCTCACTGGAGGCTGATTTTTGAAGAAAGGGGTTGTAGCCTAAAAGATGATGGTGTTAAGTCTTCTGTACCTGTTGACAGCCCTTCCGGGTATCCTGTCAGAGGTGCAGCTTCAGGAGTCAGGACCTAGCCTCGTGAAACCTTCTCAGACTCTGTCCCTCACCTGTTCTGTCACTGGCGACTCCATCACCAGTGGTTACTGGAACTGGATCCGGAAATTCCCAGGGAATAAACTTGAGTACATGGGGTACATAAGCTACAGTGGTAGCACTTACTACAATCCATCTCTCAAAAGTCGAATCTCCATCACTCGAGACACATCCAAGAACCAGTACTACCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGATGGGACTATGACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG
Testing our new collate function#
Before we inject our collate function into a dataloader, lets first take a look at what it actually does. As we saw previously, it simply replaces every character with ‘A’, this should be visually obvious! To do this, we must also include a tokenizer. This is required by the default language modeling collate function, which tokenizes the input, applies padding, and aligns it for distributed training. We call collate_fn
with some dummy data and then watch the output transformed to AAAA..
from bionemo.data.dataloader.custom_protein_collate import CustomProteinBertCollate
# Some magic to get our NeMo tokenizer, filled with arguments from our config file.
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
tokenizer = get_nmt_tokenizer(
library='sentencepiece',
tokenizer_model= '/tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model',
vocab_file='/tokenizers/vocab/protein_sequence_sentencepiece.vocab',
legacy=False,
)
# Extra kwargs are again taken from our config file.
collate_fn = CustomProteinBertCollate(tokenizer=tokenizer,
seq_length=512,
pad_size_divisible_by_8=True,
modify_percent=.1, # Fraction of tokens to mask or perturb
perturb_percent=.5, # Fraction of modified tokens to perturb, 1-perturb_percent is masking probability
).collate_fn
collate_fn(['ACTGT', 'ADFASDFA'])
[NeMo I 2023-08-25 18:47:10 tokenizer_utils:191] Getting SentencePiece with model: /tokenizers/protein/esm1nv/vocab/protein_sequence_sentencepiece.model
{'text': tensor([[1, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[1, 6, 6, 4, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3]]),
'types': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'is_random': tensor([0, 1]),
'loss_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),
'labels': tensor([[1, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[1, 6, 6, 6, 6, 6, 6, 6, 6, 2, 3, 3, 3, 3, 3, 3]]),
'padding_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]]),
'batch': ['AAAAA', 'AAAAAAAA']}
DataLoader!#
Lastly, we must construct a dataloader composed of our collate function and our dataset object. From here, we can iterate over the reuslt and ensure it changed the data in the same way as manually calling the collate function.
from torch.utils.data import DataLoader
print("Before:")
dl = DataLoader(dataset, batch_size=2, shuffle=False)
for i, item in enumerate(dl):
if i > 10:
break
print(item)
print("\n\n\nAfter:")
dl = DataLoader(dataset, batch_size=2, shuffle=False, collate_fn=collate_fn)
for i, item in enumerate(dl):
if i > 10:
break
print(item)
Before:
['GGGAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGGCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGACGGGGGAATGATGGTTACTACGAAGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GAGCTCTGACAGAGGAGGCCAGTCCTGGAATTGATTCCCAGTTCCTCACGTTCAGTGATGAGCACTGAACACAGACACCTCACCATGAACTTTGGGCTCAGATTGATTTTCCTTGTCCTTACTTTAAAAGGTGTGAAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCGCTTTCAGTAGCTATGACATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCATACATTAGTAGTGGTGGTGGTATCACCTACTATCCAGACACTGTGAAGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAAGTCTGAGGACACAGCCATGTATTACTGTGCAAGGCCCCCGGGACGGGGCTACTGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGGTGCCTGGTCAAGGATTATT']
['AACATATGTCCAATGTCCTCTCCACAGACACTGAACACACTGACTCTAACCATGGGATGGAGCTGGATCTTTCTCTTCCTCCTGTCAGGAACTGCAGGCGTCCACTCTGAGGTCCAGCTTCAGCAGTCAGGACCTGAGCTGGTGAAACCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTTCTGGATACACATTCACTGACTACAACATGCACTGGGTGAAGCAGAGCCATGGAAAGAGCCTTGAGTGGATTGGATATATTTATCCTTACAATGGTGGTACTGGCTACAACCAGAAGTTCAAGAGCAAGGCCACATTGACTGTAGACAATTCCTCCAGCACAGCCTACATGGAGCTCCGCAGCCTGACATCTGAGGACTCTGCAGTCTATTACTGTGCAAGATGGGGGCTAACTGGTGATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GACATAACAGCAAGAGAGTGTCCGGTTAGTCTCAAGGAAGACTGAGACACAGTCTTAGATATCATGGAATGGCTGTGGAACTTGCTATTTCTCATGGCAGCAGCTCAAAGTATCCAAGCACAGATCCAGTTGGTGCAGTCTGGACCTGAGCTGAAGAAGCCTGGAGAGACAGTCAGGATCTCCTGCAAGGCTTCTGGGTATACCTTCACAACTGCTGGAATGCAGTGGGTGCAAAAGATGCCAGGAAAGGGTTTGAAGTGGATTGGCTGGATAAACACCCACTCTGGAGTGCCAAAATATGCAGAAGACTTCAAGGGACGGTTTGCCTTCTCTTTGGAAACCTCTGCCAGCACTGCATATTTACAGATAAGCAACCTCAAAAATGAGGACACGGCTACGTATTTCTGTGCGAGATCAGGTTACGACGCCTTTGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GGGGAGCATATGATCAGTGTCCTCTCCAAAGTCCTTGAACATAGACTCTAACCATGGAATGGACCTGGGTCTTTCTCTTCCTCCTGTCAGTAACTGCAGGTGTCCACTCCCAGGTTCAGCTGCAGCAGTCTGGAGTTGAGCTGATGAAGCCTGGGGCCTCAGTGAAGATATCCTGCAAGGCTACTGGCTACACACTCAGTAACTACTGGATAGAGTGGGTAAAGCAGAGGCCTGGACATGGCCTTGAGTGGATTGGAGAGATTTTACCTGGAGATGTTATTACTAACTACAATGAGAGGTTCAAGGACAAGGCCACATTCACTGCAGATACATCCTCCAACACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGATTCTGCCGTCTATTACTGTGCAAGAAGGGTTATTAAGGGGGGGTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGCCAAAACAACAGCCCCATCGGTCTATCCACTGGCCCCTGTGTGTGGAGATACAACTGGCTCCTCGGTGACTCTAGGATGCCTGGTCAAGG', 'ACATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGAAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['ATCGCTCTCACTGGAGGCTGATCTCTGAAGATAAGGAGGTGTAGCCTAAAAGATGAGAGTGCTGATTCTTTTGTGGCTGTTCACAGCCTTTCCTGGTATCCTGTCTGATGTGCAGCTTCAGGAGTCGGGACCTGGCCTGGTGAAACCTTCTCAGTCTCTGTCCCTCACCTGCACTGTCACTGGCTACTCAATCACCAGTGATTATGCCTGGAACTGGATCCGGCAGTTTCCAGGAAACAAACTGGAGTGGATGGGCTACATAAGCTACAGTGGTAGCACTAGCTACAACCCATCTCTCAAAAGTCGAATCTCTATCACTCGAGACACATCCAAGAACCAGTTCTTCCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGGAGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GGGGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGAACTCACCATGGGATGGAGCTATATCATCCTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTGCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTATATGTACTGGGTGAAGCAGAGGCCTGGACAAGGCCTTGAGTGGATTGGGGGGATTAATCCTAGCAATGGTGGTACTAACTTCAATGAGAAGTTCAAGAGCAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAAGATACGGCCTCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['ACTAGTGTGCAGATATGGACAGGCTTACTTCCTCATTGCTGCTGCTGATTGTCCCTGCATATGTCCTGTCCCAGGTTACTCTGAAAGAGTCTGGCCCTGGGATATTGCAGCCCTCCCAGACCCTCAGTCTGACTTGTTCTTTCTCTGGGTTTTCACTGACCACTTCTGGTATGGGTGTGACCTGGATTCGTCAGCCTTCAGGAAAGGGTCTGGAGTGGCTGGCACACATTTACTGGGATAATGACAAGCGCTATAATACATCCCTGAAGAGCCGGCTCACAATCTCCAAGGATACCTCCAGCAACCGGGTATTCCTCAAGATAACCAGTGTGGACACTGCAGATACTGCCACATACTACTGTACTCGAGTTACTACGTTGGTGGGCTACTTTGACCAATGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGCCAAAACGACACCCCCATCTGTCTATCCACTGGCCCCTGGATCTGCTGCCCAAACTAACTCCATGGTGACCCTGGGATGCCTGGTCAAGGG', 'AACACCACCAACAACGACATCGACAATCATTCCCTACACAAAGCTCTTCCGATCTAAACGGGAGAATAGAGTCAATGATTTATTCTTATATGAGGAGAAAAACATGAGATCACAGTTCTCTCTACAGTTACTGAGCACACAGGACCTCACCATGGGATGGAGCTATATCATTTTCTTTTTGGTAGCAACAGCTACAGGTGTCCACTCCCAGGTCCAACTCCAGCAGCCTGGGGCTGAACTGGTGAAGCCTGGGGCTTCAGTGAAGTTGTCCTGCAAGGCTTCTGGCTACACCTTCACCAGCTACTGGATGCACTGGGTGAAGCTGAGGCCTGGACAAGGCTTTGAGTGGATTGGAGAGATTAATCCTAGCAATGGTGGTACTAACTACAATGAGAAGTTCAAGAGAAAGGCCACACTGACTGTAGACAAATCCTCCAGCACAGCCTACATGCAACTCAGCAGCCTGACATCTGAGGACTCTGCGGTCTATTACTGTACAATACGGAATTACTACGGTAGTAGCTACGAGGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GTCTATGGCAGTTCCTATCTCTCTCACTGGAGGCTGATTTTTGAAGAAAGGGGTTGTAGCCTAAAAGATGATGGTGTTAAGTCTTCTGTACCTGTTGACAGCCCTTCCGGGTATCCTGTCAGAGGTGCAGCTTCAGGAGTCAGGACCTAGCCTCGTGAAACCTTCTCAGACTCTGTCCCTCACCTGTTCTGTCACTGGCGACTCCATCACCAGTGGTTACTGGAACTGGATCCGGAAATTCCCAGGGAATAAACTTGAGTACATGGGGTACATAAGCTACAGTGGTAGCACTTACTACAATCCATCTCTCAAAAGTCGAATCTCCATCACTCGAGACACATCCAAGAACCAGTACTACCTGCAGTTGAATTCTGTGACTACTGAGGACACAGCCACATATTACTGTGCAAGATGGGACTATGACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'AGTTGCGTCTTTTCTTATATGGGATCCTCTTCTCATAGAGCCTCCATCAGAGCATGGCTGTCTTGGGGCTGCTCTTCTGCCTGGTGACATTCCCAAGCTGTGTCCTATCCCAGGTGCAGCTGAAGCAGTCAGGACCTGGCCTAGTGCAGCCCTCACAGAGCCTGTCCATCACCTGCACAGTCTCTGGTTTCTCATTAACTAGCTATGGTGTACACTGGGTTCGCCAGTCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTGATATGGAGTGGTGGAAGCACAGACTATAATGCAGCTTTCATATCCAGACTGAGCATCAGCAAGGACAATTCCAAGAGCCAAGTTTTCTTTAAAATGAACAGTCTGCAAGCTAATGACACAGCCATATATTACTGTGCCAGAAATTCGGGGGGGTATGGTAACTACGCCCTTTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GGGGAACAGACACACAAACCTGGACTCACAAGTTTTTCTCTTCAGTGACAGACACAGACATAGAACATTCACGATGTACTTGGGACTGAACTATGTATTCATAGTTTTTCTCTTAAATGGTGTCCAGAGTGAAGTGAAGCTTGAGGAGTCTGGAGGAGGCTTGGTGCAACCTGGAGGATCCATGAAACTCTCTTGTGCTGCCTCTGGATTCACTTTTAGTGACGCCTGGATGGACTGGGTCCGCCAGTCTCCAGAGAAGGGGCTTGAGTGGGTTGCTGAAATTAGAAGCAAAGCTAATAATCATGCAACATACTATGCTGAGTCTGTGAAAGGGAGGTTCACCATCTCAAGAGATGATTCCAAAAGTAGTGTCTACCTGCAAATGAACAGCTTAAGAGCTGAAGACACTGGCATTTATTACTGTACCCGGTATGGTAACTGGCGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'GGAGCTCTGACAGAGGAGGCAGGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTTGTGCTCAGCTTGATTTTCCTTGCCCTCATTTTAAAAGGTGTCCAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGTCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGACTAGACCCAACTGGGAACTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['ATCTCCTCACTAGAGCCCCCATCAGAGCATGGCTGTCCTGGTGCTGTTCCTCTGCCTGGTTGCATTTCCAAGCTGTGTCCTGTCCCAGGTGCAACTGAAGGAGTCAGGACCTGGCCTGGTGGCGCCCTCACAGAGCCTGTCCATCACTTGCACTGTCTCTGGGTTTTCCTTAACCAGCTATGGTGTACACTGGGTTCGCCAGCCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTAATATGGGCTGGTGGAATCACAAATTATAATTCGGCTCTCATGTCCAGACTGAGCATCAGCAAAGACAACTCCAAGAGCCAAGTTTTCTTAAAAATGAACAGTCTGCAAACTGTTGACACAGCCATGTACTACTGTGCCAGAGATAGGGCCGGCTACTATGGTAACTACTTTGACTACTGGGGCCAAGGCACCACTCTCACAGTCTCCTCAGCCAAAACGACACCCCCATCTGTCTATCCACTGGCCCCTGGATCTGCTGCCCAAACTAACTCCATGGTGACCCTGGGATGCCTGGTCAAGGG', 'GACATACCAGCAAGGGAGTGACCAGTTTGTCTTAAGGCACCACTGAGCCCAAGTCTTAGACATCATGGATTGGCTGTGGAACTTGCTATTCCTGATGGCAGCTGCCCAAAGTGCCCAAGCACAGATCCAGTTGGTGCAGTCTGGACCTGAGCTGAAGAAGCCTGGAGAGACAGTCAAGATCTCCTGCAAGGCTTCTGGGTATACCTTCACAAACTATGGAATGAACTGGGTGAAGCAGGCTCCAGGAAAGGGTTTAAAGTGGATGGGCTGGATAAACACCTACACTGGAGAGCCAACATATGCTGATGACTTCAAGGGACGGTTTGCCTTCTCTTTGGAAACCTCTGCCAGCACTGCCTATTTGCAGATCAACAACCTCAAAAATGAGGACATGGCTACATATTTCTGTGCAAGAGGGGGTAGTAGCTACAGGGACTGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['GAGCTCTGACAGAGGAGGCCAGTCCTGGAATTGATTCCCAGTTCCTCACGTTCAGTGATGAGCAGTGAACACAGACCCCTCACCATGAACTTCGGGCTCAGATTGATTTTCCTTGTCCTTACTTTAAAAGGTGTCCAGTGTGACGTGAAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATACCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAAGTCTGAGGACACAGCCATGTATTACTGTACAAGCTCCCACCCTGATTGGGGAGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'TGGGGAGCTCTGACAGAGGAGGCCGGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACAGACACCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTATTTTAAAAGGTGTCCAGTGTGAAGTGCAGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGTCTCCAGAGAAGAGGCTGGAGTGGGTCGCAGAAATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACACTGTGACGGGCCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGGAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGGGATCAGTCTACTATGATTACGTCGTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['TCATCTCCTCACTAGAGCCCCCATCAGAGCATGGCTGTCCTGGTGCTGTTCCTCTGCCTGGTTGCATTTCCAAGCTGTGTCCTGTCCCAGGTGCAGCTGAAGGAGTCAGGACCTGGCCTGGTGGCGCCCTCACAGAGCCTGTCCATCACTTGCACTGTCTCTGGGTTTTCATTAACCAGCTATGGTGTACACTGGGTTCGCCAGCCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTAATATGGGCTGGTGGAAGCACAAATTATAATTCGGCTCTCATGTCCAGACTGAGCATCAGCAAAGACAACTCCAAGAGCCAAGTTTTCTTAAAAATGAACAGTCTGCAAACTGATGACACAGCCATGTACTACTGTGCCAGAGATCTATTCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'TATTTTTCTTATATGGGGATCCTCTTCTCATAGAGCCTCCATCAGAGCATGGCTGTCCTGGTGCTGCTCTTCTGCCTGGTGACATTCCCAAGCTGTGTCCTATCCCAGGTGCAGCTGAAGCAGTCAGGACCTGGCCTAGTGCAGCCCTCACAGAGCCTGTCCATCACCTGCACAGTCTCTGGTTTCTCATTAACTAGCTATGGTGTACACTGGGTTCGCCAGCCTCCAGGAAAGGGTCTGGAGTGGCTGGGAGTGATATGGAGTGGTGGAAGCACAGACTATAATGCTGCTTTCATATCCAGACTGAGCATCAGCAAGGACAACTCCAAGAGCCAAGTTTTCTTTAAAATGAACAGTCTGCAAGCTGATGACACAGCCATATACTACTGTGCCAGAAAGGCCCCCTATGCTATGGACTACTGGGGTCAAGGAACCTCAGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
['CTGAGCTTTCTTATATGGGGAGCTCTGACAGAGGAGGCCTGTCCTGGATTCGATTCCCAGTTCCTCACATTCAGTCAGCACTGAACACGGACCCCTCACCATGAACTTCGGGCTCAGCTTGATTTTCCTTGTCCTTGTTTTAAAAGGTGTCCAGTGTGAAGTGATGCTGGTGGAGTCTGGGGGAGGCTTAGTGAAGCCTGGAGGGTCCCTGAAACTCTCCTGTGCAGCCTCTGGATTCACTTTCAGTAGCTATGCCATGTCTTGGGTTCGCCAGACTCCGGAGAAGAGGCTGGAGTGGGTCGCAACCATTAGTAGTGGTGGTAGTTACACCTACTATCCAGACAGTGTGAAGGGGCGATTCACCATCTCCAGAGACAATGCCAAGAACACCCTGTACCTGCAAATGAGCAGTCTGAGGTCTGAGGACACGGCCATGTATTACTGTGCAAGAGGAGGGGACTATGGTAACGGGTACTTCGATGTCTGGGGCGCAGGGACCACGGTCACCGTCTCCTCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG', 'ACAGTCATTGAAAACACTGACTCTAATCATGGAATGTAACTGGATACTTCCTTTTATTCTGTCGGTAATTTCAGGGGTCTACTCAGAGGTTCAGCTCCAGCAGTCTGGGACTGTGCTGGCAAGGCCTGGGGCTTCCGTGAAGATGTCCTGCAAGGCTTCTGGCTACAGCTTTACCAGCTACTGGATGCACTGGGTAAAACAGAGGCCTGGACAGGGTCTAGAATGGATTGGTGCTATTTATCCTGGAAATAGTGATACTAGCTACAACCAGAAGTTCAAGGGCAAGGCCAAACTGACTGCAGTCACATCCGCCAGCACTGCCTACATGGAGCTCAGCAGCCTGACAAATGAGGACTCTGCGGTCTATTACTGTACCCTTATGATTACGACGACGGTTTTTGCTTACTGGGGCCAAGGGACTCTGGTCACTGTCTCTGCAGAGAGTCAGTCCTTCCCAAATGTCTTCCCCCTCGTCTCCTGCGAGAGCCCCCTGTCTGATAAGAATCTGGTGGCCATGGGCTGCCTGG']
After:
{'text': tensor([[ 1, 6, 13, ..., 6, 6, 6],
[16, 6, 6, ..., 4, 9, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 1, ..., 0, 0, 0],
[1, 0, 0, ..., 1, 1, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[ 1, 25, 6, ..., 6, 6, 6],
[ 1, 6, 6, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 1, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[1, 6, 6, ..., 6, 4, 6],
[1, 6, 4, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0, ..., 0, 1, 0],
[0, 0, 1, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch': ['AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA', 'AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA']}
{'text': tensor([[ 1, 6, 26, ..., 6, 6, 6],
[ 1, 6, 4, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 1, ..., 0, 0, 0],
[0, 0, 1, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[ 1, 6, 6, ..., 22, 6, 6],
[ 4, 6, 6, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0, ..., 1, 0, 0],
[1, 0, 0, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[ 1, 4, 6, ..., 6, 6, 6],
[ 1, 26, 6, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 1, 0, ..., 0, 0, 0],
[0, 1, 0, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
{'text': tensor([[ 1, 4, 6, ..., 6, 6, 6],
[ 1, 6, 6, ..., 6, 19, 6]]), 'types': tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]]), 'is_random': tensor([0, 1]), 'loss_mask': tensor([[0, 1, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 1, 0]]), 'labels': tensor([[1, 6, 6, ..., 6, 6, 6],
[1, 6, 6, ..., 6, 6, 6]]), 'padding_mask': tensor([[1, 1, 1, ..., 1, 1, 1],
[1, 1, 1, ..., 1, 1, 1]]), 'batch}
Conclusion and further reading.#
This concludes our tutorial on including custom data in the BioNeMo framework. Throughout these tutorials we described how to manually update a model with a new dataset, and how those changes propagate throughout the framework. Checkout other Dataset classes and tokenizers in NeMo to learn about further customization.