Source code for nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset

# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import inspect
import json
import multiprocessing as mp
import os
import pickle
import re
import shutil
import tempfile
from collections import deque
from pathlib import Path
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, Union

import numpy as np
import torch
import webdataset as wds
from joblib import Parallel, delayed
from omegaconf import DictConfig
from torch.utils.data import IterableDataset

from nemo.collections.common.tokenizers import TokenizerSpec
from nemo.collections.nlp.data.token_classification.punctuation_capitalization_dataset import (
    LABEL_ID_DIR_FOR_NEMO_CHECKPOINT,
    BertPunctuationCapitalizationDataset,
    Progress,
    create_label_ids,
    create_masks_and_segment_ids,
    load_label_ids,
    raise_not_equal_labels_error,
)
from nemo.collections.nlp.modules.common.tokenizer_utils import get_tokenizer
from nemo.core.neural_types import ChannelType, LabelsType, MaskType, NeuralType
from nemo.utils import logging

NUMBER_RE = "(0|[1-9][0-9]*)"
TAR_FRAGMENT_TMPL_IN_PROGRESS = "fragment{fragment_idx}.{file_idx}.tar"
TAR_FRAGMENT_TMPL_FINISHED = "fragment{fragment_idx}.num_batches{num_batches}.{file_idx}.tar"
TAR_FRAGMENT_TMPL_TO_REPACK = "fragment{fragment_idx}.num_batches{num_batches}.{file_idx}.tar.to_repack"
TAR_FRAGMENT_PATTERN_IN_PROGRESS = re.compile(f"fragment{NUMBER_RE}.{NUMBER_RE}.tar$")
TAR_FRAGMENT_PATTERN_FINISHED = re.compile(f"fragment{NUMBER_RE}.num_batches{NUMBER_RE}.{NUMBER_RE}.tar$")
TAR_FRAGMENT_PATTERN_TO_REPACK = re.compile(f"fragment{NUMBER_RE}.num_batches{NUMBER_RE}.{NUMBER_RE}.tar.to_repack$")
NOT_ALLOWED_CHARACTERS_IN_FILE_NAME = re.compile(f"[^a-zA-Z0-9_.-]")
REPLACE_NOT_ALLOWED_CHARACTERS_IN_FILE_NAME = re.compile(f"-*[^a-zA-Z0-9_.-]+-*")

DATASET_PARAMETERS_TMPL = "{prefix}.tokens{tokens_in_batch}.max_seq_length{max_seq_length}.{tokenizer}"
TAR_FINAL_TMPL = ".batches{num_batches}.{ctr}.tar"

PROGRESS_REPORT_PERIOD = 10 ** 4

METADATA_PUNCT_LABEL_VOCAB_KEY = 'punct_label_vocab_file'
METADATA_CAPIT_LABEL_VOCAB_KEY = 'capit_label_vocab_file'
DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME = 'punct_label_vocab.csv'
DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME = 'capit_label_vocab.csv'


def count_lines_and_get_fragment_starting_positions(
    file_name: Path, lines_per_dataset_fragment: int
) -> Tuple[int, List[int]]:
    """
    Returns number of lines in a file and indices of fragment starting bytes.

    Args:
        file_name: a path to a text or label file
        lines_per_dataset_fragment: number of lines in a dataset fragment. The last fragment can contain less lines

    Returns:
        num_lines: number of lines in a file
        start_bytes: indices of fragment starting bytes
    """
    pos = [0]
    with file_name.open() as f:
        i = 0
        line = f.readline()
        while line:
            i += 1
            if i % lines_per_dataset_fragment == 0:
                pos.append(f.tell())
            line = f.readline()
    return i, pos[:-1] if i % lines_per_dataset_fragment == 0 else pos


def get_fragment_start_bytes(
    text_file: Path, labels_file: Path, lines_per_dataset_fragment: int
) -> Tuple[int, List[int], List[int]]:
    """
    A function for calculating borders of dataset fragments. The function is used to split ``text_file`` and
    ``labels_file`` for processing them in parallel.

    Args:
        text_file: a path to a dataset source file
        labels_file: a path to a dataset label file
        lines_per_dataset_fragment: a number of lines in one fragment

    Returns:
        num_lines: total number of elements in the dataset (number of lines in ``text_file``` and ``labels_file``)
        text_start_bytes: indices of the first bytes of fragments in ``text_file``
        label_start_bytes: indices of the first bytes of fragments in ``labels_file``
    """
    logging.info(
        f"Counting lines in files {text_file} and {labels_file} and creating segment borders. This may take "
        f"considerable time. 86GB, 1.27b lines file was processed in 7 minutes."
    )
    result = Parallel(n_jobs=2)(
        delayed(count_lines_and_get_fragment_starting_positions)(file_name, lines_per_dataset_fragment)
        for file_name in [text_file, labels_file]
    )
    if result[0][0] != result[1][0]:
        raise ValueError(
            f"Text file {text_file} and label file {labels_file} contain different number of lines. Number of lines "
            f"in text file: {result[0][0]}, number of lines in label file: {result[1][0]}."
        )
    num_lines = result[0][0]
    text_start_bytes, label_start_bytes = result[0][1], result[1][1]
    assert len(text_start_bytes) == len(label_start_bytes)
    return num_lines, text_start_bytes, label_start_bytes


