NVIDIA Modulus Core (Latest Release)
Core (Latest Release)

deeplearning/modulus/modulus-core/_modules/modulus/datapipes/climate/climate.html

Source code for modulus.datapipes.climate.climate

# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
from abc import ABC, abstractmethod
from datetime import datetime, timedelta
from itertools import chain

import h5py
import netCDF4 as nc
import numpy as np
import torch

try:
    import nvidia.dali as dali
    import nvidia.dali.plugin.pytorch as dali_pth
except ImportError:
    raise ImportError(
        "DALI dataset requires NVIDIA DALI package to be installed. "
        + "The package can be installed at:\n"
        + "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
    )

from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, List, Mapping, Tuple, Union

from scipy.io import netcdf_file

from modulus.datapipes.climate.utils.invariant import latlon_grid
from modulus.datapipes.climate.utils.zenith_angle import cos_zenith_angle
from modulus.datapipes.datapipe import Datapipe
from modulus.datapipes.meta import DatapipeMetaData
from modulus.launch.logging import PythonLogger

Tensor = torch.Tensor


[docs]@dataclass class MetaData(DatapipeMetaData): name: str = "Climate" # Optimization auto_device: bool = True cuda_graphs: bool = True # Parallel ddp_sharding: bool = True
[docs]class ClimateDataSourceSpec: """ A data source specification for ClimateDatapipe. HDF5 files should contain the following variable with the corresponding name: `fields`: Tensor of shape (num_timesteps, num_channels, height, width), containing climate data. The order of the channels should match the order of the channels in the statistics files. The statistics files should be `.npy` files with the shape (1, num_channels, 1, 1). The names of the variables are found in the metadata file found in `metadata_path`. NetCDF4 files should contain a variable of shape (num_timesteps, height, width) for each variable they provide. Only the variables listed in `variables` will be loaded. Parameters ---------- data_dir : str Directory where climate data is stored name: Union[str, None], optional The name that is used to label datapipe outputs from this source. If None, the datapipe uses the number of the source in sequential order. file_type: str Type of files to read, supported values are "hdf5" (default) and "netcdf4" stats_files: Union[Mapping[str, str], None], optional Numpy files to data statistics for normalization. Supports either a channels format, in which case the dict should contain the keys "mean" and "std", or a named-variable format, in which case the dict should contain the key "norm" . If None, no normalization will be used, by default None metadata_path: Union[Mapping[str, str], None], optional for NetCDF, required for HDF5 Path to the metadata JSON file for the dataset (usually called data.json). channels : Union[List[int], None], optional Defines which climate variables to load, if None will use all in HDF5 file, by default None variables: Union[List[str], None], optional for HDF5 files, mandatory for NetCDF4 files List of named variables to load. Variables will be read in the order specified by this parameter. Must be used for NetCDF4 files. Supported for HDF5 files in which case it will override `channels`. use_cos_zenith: bool, optional If True, the cosine zenith angles corresponding to the coordinates of this data source will be produced, default False aux_variables : Union[Mapping[str, Callable], None], optional A dictionary mapping strings to callables that accept arguments (timestamps: numpy.ndarray, latlon: numpy.ndarray). These define any auxiliary variables returned from this source. num_steps : int, optional Number of timesteps to return, by default 1 stride : int, optional Number of steps between input and output variables. For example, if the dataset contains data at every 6 hours, a stride 1 = 6 hour delta t and stride 2 = 12 hours delta t, by default 1 """ def __init__( self, data_dir: str, name: Union[str, None] = None, file_type: str = "hdf5", stats_files: Union[Mapping[str, str], None] = None, metadata_path: Union[str, None] = None, channels: Union[List[int], None] = None, variables: Union[List[str], None] = None, use_cos_zenith: bool = False, aux_variables: Union[Mapping[str, Callable], None] = None, num_steps: int = 1, stride: int = 1, backend_kwargs: Union[dict, None] = None, ): self.data_dir = Path(data_dir) self.name = name self.file_type = file_type self.stats_files = ( {k: Path(fn) for (k, fn) in stats_files.items()} if stats_files is not None else None ) self.metadata_path = Path(metadata_path) if metadata_path is not None else None self.channels = channels self.variables = variables self.use_cos_zenith = use_cos_zenith self.aux_variables = aux_variables if aux_variables is not None else {} self.num_steps = num_steps self.stride = stride self.backend_kwargs = {} if backend_kwargs is None else backend_kwargs self.logger = PythonLogger() if file_type == "netcdf4" and not variables: raise ValueError("Variables must be specified for a NetCDF4 source.") # check root directory exists if not self.data_dir.is_dir(): raise IOError(f"Error, data directory {self.data_dir} does not exist") if self.stats_files is None: self.logger.warning( "Warning, no stats files specified, this will result in no normalisation" )
[docs] def dimensions_compatible(self, other) -> bool: """ Basic sanity check to test if two `ClimateDataSourceSpec` are compatible. """ return ( self.data_shape == other.data_shape and self.cropped_data_shape == other.cropped_data_shape and self.num_samples_per_year == other.num_samples_per_year and self.total_length == other.total_length and self.n_years == other.n_years )
[docs] def parse_dataset_files( self, num_samples_per_year: Union[int, None] = None, patch_size: Union[int, None] = None, ) -> None: """Parses the data directory for valid files and determines training samples Parameters ---------- num_samples_per_year : int, optional Number of samples taken from each year. If None, all will be used, by default None patch_size : Union[Tuple[int, int], int, None], optional If specified, crops input and output variables so image dimensions are divisible by patch_size, by default None Raises ------ ValueError In channels specified or number of samples per year is not valid """ # get all input data files suffix = {"hdf5": "h5", "netcdf4": "nc"}[self.file_type] self.data_paths = sorted(self.data_dir.glob(f"*.{suffix}")) for data_path in self.data_paths: self.logger.info(f"Climate data file found: {data_path}") self.n_years = len(self.data_paths) self.logger.info(f"Number of years: {self.n_years}") # get total number of examples and image shape from the first file, # assuming other files have exactly the same format. self.logger.info(f"Getting file stats from {self.data_paths[0]}") if self.file_type == "hdf5": with h5py.File(self.data_paths[0], "r") as f: dataset_shape = f["fields"].shape else: with nc.Dataset(self.data_paths[0], "r") as f: var_shape = f[self.variables[0]].shape dataset_shape = (var_shape[0], len(self.variables)) + var_shape[1:] # truncate the dataset to avoid out-of-range sampling data_samples_per_year = dataset_shape[0] - (self.num_steps - 1) * self.stride self.data_shape = dataset_shape[2:] # interpret list of variables into list of channels or vice versa if self.file_type == "hdf5": with open(self.metadata_path, "r") as f: metadata = json.load(f) data_vars = metadata["coords"]["channel"] if self.variables is not None: self.channels = [data_vars.index(v) for v in self.variables] else: if self.channels is None: self.variables = data_vars else: self.variables = [data_vars[i] for i in self.channels] # If channels not provided, use all of them if self.channels is None: self.channels = list(range(dataset_shape[1])) # If num_samples_per_year use all if num_samples_per_year is None: num_samples_per_year = data_samples_per_year self.num_samples_per_year = num_samples_per_year # Adjust image shape if patch_size defined if patch_size is not None: self.cropped_data_shape = tuple( s - s % patch_size[i] for i, s in enumerate(self.data_shape) ) else: self.cropped_data_shape = self.data_shape self.logger.info(f"Input data shape: {self.cropped_data_shape}") # Get total length self.total_length = self.n_years * self.num_samples_per_year # Sanity checks if max(self.channels) >= dataset_shape[1]: raise ValueError( f"Provided channel has indexes greater than the number \ of fields {dataset_shape[1]}" ) if self.num_samples_per_year > data_samples_per_year: raise ValueError( f"num_samples_per_year ({self.num_samples_per_year}) > number of \ samples available ({data_samples_per_year})!" ) self._load_statistics() self.logger.info(f"Number of samples/year: {self.num_samples_per_year}") self.logger.info(f"Number of channels available: {dataset_shape[1]}")

