ai4med.common package

class BuildContext

Bases: dlmed.common.ctx.BaseContext

BuildContext contains information generated during graph construction. Data contained in the BuildContext can be used for graph building components to share information.

Returns

A BuildContext object

KEY_DATA_PROP = '_data_property'
KEY_GRAPH = '_graph'
KEY_IS_TRAIN = '_is_train'
KEY_LABEL_INPUT = '_label_input'
KEY_LEARNING_RATE = '_learning_rate'
KEY_LOSS = '_loss'
KEY_MODEL_INPUT = '_model_input'
KEY_MODEL_LOSS = '_model_loss'
KEY_MODEL_OUTPUT = '_model_output'
KEY_MULTI_GPU = '_multi_gpu'
KEY_PRINT_TENSORS = '_print_tensors'
KEY_STEP_TENSORS = '_step_tensors'
KEY_SUMMARY_TENSORS = '_summary_tensors'
KEY_TF_CONFIG = '_tf_config'
KEY_TF_SESSION = '_tf_session'
KEY_TRAIN_DATA_SOURCE = '_train_data_source'
KEY_UPDATE_OPS = '_update_ops'
KEY_VALIDATION_DATA_SOURCE = '_validation_data_source'
KEY_VALIDATION_OUTPUT_TENSORS = '_validation_output_tensors'
PROP_DUMP_DATA_ON_TRANSFORM_ERROR = '_dump_data_on_trans_err'
PROP_TRAINING_IMPOSSIBLE = '_training_impossible'
get(key: str, default=None)

Fetch the item by the specified key. If the item does not exist, returns the specified default value.

Parameters
  • key (str) – key of the item to be fetched

  • default – default value to be returned if the item does not exist

Returns: value of the item, or the default value

static is_reserved(key: str)

Check whether the specified key is reserved. In general, key names starting with the underscore are reserved for framework use. Custom graph building components should avoid using reserved names.

Parameters

key (str) – name of the key

Returns: whether the key is reserved.

must_get(key)

Fetch the item by the specified key. If the item does not exist, exception will be thrown.

Parameters

key – key of the item to be fetched

Returns: value of the item

Raises: KeyError exception if the specified item does not exist

must_set(key: str, value)

Forcefully set the value of the item with specified key, regardless whether the key is reserved or not.

Parameters
  • key (str) – key of the item to be set

  • value – value of the item to be set

set(key: str, value, check_conflict=False)

Tries to set the item value of the specified key. If the key is not a string or the key is reserved, exception will be thrown.

Parameters
  • key (str) – key of the item to be set

  • value – value of the item to be set

  • check_conflict – whether to check key conflict

Raises: AssertionError

set_print_tensor(key: str, value)
set_summary_tensor(key: str, value)
set_val_output_tensor(key: str, value)

This module defines commonly used constants.

class ActivationFunc

Bases: object

Commonly used activation function names.

LINEAR = 'linear'
SIGMOID = 'sigmoid'
SOFTMAX = 'softmax'
class DataElementKey

Bases: object

Data Element keys

ID = 'ID'
IMAGE = 'image'
IMAGE_SHAPE_FORMAT = 'image.shape_format'
LABEL = 'label'
WEIGHT = 'weight'
class DataListKey

Bases: object

Data List keys

TEST = 'test'
TRAIN = 'train'
VALIDATION = 'validation'
class DataPropKey

Bases: object

Data Property keys

CHANNELS = 'channels'
CLASS_WEIGHTS = 'class_weights'
DATA_FORMAT = 'data_format'
LABEL_CHANNELS = 'label_channels'
LABEL_FORMAT = 'label_format'
TASK = 'task'
class DataSampling

Bases: object

Names of common data sampling types

AUTOMATIC = 'automatic'
CLASS = 'class'
ELEMENT = 'element'
class FormalTensorNames

Bases: object

Globally unique formal names of certain tensors that may need to be identified in a different context (e.g. inference)

INPUT_LABEL = 'NV_LABEL_INPUT'
IS_TRAINING = 'NV_IS_TRAINING'
LOSS = 'NV_LOSS'
LR = 'NV_LEARNING_RATE'
MODEL_INPUT = 'NV_MODEL_INPUT'
MODEL_OUTPUT = 'NV_MODEL_OUTPUT'
OP_MINIMIZE = 'NV_OP_MINIMIZE'
class ImageOps