def process_fragment(
    text_file: Path,
    labels_file: Path,
    output_dir: Path,
    text_start_pos: int,
    label_start_pos: int,
    lines_per_dataset_fragment: int,
    max_seq_length: int,
    tokens_in_batch: int,
    num_batches_per_tarfile: int,
    tokenizer_name: str,
    tokenizer_model: Optional[Path],
    vocab_file: Optional[Path],
    merges_file: Optional[Path],
    special_tokens: Dict[str, str],
    use_fast_tokenizer: Optional[bool],
    pad_label: str,
    punct_label_ids: Dict[str, int],
    capit_label_ids: Dict[str, int],
    fragment_idx: int,
    tokenization_progress_queue: mp.Queue,
    batch_mark_up_progress_queue: mp.Queue,
    batch_building_progress_queue: mp.Queue,
    writing_to_tar_progress_queue: mp.Queue,
) -> None:
    tokenizer = get_tokenizer(
        tokenizer_name,
        tokenizer_model=None if tokenizer_model is None else str(tokenizer_model),
        vocab_file=None if vocab_file is None else str(vocab_file),
        merges_file=None if merges_file is None else str(merges_file),
        special_tokens=special_tokens,
        use_fast=use_fast_tokenizer,
    )
    tmp_text: Optional[str] = None
    tmp_labels: Optional[str] = None
    try:
        otfd, tmp_text = tempfile.mkstemp(suffix='.txt', prefix=f'text_{fragment_idx}_', dir=output_dir, text=True)
        olfd, tmp_labels = tempfile.mkstemp(suffix='.txt', prefix=f'labels_{fragment_idx}_', dir=output_dir, text=True)
        with text_file.open() as tf, labels_file.open() as lf, os.fdopen(otfd, 'w') as otf, os.fdopen(
            olfd, 'w'
        ) as olf:
            tf.seek(text_start_pos)
            lf.seek(label_start_pos)
            for _ in range(lines_per_dataset_fragment):
                text_line = tf.readline()
                if not text_line:
                    break
                otf.write(text_line)
                olf.write(lf.readline())
        dataset = BertPunctuationCapitalizationDataset(
            tmp_text,
            tmp_labels,
            max_seq_length,
            tokenizer,
            tokens_in_batch=tokens_in_batch,
            pad_label=pad_label,
            punct_label_ids=punct_label_ids,
            capit_label_ids=capit_label_ids,
            n_jobs=0,
            use_cache=False,
            add_masks_and_segment_ids_to_batch=False,
            verbose=False,
            tokenization_progress_queue=tokenization_progress_queue,
            batch_mark_up_progress_queue=batch_mark_up_progress_queue,
            batch_building_progress_queue=batch_building_progress_queue,
        )
    finally:
        if tmp_text is not None and os.path.exists(tmp_text):
            os.remove(tmp_text)
        if tmp_labels is not None and os.path.exists(tmp_labels):
            os.remove(tmp_labels)
    dataset.features_pkl.unlink()
    tar_ctr = 0
    current_file_name = output_dir / TAR_FRAGMENT_TMPL_IN_PROGRESS.format(fragment_idx=fragment_idx, file_idx=tar_ctr)
    current_num_batches = 0
    sink = wds.TarWriter(str(current_file_name))
    progress_made = 0
    for batch_i, batch in enumerate(dataset):
        sink.write({"__key__": f"fragment-{fragment_idx}-batch-{batch_i}", "batch.pyd": batch})
        current_num_batches += 1
        progress_made += len(batch['input_ids'])
        if current_num_batches % num_batches_per_tarfile == 0:
            sink.close()
            current_file_name.rename(
                output_dir
                / TAR_FRAGMENT_TMPL_FINISHED.format(
                    fragment_idx=fragment_idx, num_batches=current_num_batches, file_idx=tar_ctr
                )
            )
            writing_to_tar_progress_queue.put(progress_made)
            progress_made = 0
            tar_ctr += 1
            current_file_name = output_dir / TAR_FRAGMENT_TMPL_IN_PROGRESS.format(
                fragment_idx=fragment_idx, file_idx=tar_ctr
            )
            current_num_batches = 0
            sink = wds.TarWriter(str(current_file_name))
    sink.close()
    writing_to_tar_progress_queue.put(progress_made)
    if progress_made > 0:
        new_file_name = output_dir / TAR_FRAGMENT_TMPL_TO_REPACK.format(
            fragment_idx=fragment_idx, num_batches=current_num_batches, file_idx=tar_ctr
        )
        current_file_name.rename(new_file_name)
    else:
        current_file_name.unlink()
    if fragment_idx == 0:
        punct_label_ids_file, capit_label_ids_file = dataset.save_labels_and_get_file_paths(
            DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME, DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME
        )
        punct_label_ids_file.rename(output_dir / DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME)
        capit_label_ids_file.rename(output_dir / DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME)
        shutil.rmtree(punct_label_ids_file.parent)


def remove_unexpected_files_and_dirs(output_dir: Path, output_file_tmpl: str, metadata_file_name: Path) -> None:
    """
    This function removes all files with names which may be used in the dataset creation.

    Args:
        output_dir: a path to directory where removal is performed
        output_file_tmpl: a format string for a name of final tar file. Must include fields ``ctr`` for number of the
            file and ``num_batches`` for number of batches in the file.
        metadata_file_name: a metadata file name
    """
    if not output_dir.is_dir():
        return
    tar_final_pattern = re.compile(output_file_tmpl.format(ctr=NUMBER_RE, num_batches=NUMBER_RE))
    unexpected_tar_files = [
        path
        for path in output_dir.iterdir()
        if any(
            [
                p.match(path.name) is not None
                for p in [
                    TAR_FRAGMENT_PATTERN_IN_PROGRESS,
                    TAR_FRAGMENT_PATTERN_FINISHED,
                    TAR_FRAGMENT_PATTERN_TO_REPACK,
                    tar_final_pattern,
                ]
            ]
        )
    ]
    if unexpected_tar_files:
        logging.warning(
            f"Found {len(unexpected_tar_files)} unexpected tar files in the output directory {output_dir}. "
            f"All of them are going to be removed. The files match one of 3 patterns: "
            f"'{TAR_FRAGMENT_PATTERN_IN_PROGRESS.pattern}', '{TAR_FRAGMENT_PATTERN_FINISHED.pattern}', "
            f"'{tar_final_pattern.pattern}'. The first unexpected files: "
            f"{', '.join([str(f) for f in unexpected_tar_files[:3]])}."
        )
        for fn in unexpected_tar_files:
            fn.unlink()
    if metadata_file_name.exists():
        logging.warning(f"Found metadata file {metadata_file_name}. It is going to be removed.")
        metadata_file_name.unlink()
    punct_label_ids = output_dir / DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME
    capit_label_ids = output_dir / DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME
    if punct_label_ids.exists():
        logging.warning(f"Found unexpected punctuation label file {punct_label_ids}. It is going to be removed.")
        punct_label_ids.unlink()
    if capit_label_ids.exists():
        logging.warning(f"Found unexpected capitalization label file {capit_label_ids}. It is going to be removed.")
        capit_label_ids.unlink()


def collect_unique_labels_from_fragment(
    labels_file: Path, start_pos: int, lines_per_dataset_fragment: int, progress_queue: mp.Queue, fragment_idx: int
) -> Tuple[Set[str], Set[str]]:
    """
    Returns a set of unique punctuation labels and a set of unique capitalization labels.

    Args:
        labels_file: a path to a file with labels
        start_pos: an index of the first byte of a fragment in ``labels_file``
        lines_per_dataset_fragment: number of lines in dataset fragment. In the last fragment there can be less lines.
        progress_queue: a queue for reporting number of processed lines
        fragment_idx: a processed fragment index

    Returns:
        unique_punct: a set of unique punctuation labels
        unique_capit: a set of unique capitalization labels
    """
    unique_punct, unique_capit = set(), set()
    with labels_file.open() as f:
        f.seek(start_pos)
        progress_report = 0
        for i in range(lines_per_dataset_fragment):
            line = f.readline()
            if not line:
                break
            pairs = line.split()
            if not all([len(p) == 2 for p in pairs]):
                broken_pairs = [i for i, p in enumerate(pairs) if len(p) != 2]
                raise ValueError(
                    f"Found broken labels line in number {fragment_idx * lines_per_dataset_fragment + i} in file "
                    f"{labels_file}. Indices of broken pairs of labels: {broken_pairs}"
                )
            punct, capit = zip(*pairs)
            unique_punct.update(punct)
            unique_capit.update(capit)
            progress_report += 1
            if progress_report >= PROGRESS_REPORT_PERIOD:
                progress_queue.put(progress_report)
                progress_report = 0
        progress_queue.put(progress_report)
    return unique_punct, unique_capit