def _load_statistics(self) -> None: """Loads climate statistics from pre-computed numpy files The statistic files should be of name global_means.npy and global_std.npy with a shape of [1, C, 1, 1] located in the stat_dir. Raises ------ IOError If statistics files are not found AssertionError If loaded numpy arrays are not of correct size """ # If no stats files we just skip loading the stats if self.stats_files is None: self.mu = None self.sd = None return # load normalisation values if set(self.stats_files) == {"mean", "std"}: # use mean and std files mean_stat_file = self.stats_files["mean"] std_stat_file = self.stats_files["std"] if not mean_stat_file.exists(): raise IOError(f"Mean statistics file {mean_stat_file} not found") if not std_stat_file.exists(): raise IOError(f"Std statistics file {std_stat_file} not found") # has shape [1, C, 1, 1] self.mu = np.load(str(mean_stat_file))[:, self.channels] # has shape [1, C, 1, 1] self.sd = np.load(str(std_stat_file))[:, self.channels] elif set(self.stats_files) == { "norm", }: # use dict formatted file with named variables norm_stat_file = self.stats_files["norm"] if not norm_stat_file.exists(): raise IOError(f"Statistics file {norm_stat_file} not found") norm = np.load(str(norm_stat_file), allow_pickle=True).item() mu = np.array([norm[var]["mean"] for var in self.variables]) self.mu = mu.reshape((1, len(mu), 1, 1)) sd = np.array([norm[var]["std"] for var in self.variables]) self.sd = sd.reshape((1, len(sd), 1, 1)) else: raise ValueError(("Invalid statistics file specification")) if not self.mu.shape == self.sd.shape == (1, len(self.channels), 1, 1): raise ValueError("Error, normalisation arrays have wrong shape")

