Framework-specific API¶
- pyTorch
Linear
GroupedLinear
LayerNorm
RMSNorm
LayerNormLinear
LayerNormMLP
DotProductAttention
MultiheadAttention
TransformerLayer
InferenceParams
CudaRNGStatesTracker
fp8_autocast()
fp8_model_init()
checkpoint()
onnx_export()
make_graphed_callables()
get_cpu_offload_context()
moe_permute()
moe_unpermute()
- Jax
- paddle