def create_label_dictionaries(
    labels_file: Path,
    text_start_bytes: List[int],
    num_lines: int,
    lines_per_dataset_fragment: int,
    pad_label: str,
    n_jobs: int,
) -> Tuple[Dict[str, int], Dict[str, int]]:
    """
    Creates punctuation and capitalization label ids dictionaries based on labels present in ``labels_file``.

    Args:
        labels_file: a path to file with labels
        text_start_bytes: indices of first bytes of fragments in ``labels_file``
        num_lines: total number of lines in ``labels_file``
        lines_per_dataset_fragment: number of lines in dataset fragments. The last fragment can have less lines
        pad_label: a label used for padding and for absence of punctuation and capitalization
        n_jobs: a number of fragments processed in parallel

    Returns:
        punct_label_ids: a dictionary with punctuation label ids
        capit_label_ids: a dictionary with capitalization label ids
    """
    with Progress(num_lines, "Creating label dictionary", "line") as progress_queues:
        result = Parallel(n_jobs=min(n_jobs, len(text_start_bytes)))(
            delayed(collect_unique_labels_from_fragment)(
                labels_file, start_pos, lines_per_dataset_fragment, *progress_queues, fragment_idx
            )
            for fragment_idx, start_pos in enumerate(text_start_bytes)
        )
    unique_punct, unique_capit = zip(*result)
    unique_punct = set().union(*unique_punct)
    unique_capit = set().union(*unique_capit)
    return create_label_ids(unique_punct, pad_label), create_label_ids(unique_capit, pad_label)


def check_label_ids(pad_label: str, punct_label_ids: Dict[str, int], capit_label_ids: Dict[str, int]) -> None:
    """
    A function for checking that pad label has zeroth id in ``punct_label_dis`` and ``capit_label_ids`` dictionaries.
    Args:
        pad_label: a pad label
        punct_label_ids: a dictionary with punctuation label ids
        capit_label_ids: a dictionary with capitalization label ids
    """
    msg = "Parameter `pad_label` has to have id 0 in dictionary `{param_name}` whereas it has id {id_}." + (
        '' if len(pad_label) > 10 else f" pad_label='{pad_label}'"
    )
    if punct_label_ids is not None:
        if punct_label_ids[pad_label] != 0:
            raise ValueError(msg.format(param_name='punct_label_ids', id_=punct_label_ids[pad_label]))
    if capit_label_ids is not None:
        if capit_label_ids[pad_label] != 0:
            raise ValueError(msg.format(param_name='capit_label_ids', id_=capit_label_ids[pad_label]))


def process_error(msg: str, error_class_or_function: Union[Type[Exception], Callable[[str], Any]]) -> None:
    if inspect.isclass(error_class_or_function) and issubclass(error_class_or_function, Exception):
        raise error_class_or_function(msg)
    if callable(error_class_or_function):
        error_class_or_function(msg)
    raise ValueError(
        f"Parameter `error_class_or_function` has to be a subclass of `Exception` or a function."
        f"Given {type(error_class_or_function)}"
    )


def check_labels_for_being_unique_before_building_label_ids(
    pad_label: str,
    other_labels: List[str],
    pad_label_name: str,
    other_labels_name: str,
    error_class_or_function: Union[Type[Exception], Callable[[str], Any]],
) -> None:
    """
    A function for checking that that all labels are unique.

    Args:
        pad_label: a pad label
        other_labels: a list of labels except for the pad label
        pad_label_name: a name of the pad label used in error message
        other_labels_name: a name of other labels used in error message
        error_class_or_function: a class of an exception which is raised if there is a problem with labels.
            Alternatively it can be a function for handling exceptions, for example ``argparse.ArgumentParser.error``.
            Such a function has to take one argument -- error message.
    """
    for i, lbl in enumerate(other_labels):
        if lbl == pad_label:
            msg = f"Label number {i} in parameter `{other_labels_name}` is equal to `{pad_label_name}`."
            process_error(msg, error_class_or_function)
    for i in range(len(other_labels) - 1):
        for lbl in other_labels[i + 1 :]:
            if lbl == other_labels[i]:
                msg = f"Label number {i} occurs at least 2 times in parameter `{other_labels_name}`."
                process_error(msg, error_class_or_function)


def build_label_ids_from_list_of_labels(pad_label: str, other_labels: List[str]) -> Dict[str, int]:
    """
    Builds label ids dictionary from pad label and list of other labels. Used for parsing command line arguments.
    Args:
        pad_label: a pad label
        other_labels: list of labels except for the pad label

    Returns:
        a dictionary with label ids
    """
    check_labels_for_being_unique_before_building_label_ids(
        pad_label, other_labels, 'pad_label', 'other_labels', ValueError
    )
    ids = {pad_label: 0}
    for lbl in other_labels:
        ids[lbl] = len(ids)
    return ids


def get_label_dictionaries(
    labels_file: Path,
    start_bytes: List[int],
    num_lines: int,
    lines_per_dataset_fragment: int,
    pad_label: str,
    punct_label_ids: Optional[Dict[str, int]],
    capit_label_ids: Optional[Dict[str, int]],
    punct_label_vocab_file: Optional[Path],
    capit_label_vocab_file: Optional[Path],
    n_jobs: int,
) -> Tuple[Dict[str, int], Dict[str, int]]:
    """
    Return label ids if the label ids are present in parameters ``punct_label_ids``, ``capit_label_ids``,
    ``punct_label_vocab_file``, ``capit_label_vocab_file``. Otherwise, label ids are created using ``labels_file``.

    Args:
        labels_file: a path to file with labels. Labels have to be given in the format described in
            https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/nlp/punctuation_and_capitalization.html#nemo-data-format
        start_bytes: a list of positions in ``labels_file`` at which fragments start. Parameter ``start_bytes`` is used
            for creating labels for several fragments in parallel
        num_lines: total number of lines in ``labels_file``. Parameter ``num_lines`` is used for showing progress of
            label ids collection
        lines_per_dataset_fragment: number of lines in a dataset fragment
        pad_label: a label used for padding and also neutral label showing there is no punctuation and capitalization.
            Label ``pad_label`` has to have id ``0`` in parameters ``punct_label_ids``, ``capit_label_ids``,
            ``punct_label_vocab_file``, ``capit_label_vocab_file`` if these parameters are provided.
        punct_label_ids: a dictionary with punctuation label ids. Pad label has to have id ``0``. No more than 1 of
            parameters ``punct_label_ids`` and ``punct_label_vocab_file`` can be provided.
        capit_label_ids: a dictionary with capitalization label ids. Pad label has to have id ``0``. No more than 1 of
            parameters ``capit_label_ids`` and ``capit_label_vocab_file`` can be provided.
        punct_label_vocab_file: a text file with punctuation labels. Every line in the file contains 1 label. Pad label
            has to be in the first line. No more than 1 of parameters ``punct_label_ids`` and
            ``punct_label_vocab_file`` can be provided.
        capit_label_vocab_file: a text file with capitalization labels. Every line in the file contains 1 label. Pad
            label has to be in the first line. No more than 1 of parameters ``capit_label_ids`` and
            ``capit_label_vocab_file`` can be provided.
        n_jobs: a number of fragments processed in parallel

    Returns:
        punct_label_ids: a dictionary with punctuation label ids
        capit_label_ids: a dictionary with capitalization label ids
    """
    if punct_label_ids is not None and punct_label_vocab_file is not None:
        raise ValueError("You can provide at most one of parameters `punct_label_ids` and `punct_label_vocab_file`.")
    if capit_label_ids is not None and capit_label_vocab_file is not None:
        raise ValueError("You can provide at most one of parameters `capit_label_ids` and `capit_label_vocab_file`.")
    if punct_label_ids is None and punct_label_vocab_file is not None:
        punct_label_ids = load_label_ids(punct_label_vocab_file)
    if capit_label_ids is None and capit_label_vocab_file is not None:
        capit_label_ids = load_label_ids(capit_label_vocab_file)
    check_label_ids(pad_label, punct_label_ids, capit_label_ids)
    if punct_label_ids is None or capit_label_ids is None:
        _punct_label_ids, _capit_label_ids = create_label_dictionaries(
            labels_file, start_bytes, num_lines, lines_per_dataset_fragment, pad_label, n_jobs
        )
        if punct_label_ids is None:
            punct_label_ids = _punct_label_ids
        if capit_label_ids is None:
            capit_label_ids = _capit_label_ids
    return punct_label_ids, capit_label_ids