Bases: object

Names of common image interpolations

INTERPOLATION_LINEAR = 'linear'
INTERPOLATION_NEAREST = 'nearest'
class ImageProperty

Bases: object

Key names for image properties.

AFFINE = 'affine'
AS_CANONICAL = 'as_canonical'
BACKGROUND_INDEX = 'background_index'
DATA = 'data'
DIRECTION = 'direction'
FILENAME = 'file_name'
FILE_EXT = 'file_ext'
FORMAT = 'file_format'
NIFTI_FORMAT = 'nii'
ORIGIN = 'origin'
ORIGINAL_AFFINE = 'original_affine'
ORIGINAL_SHAPE = 'original_shape'
ORIGINAL_SHAPE_FORMAT = 'original_shape_format'
SHAPE_FORMAT = 'shape_format'
SPACING = 'spacing'
class Task

Bases: object

Defines DL Task.

CLASSIFICATION = 'classification'
SEGMENTATION = 'segmentation'
VALID_TASKS = ['segmentation', 'classification']
static is_classification(task: str)

Check whether the task is classification

Parameters

task (str) – name of the task

Returns: whether it is classification task

static is_segmentation(task: str)

Check whether the task is segmentation :param task: name of the task :type task: str

Returns: whether it is segmentation task

static is_valid_task(task: str)

Check whether the specified task is valid

Parameters

task (str) – task name to be checked

Returns: boolean, whether the specified task is valid

class ValidationInputDataKeys

Bases: object

Data keys for validation input dict.

IMAGE = 'image'
IS_TRAINING = 'is_training'
class ValidationOutputDataKeys

Bases: object

Data keys for validation output dict

LABEL = 'label'
MODEL = 'model'
class DataFormat

Bases: object

Common data formats and convenience functions.

CHANNELS_FIRST = 'channels_first'
CHANNELS_LAST = 'channels_last'
GRAYSCALE = 'grayscale'
VALID_FORMATS = ['channels_first', 'channels_last', 'grayscale']
static is_channels_first(fmt: str)

Check whether the specified format is channels first.

Parameters

fmt (str) – data format to be checked

Returns: whether the specified format is channels first

static is_channels_last(fmt: str)

Check whether the specified format is channels last.

Parameters

fmt (str) – data format to be checked

Returns: whether the specified format is channels last

static is_grayscale(fmt: str)

Check whether the specified format is gray scale.

Parameters

fmt (str) – data format to be checked

Returns: whether the specified format is gray scale.

static is_valid_format(fmt: str)

Check whether the specified format is valid

Parameters

fmt (str) – data format to be checked

Returns: whether the specified format is valid

class DataProperty

Bases: object

DataProperty specifies the properties of training data.

Returns

A DataProperty object

determine_image_shape(dynamic_input_shape=False)
determine_label_shape()
get_crop_size()
get_data_format()
get_label_format()
get_number_of_channels()

Get number of channels of the sample data

Returns: number of channels of the sample data

get_number_of_data_dims()

Get number of dims

Returns: number of dims of sample data

get_number_of_label_channels()

Get number of channels of the label data

Returns: number of channels or None

get_task()

Get the value of task

Returns: value of task

is_channels_first()
is_classification_task()

Check whether the task is classification

Returns: whether the task is classification

is_segmentation_task()

Check whether the task is segmentation

Returns: whether the task is segmentation

set_crop_size(crop_size)
set_data_format(data_format)
set_label_format(label_format_vector)
set_number_of_channels(num: int)

Set number of channels of sample data

Parameters

num (int) – number of channels

set_number_of_data_dims(n: int)

Set data dimensions

Parameters

n (int) – number of data dims

set_number_of_label_channels(num: int)

Set number of channels of the label data

Parameters

num (int) – number of channels

set_task(task: str)

Set the training task.

Parameters

task (str) – value of the task

to_dict()
class GraphComponent

Bases: abc.ABC

This class defines the interface of Graph Building Components

abstract build(build_ctx: ai4med.common.build_ctx.BuildContext)

