# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Defines helper Plotter class for adding plots to tensorboard summaries
"""
import numpy as np
import scipy
import matplotlib.pyplot as plt
from typing import Dict
class _Plotter:
def __call__(self, *args):
raise NotImplementedError
def _add_figures(self, group, name, results_dir, writer, step, *args):
"Try to make plots and write them to tensorboard summary"
# catch exceptions on (possibly user-defined) __call__
try:
fs = self(*args)
except Exception as e:
print(f"error: {self}.__call__ raised an exception:", str(e))
else:
for f, tag in fs:
f.savefig(
results_dir + name + "_" + tag + ".png",
bbox_inches="tight",
pad_inches=0.1,
)
writer.add_figure(group + "/" + name + "/" + tag, f, step, close=True)
plt.close("all")
def _interpolate_2D(self, size, invar, *outvars):
"Interpolate 2D outvar solutions onto a regular mesh"
assert len(invar) == 2
# define regular mesh to interpolate onto
xs = [invar[k][:, 0] for k in invar]
extent = (xs[0].min(), xs[0].max(), xs[1].min(), xs[1].max())
xyi = np.meshgrid(
np.linspace(extent[0], extent[1], size),
np.linspace(extent[2], extent[3], size),
indexing="ij",
)
# interpolate outvars onto mesh
outvars_interp = []
for outvar in outvars:
outvar_interp = {}
for k in outvar:
outvar_interp[k] = scipy.interpolate.griddata(
(xs[0], xs[1]), outvar[k][:, 0], tuple(xyi)
)
outvars_interp.append(outvar_interp)
return [extent] + outvars_interp
[docs]class ValidatorPlotter(_Plotter):
"Default plotter class for validator"
def __call__(self, invar, true_outvar, pred_outvar):
"Default function for plotting validator data"
ndim = len(invar)
if ndim > 2:
print("Default plotter can only handle <=2 input dimensions, passing")
return []
# interpolate 2D data onto grid
if ndim == 2:
extent, true_outvar, pred_outvar = self._interpolate_2D(
100, invar, true_outvar, pred_outvar
)
# make plots
dims = list(invar.keys())
fs = []
for k in pred_outvar:
f = plt.figure(figsize=(3 * 5, 4), dpi=100)
for i, (o, tag) in enumerate(
zip(
[true_outvar[k], pred_outvar[k], true_outvar[k] - pred_outvar[k]],
["true", "pred", "diff"],
)
):
plt.subplot(1, 3, 1 + i)
if ndim == 1:
plt.plot(invar[dims[0]][:, 0], o[:, 0])
plt.xlabel(dims[0])
elif ndim == 2:
plt.imshow(o.T, origin="lower", extent=extent)
plt.xlabel(dims[0])
plt.ylabel(dims[1])
plt.colorbar()
plt.title(f"{k}_{tag}")
plt.tight_layout()
fs.append((f, k))
return fs
[docs]class InferencerPlotter(_Plotter):
"Default plotter class for inferencer"
def __call__(self, invar, outvar):
"Default function for plotting inferencer data"
ndim = len(invar)
if ndim > 2:
print("Default plotter can only handle <=2 input dimensions, passing")
return []
# interpolate 2D data onto grid
if ndim == 2:
extent, outvar = self._interpolate_2D(100, invar, outvar)
# make plots
dims = list(invar.keys())
fs = []
for k in outvar:
f = plt.figure(figsize=(5, 4), dpi=100)
if ndim == 1:
plt.plot(invar[dims[0]][:, 0], outvar[:, 0])
plt.xlabel(dims[0])
elif ndim == 2:
plt.imshow(outvar[k].T, origin="lower", extent=extent)
plt.xlabel(dims[0])
plt.ylabel(dims[1])
plt.colorbar()
plt.title(k)
plt.tight_layout()
fs.append((f, k))
return fs
[docs]class GridValidatorPlotter(_Plotter):
"""Grid validation plotter for structured data"""
def __init__(self, n_examples: int = 1):
self.n_examples = n_examples
def __call__(
self,
invar: Dict[str, np.array],
true_outvar: Dict[str, np.array],
pred_outvar: Dict[str, np.array],
):
ndim = next(iter(invar.values())).ndim - 2
if ndim > 3:
print("Default plotter can only handle <=3 input dimensions, passing")
return []
# get difference
diff_outvar = {}
for k, v in true_outvar.items():
diff_outvar[k] = true_outvar[k] - pred_outvar[k]
fs = []
for ie in range(self.n_examples):
f = self._make_plot(ndim, ie, invar, true_outvar, pred_outvar, diff_outvar)
fs.append((f, f"prediction_{ie}"))
return fs
def _make_plot(self, ndim, ie, invar, true_outvar, pred_outvar, diff_outvar):
# make plot
nrows = max(len(invar), len(true_outvar))
f = plt.figure(figsize=(4 * 5, nrows * 4), dpi=100)
for ic, (d, tag) in enumerate(
zip(
[invar, true_outvar, pred_outvar, diff_outvar],
["in", "true", "pred", "diff"],
)
):
for ir, k in enumerate(d):
plt.subplot2grid((nrows, 4), (ir, ic))
if ndim == 1:
plt.plot(d[k][ie, 0, :])
elif ndim == 2:
plt.imshow(d[k][ie, 0, :, :].T, origin="lower")
else:
z = d[k].shape[-1] // 2 # Z slice
plt.imshow(d[k][ie, 0, :, :, z].T, origin="lower")
plt.title(f"{k}_{tag}")
plt.colorbar()
plt.tight_layout()
return f
[docs]class DeepONetValidatorPlotter(_Plotter):
"""DeepONet validation plotter for structured data"""
def __init__(self, n_examples: int = 1):
self.n_examples = n_examples
def __call__(
self,
invar: Dict[str, np.array],
true_outvar: Dict[str, np.array],
pred_outvar: Dict[str, np.array],
):
ndim = next(iter(invar.values())).shape[-1]
if ndim > 3:
print("Default plotter can only handle <=2 input dimensions, passing")
return []
# get difference
diff_outvar = {}
for k, v in true_outvar.items():
diff_outvar[k] = true_outvar[k] - pred_outvar[k]
fs = []
for ie in range(self.n_examples):
f = self._make_plot(ndim, ie, invar, true_outvar, pred_outvar, diff_outvar)
fs.append((f, f"prediction_{ie}"))
return fs
def _make_plot(self, ndim, ie, invar, true_outvar, pred_outvar, diff_outvar):
# make plot
# invar: input of trunk net. Dim: N*P*ndim
# outvar: output of DeepONet. Dim: N*P
nrows = max(len(invar), len(true_outvar))
f = plt.figure(figsize=(4 * 5, nrows * 4), dpi=100)
invar_data = next(iter(invar.values()))
for ic, (d, tag) in enumerate(
zip(
[true_outvar, pred_outvar, diff_outvar],
["true", "pred", "diff"],
)
):
for ir, k in enumerate(d):
plt.subplot2grid((nrows, 4), (ir, ic))
if ndim == 1:
plt.plot(invar_data[ie, :].flatten(), d[k][ie, :])
elif ndim == 2:
plt.scatter(
x=invar_data[ie, :, 0],
y=invar_data[ie, :, 1],
c=d[k][ie, :],
s=0.5,
origin="lower",
cmap="jet",
)
plt.colorbar()
plt.title(f"{k}_{tag}")
plt.tight_layout()
return f