[docs]class ClimateDatapipe(Datapipe): """ A Climate DALI data pipeline. This pipeline loads data from HDF5/NetCDF4 files. It can also return additional data such as the solar zenith angle for each time step. Additionally, it normalizes the data if a statistics file is provided. The pipeline returns a dictionary with the following structure, where {name} indicates the name of the data source provided: - `state_seq-{name}`: Tensors of shape (batch_size, num_steps, num_channels, height, width). This sequence is drawn from the data file and normalized if a statistics file is provided. - `timestamps-{name}`: Tensors of shape (batch_size, num_steps), containing timestamps for each timestep in the sequence. - `{aux_variable}-{name}`: Tensors of shape (batch_size, num_steps, aux_channels, height, width), containing the auxiliary variables returned by each data source - `cos_zenith-{name}`: Tensors of shape (batch_size, num_steps, 1, height, width), containing the cosine of the solar zenith angle if specified. - `{invariant_name}: Tensors of shape (batch_size, invariant_channels, height, width), containing the time-invariant data (depending only on spatial coordinates) returned by the datapipe. These can include e.g. land-sea mask and geopotential/surface elevation. To use this data pipeline, your data directory must be structured as follows: ``` data_dir ├── 1980.h5 ├── 1981.h5 ├── 1982.h5 ├── ... └── 2020.h5 ``` The files are assumed have no metadata, such as timestamps. Because of this, it's important to specify the `dt` parameter and the `start_year` parameter so that the pipeline can compute the correct timestamps for each timestep. These timestamps are then used to compute the cosine of the solar zenith angle, if specified. Parameters ---------- sources: Iterable[ClimateDataSpec] A list of data specifications defining the sources for the climate variables batch_size : int, optional Batch size, by default 1 dt : float, optional Time in hours between each timestep in the dataset, by default 6 hr start_year : int, optional Start year of dataset, by default 1980 latlon_bounds : Tuple[Tuple[float, float], Tuple[float, float]], optional Bounds of latitude and longitude in the data, in the format ((lat_start, lat_end,), (lon_start, lon_end)). By default ((90, -90), (0, 360)). crop_window: Union[Tuple[Tuple[float, float], Tuple[float, float]], None], optional The window to crop the data to, in the format ((i0,i1), (j0,j1)) where the first spatial dimension will be cropped to i0:i1 and the second to j0:j1. If not given, all data will be used. invariants : Mapping[str,Callable], optional Specifies the time-invariant data (for example latitude and longitude) included in the data samples. Should be a dict where the keys are the names of the invariants and the values are the corresponding functions. The functions need to accept an argument of the shape (2, data_shape[0], data_shape[1]) where the first dimension contains latitude and longitude in degrees and the other dimensions corresponding to the shape of data in the data files. For example, invariants={"trig_latlon": invariants.LatLon()} will include the sin/cos of lat/lon in the output. num_samples_per_year : int, optional Number of samples taken from each year. If None, all will be used, by default None shuffle : bool, optional Shuffle dataset, by default True num_workers : int, optional Number of workers, by default 1 device: Union[str, torch.device], optional Device for DALI pipeline to run on, by default cuda process_rank : int, optional Rank ID of local process, by default 0 world_size : int, optional Number of training processes, by default 1 """ def __init__( self, sources: Iterable[ClimateDataSourceSpec], batch_size: int = 1, dt: float = 6.0, start_year: int = 1980, latlon_bounds: Tuple[Tuple[float, float], Tuple[float, float]] = ( (90, -90), (0, 360), ), crop_window: Union[ Tuple[Tuple[float, float], Tuple[float, float]], None ] = None, invariants: Union[Mapping[str, Callable], None] = None, num_samples_per_year: Union[int, None] = None, shuffle: bool = True, num_workers: int = 1, # TODO: is there a faster good default? device: Union[str, torch.device] = "cuda", process_rank: int = 0, world_size: int = 1, ): super().__init__(meta=MetaData()) self.sources = list(sources) self.batch_size = batch_size self.num_workers = num_workers self.shuffle = shuffle self.dt = dt self.start_year = start_year self.data_latlon_bounds = latlon_bounds self.process_rank = process_rank self.world_size = world_size self.num_samples_per_year = num_samples_per_year self.logger = PythonLogger() if invariants is None: invariants = {} # Determine outputs of pipeline self.pipe_outputs = [] for (i, spec) in enumerate(self.sources): name = spec.name if spec.name is not None else i self.pipe_outputs += [f"state_seq-{name}", f"timestamps-{name}"] self.pipe_outputs.extend( f"{aux_var}-{name}" for aux_var in spec.aux_variables ) if spec.use_cos_zenith: self.pipe_outputs.append(f"cos_zenith-{name}") self.pipe_outputs.extend(invariants.keys()) # Set up device, needed for pipeline if isinstance(device, str): device = torch.device(device) # Need a index id if cuda if device.type == "cuda" and device.index is None: device = torch.device("cuda:0") self.device = device # Load all data files and statistics for spec in sources: spec.parse_dataset_files(num_samples_per_year=num_samples_per_year) for (i, spec_i) in enumerate(sources): for spec_j in sources[i + 1 :]: if not spec_i.dimensions_compatible(spec_j): raise ValueError("Incompatible data sources") self.data_latlon = np.stack( latlon_grid(bounds=self.data_latlon_bounds, shape=sources[0].data_shape), axis=0, ) if crop_window is None: crop_window = ( (0, sources[0].cropped_data_shape[0]), (0, sources[0].cropped_data_shape[1]), ) self.crop_window = crop_window self.window_latlon = self._crop_to_window(self.data_latlon) self.window_latlon_dali = dali.types.Constant(self.window_latlon) # load invariants self.invariants = { var: callback(self.window_latlon) for (var, callback) in invariants.items() } # Create pipeline self.pipe = self._create_pipeline() def _source_cls_from_type(self, source_type: str) -> type: """Get the external source class based on a string descriptor.""" return { "hdf5": ClimateHDF5DaliExternalSource, "netcdf4": ClimateNetCDF4DaliExternalSource, }[source_type] def _crop_to_window(self, x): cw = self.crop_window if isinstance(x, dali.pipeline.DataNode): # DALI doesn't support ellipsis notation return x[:, :, cw[0][0] : cw[0][1], cw[1][0] : cw[1][1]] else: return x[..., cw[0][0] : cw[0][1], cw[1][0] : cw[1][1]] def _source_outputs(self, spec: ClimateDataSourceSpec) -> List: """Create DALI outputs for a given data source specification. Parameters ---------- spec: ClimateDataSourceSpec The data source specification. """ # HDF5/NetCDF source source_cls = self._source_cls_from_type(spec.file_type) source = source_cls( data_paths=spec.data_paths, num_samples=spec.total_length, channels=spec.channels, latlon=self.data_latlon, variables=spec.variables, aux_variables=spec.aux_variables, stride=spec.stride, dt=self.dt, start_year=self.start_year, num_steps=spec.num_steps, num_samples_per_year=spec.num_samples_per_year, batch_size=self.batch_size, shuffle=self.shuffle, process_rank=self.process_rank, world_size=self.world_size, ) # Update length of dataset self.total_length = len(source) // self.batch_size # Read current batch (state_seq, timestamps, *aux) = dali.fn.external_source( source, num_outputs=source.num_outputs(), parallel=True, batch=False, ) # Crop state_seq = self._crop_to_window(state_seq) aux = (self._crop_to_window(x) for x in aux) # Normalize if spec.stats_files is not None: state_seq = dali.fn.normalize(state_seq, mean=spec.mu, stddev=spec.sd) # Make output list outputs = [state_seq, timestamps, *aux] # Get cosine zenith angle if spec.use_cos_zenith: cos_zenith = dali.fn.cast( cos_zenith_angle(timestamps, latlon=self.window_latlon_dali), dtype=dali.types.FLOAT, ) outputs.append(cos_zenith) return outputs def _invariant_outputs(self): for inv in self.invariants.values(): if self.crop_window is not None: inv = self._crop_to_window(inv) yield dali.types.Constant(inv) def _create_pipeline(self) -> dali.Pipeline: """Create DALI pipeline Returns ------- dali.Pipeline Climate DALI pipeline """ pipe = dali.Pipeline( batch_size=self.batch_size, num_threads=2, prefetch_queue_depth=2, py_num_workers=self.num_workers, device_id=self.device.index, py_start_method="spawn", ) with pipe: # Concatenate outputs from all sources as well as invariants outputs = list( chain( *(self._source_outputs(spec) for spec in self.sources), self._invariant_outputs(), ) ) if self.device.type == "cuda": # Move tensors to GPU as external_source won't do that outputs = [o.gpu() for o in outputs] # Set outputs pipe.set_outputs(*outputs) return pipe def __iter__(self): # Reset the pipeline before creating an iterator to enable epochs. self.pipe.reset() # Create DALI PyTorch iterator. return dali_pth.DALIGenericIterator([self.pipe], self.pipe_outputs) def __len__(self): return self.total_length
[docs]class ClimateDaliExternalSource(ABC): """DALI Source for lazy-loading the HDF5/NetCDF4 climate files Parameters ---------- data_paths : Iterable[str] Directory where climate data is stored num_samples : int Total number of training samples channels : Iterable[int] List representing which climate variables to load num_steps : int Number of timesteps to load stride : int Number of steps between input and output variables dt : float, optional Time in hours between each timestep in the dataset, by default 6 hr start_year : int, optional Start year of dataset, by default 1980 num_samples_per_year : int Number of samples randomly taken from each year variables: Union[List[str], None], optional for HDF5 files, mandatory for NetCDF4 files List of named variables to load. Variables will be read in the order specified by this parameter. aux_variables : Union[Mapping[str, Callable], None], optional A dictionary mapping strings to callables that accept arguments (timestamps: numpy.ndarray, latlon: numpy.ndarray). These define any auxiliary variables returned from this source. batch_size : int, optional Batch size, by default 1 shuffle : bool, optional Shuffle dataset, by default True process_rank : int, optional Rank ID of local process, by default 0 world_size : int, optional Number of training processes, by default 1 Note ---- For more information about DALI external source operator: https://docs.nvidia.com/deeplearning/dali/archives/dali_1_13_0/user-guide/docs/examples/general/data_loading/parallel_external_source.html """ def __init__( self, data_paths: Iterable[str], num_samples: int, channels: Iterable[int], num_steps: int, stride: int, dt: float, start_year: int, num_samples_per_year: int, latlon: np.ndarray, variables: Union[List[str], None] = None, aux_variables: List[Union[str, Callable]] = (), batch_size: int = 1, shuffle: bool = True, process_rank: int = 0, world_size: int = 1, backend_kwargs: Union[dict, None] = None, ): self.data_paths = list(data_paths) # Will be populated later once each worker starts running in its own process. self.data_files = [None] * len(self.data_paths) self.num_samples = num_samples self.chans = list(channels) self.latlon = latlon self.variables = variables self.aux_variables = aux_variables self.num_steps = num_steps self.stride = stride self.dt = dt self.start_year = start_year self.num_samples_per_year = num_samples_per_year self.batch_size = batch_size self.shuffle = shuffle self.backend_kwargs = {} if backend_kwargs is None else backend_kwargs self.last_epoch = None self.indices = np.arange(num_samples) # Shard from indices if running in parallel self.indices = np.array_split(self.indices, world_size)[process_rank] # Get number of full batches, ignore possible last incomplete batch for now. # Also, DALI external source does not support incomplete batches in parallel mode. self.num_batches = len(self.indices) // self.batch_size @abstractmethod def _load_sequence(self, year_idx: int, idx: int) -> np.array: """Write data from year index `year_idx` and sample index `idx` to output""" pass def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, np.ndarray]: if sample_info.iteration >= self.num_batches: raise StopIteration() # Shuffle before the next epoch starts if self.shuffle and sample_info.epoch_idx != self.last_epoch: # All workers use the same rng seed so the resulting # indices are the same across workers np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices) self.last_epoch = sample_info.epoch_idx # Get local indices from global index # TODO: This is very hacky, but it works for now idx = self.indices[sample_info.idx_in_epoch] year_idx = idx // self.num_samples_per_year in_idx = idx % self.num_samples_per_year state_seq = self._load_sequence(year_idx, in_idx) # Load sequence of timestamps year = self.start_year + year_idx start_time = datetime(year, 1, 1) + timedelta(hours=int(in_idx) * self.dt) timestamps = np.array( [ (start_time + timedelta(hours=i * self.stride * self.dt)).timestamp() for i in range(self.num_steps) ] ) # outputs from auxiliary sources aux_outputs = ( callback(timestamps, self.latlon) for callback in self.aux_variables.values() ) return (state_seq, timestamps, *aux_outputs) def num_outputs(self): return 2 + len(self.aux_variables) def __len__(self): return len(self.indices)
[docs]class ClimateHDF5DaliExternalSource(ClimateDaliExternalSource): """DALI source for reading HDF5 formatted climate data files.""" def _get_data_file(self, year_idx: int) -> h5py.File: """Return the opened file for year `year_idx`.""" if self.data_files[year_idx] is None: # This will be called once per worker. Workers are persistent, # so there is no need to explicitly close the files - this will be done # when corresponding pipeline/dataset is destroyed. # Lazy opening avoids unnecessary file open ops when sharding. self.data_files[year_idx] = h5py.File(self.data_paths[year_idx], "r") return self.data_files[year_idx] def _load_sequence(self, year_idx: int, idx: int) -> np.array: # TODO: the data is returned in a weird (time, channels, width, height) shape data = self._get_data_file(year_idx)["fields"] return data[idx : idx + self.num_steps * self.stride : self.stride, self.chans]
[docs]class ClimateNetCDF4DaliExternalSource(ClimateDaliExternalSource): """DALI source for reading NetCDF4 formatted climate data files.""" def _get_data_file(self, year_idx: int) -> netcdf_file: """Return the opened file for year `year_idx`.""" if self.data_files[year_idx] is None: # This will be called once per worker. Workers are persistent, # so there is no need to explicitly close the files - this will be done # when corresponding pipeline/dataset is destroyed # Lazy opening avoids unnecessary file open ops when sharding. # NOTE: The SciPy NetCDF reader can be used if the netCDF4 library # causes crashes. reader = self.backend_kwargs.get("reader", "netcdf4") if reader == "scipy": self.data_files[year_idx] = netcdf_file(self.data_paths[year_idx]) elif reader == "netcdf4": self.data_files[year_idx] = nc.Dataset(self.data_paths[year_idx], "r") self.data_files[year_idx].set_auto_maskandscale(False) return self.data_files[year_idx] def _load_sequence(self, year_idx: int, idx: int) -> np.array: data_file = self._get_data_file(year_idx) shape = data_file.variables[self.variables[0]].shape shape = (self.num_steps, len(self.variables)) + shape[1:] # TODO: this can be optimized to do the NetCDF scale/offset on GPU output = np.empty(shape, dtype=np.float32) for (i, var) in enumerate(self.variables): v = data_file.variables[var] output[:, i] = v[ idx : idx + self.num_steps * self.stride : self.stride ].copy() # .copy() avoids hanging references if hasattr(v, "scale_factor"): output[:, i] *= v.scale_factor if hasattr(v, "add_offset"): output[:, i] += v.add_offset return output
© Copyright 2023, NVIDIA Modulus Team. Last updated on Jul 25, 2024.