A graph building component must implement this method to build parts of the computation graph. While building, this method can use any objects in the build_ctx and put objects into the build_ctx. Therefore, different graph building components can use the build_ctx to share objects.

Parameters

build_ctx (BuildContext) – the build context.

Returns: tensor(s) built

This module defines label format related classes and convenience functions.

class LabelFormatInfo(total_num_labels: int, num_binary_labels: int, multi_class_labels: list, multi_class_indices, binary_indices: list)

Bases: object

Defines information about a label format:

Parameters
  • total_num_labels (int) – total number of labels

  • num_binary_labels (int) – number of binary labels

  • multi_class_labels (list) – List of multi-class labels.

  • multi_class_indices – vector containing the indices of multi-class labels

  • binary_indices (list) – vector containing the indices of binary labels

Returns

A LabelFormatInfo object

validate_label_format_vector(label_format_vector)

Validates the specified label format vector.

Parameters

label_format_vector – the label format to be validated

Returns: error message if the format is invalid; otherwise None

class MedicalImage(data: numpy.ndarray, shape_format: ai4med.common.shape_format.ShapeFormat, props=None)

Bases: object

Defines a standard structure of Medical Image.

Parameters
  • data – image data in numpy array

  • shape_format – shape format of the data

  • props – a dict of image properties

Returns

A MedicalImage object

static from_dict(src_values: dict, data_key='data')

Creates a MedicalImage object from data in a dict.

Parameters
  • src_values (dict) – dict that contains source values

  • data_key (str) – key for the image data

Returns: a MedicalImage object

Raises: AssertionError if image data or shape format does not exist in the source value dict, or data and shape format are inconsistent.

get_batch_size()

Get batch size of the image

Returns: batch size, or 0 if the image is not batched.

get_data()

Get the image data

Returns: image data

get_number_of_channels()

Get the number of channels of the image

Returns: number of channels, or 0 if no channel info.

get_properties()

Get all properties of the medical image

Returns: dict of image properties

get_property(key: str, default=None)

Get value of the specified property

Parameters
  • key (str) – key of the property

  • default – default value if the property does not exist.

Returns: the value of the property, or the default value if property does not exist.

get_shape_format()

Get the shape format

Returns: shape format of the image

get_spatial_shape()

Get the spatial shape of the medical image

Returns: a tuple representing the spatial shape of the image

new_image(data, shape_format)

Create a new image from the specified data and shape format, and copy all properties from this image.

Parameters
  • data – data for the new image

  • shape_format – shape format for the new image

Returns: a MedicalImage object

Raises: AssertionError if the specified data and shape format are invalid or inconsistent

remove_property(key: str)

Remove specified property from the medical image. Note: this method does nothing if the specified property does not exist.

Parameters

key (str) – key of the property to be removed

Returns:

set_data(data: numpy.ndarray, shape_format: ai4med.common.shape_format.ShapeFormat)

Set data and shape format of the medical image

Parameters
  • data – data to be set

  • shape_format – shape format to be set

Returns:

Raises: AssertionError if data and shape format are invalid or inconsistent

set_property(key: str, value)

Set an image property

Parameters
  • key (str) – key of the property

  • value – value of the property

Returns:

to_dict(data_key='data')

Create a dict for the image. All properties will be included into the resultant dict. Image data entry will use the specified data_key; shape format entry will use the predefined property key.

Parameters

data_key (str) – the dict key to be used for the image data

Returns: dict that contains image data, shape format, and all properties

static validate_data_and_shape_format(data, shape_format)

Validate specified data and shape format to determine they are consistent.

Parameters
  • data – the image data

  • shape_format – shape format

Raises: AssertionError if data and shape format are invalid or inconsistent

class MetricContext(data_prop: ai4med.common.data_prop.DataProperty, src_dict=None)

Bases: dict

Defines processing context for metric computation. It is a dict that contains computed and transformed data during training. Metric components perform metric computation based on items in the metric context.

Parameters
  • data_prop (DataProperty) – the data property of sample and label data

  • src_dict (dict) – source values for the metric context

Returns: a MetricContext object

dump(prefix='\t')

Prints the content of the metric context

Parameters

prefix (str) – prefix to each line of output

Returns:

get_data_property()

Get the data property of the sample and label data

