Quickstart Guide for NVIDIA Earth-2 Correction Diffusion NIM#

Use this documentation to get started with NVIDIA Earth-2 Correction Diffusion (CorrDiff) NIM.

Important

Before you can use this documentation, you must satisfy all prerequisites.

Warning

This quick start guide is designed for use with Nvidia’s A100 or H100 GPUs. GPUs with lower VRAM and/or compute may require EARTH2NIM_TARGET_BATCHSIZE to be lowered or the time out in the post request to be increased. See the configuration page for different deployment options.

Launching the NIM#

  1. Pull the NIM container with the following command.

Note

The container is ~26GB (uncompressed) and the download time depends on internet connection speeds.

docker pull nvcr.io/nim/nvidia/corrdiff:1.0.0
  1. Run the NIM container with the following command. This command starts the NIM container and exposes port 8000 for the user to interact with the NIM. It pulls the model on the local filesystem.

Note

The model is ~4GB and the download time depends on internet connection speeds.

export NGC_API_KEY=<NGC API Key>

docker run --rm --runtime=nvidia --gpus all --shm-size 2g \
    -p 8000:8000 \
    -e NGC_API_KEY \
    -t nvcr.io/nim/nvidia/corrdiff:1.0.0

After the NIM is running, you see output similar to the following:

I0829 05:18:42.152983 108 grpc_server.cc:2466] "Started GRPCInferenceService at 0.0.0.0:8001"
I0829 05:18:42.153277 108 http_server.cc:4638] "Started HTTPService at 0.0.0.0:8090"
I0829 05:18:42.222628 108 http_server.cc:320] "Started Metrics Service at 0.0.0.0:8002"

Checking NIM Health#

  1. Open a new terminal, leaving the current terminal open with the launched service.

  2. Wait until the health check end point returns {"status":"ready"} before proceeding. This might take a couple of minutes. Use the following methods to query the health check.

Bash#

curl -X 'GET' \
    'http://localhost:8000/v1/health/ready' \
    -H 'accept: application/json'

Python#

import requests

r = requests.get("http://localhost:8000/v1/health/ready")
if r.status_code == 200:
   print("NIM is healthy!")
else:
   print("NIM is not ready!")

Fetching Input Data#

The CorrDiff NIM launches with a default model profile trained for downscaling over the contiguous United States. The input to this model is GEFS weather data at 0.25-degree (~25km) resolution over the United States, sourced from NOAA. CorrDiff use this input data to generate downscaled weather fields at a 3km resolution, similar to NOAA’s HRRR model. For additional details, refer to the model card.

To simplify data preparation, we use Earth2Studio to fetch and format the input data as follows:

  1. Fetch a set of select GEFS variable at a resolution 0.25-degrees on a lat-lon grid.

  2. Fetch a set of regular GEFS variables at a resolution of 0.5-degrees and interpolate to 0.25 degree grid on a lat-lon grid.

  3. Crop both variables to a bounded box over the United States.

  4. Concatenate both sets of variables as well as an integer field that denotes the forecast lead time the input represents.

Use the script below to download and structure the data:

from datetime import datetime, timedelta

import numpy as np
import torch
from earth2studio.data import GEFS_FX, HRRR, GEFS_FX_721x1440

GEFS_SELECT_VARIABLES = [
    "u10m",
    "v10m",
    "t2m",
    "r2m",
    "sp",
    "msl",
    "tcwv",
]

GEFS_VARIABLES = [
    "u1000",
    "u925",
    "u850",
    "u700",
    "u500",
    "u250",
    "v1000",
    "v925",
    "v850",
    "v700",
    "v500",
    "v250",
    "z1000",
    "z925",
    "z850",
    "z700",
    "z500",
    "z200",
    "t1000",
    "t925",
    "t850",
    "t700",
    "t500",
    "t100",
    "r1000",
    "r925",
    "r850",
    "r700",
    "r500",
    "r100",
]

ds_gefs = GEFS_FX(cache=True)
ds_gefs_select = GEFS_FX_721x1440(cache=True, product="gec00")


def fetch_input_gefs(
    time: datetime, lead_time: timedelta, content_dtype: str = "float32"
):
    """Fetch input GEFS data and place into a single numpy array

    Parameters
    ----------
    time : datetime
        Time stamp to fetch
    lead_time : timedelta
        Lead time to fetch
    filename : str
        File name to save input array to
    content_dtype : str
        Numpy dtype to save numpy
    """
    dtype = np.dtype(getattr(np, content_dtype))
    # Fetch high-res select GEFS input data
    select_data = ds_gefs_select(time, lead_time, GEFS_SELECT_VARIABLES)
    select_data = select_data.values
    # Crop to bounding box [225, 21, 300, 53]
    select_data = select_data[:, 0, :, 148:277, 900:1201].astype(dtype)
    assert select_data.shape == (1, len(GEFS_SELECT_VARIABLES), 129, 301)

    # Fetch GEFS input data
    pressure_data = ds_gefs(time, lead_time, GEFS_VARIABLES)
    # Interpolate to 0.25 grid
    pressure_data = torch.nn.functional.interpolate(
        torch.Tensor(pressure_data.values),
        (len(GEFS_VARIABLES), 721, 1440),
        mode="nearest",
    )
    pressure_data = pressure_data.numpy()
    # Crop to bounding box [225, 21, 300, 53]
    pressure_data = pressure_data[:, 0, :, 148:277, 900:1201].astype(dtype)
    assert pressure_data.shape == (1, len(GEFS_VARIABLES), 129, 301)

    # Create lead time field
    lead_hour = int(lead_time.total_seconds() // (3 * 60 * 60)) * np.ones(
        (1, 1, 129, 301)
    ).astype(dtype)

    input_data = np.concatenate([select_data, pressure_data, lead_hour], axis=1)[None]
    return input_data 


