Modulus Utils
- class modulus.utils.capture.StaticCaptureEvaluateNoGrad(*args, **kwargs)[source]
Bases:
_StaticCapture
An performance optimization decorator for PyTorch no grad evaluation.
This class should be initialized as a decorator on a function that computes run the forward pass of the model that does not require gradient calculations. This is the recommended method to use for inference and validation methods.
- Parameters
model (modulus.models.Module) – Modulus Model
logger (Optional[Logger], optional) – Modulus Launch Logger, by default None
use_graphs (bool, optional) – Toggle CUDA graphs if supported by model, by default True
use_amp (bool, optional) – Toggle AMP if supported by mode, by default True
cuda_graph_warmup (int, optional) – Number of warmup steps for cuda graphs, by default 11
amp_type (Union[float16, bfloat16], optional) – Auto casting type for AMP, by default torch.float16
label (Optional[str], optional) – Static capture checkpoint label, by default None
- Raises
ValueError – If the model provided is not a modulus.models.Module. I.e. has no meta data.
Example
>>> # Create model >>> model = modulus.models.mlp.FullyConnected(2, 64, 2) >>> input = torch.rand(8, 2) >>> # Create evaluate function with optimization wrapper >>> @StaticCaptureEvaluateNoGrad(model=model) ... def eval_step(model, invar): ... predvar = model(invar) ... return predvar ... >>> output = eval_step(model, input) >>> output.size() torch.Size([8, 2])
NoteCapturing multiple cuda graphs in a single program can lead to potential invalid CUDA memory access errors on some systems. Prioritize capturing training graphs when this occurs.
- class modulus.utils.capture.StaticCaptureTraining(*args, **kwargs)[source]
Bases:
_StaticCapture
A performance optimization decorator for PyTorch training functions.
This class should be initialized as a decorator on a function that computes the forward pass of the neural network and loss function. The user should only call the defind training step function. This will apply optimizations including: AMP and Cuda Graphs.
- Parameters
model (modulus.models.Module) – Modulus Model
optim (torch.optim) – Optimizer
logger (Optional[Logger], optional) – Modulus Launch Logger, by default None
use_graphs (bool, optional) – Toggle CUDA graphs if supported by model, by default True
use_amp (bool, optional) – Toggle AMP if supported by mode, by default True
cuda_graph_warmup (int, optional) – Number of warmup steps for cuda graphs, by default 11
amp_type (Union[float16, bfloat16], optional) – Auto casting type for AMP, by default torch.float16
gradient_clip_norm (Optional[float], optional) – Threshold for gradient clipping
label (Optional[str], optional) – Static capture checkpoint label, by default None
- Raises
ValueError – If the model provided is not a modulus.models.Module. I.e. has no meta data.
Example
>>> # Create model >>> model = modulus.models.mlp.FullyConnected(2, 64, 2) >>> input = torch.rand(8, 2) >>> output = torch.rand(8, 2) >>> # Create optimizer >>> optim = torch.optim.Adam(model.parameters(), lr=0.001) >>> # Create training step function with optimization wrapper >>> @StaticCaptureTraining(model=model, optim=optim) ... def training_step(model, invar, outvar): ... predvar = model(invar) ... loss = torch.sum(torch.pow(predvar - outvar, 2)) ... return loss ... >>> # Sample training loop >>> for i in range(3): ... loss = training_step(model, input, output) ...
NoteStatic captures must be checkpointed when training using the state_dict() if AMP is being used with gradient scaler. By default, this requires static captures to be instantiated in the same order as when they were checkpointed. The label parameter can be used to relax/circumvent this ordering requirement.
NoteCapturing multiple cuda graphs in a single program can lead to potential invalid CUDA memory access errors on some systems. Prioritize capturing training graphs when this occurs.
- class modulus.utils.graphcast.data_utils.StaticData(static_dataset_path: str, latitudes: Tensor, longitudes: Tensor)[source]
Bases:
object
Class to load static data from netCDF files. Static data includes land-sea mask, geopotential, and latitude-longitude coordinates.
- Parameters
static_dataset_path (str) – Path to directory containing static data.
latitudes (Tensor) – Tensor with shape (lat,) that includes latitudes.
longitudes (Tensor) – Tensor with shape (lon,) that includes longitudes.
- get() → Tensor[source]
Get all static data.
- Returns
Tensor with shape (1, 5, lat, lon) that includes land-sea mask, geopotential, cosine of latitudes, sine and cosine of longitudes.
- Return type
Tensor
- get_geop(normalize: bool = True) → Tensor[source]
Get geopotential from netCDF file.
- Parameters
normalize (bool, optional) – Whether to normalize the geopotential, by default True
- Returns
Normalized geopotential with shape (1, 1, lat, lon).
- Return type
Tensor
- get_lat_lon() → Tensor[source]
Computes cosine of latitudes and sine and cosine of longitudes.
- Returns
Tensor with shape (1, 3, lat, lon) tha includes cosine of latitudes, sine and cosine of longitudes.
- Return type
Tensor
- get_lsm() → Tensor[source]
Get land-sea mask from netCDF file.
- Returns
Land-sea mask with shape (1, 1, lat, lon).
- Return type
Tensor
- class modulus.utils.graphcast.graph.Graph(lat_lon_grid: Tensor, mesh_level: int = 6, multimesh: bool = True, khop_neighbors: int = 0, dtype=torch.float32)[source]
Bases:
object
Graph class for creating the graph2mesh, latent mesh, and mesh2graph graphs.
- Parameters
lat_lon_grid (Tensor) – Tensor with shape (lat, lon, 2) that includes the latitudes and longitudes meshgrid.
mesh_level (int, optional) – Level of the latent mesh, by default 6
multimesh (bool, optional) – If the latent mesh is a multimesh, by default True If True, the latent mesh includes the nodes corresponding to the specified mesh_level`and incorporates the edges from all mesh levels ranging from level 0 up to and including `mesh_level.
khop_neighbors (int, optional) – This option is used to retrieve a list of indices for the k-hop neighbors of all mesh nodes. It is applicable when a graph transformer is used as the processor. If set to 0, this list is not computed. If a message passing processor is used, it is forced to 0. By default 0.
dtype (torch.dtype, optional) – Data type of the graph, by default torch.float
- create_g2m_graph(verbose: bool = True) → Tensor[source]
Create the graph2mesh graph.
- Parameters
verbose (bool, optional) – verbosity, by default True
- Returns
Graph2mesh graph.
- Return type
DGLGraph
- create_m2g_graph(verbose: bool = True) → Tensor[source]
Create the mesh2grid graph.
- Parameters
verbose (bool, optional) – verbosity, by default True
- Returns
Mesh2grid graph.
- Return type
DGLGraph
- create_mesh_graph(verbose: bool = True) → Tensor[source]
Create the multimesh graph.
- Parameters
verbose (bool, optional) – verbosity, by default True
- Returns
Multimesh graph
- Return type
DGLGraph
- modulus.utils.graphcast.graph_utils.add_edge_features(graph: DGLGraph, pos: Tensor, normalize: bool = True) → DGLGraph[source]
Adds edge features to the graph.
- Parameters
graph (DGLGraph) – The graph to add edge features to.
pos (Tensor) – The node positions.
normalize (bool, optional) – Whether to normalize the edge features, by default True
- Returns
The graph with edge features.
- Return type
DGLGraph
- modulus.utils.graphcast.graph_utils.add_node_features(graph: DGLGraph, pos: Tensor) → DGLGraph[source]
Adds cosine of latitude, sine and cosine of longitude as the node features to the graph.
- Parameters
graph (DGLGraph) – The graph to add node features to.
pos (Tensor) – The node positions.
- Returns
graph – The graph with node features.
- Return type
DGLGraph
- modulus.utils.graphcast.graph_utils.azimuthal_angle(lon: Tensor) → Tensor[source]
Gives the azimuthal angle of a point on the sphere
- Parameters
lon (Tensor) – Tensor of shape (N, ) containing the longitude of the point
- Returns
Tensor of shape (N, ) containing the azimuthal angle
- Return type
Tensor
- modulus.utils.graphcast.graph_utils.cell_to_adj(cells: List[List[int]])[source]
creates adjancy matrix in COO format from mesh cells
- Parameters
cells (List[List[int]]) – List of cells, each cell is a list of 3 vertices
- Returns
src, dst – List of source and destination vertices
- Return type
List[int], List[int]
- modulus.utils.graphcast.graph_utils.create_graph(src: List, dst: List, to_bidirected: bool = True, add_self_loop: bool = False, dtype: dtype = torch.int32) → DGLGraph[source]
Creates a DGL graph from an adj matrix in COO format.
- Parameters
src (List) – List of source nodes
dst (List) – List of destination nodes
to_bidirected (bool, optional) – Whether to make the graph bidirectional, by default True
add_self_loop (bool, optional) – Whether to add self loop to the graph, by default False
dtype (torch.dtype, optional) – Graph index data type, by default torch.int32
- Returns
The dgl Graph.
- Return type
DGLGraph
- modulus.utils.graphcast.graph_utils.create_heterograph(src: List, dst: List, labels: str, dtype: dtype = torch.int32, num_nodes_dict: Optional[dict] = None) → DGLGraph[source]
Creates a heterogeneous DGL graph from an adj matrix in COO format.
- Parameters
src (List) – List of source nodes
dst (List) – List of destination nodes
labels (str) – Label of the edge type
dtype (torch.dtype, optional) – Graph index data type, by default torch.int32
num_nodes_dict (dict, optional) – number of nodes for some node types, see dgl.heterograph for more information
- Returns
The dgl Graph.
- Return type
DGLGraph
- modulus.utils.graphcast.graph_utils.deg2rad(deg: Tensor) → Tensor[source]
Converts degrees to radians
- Parameters
deg – Tensor of shape (N, ) containing the degrees
- Returns
Tensor of shape (N, ) containing the radians
- Return type
Tensor
- modulus.utils.graphcast.graph_utils.geospatial_rotation(invar: Tensor, theta: Tensor, axis: str, unit: str = 'rad') → Tensor[source]
Rotation using right hand rule
- Parameters
invar (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates
theta (Tensor) – Tensor of shape (N, ) containing the rotation angle
axis (str) – Axis of rotation
unit (str, optional) – Unit of the theta, by default “rad”
- Returns
Tensor of shape (N, 3) containing the rotated x, y, z coordinates
- Return type
Tensor
- modulus.utils.graphcast.graph_utils.get_face_centroids(vertices: List[Tuple[float, float, float]], faces: List[List[int]]) → List[Tuple[float, float, float]][source]
Compute the centroids of triangular faces in a graph.
Parameters: vertices (List[Tuple[float, float, float]]): A list of tuples representing the coordinates of the vertices. faces (List[List[int]]): A list of lists, where each inner list contains three indices representing a triangular face.
Returns: List[Tuple[float, float, float]]: A list of tuples representing the centroids of the faces.
- modulus.utils.graphcast.graph_utils.latlon2xyz(latlon: Tensor, radius: float = 1, unit: str = 'deg') → Tensor[source]
Converts latlon in degrees to xyz Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.
- Parameters
latlon (Tensor) – Tensor of shape (N, 2) containing latitudes and longitudes
radius (float, optional) – Radius of the sphere, by default 1
unit (str, optional) – Unit of the latlon, by default “deg”
- Returns
Tensor of shape (N, 3) containing x, y, z coordinates
- Return type
Tensor
- modulus.utils.graphcast.graph_utils.max_edge_length(vertices: List[List[float]], source_nodes: List[int], destination_nodes: List[int]) → float[source]
Compute the maximum edge length in a graph.
Parameters: vertices (List[List[float]]): A list of tuples representing the coordinates of the vertices. source_nodes (List[int]): A list of indices representing the source nodes of the edges. destination_nodes (List[int]): A list of indices representing the destination nodes of the edges.
Returns: The maximum edge length in the graph (float).
- modulus.utils.graphcast.graph_utils.polar_angle(lat: Tensor) → Tensor[source]
Gives the polar angle of a point on the sphere
- Parameters
lat (Tensor) – Tensor of shape (N, ) containing the latitude of the point
- Returns
Tensor of shape (N, ) containing the polar angle
- Return type
Tensor
- modulus.utils.graphcast.graph_utils.rad2deg(rad)[source]
Converts radians to degrees
- Parameters
rad – Tensor of shape (N, ) containing the radians
- Returns
Tensor of shape (N, ) containing the degrees
- Return type
Tensor
- modulus.utils.graphcast.graph_utils.xyz2latlon(xyz: Tensor, radius: float = 1, unit: str = 'deg') → Tensor[source]
Converts xyz to latlon in degrees Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.
- Parameters
xyz (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates
radius (float, optional) – Radius of the sphere, by default 1
unit (str, optional) – Unit of the latlon, by default “deg”
- Returns
Tensor of shape (N, 2) containing latitudes and longitudes
- Return type
Tensor
- class modulus.utils.graphcast.loss.CellAreaWeightedLossFunction(area)[source]
Bases:
Module
Loss function with cell area weighting.
- Parameters
area (torch.Tensor) – Cell area with shape [H, W].
- forward(invar, outvar)[source]
Implicit forward function which computes the loss given a prediction and the corresponding targets.
- Parameters
invar (torch.Tensor) – prediction of shape [T, C, H, W].
outvar (torch.Tensor) – target values of shape [T, C, H, W].
- class modulus.utils.graphcast.loss.CustomCellAreaWeightedLossAutogradFunction(*args, **kwargs)[source]
Bases:
Function
Autograd fuunction for custom loss with cell area weighting.
- static backward(ctx, grad_loss: Tensor)[source]
Backward method of custom loss function with cell area weighting.
- static forward(ctx, invar: Tensor, outvar: Tensor, area: Tensor)[source]
Forward of custom loss function with cell area weighting.
- class modulus.utils.graphcast.loss.CustomCellAreaWeightedLossFunction(area: Tensor)[source]
Bases:
<a href="#modulus.utils.graphcast.loss.CellAreaWeightedLossFunction">CellAreaWeightedLossFunction</a>
Custom loss function with cell area weighting.
- Parameters
area (torch.Tensor) – Cell area with shape [H, W].
- forward(invar: Tensor, outvar: Tensor) → Tensor[source]
Implicit forward function which computes the loss given a prediction and the corresponding targets.
- Parameters
invar (torch.Tensor) – prediction of shape [T, C, H, W].
outvar (torch.Tensor) – target values of shape [T, C, H, W].
- class modulus.utils.graphcast.loss.GraphCastLossFunction(area, channels_list, dataset_metadata_path, time_diff_std_path)[source]
Bases:
Module
Loss function as specified in GraphCast. :param area: Cell area with shape [H, W]. :type area: torch.Tensor
- assign_atmosphere_weights()[source]
Assigns weights to atmospheric variables
- assign_surface_weights()[source]
Assigns weights to surface variables
- assign_variable_weights()[source]
assigns per-variable per-pressure level weights
- calculate_linear_weights(variables)[source]
Calculate weights for each variable group.
- forward(invar, outvar)[source]
Implicit forward function which computes the loss given a prediction and the corresponding targets. :param invar: prediction of shape [T, C, H, W]. :type invar: torch.Tensor :param outvar: target values of shape [T, C, H, W]. :type outvar: torch.Tensor
- get_channel_dict(dataset_metadata_path, channels_list)[source]
Gets lists of surface and atmospheric channels
- get_time_diff_std(time_diff_std_path, channels_list)[source]
Gets the time difference standard deviation
- parse_variable(variable_list)[source]
Parse variable into its letter and numeric parts.
- class modulus.utils.filesystem.Package(root: str, seperator: str = '/')[source]
Bases:
object
A generic file system abstraction. Can be used to represent local and remote file systems. Remote files are automatically fetched and stored in the $LOCAL_CACHE or $HOME/.cache/modulus folder. The get method can then be used to access files present.
Presently one can use Package with the following directories: - Package(“/path/to/local/directory”) = local file system - Package(“s3://bucket/path/to/directory”) = object store file system - Package(“http://url/path/to/directory”) = http file system - Package(“ngc://model/<org_id/team_id/model_id>@<version>”) = ngc model file system
- Parameters
root (str) – Root directory for file system
seperator (str, optional) – directory seperator. Defaults to “/”.
- get(path: str, recursive: bool = False) → str[source]
Get a local path to the item at
path
path
might be a remote file, in which case it is downloaded to a local cache at $LOCAL_CACHE or $HOME/.cache/modulus first.
Miscellaneous utility classes and functions.
- class modulus.utils.generative.utils.EasyDict[source]
Bases:
dict
Convenience class that behaves like a dict but allows access with the attribute syntax.
- class modulus.utils.generative.utils.InfiniteSampler(dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5)[source]
Bases:
Sampler
Sampler for torch.utils.data.DataLoader that loops over the dataset indefinitely, shuffling items as it goes.
- class modulus.utils.generative.utils.StackedRandomGenerator(device, seeds)[source]
Bases:
object
Wrapper for torch.Generator that allows specifying a different random seed for each sample in a minibatch.
- modulus.utils.generative.utils.assert_shape(tensor, ref_shape)[source]
Assert that the shape of a tensor matches the given list of integers. None indicates that the size of a dimension is allowed to vary. Performs symbolic assertion when used in torch.jit.trace().
- modulus.utils.generative.utils.call_func_by_name(*args, func_name: Optional[str] = None, **kwargs) → Any[source]
Finds the python object with the given name and calls it as a function.
- modulus.utils.generative.utils.check_ddp_consistency(module, ignore_regex=None)[source]
Check DistributedDataParallel consistency across processes.
- modulus.utils.generative.utils.constant(value, shape=None, dtype=None, device=None, memory_format=None)[source]
Cached construction of constant tensors
- modulus.utils.generative.utils.construct_class_by_name(*args, class_name: Optional[str] = None, **kwargs) → Any[source]
Finds the python class with the given name and constructs it with the given arguments.
- modulus.utils.generative.utils.convert_datetime_to_cftime(time: ~datetime.datetime, cls=<class 'cftime._cftime.DatetimeGregorian'>) → DatetimeGregorian[source]
Convert a Python datetime object to a cftime DatetimeGregorian object.
- modulus.utils.generative.utils.copy_files_and_create_dirs(files: List[Tuple[str, str]]) → None[source]
Takes in a list of tuples of (src, dst) paths and copies files. Will create all necessary directories.
- modulus.utils.generative.utils.copy_params_and_buffers(src_module, dst_module, require_all=False)[source]
Copy parameters and buffers from a source module to target module
- modulus.utils.generative.utils.ddp_sync(module, sync)[source]
Context manager for easily enabling/disabling DistributedDataParallel synchronization.
- modulus.utils.generative.utils.format_time(seconds: Union[int, float]) → str[source]
Convert the seconds to human readable string with days, hours, minutes and seconds.
- modulus.utils.generative.utils.format_time_brief(seconds: Union[int, float]) → str[source]
Convert the seconds to human readable string with days, hours, minutes and seconds.
- modulus.utils.generative.utils.get_dtype_and_ctype(type_obj: Any) → Tuple[dtype, Any][source]
Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes.
- modulus.utils.generative.utils.get_module_dir_by_obj_name(obj_name: str) → str[source]
Get the directory path of the module containing the given object name.
- modulus.utils.generative.utils.get_module_from_obj_name(obj_name: str) → Tuple[module, str][source]
Searches for the underlying module behind the name to some python object. Returns the module and the object name (original name with module part removed).
- modulus.utils.generative.utils.get_obj_by_name(name: str) → Any[source]
Finds the python object with the given name.
- modulus.utils.generative.utils.get_obj_from_module(module: module, obj_name: str) → Any[source]
Traverses the object name and returns the last (rightmost) python object.
- modulus.utils.generative.utils.get_top_level_function_name(obj: Any) → str[source]
Return the fully-qualified name of a top-level function.
- modulus.utils.generative.utils.is_top_level_function(obj: Any) → bool[source]
Determine whether the given object is a top-level function, i.e., defined at module scope using ‘def’.
- modulus.utils.generative.utils.list_dir_recursively_with_ignore(dir_path: str, ignores: Optional[List[str]] = None, add_base_to_relative: bool = False) → List[Tuple[str, str]][source]
List all files recursively in a given directory while ignoring given file and directory names. Returns list of tuples containing both absolute and relative paths.
- modulus.utils.generative.utils.named_params_and_buffers(module)[source]
Get named parameters and buffers of a nn.Module
- modulus.utils.generative.utils.params_and_buffers(module)[source]
Get parameters and buffers of a nn.Module
- modulus.utils.generative.utils.parse_int_list(s)[source]
Parse a comma separated list of numbers or ranges and return a list of ints. Example: ‘1,2,5-10’ returns [1, 2, 5, 6, 7, 8, 9, 10]
- modulus.utils.generative.utils.print_module_summary(module, inputs, max_nesting=3, skip_redundant=True)[source]
Print summary table of module hierarchy.
- modulus.utils.generative.utils.profiled_function(fn)[source]
Function decorator that calls torch.autograd.profiler.record_function().
- modulus.utils.generative.utils.suppress_tracer_warnings()[source]
Context manager to temporarily suppress known warnings in torch.jit.trace(). Note: Cannot use catch_warnings because of https://bugs.python.org/issue29672
- modulus.utils.generative.utils.time_range(start_time: datetime, end_time: datetime, step: timedelta, inclusive: bool = False)[source]
Like the Python range iterator, but with datetimes.
- modulus.utils.generative.utils.tuple_product(t: Tuple) → Any[source]
Calculate the product of the tuple elements.
- modulus.utils.neighbor_list.radius_search(points: <warp.types.array object at 0x7f25aafdc3d0>, queries: <warp.types.array object at 0x7f25aacdf1c0>, radius: float, grid_dim: int | tuple[int, int, int] = (128, 128, 128), device: str = 'cuda')[source]
Performs a radius search for each query point within a specified radius, using a hash grid for efficient spatial querying.
- Parameters
points – An array of points in space.
queries – An array of query points.
radius – The search radius around each query point.
grid_dim – The dimensions of the hash grid, either as an integer or a tuple of three integers.
device – The device (e.g., ‘cuda’ or ‘cpu’) on which computations are performed.
- Returns
A tuple containing the indices of neighboring points, their distances to the query points, and an offset array for result indexing.
- modulus.utils.insolation.insolation(dates, lat, lon, scale=1.0, daily=False, enforce_2d=False, clip_zero=True)[source]
Calculate the approximate solar insolation for given dates.
For an example reference, see: https://brian-rose.github.io/ClimateLaboratoryBook/courseware/insolation.html
- Parameters
dates (np.ndarray) –
dates – 1d array: datetime or Timestamp
lat (np.ndarray) – 1d or 2d array of latitudes
lon (np.ndarray) – 1d or 2d array of longitudes (0-360deg). If 2d, must match the shape of lat.
scale (float, optional) – scaling factor (solar constant)
daily (bool, optional) – if True, return the daily max solar radiation (lat and day of year dependent only)
enforce_2d (bool, optional) – if True and lat/lon are 1-d arrays, turns them into 2d meshes.
clip_zero (bool, optional) – if True, set values below 0 to 0
- Returns
np.ndarray
- Return type
insolation (date, lat, lon)