Returns: a DataProperty object, or None if the data property is not available

must_get_data_property()

Try to get the data property of the sample and label data.

Returns: a DataProperty object

Raises: KeyError if the data property is not available

class PlaceholderSpec(name: str, shape, source_dtype: str, target_dtype: str, data_key: str, default_value=None)

Bases: object

Defines placeholder specification for custom inputs to the training graph.

Parameters
  • name (str) – the name of the input. The name is used as the key to the placeholder tensor in the Build Context.

  • shape (tuple, list) – spatial shape of the input

  • source_dtype (str) – source data type. Source is where the data comes from (e.g. str for file name)

  • target_dtype (str) – target data type. Target is the result after transformation (e.g. float32)

  • data_key (str) – the tag that identifies the data in Transform Context

  • default_value – if not None, the default value of the input. It must be of a scalar type or np.ndarray.

class ShapeFormat(fmt)

Bases: str

ShapeFormat defines meanings for the data in a MedicalImage. Image data is a numpy’s ndarray. Without shape format, it is impossible to know what each dimension means.

NOTE: ShapeFormat objects are immutable.

ALL_FORMATS = ['DHW', 'DHWC', 'CDHW', 'NDHW', 'NDHWC', 'NCDHW', 'HW', 'HWC', 'CHW', 'NHW', 'NHWC', 'NCHW']
BATCHED_FORMATS = ['NHW', 'NHWC', 'NCHW', 'NDHW', 'NDHWC', 'NCDHW']
CDHW = 'CDHW'
CHANNELS_FIRST_FORMATS = ['CDHW', 'NCDHW', 'CHW', 'NCHW']
CHANNELS_LAST_FORMATS = ['DHWC', 'NDHWC', 'HWC', 'NHWC']
CHW = 'CHW'
DHW = 'DHW'
DHWC = 'DHWC'
GRAYSCALE_FORMATS = ['DHW', 'NDHW', 'HW', 'NHW']
HW = 'HW'
HWC = 'HWC'
NCDHW = 'NCDHW'
NCHW = 'NCHW'
NDHW = 'NDHW'
NDHWC = 'NDHWC'
NHW = 'NHW'
NHWC = 'NHWC'
THREE_D_FORMATS = ['DHW', 'DHWC', 'CDHW', 'NDHW', 'NDHWC', 'NCDHW']
TWO_D_FORMATS = ['HW', 'HWC', 'CHW', 'NHW', 'NHWC', 'NCHW']
get_channel_axis()

Get the channel axis number

Returns: the channel axis if the format is channeled, or None if not.

get_data_format()

Determines the data format (channels first/last, gray scale) of the shape format

Returns: a data format

get_number_of_dims()

Get the number of dimensions

Returns: number of dimensions

get_number_of_spatial_dims()

Get the number of spatial dimensions

Returns: number of spatial dimensions (2 or 3).

get_spatial_axis()

Get the start and end of spatial axis

Returns: start, end of spatial axes

is_2d()

Determines whether the format is two-dimensional.

Returns: bool, whether the format is two-dimensional

is_3d()

Determines whether the format is three-dimensional.

Returns: bool, whether the format is three-dimensional

is_batched()

Determines whether the format is batched

Returns: bool, whether the format is batched

is_channeled()

Determines whether the format is channeled

Returns: bool, whether the format is channeled

is_channels_first()

Determines whether the format is channels first.

Returns: bool, whether the format is channels first

is_channels_last()

Determines whether the format is channels last.

Returns: bool, whether the format is channels last

is_grayscale()

Determines whether the format is gray scale.

Returns: bool, whether the format is gray scale

static is_valid_format(fmt)

Checks whether the specified format is valid.

Parameters

fmt – the format to be checked

Returns: bool, whether the format is valid

to_batched()

Create a batched shape format.

Returns: a batched ShapeFormat

class StdShapeFormat

Bases: object

Defines standard shape formats supported in this framework.

CDHW = 'CDHW'
CHW = 'CHW'
DHW = 'DHW'
DHWC = 'DHWC'
HW = 'HW'
HWC = 'HWC'
NCDHW = 'NCDHW'
NCHW = 'NCHW'
NDHW = 'NDHW'
NDHWC = 'NDHWC'
NHW = 'NHW'
NHWC = 'NHWC'
get_format(data_format: str, num_data_dims: int, batched: bool)