def decode_pyd(key: str, value: bytes) -> Any:
    """
    Used for decoding batch loaded by ``webdataset`` from tar files.
    Args:
        key: name of a batch
        value: pickled batch

    Returns:
        decoded batch
    """
    return pickle.loads(value)


def repack_tar_files_with_not_enough_batches(output_dir: Path, num_batches_per_tarfile: int) -> None:
    f"""
    It is possible that number of batches in a fragment is not evenly divisible by ``num_batches_per_tarfile``.
    In such a case excess batches are put in a tar file which matches a pattern
    ``fragment(0|[1-9][0-9]*).num_batches(0|[1-9][0-9]*).(0|[1-9][0-9]*).tar.to_repack``. Such files are repacked by
    ``repack_tar_files_with_not_enough_batches`` function into tar files with correct ``num_batches_per_tarfile``
    batches each. If there is no enough batches in repacked files, then up to ``num_batches_per_tarfile - 1``
    remaining batches may be discarded.

    Args:
        output_dir: a path to the output directory which contains files to repack and where new files are saved
        num_batches_per_tarfile: a number of batches in 1 tar file. If number of batches in files matching a pattern
            ``fragment(0|[1-9][0-9]*).num_batches(0|[1-9][0-9]*).(0|[1-9][0-9]*).tar.to_repack`` is not evenly
            divisible by ``num_batches_per_tarfile`` excess batches are discarded.
    """
    files_to_repack_with_matches = [
        (path, TAR_FRAGMENT_PATTERN_TO_REPACK.match(path.name))
        for path in output_dir.iterdir()
        if TAR_FRAGMENT_PATTERN_TO_REPACK.match(path.name) is not None
    ]
    files_to_repack_with_matches = sorted(files_to_repack_with_matches, key=lambda x: int(x[1].group(3)))
    logging.info(f"Found {len(files_to_repack_with_matches)} files for repacking.")
    files_to_repack_with_matches = deque(files_to_repack_with_matches)
    total_batches_in_repacked_files = 0
    initial_number_of_files_to_repack = len(files_to_repack_with_matches)
    pop_file_ds = None
    new_file_sink = None
    new_file_num_batches = 0
    while files_to_repack_with_matches:
        assert pop_file_ds is None or new_file_sink is None
        if new_file_sink is None:
            # `append_file` is a file which content will serve as a start for new tar file. `append_file` content is
            # copied into a `new_file` and then content of other files needing repacking is appended to content of
            # `new_file`.
            append_file, match = files_to_repack_with_matches.popleft()
            new_file = append_file.parent / TAR_FRAGMENT_TMPL_FINISHED.format(
                fragment_idx=match.group(1), num_batches=num_batches_per_tarfile, file_idx=match.group(3)
            )
            new_file_sink = wds.TarWriter(str(new_file))
            append_ds_to_rewrite = (
                wds.WebDataset(urls=[str(append_file)], nodesplitter=None)
                .decode(wds.handle_extension('.pyd', decode_pyd))
                .to_tuple('__key__', 'batch.pyd')
            )
            for key, batch in iter(append_ds_to_rewrite):
                new_file_sink.write({"__key__": key, "batch.pyd": batch})
                new_file_num_batches += 1
                total_batches_in_repacked_files += 1
            assert total_batches_in_repacked_files < initial_number_of_files_to_repack * num_batches_per_tarfile
            assert new_file_num_batches == int(match.group(2)), (
                f"Number of batches {new_file_num_batches} in {append_file} is different from number of batches "
                f"{match.group(2)} in repacked tar file with name {append_file}."
            )
            append_file.unlink()
        if files_to_repack_with_matches and pop_file_ds is None:
            pop_file, _ = files_to_repack_with_matches.pop()
            pop_file_ds = (
                wds.WebDataset(urls=[str(pop_file)], nodesplitter=None)
                .decode(wds.handle_extension('.pyd', decode_pyd))
                .to_tuple('__key__', 'batch.pyd')
            )
            pop_file_ds = iter(pop_file_ds)
        if pop_file_ds is not None and new_file_sink is not None:
            while new_file_num_batches < num_batches_per_tarfile:
                try:
                    key, batch = next(pop_file_ds)
                except StopIteration:
                    pop_file_ds = None
                    pop_file.unlink()
                    break
                new_file_sink.write({"__key__": key, "batch.pyd": batch})
                total_batches_in_repacked_files += 1
                assert total_batches_in_repacked_files < initial_number_of_files_to_repack * num_batches_per_tarfile
                new_file_num_batches += 1
            if new_file_num_batches >= num_batches_per_tarfile:
                assert new_file_num_batches == num_batches_per_tarfile
                new_file_sink.close()
                new_file_sink = None
                new_file_num_batches = 0
    if new_file_sink is not None:
        new_file_sink.close()
        new_file.unlink()
        logging.info(f"Discarded {new_file_num_batches} batches.")
    if pop_file_ds is not None:
        pop_file.unlink()
    logging.info(f"Repacked {total_batches_in_repacked_files} batches from short tar files")


def create_metadata_file(
    output_dir: Path, output_file_tmpl: str, metadata_file_name: Path, num_batches_per_tarfile: int
) -> None:
    """
    Rename tar files according to template ``output_file_tmpl`` and save metadata file.
    Args:
        output_dir: a path to directory which contains initial tar files and where renamed tar files are saved
        output_file_tmpl: a template of a new tar file name
        metadata_file_name: a path to a file into which metadata is going to be saved
        num_batches_per_tarfile: a required number of batches in tar files. Used for checking that present tar files
            have correct number of batches
    """
    metadata = {"num_batches": 0, "tar_files": []}
    for i, fn in enumerate([fn for fn in output_dir.iterdir() if TAR_FRAGMENT_PATTERN_FINISHED.match(fn.name)]):
        nb = int(TAR_FRAGMENT_PATTERN_FINISHED.match(fn.name).group(2))
        assert nb == num_batches_per_tarfile
        new_name = output_dir / output_file_tmpl.format(ctr=i, num_batches=nb)
        fn.rename(new_name)
        metadata['tar_files'].append(new_name.name)
        metadata["num_batches"] += nb
    metadata[METADATA_PUNCT_LABEL_VOCAB_KEY] = DEFAULT_PUNCT_LABEL_VOCAB_FILE_NAME
    metadata[METADATA_CAPIT_LABEL_VOCAB_KEY] = DEFAULT_CAPIT_LABEL_VOCAB_FILE_NAME
    logging.info(f"{metadata['num_batches']} batches are in tarred dataset with metadata file {metadata_file_name}")
    with metadata_file_name.open('w') as f:
        json.dump(metadata, f, indent=2)


def check_tar_file_prefix(
    tar_file_prefix: str, error_class_or_function: Union[Type[Exception], Callable[[str], Any]], var_name: str
) -> None:
    not_allowed_characters_in_prefix = NOT_ALLOWED_CHARACTERS_IN_FILE_NAME.findall(tar_file_prefix)
    if not_allowed_characters_in_prefix:
        not_allowed_characters_in_prefix = set(not_allowed_characters_in_prefix)
        msg = (
            f"Found {len(not_allowed_characters_in_prefix)} not allowed characters in `{var_name}`. Only 'A-Z', "
            f"'a-z', '0-9', '_', '-', '.' characters are allowed. Examples of not allowed characters: "
            f"{list(not_allowed_characters_in_prefix)[:10]}. `{var_name}`[:30]={repr(tar_file_prefix)[:30]}."
        )
        process_error(msg, error_class_or_function)