input_array = fetch_input_gefs(datetime(2023, 1, 1), timedelta(hours=0))
np.save("corrdiff_inputs.npy", input_array)

Running the script will generate the NumPy array corrdiff_inputs.npy, which is now ready to use with the model.

Inference Request#

With the pre-processed input for this CorrDiff US model, the API to the NIM can now be used as follows. The NumPy array is sent to the NIM as a file along with several other parameters that can be used to control the underlying diffusion model. For complete documentation on the API specification of the NIM, visit the API documentation page.

Bash#

curl -X POST \
    -F "input_array=@corrdiff_inputs.npy" \
    -F "samples=2" \
    -F "steps=12" \
    -o output.tar \
    http://localhost:8000/v1/infer

Python#

import requests

url = "http://localhost:8000/v1/infer"
files = {
    "input_array": ("input_array", open("corrdiff_inputs.npy", "rb")),
}
data = {
    "samples": 2,
    "steps": 14,
    "seed": 0,
}
headers = {
    "accept": "application/x-tar",
}
print("Sending post request to NIM")

# Adjust time out (seconds) below
r = requests.post(url, headers=headers, data=data, files=files, timeout=180)

if r.status_code != 200:
    raise Exception(r.content)
else:
    # Dump response to file
    with open("output.tar", "wb") as tar:
        tar.write(r.content)

Results#

The results are inside a tar file that can be explored using:

tar -tvf output.tar

-rw-r--r-- 0/0        60555392 1970-01-01 00:00 000_000.npy
-rw-r--r-- 0/0        60555392 1970-01-01 00:00 001_000.npy

Tip

The tar archive is populated with NumPy arrays that have the naming convention {sample index}_{batch index}.npy.

Important

For most inference calls, it is important to specify a longer timeout period for a request to this NIM. Refer to the performance data for estimated inference speeds on different GPU models.

Streaming Responses#

The CorrDiff NIM can be used to stream back samples as they are generated in instances where the client requests a larger sample size than the NIMs inference batch size. This can be extremely useful when creating a pipeline that runs following processes on each time-step. The follow snippet can be used to access each time-step, in memory, from the NIM as it is generated.

Note

The environment variable EARTH2NIM_TARGET_BATCHSIZE governs the batch size used inside the NIM. When this value is lower than the requested sample size, streaming the results back can be useful. See the configuration guide for more details on this parameter.

import io
import tarfile
from pathlib import Path

import numpy as np
import requests
import tqdm

url = f"http://localhost:8000/v1/infer"
files = {
    "input_array": ("input_array", open("corrdiff_inputs.npy", "rb")),
}
data = {
    "samples": 16,
    "steps": 8,
    "seed": 0,
}
headers = {
    "accept": "application/x-tar",
}

pbar = tqdm.tqdm(range(data["samples"]), desc="CorrDiff samples")
with (
    requests.post(
        url,
        headers=headers,
        files=files,
        data=data,
        timeout=300,
        stream=True,
    )
) as resp:
    resp.raise_for_status()
    with tarfile.open(fileobj=resp.raw, mode="r|") as tar_stream:
        # Loop over file members from tar stream
        for member in tar_stream:
            arr_file = io.BytesIO()
            arr_file.write(tar_stream.extractfile(member).read())
            arr_file.seek(0)
            data = np.load(arr_file)
            # Array names are in the form <BATCH_IDX>_<SAMPLE_IDX>.npy
            arr_sample_idx, _ = (
                int(x) for x in Path(member.name).stem.split("_", maxsplit=1)
            )

            pbar.write(f"Received data for sample {arr_sample_idx}")
            pbar.write(f"Output numpy {member.name} with shape {data.shape}")
            pbar.update(1)

Post Processing#

The final step is to perform some basic visualization of the results. The surface temperature is plotted from both the input and output arrays.

Note

The CorrDiff NIM returns the raw output data array. Additional metadata must be processed client side using information in the model card. The latitude and longitude coordinate values can be obtained from the corrdiff_output_lat.npy and corrdiff_output_lon.npy files in the model package.

import matplotlib.pyplot as plt
import numpy as np
import tarfile
import io

# Output variables: [u10m, v10m, t2m, tp, csnow, cicep, cfrzr, crain]
variable = 2

input = np.load("corrdiff_inputs.npy")
fig, ax = plt.subplots(1,1, figsize=(5,3))

ax.imshow(input[0,0,variable], cmap="gnuplot")
ax.set_yticklabels([])
ax.set_xticklabels([])

fig.tight_layout()
plt.savefig("input_t2m.png")

with tarfile.open("output.tar") as tar:
    for i, member in enumerate(tar.getmembers()):
        arr_file = io.BytesIO()
        arr_file.write(tar.extractfile(member).read())
        arr_file.seek(0)
        data = np.load(arr_file)

        plt.close("all")
        fig, ax = plt.subplots(1,1, figsize=(5,3))
        ax.set_title(f"Sample: {i}")
        ax.imshow(data[0,0,variable], origin='lower', cmap="gnuplot")
        ax.set_yticklabels([])
        ax.set_xticklabels([])
        fig.tight_layout()
        plt.savefig(f"output_t2m_sample_{i}.png")

Input surface temperature contour at ~25km resolution:

corrdiff input

Output surface temperature samples downscaled using CorrDiff to 3km resolution:

corrdiff output

corrdiff output