Return a shape format based on the specified data format, spatial data dims, and whether the format should be batched or not.

Parameters
  • data_format (str) – data format of the shape format

  • num_data_dims (int) – number of spatial data dims

  • batched (bool) – whether the shape format should be batched

Returns: a shape format

Raise: AssertionError if any of the specified args is invalid

class TrainContext(task='segmentation')

Bases: dlmed.common.ctx.BaseContext

TrainContext contains contextual data generated during fitting, which can be used by other modules to perform their functions (e.g. adaptively adjust learning rate, logging training stats, etc.). Attributes can be accessed from functions taking the parameter train_ctx, for example, the current_epoch can be obtained with train_ctx.current_epoch.

Many of the attributes are used internally by the Fitter, but ai4med.components.handlers.stats_handler.StatsHandler provides a good example of how other attributes are used functions.

my_rank
num_devices
task
multi_gpu
Type

bool

running_in_ngc
graph
tf_config
session
model_log_dir
summary_writer
train_start_time
best_validation_metric
best_validation_epoch
build_ctx
stop_immediately
Type

bool

training_impossible
Type

bool

epoch_range
current_epoch
epoch_start_time
epoch_end_time
iteration_start_time
iteration_end_time
iteration_output
total_epochs
total_iterations
current_iteration
run_interrupted
Type

bool

initial_learning_rate
current_learning_rate
current_train_loss
current_validation_metric
current_train_output
total_steps
current_step
step_start_time
step_end_time
train_prints_dict
train_summaries_dict
validation_prints_dict
validation_summaries_dict
validation_start_time
validation_end_time
validation_step_start_time
validation_step_end_time
global_round
epoch_of_start_time
iter_of_start_time
next_start_epoch
next_start_iter
total_executed_steps
total_executed_epochs
placeholder_values
fl_init_validation_metric

TrainContext is initialized with the task (segmentation or classification)

Parameters

task (str) – type of the task (segmentation or classification).

ask_to_stop_immediately()
get_train_stats()
is_stop_training_asked()

Check whether any module has asked to stop training. Fitter calls this method and determines whether training should be stopped.

Returns: boolean, whether stopping training has been asked.

set_placeholder_value(name: str, value)
class EventType

Bases: object

END_EPOCH = 'endEpoch'
END_ITER = 'endIter'
END_TRAIN = 'endTrain'
END_VAL = 'endVal'
END_VAL_ITER = 'endValIter'
KEY_METRIC_COMPUTED = 'keyMetricComputed'
NEW_BEST_MODEL = 'newBestModel'
START_EPOCH = 'startEpoch'
START_ITER = 'startIter'
START_TRAIN = 'startTrain'
START_VAL = 'startVal'
START_VAL_ITER = 'startValIter'
class TrainHandler

Bases: object

end_epoch(ctx: ai4med.common.train_ctx.TrainContext)
end_iteration(ctx: ai4med.common.train_ctx.TrainContext)
end_train(ctx: ai4med.common.train_ctx.TrainContext)
end_val_iteration(ctx: ai4med.common.train_ctx.TrainContext)
end_val_metric(ctx: ai4med.common.train_ctx.TrainContext)
end_validation(ctx: ai4med.common.train_ctx.TrainContext)
handle_event(event: str, ctx: ai4med.common.train_ctx.TrainContext)

The default event handler that calls predefined method for each event type.

Parameters
  • event – event to be handled

  • ctx – context for cross-component data sharing

Returns:

key_metric_computed(ctx: ai4med.common.train_ctx.TrainContext)
new_best_model(ctx: ai4med.common.train_ctx.TrainContext)
start_epoch(ctx: ai4med.common.train_ctx.TrainContext)
start_iteration(ctx: ai4med.common.train_ctx.TrainContext)
start_train(ctx: ai4med.common.train_ctx.TrainContext)
start_val_iteration(ctx: ai4med.common.train_ctx.TrainContext)
start_val_metric(ctx: ai4med.common.train_ctx.TrainContext)
start_validation(ctx: ai4med.common.train_ctx.TrainContext)
fire_event(event: str, handlers: list, ctx: ai4med.common.train_ctx.TrainContext)