[docs]def create_tarred_dataset( text_file: Union[os.PathLike, str], labels_file: Union[os.PathLike, str], output_dir: Union[os.PathLike, str], max_seq_length: int, tokens_in_batch: int, lines_per_dataset_fragment: int, num_batches_per_tarfile: int, tokenizer_name: str, tokenizer_model: Optional[Union[os.PathLike, str]] = None, vocab_file: Optional[Union[os.PathLike, str]] = None, merges_file: Optional[Union[os.PathLike, str]] = None, special_tokens: Optional[Dict[str, str]] = None, use_fast_tokenizer: Optional[bool] = False, pad_label: str = 'O', punct_label_ids: Optional[Dict[str, int]] = None, capit_label_ids: Optional[Dict[str, int]] = None, punct_label_vocab_file: Optional[Union[os.PathLike, str]] = None, capit_label_vocab_file: Optional[Union[os.PathLike, str]] = None, tar_file_prefix: Optional[str] = 'punctuation_capitalization', n_jobs: Optional[int] = None, ) -> None: """ Creates tarred dataset from ``text_file`` and ``labels_file``. A tarred dataset allows to train on large amounts of data without storing it all into memory simultaneously. You may use these function directly or try script `examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py>`_. Tarred dataset is a directory which contains metadata file, tar files with batches, ``punct_label_vocab.csv`` and ``capit_label_vocab.csv`` files. Metadata file is a JSON file with 4 items: ``'num_batches'``, ``'tar_files'``, ``'punct_label_vocab_file'``, ``'capit_label_vocab_file'``. The item ``'num_batches'`` (``int``) is a total number of batches in tarred dataset. ``'tar_files'`` is a list of paths to tar files relative to directory containing the metadata file. The items ``'punct_label_vocab_file'`` and ``'capit_label_vocab_file'`` are correspondingly paths to punctuation and capitalization label vocabulary files. These paths are relative to directory containing the metadata file. Every tar file contains objects written using ``webdataset.TarWriter``. Each object is a dictionary with two items: ``'__key__'`` and ``'batch.pyd'``. ``'__key__'`` is a name of a batch and ``'batch.pyd'`` is a pickled dictionary which contains ``'input_ids'``, ``'subtokens_mask'``, ``'punct_labels'``, ``'capit_labels'``. ``'input_ids'`` is an array containing ids of source tokens, ``'subtokens_mask'`` is a boolean array showing first tokens in words, ``'punct_labels'`` and ``'capit_labels'`` are arrays with ids of labels. Metadata file should be passed to constructor of :class:`BertPunctuationCapitalizationTarredDataset` and the instance of the class will handle iteration and constructing masks and token types for BERT model. Args: text_file (:obj:`Union[os.PathLike, str]`): a path to a file with dataset source. Dataset source is lowercased text without punctuation. Number of lines in ``text_file`` has to be equal to the number of lines in ``labels_file``. labels_file (:obj:`Union[os.PathLike, str]`): a path to a file with labels. Labels are given in the format described in :ref:`NeMo Data Format<nemo-data-format-label>`. output_dir (:obj:`Union[os.PathLike, str]`): a path to a directory where metadata file, tar files and ``'punct_label_ids.csv'`` and ``'capit_label_ids.csv'`` files are saved. max_seq_length (:obj:`int`): Maximum number of subtokens in an input sequence. A source sequence which contains too many subtokens is clipped to ``max_seq_length - 2`` subtokens and then [CLS] token is prepended to the clipped sequence and [SEP] token is appended to the clipped sequence. The clipping is performed via removal of subtokens in the end of a source sequence. tokens_in_batch (:obj:`int`): maximum number of tokens in a batch including [CLS], [SEP], [UNK], and [PAD] tokens. Before packing into batches source sequences are sorted by number of tokens in order to reduce number of pad tokens. So the number of samples in a batch may vary. lines_per_dataset_fragment (:obj:`int`): a number of lines processed by one worker during creation of tarred dataset. A worker tokenizes ``lines_per_dataset_fragment`` lines and keeps in RAM tokenized text labels before packing them into batches. Reducing ``lines_per_dataset_fragment`` leads to reducing of the amount of memory used by this function. num_batches_per_tarfile (:obj:`int`): a number of batches saved in a tar file. If you increase ``num_batches_per_tarfile``, then there will be less tar files in the dataset. There cannot be less then ``num_batches_per_tarfile`` batches in a tar file, and all excess batches are removed. Maximum number of discarded batches is ``num_batches_per_tarfile - 1``. tokenizer_name (:obj:`str`): a name of the tokenizer used for tokenization of source sequences. Possible options are ``'sentencepiece'``, ``'word'``, ``'char'``, HuggingFace tokenizers. For more options see function ``nemo.collections.nlp.modules.common.get_tokenizer``. The tokenizer must have properties ``cls_id``, ``pad_id``, ``sep_id``, ``unk_id``. tokenizer_model (:obj:`Union[os.PathLike, str]`, `optional`): a path to a tokenizer model required for ``'sentencepiece'`` tokenizer. vocab_file (:obj:`Union[os.PathLike, str]`, `optional`): a path to a vocabulary file which can be used in ``'word'``, ``'char'``, and HuggingFace tokenizers. merges_file (:obj:`Union[os.PathLike, str]`, `optional`): a path to merges file which can be used in HuggingFace tokenizers. special_tokens (:obj:`Dict[str, str]`, `optional`): a dictionary with special tokens passed to constructors of ``'char'``, ``'word'``, ``'sentencepiece'``, and various HuggingFace tokenizers. use_fast_tokenizer (:obj:`bool`, `optional`, defaults to :obj:`False`): whether to use fast HuggingFace tokenizer. pad_label (:obj:`str`, `optional`, defaults to :obj:`'O'`): a pad label both for punctuation and capitalization. This label is also a neutral label (used for marking words which do not need punctuation and capitalization). punct_label_ids (:obj:`Dict[str, int]`, `optional`): a dictionary which keys are punctuation labels and values are label ids. The pad label ``pad_label`` has to have id ``0``. You can provide at most one of parameters ``punct_label_ids`` and ``punct_label_vocab_file``. If none of parameters ``punct_label_ids`` and ``punct_label_vocab_file`` is provided, then punctuation label ids will be inferred from ``labels_file`` file. capit_label_ids (:obj:`Dict[str, int]`, `optional`): same as ``punct_label_ids`` for capitalization labels. punct_label_vocab_file (:obj:`Union[os.PathLike, str]`, `optional`): a path to a file with punctuation labels. These labels include pad label. The pad label has to be the first label in the file. Each label is written on a separate line. Alternatively you can use ``punct_labels_ids`` parameter. If none of parameters ``punct_labels_ids`` and ``punct_label_vocab_file`` is provided, then punctuation label ids will be inferred from ``labels_file`` file. capit_label_vocab_file (:obj:`Union[os.PathLike, str]`, `optional`): same as ``punct_label_vocab_file`` for capitalization labels. tar_file_prefix (:obj:`str`, `optional`, defaults :obj:`'punctuation_capitalization'`): a string from which tar file names start. The string can contain only characters ``A-Z``, ``a-z``, ``0-9``, ``_``, ``-``, ``.``. n_jobs (:obj:`int`, `optional`): a number of workers for creating tarred dataset. If ``None``, then ``n_jobs`` is equal to number of CPUs. """ check_tar_file_prefix(tar_file_prefix, ValueError, 'tar_file_prefix') if n_jobs is None: n_jobs = mp.cpu_count() text_file, labels_file = Path(text_file).expanduser(), Path(labels_file).expanduser() output_dir = Path(output_dir).expanduser() ds_params_str = DATASET_PARAMETERS_TMPL.format( prefix=tar_file_prefix, tokens_in_batch=tokens_in_batch, max_seq_length=max_seq_length, tokenizer=REPLACE_NOT_ALLOWED_CHARACTERS_IN_FILE_NAME.sub('-', tokenizer_name), ) output_file_tmpl = ds_params_str + TAR_FINAL_TMPL metadata_file_name = output_dir / ('metadata.' + ds_params_str + '.json') remove_unexpected_files_and_dirs(output_dir, output_file_tmpl, metadata_file_name) num_lines, text_start_bytes, label_start_bytes = get_fragment_start_bytes( text_file, labels_file, lines_per_dataset_fragment ) if text_start_bytes: output_dir.mkdir(parents=True, exist_ok=True) else: raise ValueError(f"Both {labels_file} and {text_file} are empty. Tarred dataset cannot be created.") punct_label_ids, capit_label_ids = get_label_dictionaries( labels_file, label_start_bytes, num_lines, lines_per_dataset_fragment, pad_label, punct_label_ids, capit_label_ids, punct_label_vocab_file, capit_label_vocab_file, n_jobs, ) with Progress( num_lines, ["Tokenization", "Batch mark up", "Batch building", "Writing tarred dataset"], "query" ) as progress_queues: Parallel(n_jobs=min(n_jobs, len(text_start_bytes)))( delayed(process_fragment)( text_file, labels_file, output_dir, text_start_pos, label_start_pos, lines_per_dataset_fragment, max_seq_length, tokens_in_batch, num_batches_per_tarfile, tokenizer_name, None if tokenizer_model is None else Path(tokenizer_model).expanduser(), None if vocab_file is None else Path(vocab_file).expanduser(), None if merges_file is None else Path(merges_file).expanduser(), special_tokens, use_fast_tokenizer, pad_label, punct_label_ids, capit_label_ids, fragment_idx, *progress_queues, ) for fragment_idx, (text_start_pos, label_start_pos) in enumerate(zip(text_start_bytes, label_start_bytes)) ) repack_tar_files_with_not_enough_batches(output_dir, num_batches_per_tarfile) create_metadata_file(output_dir, output_file_tmpl, metadata_file_name, num_batches_per_tarfile)
[docs]class BertPunctuationCapitalizationTarredDataset(IterableDataset): """ Punctuation capitalization dataset which allows not to load all data in memory simultaneously. A tarred dataset is created from text and label files using script `examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py>`_ or function :func:`~nemo.collections.nlp.data.token_classification.punctuation_capitalization_tarred_dataset.create_tarred_dataset`. Args: metadata_file (:obj:`Union[os.PathLike, str]`): a path to tarred dataset metadata file. Metadata file and files referenced in metadata file are created by `examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py <https://github.com/NVIDIA/NeMo/blob/main/examples/nlp/token_classification/data/create_punctuation_capitalization_tarred_dataset.py>`_. Metadata file is a JSON file which contains ``'num_batches'``, ``'tar_files'``, ``'punct_label_vocab_file'``, ``'capit_label_vocab_file'`` items. The first item is total number of batches in a dataset, the second is a list of paths to tar files relative to directory containing ``metadata_file``. Items ``'punct_label_vocab_file'`` and ``'capit_label_vocab_file'`` are paths to ``.csv`` files which contain unique punctuation an capitalization label vocabularies. Vocabulary file paths are relative to directory containing the ``metadata_file``. Each line in ``'punct_label_vocab_file'`` and ``'capit_label_vocab_file'`` contains 1 label. The first lines in ``'punct_label_vocab_file'`` and ``'capit_label_vocab_file'`` files are neutral labels which also serve as pad labels. Neutral labels for punctuation and capitalization must be equal to the ``pad_label`` parameter. tokenizer (:obj:`TokenizerSpec`): a tokenizer instance used for tokenization of dataset source. A tokenizer instance is used for getting ids of [CLS], [PAD], and [SEP] tokens which are used for masks creation. pad_label (:obj:`str`): a label that is used for padding and for absence of punctuation or capitalization. Used for checking items ``'punct_label_vocab'`` and ``'capit_label_vocab'`` of dictionary in ``metadata_file``. label_info_save_dir (:obj:`Union[os.PathLike, str]`, `optional`): a path to a directory where label vocabularies are copied when method :meth:`save_labels_and_get_file_paths` is called. This parameter is useful if tarred dataset directory is read-only. ignore_extra_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): whether to use only first token in a word for loss computation and training. If set to ``True``, then loss will be computed only for the first tokens of words. ignore_start_end (:obj:`bool`, `optional`, defaults to :obj:`True`): whether to compute loss for [CLS] and [SEP] tokens. If set to ``True``, then loss will not be computed for [CLS] and [SEP] tokens. world_size (:obj:`int`, `optional`, defaults to :obj:`1`): a number of processes used for model training. It is used together with a ``global_rank`` parameter to decide which tar files will be used in the current process. global_rank (:obj:`int`, `optional`, defaults to :obj:`0`): a number of current process in the pool of workers used for model training. It is used together with ``world_size`` parameter to decide which tar files will be used in the current process. shuffle_n (:obj:`int`, `optional`, defaults to :obj:`1`): a number of shuffled batches in a buffer. ``shuffle_n`` batches are loaded into memory, shuffled, and then yielded by a dataset instance. shard_strategy (:obj:`str`, defaults to :obj:``'scatter'``): Tarred dataset shard distribution strategy chosen as a str value during ddp. - ``'scatter'``: The default shard strategy applied by WebDataset, where each node gets a unique set of shards, which are permanently pre-allocated and never changed at runtime. - ``'replicate'``: Optional shard strategy, where each node gets all of the set of shards available in the tarred dataset, which are permanently pre-allocated and never changed at runtime. The benefit of replication is that it allows each node to sample data points from the entire dataset independently of other nodes, and reduces dependence on value of :param:`shuffle_n`. .. warning:: Replicated strategy allows every node to sample the entire set of available tarfiles, and therefore more than one node may sample the same tarfile, and even sample the same data points! As such, there is no assured guarantee that all samples in the dataset will be sampled at least once during 1 epoch. Scattered strategy, on the other hand, on specific occasions (when the number of shards is not divisible with ``world_size``), will not sample the entire dataset. For these reasons it is not advisable to use tarred datasets as validation or test datasets. """ @property def output_types(self) -> Optional[Dict[str, NeuralType]]: """Returns neural types of batches yielded by this dataset.""" return { 'input_ids': NeuralType(('B', 'T'), ChannelType()), 'segment_ids': NeuralType(('B', 'T'), ChannelType()), 'input_mask': NeuralType(('B', 'T'), MaskType()), 'subtokens_mask': NeuralType(('B', 'T'), MaskType()), 'loss_mask': NeuralType(('B', 'T'), MaskType()), 'punct_labels': NeuralType(('B', 'T'), LabelsType()), 'capit_labels': NeuralType(('B', 'T'), LabelsType()), } def __init__( self, metadata_file: Union[os.PathLike, str], tokenizer: TokenizerSpec, pad_label: str, label_info_save_dir: Optional[Union[os.PathLike, str]] = None, ignore_extra_tokens: bool = False, ignore_start_end: bool = True, world_size: int = 1, global_rank: int = 0, shuffle_n: int = 1, shard_strategy: str = "scatter", ) -> None: super().__init__() valid_shard_strategies = ['scatter', 'replicate'] if shard_strategy not in valid_shard_strategies: raise ValueError( f"Invalid shard strategy of type {type(shard_strategy)} " f"{repr(shard_strategy) if len(repr(shard_strategy)) < 100 else repr(shard_strategy)[:100] + '...'}! " f"Allowed values are: {valid_shard_strategies}." ) self.tokenizer = tokenizer self.metadata_file = Path(metadata_file).expanduser() if label_info_save_dir is None: self.for_nemo_ckpt = self.metadata_file.parent / LABEL_ID_DIR_FOR_NEMO_CHECKPOINT else: self.for_nemo_ckpt = Path(label_info_save_dir).expanduser() / LABEL_ID_DIR_FOR_NEMO_CHECKPOINT with open(self.metadata_file) as f: self.metadata = json.load(f) self.ignore_extra_tokens = ignore_extra_tokens self.ignore_start_end = ignore_start_end self.tar_files = [] for file_path in self.metadata['tar_files']: file_path = Path(file_path).expanduser() if file_path.is_absolute(): self.tar_files.append(str(file_path)) else: self.tar_files.append(str(self.metadata_file.parent / file_path)) self.punct_label_vocab_file = self.metadata_file.parent / self.metadata[METADATA_PUNCT_LABEL_VOCAB_KEY] self.capit_label_vocab_file = self.metadata_file.parent / self.metadata[METADATA_CAPIT_LABEL_VOCAB_KEY] self.punct_label_ids = load_label_ids(self.punct_label_vocab_file) self.capit_label_ids = load_label_ids(self.capit_label_vocab_file) self.pad_label = pad_label self._check_pad_label() if shard_strategy == 'scatter': logging.info("Tarred dataset shards will be scattered evenly across all nodes.") if len(self.tar_files) % world_size != 0: logging.warning( f"Number of shards in tarred dataset ({len(self.tar_files)}) is not divisible " f"by number of distributed workers ({world_size}). " f"Some shards will not be used ({len(self.tar_files) % world_size})." ) begin_idx = (len(self.tar_files) // world_size) * global_rank end_idx = begin_idx + (len(self.tar_files) // world_size) logging.info( "Partitioning tarred dataset: process (%d) taking shards [%d, %d)", global_rank, begin_idx, end_idx ) batches_per_tar = self.metadata['num_batches'] // len(self.tar_files) self.tar_files = self.tar_files[begin_idx:end_idx] self.length = batches_per_tar * len(self.tar_files) * world_size elif shard_strategy == 'replicate': logging.info("All tarred dataset shards will be replicated across all nodes.") self.length = self.metadata['num_batches'] else: raise ValueError(f"Invalid shard strategy! Allowed values are: {valid_shard_strategies}") self._dataset = wds.WebDataset(urls=self.tar_files, nodesplitter=None).decode( wds.handle_extension('.pyd', decode_pyd) ) if shuffle_n > 0: self._dataset.shuffle(shuffle_n) else: logging.info("WebDataset will not shuffle files within the tar files.") self._dataset = self._dataset.to_tuple('__key__', 'batch.pyd').map(f=self._build_sample) def _check_pad_label(self) -> None: """ Checks the condition that ``pad_label`` passed to this class constructor has ``0`` id in ``self.punct_label_ids`` and ``self.capit_label_ids`` loaded from tarred dataset. """ for label_ids, labels_file, task in [ (self.punct_label_ids, self.metadata[METADATA_PUNCT_LABEL_VOCAB_KEY], "punctuation"), (self.capit_label_ids, self.metadata[METADATA_CAPIT_LABEL_VOCAB_KEY], "capitalization"), ]: if label_ids[self.pad_label] != 0: raise ValueError( f"Pad label '{self.pad_label}' has non zero id {label_ids[self.pad_label]} in {task} " f"ids dictionary loaded from {labels_file}." )
[docs] def check_for_label_consistency_with_model_config( self, punct_label_ids: Optional[Dict[str, int]], capit_label_ids: Optional[Dict[str, int]], class_labels: DictConfig, common_dataset_parameters_config: DictConfig, ) -> None: """ Checks that label ids loaded from tarred dataset are identical to those provided in ``model.common_dataset_parameters`` :ref:`config<common-dataset-parameters-config-label>` item. In addition, this method checks that label ids set in attributes ``punct_label_ids`` and ``capit_label_ids`` of an instance of :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_model.PunctuationCapitalizationModel` are identical to label ids loaded from tarred dataset. Args: punct_label_ids: a content of ``punct_label_ids`` attribute of an instance of :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_model.PunctuationCapitalizationModel` in which this tarred dataset is used. capit_label_ids: a content of ``capit_label_ids`` attribute of an instance of :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_model.PunctuationCapitalizationModel` in which this tarred dataset is used. class_labels: a config item ``model.class_labels``. See more in description of :ref:`class labels config<class-labels-config-label>`. common_dataset_parameters_config: a config item ``model.common_dataset_parameters``. See more in of :ref:`common dataset parameters config<common-dataset-parameters-config-label>`. """ tarred_dataset_label_desc_tmpl = ( f'{{label_type}} labels loaded from tarred dataset with metadata file {self.metadata_file}' ) if punct_label_ids is not None: if punct_label_ids != self.punct_label_ids: raise_not_equal_labels_error( first_labels=self.punct_label_ids, second_labels=punct_label_ids, first_labels_desc=tarred_dataset_label_desc_tmpl.format(label_type='Punctuation'), second_labels_desc="Punctuation labels stored in an attribute " "`PunctuationCapitalizationModel.punct_label_ids`", ) if capit_label_ids is not None: if capit_label_ids != self.capit_label_ids: raise_not_equal_labels_error( first_labels=self.capit_label_ids, second_labels=capit_label_ids, first_labels_desc=tarred_dataset_label_desc_tmpl.format(label_type='Capitalization'), second_labels_desc="Capitalization labels stored in an attribute" "`PunctuationCapitalizationModel.capit_label_ids`", ) if common_dataset_parameters_config.punct_label_ids is not None: cfg_punct_label_ids = dict(common_dataset_parameters_config.punct_label_ids) if cfg_punct_label_ids != self.punct_label_ids: raise_not_equal_labels_error( first_labels=self.punct_label_ids, second_labels=cfg_punct_label_ids, first_labels_desc=tarred_dataset_label_desc_tmpl.format(label_type='Punctuation'), second_labels_desc='Punctuation labels stored a config field ' '`model.common_dataset_parameters.punct_label_ids`', ) if common_dataset_parameters_config.capit_label_ids is not None: cfg_capit_label_ids = dict(common_dataset_parameters_config.capit_label_ids) if cfg_capit_label_ids != self.capit_label_ids: raise_not_equal_labels_error( first_labels=self.capit_label_ids, second_labels=cfg_capit_label_ids, first_labels_desc=tarred_dataset_label_desc_tmpl.format(label_type='Capitalization'), second_labels_desc='Capitalization labels stored a config field ' '`model.common_dataset_parameters.capit_label_ids`', ) if common_dataset_parameters_config.label_vocab_dir is not None: label_vocab_dir = Path(common_dataset_parameters_config.label_vocab_dir).expanduser() punct_label_vocab_file = label_vocab_dir / class_labels.punct_labels_file file_punct_vocab = load_label_ids(punct_label_vocab_file) if file_punct_vocab != self.punct_label_ids: raise_not_equal_labels_error( first_labels=self.punct_label_ids, second_labels=file_punct_vocab, first_labels_desc=tarred_dataset_label_desc_tmpl.format(label_type='Punctuation'), second_labels_desc=f'labels stored in file {punct_label_vocab_file} passed in ' f'`model.common_dataset_parameters.punct_label_vocab_file`', ) capit_label_vocab_file = label_vocab_dir / class_labels.capit_labels_file file_capit_vocab = load_label_ids(capit_label_vocab_file) if file_capit_vocab != self.capit_label_ids: raise_not_equal_labels_error( first_labels=self.capit_label_ids, second_labels=file_capit_vocab, first_labels_desc=tarred_dataset_label_desc_tmpl.format(label_type='Capitalization'), second_labels_desc=f'labels stored in file {capit_label_vocab_file} passed in ' f'`model.common_dataset_parameters.capit_label_vocab_file`', )
[docs] def save_labels_and_get_file_paths( self, punct_labels_file_name: str, capit_labels_file_name: str ) -> Tuple[Path, Path]: """ Copies label vocabulary files for punctuation and capitalization into directory passed in the constructor parameter ``label_info_save_dir``. The names of new files are ``punct_labels_file_name`` and ``capit_labels_file_name``. The signatures of this method and the signature of the method :meth:`~nemo.collections.nlp.data.token_classification.BertPunctuationCapitalizationDataset.save_labels_and_get_file_paths` must be identical. Args: punct_labels_file_name (:obj:`str`): a name of punctuation labels file capit_labels_file_name (:obj:`str`): a name of capitalization labels file Returns: :obj:`Tuple[Path, Path]`: a tuple of 2 elements - :obj:`pathlib.Path`: a path to the new punctuation label ids file - :obj:`pathlib.Path`: a path to the new capitalization label ids file """ self.for_nemo_ckpt.mkdir(parents=True, exist_ok=True) punct_label_ids_file = self.for_nemo_ckpt / punct_labels_file_name capit_label_ids_file = self.for_nemo_ckpt / capit_labels_file_name shutil.copy(str(self.punct_label_vocab_file), str(punct_label_ids_file)) shutil.copy(str(self.capit_label_vocab_file), str(capit_label_ids_file)) return punct_label_ids_file, capit_label_ids_file
def _build_sample(self, batch: Tuple[str, Dict[str, np.ndarray]]) -> Dict[str, np.ndarray]: """ Takes batch loaded from tarred dataset and transforms it for passing to the model. Adds ``'segment_ids'``, ``'input_mask'``, ``'loss_mask'`` items to the batch. Args: batch: a tuple of 2 elements: batch name and a dictionary with ``'input_ids'``, ``'subtokens_mask'``, ``'punct_labels'``, ``'capit_labels'``. Batch name is not needed for training and inference and discarded. Returns: a batch in the form of a dictionary with items: - ``'input_ids'``: a ``np.int32`` numpy array of shape ``[Batch, Time]``; - ``'subtokens_mask'``: a boolean numpy array of shape ``[Batch, Time]``; - ``'punct_labels'``: a ``np.int32`` numpy array of shape ``[Batch, Time]``; - ``'capit_labels'``: a ``np.int32`` numpy array of shape ``[Batch, Time]``; - ``'segment_ids'``: a ``np.int8`` numpy array of shape ``[Batch, Time]``; - ``'input_mask'``: a boolean numpy array of shape ``[Batch, Time]``; - ``'loss_mask'``: a boolean numpy array of shape ``[Batch, Time]``. """ _, batch = batch batch_segment_ids, batch_input_mask, batch_loss_mask = create_masks_and_segment_ids( batch['input_ids'], batch['subtokens_mask'], self.tokenizer.pad_id, self.tokenizer.cls_id, self.tokenizer.sep_id, self.ignore_start_end, self.ignore_extra_tokens, ) batch['segment_ids'] = batch_segment_ids batch['input_mask'] = batch_input_mask batch['loss_mask'] = batch_loss_mask return batch
[docs] def __iter__(self) -> Iterator[Dict[str, np.ndarray]]: """ Constructs an iterator of batches. The values of one batch dictionary are numpy arrays of identical shapes ``[Batch, Time]``. Returns: :obj:`Iterator[Dict[str, np.ndarray]]`: an iterator of batches with items: - ``'input_ids'``: ``np.int32`` array containing encoded tokens, - ``'subtokens_mask'``: ``bool`` array which elements are ``True`` if they correspond to first token in a word, - ``'punct_labels'``: ``np.int32`` array with encoded punctuation labels, - ``'capit_labels'``: ``np.int32`` array with encoded capitalization labels, - ``'segment_ids'``: ``np.int8`` array filled with zeros (BERT token types in HuggingFace terminology), - ``'input_mask'``: ``bool`` array which elements are ``True`` if corresponding token is not a padding token, - ``'loss_mask'``: ``bool`` array which elements are ``True`` if loss is computed for corresponding token. See more in description of constructor parameters ``ignore_start_end``, ``ignore_extra_tokens``. """ return self._dataset.__iter__()
def __len__(self) -> int: return self.length
[docs] @staticmethod def collate_fn(batches: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]: """ Return zeroth batch of ``batches`` list passed for collating and casts ``'segment_ids'``, ``'punct_labels'``, ``'capit_labels'`` to types supported by :class:`~nemo.collections.nlp.models.token_classification.punctuation_capitalization_model.PunctuationCapitalizationModel`. All output tensors have shape ``[Batch, Time]``. .. warning:: ``batch size`` parameter of a PyTorch data loader and sampler has to be ``1``. Args: batches (:obj:`List[Dict[str, np.ndarray]]`): a list of batches passed for collating Returns: :obj:`Dict[str, torch.Tensor]`: a batch dictionary with following items (for detailed description of batch items see method :meth:`__getitem__`): - ``'input_ids'`` (:obj:`torch.Tensor`): :obj:`torch.int32` tensor, - ``'subtokens_mask'`` (:obj:`torch.Tensor`): :obj:`torch.bool` tensor, - ``'punct_labels'`` (:obj:`torch.Tensor`): :obj:`torch.int64` tensor, - ``'capit_labels'`` (:obj:`torch.Tensor`): :obj:`torch.int64` tensor, - ``'segment_ids'`` (:obj:`torch.Tensor`): :obj:`torch.int32` tensor, - ``'input_mask'`` (:obj:`torch.Tensor`): :obj:`torch.bool` tensor, - ``'loss_mask'`` (:obj:`torch.Tensor`): :obj:`torch.bool` tensor. """ batch = {k: torch.as_tensor(v) for k, v in batches[0].items()} batch['segment_ids'] = batch['segment_ids'].int() batch['punct_labels'] = batch['punct_labels'].long() batch['capit_labels'] = batch['capit_labels'].long() return batch