# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import fnmatch
import logging
import os
import tarfile
from typing import IO, Union
LOGGER = logging.getLogger("NeMo")
try:
from zarr.storage import BaseStore
HAVE_ZARR = True
except Exception as e:
LOGGER.warning(f"Cannot import zarr, support for zarr-based checkpoints is not available. {type(e).__name__}: {e}")
BaseStore = object
HAVE_ZARR = False
[docs]
class TarPath:
"""A class that represents a path inside a TAR archive and behaves like pathlib.Path.
Expected use is to create a TarPath for the root of the archive first, and then derive
paths to other files or directories inside the archive like so:
with TarPath('/path/to/archive.tar') as archive:
myfile = archive / 'filename.txt'
if myfile.exists():
data = myfile.read()
...
Only read and enumeration operations are supported.
"""
def __init__(self, tar: Union[str, tarfile.TarFile, "TarPath"], *parts):
self._needs_to_close = False
self._relpath = ""
if isinstance(tar, TarPath):
self._tar = tar._tar
self._relpath = os.path.join(tar._relpath, *parts)
elif isinstance(tar, tarfile.TarFile):
self._tar = tar
if parts:
self._relpath = os.path.join(*parts)
elif isinstance(tar, str):
self._needs_to_close = True
self._tar = tarfile.open(tar, "r")
if parts:
self._relpath = os.path.join(*parts)
else:
raise ValueError(f"Unexpected argument type for TarPath: {type(tar).__name__}")
[docs]
def __del__(self):
if self._needs_to_close:
self._tar.close()
[docs]
def __truediv__(self, key) -> "TarPath":
return TarPath(self._tar, os.path.join(self._relpath, key))
[docs]
def __str__(self) -> str:
return os.path.join(self._tar.name, self._relpath)
@property
def tarobject(self):
"""Returns the wrapped tar object."""
return self._tar
@property
def relpath(self):
"""Returns the relative path of the path."""
return self._relpath
@property
def name(self):
"""Returns the name of the path."""
return os.path.split(self._relpath)[1]
@property
def suffix(self):
"""Returns the suffix of the path."""
name = self.name
i = name.rfind(".")
if 0 < i < len(name) - 1:
return name[i:]
else:
return ""
[docs]
def __enter__(self):
self._tar.__enter__()
return self
[docs]
def __exit__(self, *args):
return self._tar.__exit__(*args)
[docs]
def exists(self):
"""Checks if the path exists."""
try:
self._tar.getmember(self._relpath)
return True
except KeyError:
try:
self._tar.getmember(os.path.join(".", self._relpath))
return True
except KeyError:
return False
[docs]
def is_file(self):
"""Checks if the path is a file."""
try:
self._tar.getmember(self._relpath).isreg()
return True
except KeyError:
try:
self._tar.getmember(os.path.join(".", self._relpath)).isreg()
return True
except KeyError:
return False
[docs]
def is_dir(self):
"""Checks if the path is a directory."""
try:
self._tar.getmember(self._relpath).isdir()
return True
except KeyError:
try:
self._tar.getmember(os.path.join(".", self._relpath)).isdir()
return True
except KeyError:
return False
[docs]
def open(self, mode: str) -> IO[bytes]:
"""Opens a file in the archive."""
if mode != "r" and mode != "rb":
raise NotImplementedError()
file = None
try:
# Try the relative path as-is first
file = self._tar.extractfile(self._relpath)
except KeyError:
try:
# Try the relative path with "./" prefix
file = self._tar.extractfile(os.path.join(".", self._relpath))
except KeyError:
raise FileNotFoundError()
if file is None:
raise FileNotFoundError()
return file
[docs]
def glob(self, pattern):
"""Returns an iterator over the files in the directory, matching the pattern."""
for member in self._tar.getmembers():
# Remove the "./" prefix, if any
name = member.name[2:] if member.name.startswith("./") else member.name
# If we're in a subdirectory, make sure the file is too, and remove that subdir component
if self._relpath:
if not name.startswith(self._relpath + "/"):
continue
name = name[len(self._relpath) + 1 :]
# See if the name matches the pattern
if fnmatch.fnmatch(name, pattern):
yield TarPath(self._tar, os.path.join(self._relpath, name))
[docs]
def rglob(self, pattern):
"""Returns an iterator over the files in the directory, including subdirectories."""
for member in self._tar.getmembers():
# Remove the "./" prefix, if any
name = member.name[2:] if member.name.startswith("./") else member.name
# If we're in a subdirectory, make sure the file is too, and remove that subdir component
if self._relpath:
if not name.startswith(self._relpath + "/"):
continue
name = name[len(self._relpath) + 1 :]
# See if any tail of the path matches the pattern, return full path if that's true
parts = name.split("/")
for i in range(len(parts)):
subname = "/".join(parts[i:])
if fnmatch.fnmatch(subname, pattern):
yield TarPath(self._tar, os.path.join(self._relpath, name))
break
[docs]
def iterdir(self):
"""Returns an iterator over the files in the directory."""
return self.glob("*")
[docs]
class ZarrPathStore(BaseStore):
"""An implementation of read-only Store for zarr library that works with pathlib.Path or TarPath objects."""
def __init__(self, tarpath: TarPath):
assert HAVE_ZARR, "Package zarr>=2.18.2,<3.0.0 is required to use ZarrPathStore"
self._path = tarpath
self._writable = False
self._erasable = False
[docs]
def __getitem__(self, key):
with (self._path / key).open("rb") as file:
return file.read()
[docs]
def __contains__(self, key):
return (self._path / key).is_file()
[docs]
def __iter__(self):
return self.keys()
[docs]
def __len__(self):
return sum(1 for _ in self.keys())
[docs]
def __setitem__(self, key, value):
raise NotImplementedError()
[docs]
def __delitem__(self, key):
raise NotImplementedError()
[docs]
def keys(self):
"""Returns an iterator over the keys in the store."""
return self._path.iterdir()