Fires the specified event and invokes the list of handlers.

Parameters
  • event – the event to be fired

  • handlers – handlers to be invoked

  • ctx – context for cross-component data sharing

Returns:

read_json(file_name: str)
write_json(output_file_name: str, stats: dict)
class TransformContext(src_dict=None)

Bases: dict

Defines the context for data transformations. During training, data items are processed by a chain of transformations, which turn the data items into the form that can be used for training computation. The transforms are called one by one in the order they are defined in the chain. The TransformContext holds the current values for the data items to be processed.

Each item in the transform context is called a field, which has key and value. A field could have 0 or more sub-fields, which specify the properties of the field.

Parameters

src_dict (dict) – the dict that provides initial data items for the transform context

Returns: a TransformContext object

dump(prefix='\t')

Prints the content of the transform context.

Parameters

prefix (str) – prefix string for each line of output

get_image(field: str, allow_wildcard_shape_format=True)

Get the value of the specified field as a MedicalImage object. This is a convenience method that collects all sub-fields of the field and then make a MedicalImage object from these values.

Parameters
  • field (str) – name of the field

  • allow_wildcard_shape_format (bool) – whether to use the wildcard shape format as the

  • shape format (image's) –

  • case the image's shape format is not explicitly defined. (in) –

Returns: a MedicalImage object

Raises: AssertionError if no field data is found, or shape format cannot be determined, or image data and shape format are inconsistent.

get_label_format()

Get the label format from the context

Returns: the label format

Raises: KeyError if the label format is not available

get_subfield(field: str, subfield: str, default=None)

Get the value of the specified sub-field

Parameters
  • field (str) – name of the field

  • subfield (str) – name of the sub-field

  • default – default value to be returned if the sub-field does not exist

Returns: value of the sub-field, or default value if the sub-field does not exist

Raises: AssertionError if any of the arg values is invalid

get_whole_field(field: str)

Get values of field and all sub-fields

Parameters

field (str) – name of the field

Returns

value of the field, sub_fields: values of sub-fields as a dict

Return type

field_value

get_wildcard_subfield(subfield: str, default=None)

Get the value of a wildcard sub-field.

Parameters
  • subfield (str) – name of the sub-field

  • default – value to return if the specified sub-field does not exist

Returns: value of the sub-field or specified default if the sub-field does not exist

must_get_subfield(field: str, subfield: str)

Get the value of the specified sub-field, and raises exception if the sub-field does not exist.

Parameters
  • field (str) – name of the field

  • subfield (str) – name of the sub-field

Returns: value of the sub-field

Raises: AssertionError if any of the arg values is invalid; KeyError if the specified field or sub-field does not exist.

set_image(field: str, img: ai4med.common.medical_image.MedicalImage)

Set the specified MedicalImage object into the context. Note: this is a convenience method that sets the data as the value of the field, and sets shape format and all other image properties as sub-fields of the field.

Parameters
  • field (str) – name of the field

  • img (MedicalImage) – value of the field

Raises: AssertionError is any of the arg values is invalid

set_label_format(label_format)

Set the value of label format (for classification) into the transform context.

Parameters

label_format (list, tuple) – the label format vector

set_subfield(field: str, subfield: str, value)

Set the value of a field’s sub-field

Parameters
  • field (str) – name of the field

  • subfield (str) – name of the sub-field

  • value – value of the sub-field

Raises: AssertionError if any of the arg values is invalid

set_wildcard_subfield(subfield: str, value)

Set a wildcard sub-field. A wildcard sub-field does not belong to any specific field. However it could be used as the default value for any field when appropriate.

Parameters
  • subfield (str) – name of the sub-field

  • value – value of the sub-field

setdefault(k[, d]) → D.get(k,d), also set D[k]=d if k not in D
update([E, ]**F) → None. Update D from dict/iterable E and F.

If E is present and has a .keys() method, then does: for k in E: D[k] = E[k] If E is present and lacks a .keys() method, then does: for k, v in E: D[k] = v In either case, this is followed by: for k in F: D[k] = F[k]

© Copyright 2020, NVIDIA. Last updated on Feb 2, 2023.