Modulus Utils
- class modulus.utils.capture.StaticCaptureEvaluateNoGrad(*args, **kwargs)[source]
Bases:
_StaticCapture
An performance optimization decorator for PyTorch no grad evaluation.
This class should be initialized as a decorator on a function that computes run the forward pass of the model that does not require gradient calculations. This is the recommended method to use for inference and validation methods.
- Parameters
model (modulus.models.Module) – Modulus Model
logger (Union[Logger, None], optional) – Modulus Launch Logger, by default None
use_graphs (bool, optional) – Toggle CUDA graphs if supported by model, by default True
use_amp (bool, optional) – Toggle AMP if supported by mode, by default True
cuda_graph_warmup (int, optional) – Number of warmup steps for cuda graphs, by default 11
amp_type (Union[float16, bfloat16], optional) – Auto casting type for AMP, by default torch.float16
label (Optional[str, None], optional) – Static capture checkpoint label, by default None
- Raises
ValueError – If the model provided is not a modulus.models.Module. I.e. has no meta data.
Example
>>> # Create model >>> model = modulus.models.mlp.FullyConnected(2, 64, 2) >>> input = torch.rand(8, 2) >>> # Create evaluate function with optimization wrapper >>> @StaticCaptureEvaluateNoGrad(model=model) ... def eval_step(model, invar): ... predvar = model(invar) ... return predvar ... >>> output = eval_step(model, input) >>> output.size() torch.Size([8, 2])
NoteCapturing multiple cuda graphs in a single program can lead to potential invalid CUDA memory access errors on some systems. Prioritize capturing training graphs when this occurs.
- class modulus.utils.capture.StaticCaptureTraining(*args, **kwargs)[source]
Bases:
_StaticCapture
A performance optimization decorator for PyTorch training functions.
This class should be initialized as a decorator on a function that computes the forward pass of the neural network and loss function. The user should only call the defind training step function. This will apply optimizations including: AMP and Cuda Graphs.
- Parameters
model (modulus.models.Module) – Modulus Model
optim (torch.optim) – Optimizer
logger (Union[Logger, None], optional) – Modulus Launch Logger, by default None
use_graphs (bool, optional) – Toggle CUDA graphs if supported by model, by default True
use_amp (bool, optional) – Toggle AMP if supported by mode, by default True
cuda_graph_warmup (int, optional) – Number of warmup steps for cuda graphs, by default 11
amp_type (Union[float16, bfloat16], optional) – Auto casting type for AMP, by default torch.float16
label (Optional[str, None], optional) – Static capture checkpoint label, by default None
- Raises
ValueError – If the model provided is not a modulus.models.Module. I.e. has no meta data.
Example
>>> # Create model >>> model = modulus.models.mlp.FullyConnected(2, 64, 2) >>> input = torch.rand(8, 2) >>> output = torch.rand(8, 2) >>> # Create optimizer >>> optim = torch.optim.Adam(model.parameters(), lr=0.001) >>> # Create training step function with optimization wrapper >>> @StaticCaptureTraining(model=model, optim=optim) ... def training_step(model, invar, outvar): ... predvar = model(invar) ... loss = torch.sum(torch.pow(predvar - outvar, 2)) ... return loss ... >>> # Sample training loop >>> for i in range(3): ... loss = training_step(model, input, output) ...
NoteStatic captures must be checkpointed when training using the state_dict() if AMP is being used with gradient scaler. By default, this requires static captures to be instantiated in the same order as when they were checkpointed. The label parameter can be used to relax/circumvent this ordering requirement.
NoteCapturing multiple cuda graphs in a single program can lead to potential invalid CUDA memory access errors on some systems. Prioritize capturing training graphs when this occurs.
- class modulus.utils.graphcast.data_utils.StaticData(static_dataset_path: str, latitudes: Tensor, longitudes: Tensor)[source]
Bases:
object
Class to load static data from netCDF files. Static data includes land-sea mask, geopotential, and latitude-longitude coordinates.
- Parameters
static_dataset_path (str) – Path to directory containing static data.
latitudes (Tensor) – Tensor with shape (lat,) that includes latitudes.
longitudes (Tensor) – Tensor with shape (lon,) that includes longitudes.
- get() → Tensor[source]
Get all static data.
- Returns
- Return type
Tensor with shape (1, 5, lat, lon) that includes land-sea mask, geopotential, cosine of latitudes, sine and cosine of longitudes.
Tensor
- get_geop(normalize: bool = True) → Tensor[source]
Get geopotential from netCDF file.
- Parameters
- Returns
- Return type
normalize (bool, optional) – Whether to normalize the geopotential, by default True
Normalized geopotential with shape (1, 1, lat, lon).
Tensor
- get_lat_lon() → Tensor[source]
Computes cosine of latitudes and sine and cosine of longitudes.
- Returns
- Return type
Tensor with shape (1, 3, lat, lon) tha includes cosine of latitudes, sine and cosine of longitudes.
Tensor
- get_lsm() → Tensor[source]
Get land-sea mask from netCDF file.
- Returns
- Return type
Land-sea mask with shape (1, 1, lat, lon).
Tensor
- class modulus.utils.graphcast.graph.Graph(icospheres_path: str, lat_lon_grid: Tensor, dtype=torch.float32)[source]
Bases:
object
Graph class for creating the graph2mesh, multimesh, and mesh2graph graphs.
- Parameters
icospheres_path (str) – Path to the icospheres json file. If the file does not exist, it will try to generate it using PyMesh.
lat_lon_grid (Tensor) – Tensor with shape (lat, lon, 2) that includes the latitudes and longitudes meshgrid.
dtype (torch.dtype, optional) – Data type of the graph, by default torch.float
- create_g2m_graph(verbose: bool = True) → Tensor[source]
Create the graph2mesh graph.
- Parameters
- Returns
- Return type
verbose (bool, optional) – verbosity, by default True
Graph2mesh graph.
DGLGraph
- create_m2g_graph(verbose: bool = True) → Tensor[source]
Create the mesh2grid graph.
- Parameters
- Returns
- Return type
verbose (bool, optional) – verbosity, by default True
Mesh2grid graph.
DGLGraph
- create_mesh_graph(verbose: bool = True) → Tensor[source]
Create the multimesh graph.
- Parameters
- Returns
- Return type
verbose (bool, optional) – verbosity, by default True
Multimesh graph.
DGLGraph
- modulus.utils.graphcast.graph_utils.add_edge_features(graph: DGLGraph, pos: Tensor, normalize: bool = True) → DGLGraph[source]
Adds edge features to the graph.
- Parameters
graph (DGLGraph) – The graph to add edge features to.
pos (Tensor) – The node positions.
normalize (bool, optional) – Whether to normalize the edge features, by default True
- Returns
- Return type
The graph with edge features.
DGLGraph
- modulus.utils.graphcast.graph_utils.add_node_features(graph: DGLGraph, pos: Tensor) → DGLGraph[source]
Adds cosine of latitude, sine and cosine of longitude as the node features to the graph.
- Parameters
graph (DGLGraph) – The graph to add node features to.
pos (Tensor) – The node positions.
- Returns
- Return type
graph – The graph with node features.
DGLGraph
- modulus.utils.graphcast.graph_utils.azimuthal_angle(lon: Tensor) → Tensor[source]
Gives the azimuthal angle of a point on the sphere
- Parameters
- Returns
- Return type
lon (Tensor) – Tensor of shape (N, ) containing the longitude of the point
Tensor of shape (N, ) containing the azimuthal angle
Tensor
- modulus.utils.graphcast.graph_utils.cell_to_adj(cells: List[List[int]])[source]
creates adjancy matrix in COO format from mesh cells
- Parameters
- Returns
- Return type
cells (List[List[int]]) – List of cells, each cell is a list of 3 vertices
src, dst – List of source and destination vertices
List[int], List[int]
- modulus.utils.graphcast.graph_utils.create_graph(src: List, dst: List, to_bidirected: bool = True, add_self_loop: bool = False, dtype: dtype = torch.int32) → DGLGraph[source]
Creates a DGL graph from an adj matrix in COO format.
- Parameters
src (List) – List of source nodes
dst (List) – List of destination nodes
to_bidirected (bool, optional) – Whether to make the graph bidirectional, by default True
add_self_loop (bool, optional) – Whether to add self loop to the graph, by default False
dtype (torch.dtype, optional) – Graph index data type, by default torch.int32
- Returns
- Return type
The dgl Graph.
DGLGraph
- modulus.utils.graphcast.graph_utils.create_heterograph(src: List, dst: List, labels: str, dtype: dtype = torch.int32) → DGLGraph[source]
Creates a heterogeneous DGL graph from an adj matrix in COO format.
- Parameters
src (List) – List of source nodes
dst (List) – List of destination nodes
labels (str) – Label of the edge type
dtype (torch.dtype, optional) – Graph index data type, by default torch.int32
- Returns
- Return type
The dgl Graph.
DGLGraph
- modulus.utils.graphcast.graph_utils.deg2rad(deg: Tensor) → Tensor[source]
Converts degrees to radians
- Parameters
- Returns
- Return type
deg – Tensor of shape (N, ) containing the degrees
Tensor of shape (N, ) containing the radians
Tensor
- modulus.utils.graphcast.graph_utils.geospatial_rotation(invar: Tensor, theta: Tensor, axis: str, unit: str = 'rad') → Tensor[source]
Rotation using right hand rule
- Parameters
invar (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates
theta (Tensor) – Tensor of shape (N, ) containing the rotation angle
axis (str) – Axis of rotation
unit (str, optional) – Unit of the theta, by default “rad”
- Returns
- Return type
Tensor of shape (N, 3) containing the rotated x, y, z coordinates
Tensor
- modulus.utils.graphcast.graph_utils.get_edge_len(edge_src: Tensor, edge_dst: Tensor, axis: int = 1)[source]
returns the length of the edge
- Parameters
edge_src (Tensor) – Tensor of shape (N, 3) containing the source of the edge
edge_dst (Tensor) – Tensor of shape (N, 3) containing the destination of the edge
axis (int, optional) – Axis along which the norm is computed, by default 1
- Returns
- Return type
Tensor of shape (N, ) containing the length of the edge
Tensor
- modulus.utils.graphcast.graph_utils.latlon2xyz(latlon: Tensor, radius: float = 1, unit: str = 'deg') → Tensor[source]
Converts latlon in degrees to xyz Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.
- Parameters
latlon (Tensor) – Tensor of shape (N, 2) containing latitudes and longitudes
radius (float, optional) – Radius of the sphere, by default 1
unit (str, optional) – Unit of the latlon, by default “deg”
- Returns
- Return type
Tensor of shape (N, 3) containing x, y, z coordinates
Tensor
- modulus.utils.graphcast.graph_utils.polar_angle(lat: Tensor) → Tensor[source]
Gives the polar angle of a point on the sphere
- Parameters
- Returns
- Return type
lat (Tensor) – Tensor of shape (N, ) containing the latitude of the point
Tensor of shape (N, ) containing the polar angle
Tensor
- modulus.utils.graphcast.graph_utils.rad2deg(rad)[source]
Converts radians to degrees
- Parameters
- Returns
- Return type
rad – Tensor of shape (N, ) containing the radians
Tensor of shape (N, ) containing the degrees
Tensor
- modulus.utils.graphcast.graph_utils.xyz2latlon(xyz: Tensor, radius: float = 1, unit: str = 'deg') → Tensor[source]
Converts xyz to latlon in degrees Based on: https://stackoverflow.com/questions/1185408 - The x-axis goes through long,lat (0,0); - The y-axis goes through (0,90); - The z-axis goes through the poles.
- Parameters
xyz (Tensor) – Tensor of shape (N, 3) containing x, y, z coordinates
radius (float, optional) – Radius of the sphere, by default 1
unit (str, optional) – Unit of the latlon, by default “deg”
- Returns
- Return type
Tensor of shape (N, 2) containing latitudes and longitudes
Tensor
- modulus.utils.graphcast.icospheres.generate_and_save_icospheres(save_path: str = 'icospheres.json', level: int = 6) → None[source]
enerate icospheres from level 0 to 6 (inclusive) and save them to a json file.
- Parameters
path (str) – Path to save the json file.
- class modulus.utils.graphcast.loss.CellAreaWeightedLossFunction(area)[source]
Bases:
Module
Loss function with cell area weighting.
- Parameters
area (torch.Tensor) – Cell area with shape [H, W].
- forward(invar, outvar)[source]
Implicit forward function which computes the loss given a prediction and the corresponding targets.
- Parameters
invar (torch.Tensor) – prediction of shape [T, C, H, W].
outvar (torch.Tensor) – target values of shape [T, C, H, W].
- class modulus.utils.graphcast.loss.CustomCellAreaWeightedLossAutogradFunction(*args, **kwargs)[source]
Bases:
Function
Autograd fuunction for custom loss with cell area weighting.
- static backward(ctx, grad_loss: Tensor)[source]
Backward method of custom loss function with cell area weighting.
- static forward(ctx, invar: Tensor, outvar: Tensor, area: Tensor)[source]
Forward of custom loss function with cell area weighting.
- class modulus.utils.graphcast.loss.CustomCellAreaWeightedLossFunction(area: Tensor)[source]
Bases:
CellAreaWeightedLossFunction
Custom loss function with cell area weighting.
- Parameters
area (torch.Tensor) – Cell area with shape [H, W].
- forward(invar: Tensor, outvar: Tensor) → Tensor[source]
Implicit forward function which computes the loss given a prediction and the corresponding targets.
- Parameters
invar (torch.Tensor) – prediction of shape [T, C, H, W].
outvar (torch.Tensor) – target values of shape [T, C, H, W].
- class modulus.utils.filesystem.Package(root: str, seperator: str = '/')[source]
Bases:
object
A generic file system abstraction. Can be used to represent local and remote file systems. Remote files are automatically fetched and stored in the $LOCAL_CACHE or $HOME/.cache/modulus folder. The get method can then be used to access files present.
Presently one can use Package with the following directories: - Package(“/path/to/local/directory”) = local file system - Package(“s3://bucket/path/to/directory”) = object store file system - Package(“http://url/path/to/directory”) = http file system - Package(“ngc://model/<org_id/team_id/model_id>@<version>”) = ngc model file system
- Parameters
root (str) – Root directory for file system
seperator (str, optional) – directory seperator. Defaults to “/”.
- get(path: str, recursive: bool = False) → str[source]
Get a local path to the item at
path
path
might be a remote file, in which case it is downloaded to a local cache at $LOCAL_CACHE or $HOME/.cache/modulus first.