pyTorch Plugin API reference¶
-
class
nvidia.dali.plugin.pytorch.
DALIClassificationIterator
(pipelines, size)¶ DALI iterator for classification tasks for pyTorch. It returns 2 outputs (data and label) in the form of pyTorch’s Tensor.
Calling
DALIClassificationIterator(pipelines, size)
is equivalent to calling
DALIGenericIterator(pipelines, ["data", "label"], size)
Parameters: - pipelines (list of nvidia.dali.pipeline.Pipeline) – List of pipelines to use
- size (int) – Epoch size.
-
class
nvidia.dali.plugin.pytorch.
DALIGenericIterator
(pipelines, output_map, size)¶ General DALI iterator for pyTorch. It can return any number of outputs from the DALI pipeline in the form of pyTorch’s Tensors.
Parameters: - pipelines (list of nvidia.dali.pipeline.Pipeline) – List of pipelines to use
- output_map (list of str) – List of strings (either “data” or “label”) which maps the output of DALI pipeline to proper type of tensor
- size (int) – Epoch size.
-
next
()¶ Returns the next batch of data.
-
reset
()¶ Resets the iterator after the full epoch. DALI iterators do not support resetting before the end of the epoch and will ignore such request.
-
nvidia.dali.plugin.pytorch.
feed_ndarray
(dali_tensor, arr)¶ Copy contents of DALI tensor to pyTorch’s Tensor.
Parameters: - dali_tensor (nvidia.dali.backend.TensorCPU or nvidia.dali.backend.TensorGPU) – Tensor from which to copy
- arr (torch.Tensor) – Destination of the copy