Framework-specific API
- pyTorch
LinearGroupedLinearLayerNormRMSNormLayerNormLinearLayerNormMLPDotProductAttentionMultiheadAttentionTransformerLayerCudaRNGStatesTrackerfp8_autocast()fp8_model_init()checkpoint()make_graphed_callables()get_cpu_offload_context()moe_permute()moe_permute_with_probs()moe_unpermute()moe_sort_chunks_by_index()moe_sort_chunks_by_index_with_probs()initialize_ub()destroy_ub()
- Jax