core.full_cuda_graph#

Full iteration CUDA graph for training.

Module Contents#

Classes#

StaticBufferLoader

Load data to static buffers.

FullCudaGraphWrapper

Wrapper class to enable FullIterationCUDAgraph.

Functions#

get_shared_capture_stream

Return one torch.cuda.Stream for all full-iter and optimizer graph captures.

get_shared_graph_pool

Return a process-wide handle so all call sites share one graph memory pool.

get_graph_pool

Return graph pool handle for full-iter/optimizer graph capture.

copy_tensors_in_struct

Copy src to new tensors.

clone_tensors_in_struct

Copy src to pre-existing tensors in tgt.

Data#

API#

core.full_cuda_graph.logger#

‘getLogger(…)’

core.full_cuda_graph._shared_graph_pool#

None

core.full_cuda_graph._shared_capture_stream#

None

core.full_cuda_graph.get_shared_capture_stream()#

Return one torch.cuda.Stream for all full-iter and optimizer graph captures.

Call after the target CUDA device is selected.

core.full_cuda_graph.get_shared_graph_pool()#

Return a process-wide handle so all call sites share one graph memory pool.

torch.cuda.graph_pool_handle() returns a new pool each time; this lazy singleton ensures e.g. full-iteration and optimizer captures reuse the same pool.

core.full_cuda_graph.get_graph_pool(use_single_mempool)#

Return graph pool handle for full-iter/optimizer graph capture.

When use_single_mempool is True, train/eval and optimizer captures reuse one process-wide pool. Otherwise, each capture call gets a new pool handle.

core.full_cuda_graph.copy_tensors_in_struct(src)#

Copy src to new tensors.

core.full_cuda_graph.clone_tensors_in_struct(tgt, src)#

Copy src to pre-existing tensors in tgt.

class core.full_cuda_graph.StaticBufferLoader#

Load data to static buffers.

Initialization

static_buffers: dict#

None

__call__(inputs, stage, microbatch)#
class core.full_cuda_graph.FullCudaGraphWrapper(
forward_backward_func,
cuda_graph_warmup_steps=1,
use_single_mempool=False,
)#

Wrapper class to enable FullIterationCUDAgraph.

Initialization

curr_iteration#

None

cuda_graph#

None

result#

None

data_read(data_iterator, model, training, num_microbatches)#

Read all microbatch inputs from Dataloader and copy to static buffers.

__call__(*args, **kwargs)#
curr_iter(stage)#

Return current training/validation iteration.

next_iter(stage)#

Increment current training/validation iteration.

reset_cuda_graph(stage=None)#

Reset CUDA graph.