bridge.diffusion.conversion.flux.flux_bridge#

Module Contents#

Classes#

FluxBridge

Megatron Bridge for FLUX model.

SplitRowParallelMapping

API#

class bridge.diffusion.conversion.flux.flux_bridge.FluxBridge#

Bases: megatron.bridge.models.conversion.model_bridge.MegatronModelBridge

Megatron Bridge for FLUX model.

As a user you would not use this bridge directly, but through AutoBridge.

.. rubric:: Example

from megatron.bridge import AutoBridge bridge = AutoBridge.from_hf_pretrained(“black-forest-labs/FLUX.1-dev”) provider = bridge.to_megatron_provider()

provider_bridge(
hf_pretrained: megatron.bridge.diffusion.conversion.flux.flux_hf_pretrained.PreTrainedFlux,
) megatron.bridge.diffusion.models.flux.flux_provider.FluxProvider#
maybe_modify_loaded_hf_weight(
hf_param: str | dict[str, str],
hf_state_dict: Mapping[str, torch.Tensor],
) torch.Tensor#

Load weights from HuggingFace state dict. This function can be overridden by subclasses to preprocess the HF weights before conversion, such as renaming certain parameters to avoid mapping conflicts, or dequantize the weights.

Note that loading is done lazily before this function is called, so the weights are actually loaded in this function when hf_state_dict.getitem is called.

Parameters:
  • hf_param – The parameter name or dictionary of parameter names to load.

  • hf_state_dict – The HuggingFace state dictionary.

Returns:

The loaded weights.

maybe_modify_converted_hf_weight(
task: megatron.bridge.models.conversion.model_bridge.WeightConversionTask,
converted_weights_dict: Dict[str, torch.Tensor],
hf_state_dict: Mapping[str, torch.Tensor],
) Dict[str, torch.Tensor]#

Merge split proj_out weight_1 and weight_2 back into a single HF ‘weight’ for export.

On load we split HF proj_out.weight into weight_1 (linear_fc2) and weight_2 (linear_proj). On export we must merge them back as [weight_2, weight_1] along dim=1 to match HF format.

mapping_registry() megatron.bridge.models.conversion.mapping_registry.MegatronMappingRegistry#

Return MegatronMappingRegistry containing parameter mappings from HF to Megatron format.

Returns:

Registry of parameter mappings

Return type:

MegatronMappingRegistry

class bridge.diffusion.conversion.flux.flux_bridge.SplitRowParallelMapping(megatron_param: str, hf_param: str)#

Bases: megatron.bridge.models.conversion.param_mapping.RowParallelMapping

Initialization