Source code for nemo_automodel.checkpoint._backports._fsspec_filesystem

# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# taken from https://github.com/pytorch/pytorch/blob/156b28e62a225214a5685088a848c5efe6a4b95e/torch/distributed/checkpoint/_fsspec_filesystem.py # pylint: disable=line-too-long
# pylint: disable=missing-class-docstring, missing-function-docstring

import io
import os
from collections.abc import Generator, Sequence
from contextlib import contextmanager
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Union

from fsspec.core import url_to_fs
from torch.distributed.checkpoint._extension import StreamTransformExtension

from nemo_automodel.checkpoint._backports.filesystem import (
    FileSystemBase,
    FileSystemReader,
    FileSystemWriter,
    SerializationFormat,
)


if TYPE_CHECKING:
    from fsspec import AbstractFileSystem


__all__ = [
    "FsspecWriter",
    "FsspecReader",
]


[docs] class FileSystem(FileSystemBase): def __init__(self) -> None: self.fs: Optional[AbstractFileSystem] = None
[docs] @contextmanager def create_stream(self, path: Union[str, os.PathLike], mode: str) -> Generator[io.IOBase, None, None]: assert self.fs is not None path = os.fspath(path) # fsspec does not support concurrent transactions, and not all # AbstractFileSystem have working rollback implementations, so # just manually delete the file if necessary on errors. with self.fs.open(path, mode) as stream: try: yield stream except: # noqa: B001,E722 if any(ch in mode for ch in "w+a"): # cleanup file if not read-only try: self.rm_file(path) except: # noqa: B001,E722 pass raise
[docs] def concat_path(self, path: Union[str, os.PathLike], suffix: str) -> Union[str, os.PathLike]: return os.path.join(path, suffix)
[docs] def init_path(self, path: Union[str, os.PathLike], **kwargs) -> Union[str, os.PathLike]: self.fs, _ = url_to_fs(path, **kwargs) return path
[docs] def rename(self, path: Union[str, os.PathLike], new_path: Union[str, os.PathLike]) -> None: self.fs.rename(path, new_path)
[docs] def mkdir(self, path: Union[str, os.PathLike]) -> None: self.fs.makedirs(path, exist_ok=True)
[docs] @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: if isinstance(checkpoint_id, Path): return False try: url_to_fs(checkpoint_id) except ValueError: return False return True
[docs] def exists(self, path: Union[str, os.PathLike]) -> bool: return self.fs.exists(path)
[docs] def rm_file(self, path: Union[str, os.PathLike]) -> None: self.fs.rm(path)
[docs] def ls(self, path: Union[str, os.PathLike]) -> list[str]: # setting detail to False explictly to keep the list[str] return type, # instead of the list[Dict] return type when detail=True return self.fs.ls(path, detail=False)
# TODO: add the dcp.async_save mixin
[docs] class FsspecWriter(FileSystemWriter): """ Basic implementation of StorageWriter using FFspec. This implementation makes the following assumptions and simplifications: * The checkpoint path is an empty or non-existing directory. * File creation is atomic The checkpoint consist of one file per write request plus a `.metadata` file with the serialized metadata. """ def __init__( self, path: Union[str, os.PathLike], single_file_per_rank: bool = True, sync_files: bool = True, thread_count: int = 1, per_thread_copy_ahead: int = 10_000_000, overwrite: bool = True, _extensions: Optional[Sequence[StreamTransformExtension]] = None, serialization_format: SerializationFormat = SerializationFormat.TORCH_SAVE, **kwargs, ) -> None: """ Initialize the writer pointing to `path`. Args: path: directory where the checkpoint will be written to. single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. sync_files : force files to be synced to permanent storage. Default to True. thread_count: Number of IO threads to use to write. Default to 1. per_thread_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. overwrite: Whether to allow overwriting existing checkpoints. Defaults to True. _extensions: Extensions to apply to output streams (EXPERIMENTAL) N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. """ super().__init__( path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead, overwrite=overwrite, _extensions=_extensions, serialization_format=serialization_format, ) self.fs = FileSystem() self.path = self.fs.init_path(path, **kwargs)
[docs] @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return FileSystem.validate_checkpoint_id(checkpoint_id)
[docs] class FsspecReader(FileSystemReader): def __init__(self, path: Union[str, os.PathLike], **kwargs) -> None: super().__init__(path) self.fs = FileSystem() self.path = self.fs.init_path(path, **kwargs)
[docs] @classmethod def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool: return FileSystem.validate_checkpoint_id(checkpoint_id)