HealDA — AI-based Data Assimilation on the HEALPix Grid#
🏗️ This recipe is under active construction. 🏗️ Structure and functionality are subject to changes.
HealDA is a stateless assimilation model that produces a single global weather analysis from conventional and satellite observations. It operates on a HEALPix level-6 padded XY grid and outputs ERA5-compatible atmospheric variables.
Setup#
Start by installing PhysicsNeMo (if not already installed) with the healda optional dependency group, along with the packages in requirements.txt. Then, copy this folder (examples/weather/healda) to a system with a GPU available. Also, prepare a dataset that can serve training data according to the protocols outlined in the Generalized Data Loading section below.
Normalization statistics#
Per-sensor observation stats (configs/normalizations/*.csv) and the ERA5 stats table (configs/era5_13_levels_stats.csv) ship with this recipe rather than the installed package, since they are training-set specific. Point the datapipe at them by setting HEALDA_STATS_DIR before importing physicsnemo.experimental.datapipes.healda:
export HEALDA_STATS_DIR=$(pwd)/configs
If unset, sensor configs fall back to zero-mean / unit-std (useful for tests and structural checks); the ERA5 stats loader will raise instead, since there is no sensible default.
Generalized Data Loading#
The physicsnemo.experimental.datapipes.healda package provides a composable data loading pipeline with clear extension points. The architecture separates components into loaders, transforms, datasets, and sampling infrastructure.
Architecture#
ObsERA5Dataset(era5_data, obs_loader, transform)
| Temporal windowing via FrameIndexGenerator
| __getitems__ -> get() per index -> transform.transform()
v
RestartableDistributedSampler (stateful distributed sampling with checkpointing)
|
DataLoader (pin_memory, persistent_workers)
|
prefetch_map(loader, transform.device_transform)
|
Training loop (GPU-ready batch)
Key Protocols#
Custom data sources and transforms plug in via these protocols (see physicsnemo.experimental.datapipes.healda.protocols):
``ObsLoader`` — the observation loading interface:
class MyObsLoader:
async def sel_time(self, times):
"""Return {"obs": [pa.Table, ...]}"""
...
``Transform`` / ``DeviceTransform`` — two-stage batch processing:
class MyTransform:
def transform(self, times, frames):
"""CPU-side: normalize, encode obs, time features."""
...
def device_transform(self, batch, device):
"""GPU-side: move to device, compute obs features."""
...
Provided Implementations#
Component |
Module |
Description |
|---|---|---|
|
|
ERA5 state + observations |
|
|
Parquet obs loader |
|
|
Async ERA5 zarr loader |
|
|
Two-stage transform |
|
|
Stateful distributed sampler |
|
|
CUDA stream prefetching |
All modules above are under physicsnemo.experimental.datapipes.healda.
Writing a Custom Observation Loader#
Implement async def sel_time(times) returning a dict with observation data per timestamp:
class GOESRadianceLoader:
def __init__(self, data_path, channels):
self.data_path = data_path
self.channels = channels
async def sel_time(self, times):
tables = []
for t in times:
table = self._load_goes_radiances(t)
tables.append(table)
return {"obs": tables}
Then pass it to the dataset:
from physicsnemo.experimental.datapipes.healda import (
ObsERA5Dataset,
)
from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import (
ERA5ObsTransform,
)
from physicsnemo.experimental.datapipes.healda.configs.variable_configs import (
VARIABLE_CONFIGS,
)
dataset = ObsERA5Dataset(
era5_data=era5_xr["data"],
obs_loader=GOESRadianceLoader(...),
transform=ERA5ObsTransform(sensors=["goes"], ...),
variable_config=VARIABLE_CONFIGS["era5"],
)
Putting It Together#
A complete training pipeline wires together all the components — dataset, sampler, DataLoader, and GPU prefetch:
import torch
from torch.utils.data import DataLoader
from physicsnemo.experimental.datapipes.healda import (
ObsERA5Dataset,
RestartableDistributedSampler,
identity_collate,
prefetch_map,
)
from physicsnemo.experimental.datapipes.healda.loaders.ufs_obs import (
UFSUnifiedLoader,
)
from physicsnemo.experimental.datapipes.healda.transforms.era5_obs import (
ERA5ObsTransform,
)
from physicsnemo.experimental.datapipes.healda.configs.variable_configs import (
VARIABLE_CONFIGS,
)
sensors = ["atms", "mhs", "conv"]
# 1. Build loaders
obs_loader = UFSUnifiedLoader(
data_path="/path/to/processed_obs",
sensors=sensors,
obs_context_hours=(-21, 3),
)
transform = ERA5ObsTransform(
variable_config=VARIABLE_CONFIGS["era5"],
sensors=sensors,
)
# 2. Build dataset
dataset = ObsERA5Dataset(
era5_data=era5_xr["data"],
obs_loader=obs_loader,
transform=transform,
variable_config=VARIABLE_CONFIGS["era5"],
split="train",
)
# 3. Sampler + DataLoader
sampler = RestartableDistributedSampler(
dataset, rank=rank, num_replicas=world_size,
)
sampler.set_epoch(0)
dataloader = DataLoader(
dataset,
sampler=sampler,
batch_size=2,
num_workers=8,
collate_fn=identity_collate,
pin_memory=True,
persistent_workers=True,
)
# 4. GPU prefetch (hides CPU→GPU transfer behind training)
device = torch.device("cuda")
loader = prefetch_map(
dataloader,
lambda batch: transform.device_transform(batch, device),
queue_size=1,
)
# 5. Training loop — batches arrive GPU-ready
for batch in loader:
loss = model(batch)
...