Source code for nemo_deploy.nlp.megatronllm_deployable_ray

#!/usr/bin/env python3
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os
import time
from typing import Any, Dict

import numpy as np
import ray
import torch
from fastapi import FastAPI, HTTPException
from ray import serve

from ..ray_utils import find_available_port
from .megatronllm_deployable import MegatronLLMDeployableNemo2

LOGGER = logging.getLogger("NeMo")

app = FastAPI()


[docs] @ray.remote(num_gpus=1) class ModelWorker: """Ray actor that loads and runs inference on a shard of the model. Each ModelWorker is responsible for a specific rank in the model parallel setup. """ def __init__( self, nemo_checkpoint_filepath: str, rank: int, world_size: int, tensor_model_parallel_size: int, pipeline_model_parallel_size: int, context_parallel_size: int, expert_model_parallel_size: int, master_port: str, replica_id: int = 0, enable_cuda_graphs: bool = False, enable_flash_decode: bool = False, legacy_ckpt: bool = False, ): # Use replica-specific environment variables to avoid conflicts os.environ["MASTER_PORT"] = master_port os.environ["MASTER_ADDR"] = ray._private.services.get_node_ip_address() os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size) os.environ["LOCAL_RANK"] = str(rank % torch.cuda.device_count()) # Set a unique process group name for each replica to avoid conflicts os.environ["TORCH_DISTRIBUTED_GROUP_NAME"] = f"replica_{replica_id}" # Use INFO level logging only for important initialization steps if rank == 0: # Only log from rank 0 to reduce noise LOGGER.info(f"Replica {replica_id} - Initializing workers for world_size={world_size}") LOGGER.info(f"Replica {replica_id} - MASTER_PORT: {os.environ['MASTER_PORT']}") LOGGER.info(f"Replica {replica_id} - MASTER_ADDR: {os.environ['MASTER_ADDR']}") try: self.model = MegatronLLMDeployableNemo2( nemo_checkpoint_filepath=nemo_checkpoint_filepath, num_devices=world_size, num_nodes=world_size // torch.cuda.device_count(), tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, expert_model_parallel_size=expert_model_parallel_size, context_parallel_size=context_parallel_size, enable_cuda_graphs=enable_cuda_graphs, enable_flash_decode=enable_flash_decode, legacy_ckpt=legacy_ckpt, ) if rank != 0: self.model.generate_other_ranks() except Exception as e: LOGGER.error(f"Replica {replica_id} - Failed to initialize model for rank {rank}: {str(e)}") raise
[docs] def infer(self, inputs: Dict[str, Any]) -> Dict[str, Any]: """Run inference on the model shard.""" return self.model.ray_infer_fn(inputs)
@serve.deployment( num_replicas=1, ray_actor_options={"num_cpus": 8}, max_ongoing_requests=32, ) @serve.ingress(app) class MegatronRayDeployable: """A Ray Serve deployment for distributed Megatron LLM models. This class coordinates model parallelism across multiple GPUs and nodes, with each shard handled by a separate Ray actor. """ def __init__( self, nemo_checkpoint_filepath: str, num_gpus: int = 1, num_nodes: int = 1, tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, context_parallel_size: int = 1, expert_model_parallel_size: int = 1, model_id: str = "nemo-model", enable_cuda_graphs: bool = False, enable_flash_decode: bool = False, legacy_ckpt: bool = False, ): """Initialize the distributed Megatron LLM model deployment. Args: nemo_checkpoint_filepath (str): Path to the .nemo checkpoint file. num_gpus (int): Number of GPUs to use per replica. num_nodes (int): Number of nodes to use for deployment. tensor_model_parallel_size (int): Size of tensor model parallelism. pipeline_model_parallel_size (int): Size of pipeline model parallelism. context_parallel_size (int): Size of context parallelism. model_id (str): Identifier for the model in API responses. enable_cuda_graphs (bool): Whether to enable CUDA graphs for faster inference. enable_flash_decode (bool): Whether to enable Flash Attention decode. max_batch_size (int): Maximum batch size for request batching. batch_wait_timeout_s (float): Maximum time to wait for batching requests. legacy_ckpt (bool): Whether to use legacy checkpoint format. Defaults to False. """ try: self.model_id = model_id world_size = num_gpus * num_nodes # Validate parallelism configuration total_parallel_size = tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size if total_parallel_size != world_size: raise ValueError( f"Total parallelism size ({total_parallel_size}) must equal total GPUs per replica ({world_size})" ) # Generate a unique replica ID based on the actor handle replica_id = abs(hash(str(self))) % 10000 # Pre-allocate master port to avoid race conditions between workers # Use replica-specific port to avoid conflicts between replicas base_port = 29500 + (replica_id % 100) * 100 master_port = str(find_available_port(base_port, ray._private.services.get_node_ip_address())) LOGGER.info(f"Replica {replica_id} - Pre-allocated master port: {master_port}") # Create workers with proper synchronization for distributed initialization # Rank 0 must be created first as it acts as the master in PyTorch distributed worker_futures = [] # Create rank 0 worker first rank_0_worker = ModelWorker.remote( nemo_checkpoint_filepath=nemo_checkpoint_filepath, rank=0, world_size=world_size, tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, context_parallel_size=context_parallel_size, expert_model_parallel_size=expert_model_parallel_size, master_port=master_port, replica_id=replica_id, enable_cuda_graphs=enable_cuda_graphs, enable_flash_decode=enable_flash_decode, legacy_ckpt=legacy_ckpt, ) worker_futures.append(rank_0_worker) # Wait for rank 0 to start before creating other workers # This ensures the master node is ready for distributed initialization LOGGER.info(f"Replica {replica_id} - Waiting for rank 0 to initialize...") time.sleep(1) # Give rank 0 time to start the distributed backend # Create remaining workers in parallel for rank in range(1, world_size): worker = ModelWorker.remote( nemo_checkpoint_filepath=nemo_checkpoint_filepath, rank=rank, world_size=world_size, tensor_model_parallel_size=tensor_model_parallel_size, pipeline_model_parallel_size=pipeline_model_parallel_size, context_parallel_size=context_parallel_size, expert_model_parallel_size=expert_model_parallel_size, master_port=master_port, replica_id=replica_id, enable_cuda_graphs=enable_cuda_graphs, enable_flash_decode=enable_flash_decode, ) worker_futures.append(worker) # Wait for all workers to be created and store them self.workers = worker_futures LOGGER.info(f"Replica {replica_id} - All {world_size} workers created successfully") # Primary worker for coordinating inference self.primary_worker = self.workers[0] LOGGER.info(f"Replica {replica_id} - Initialized {world_size} model workers across {num_nodes} nodes") except Exception as e: LOGGER.error(f"Error initializing distributed model deployment: {str(e)}") raise @app.post("/v1/completions/") async def completions(self, request: Dict[Any, Any]): """Handle text completion requests.""" try: if "prompt" in request: request["prompts"] = [request["prompt"]] temperature = request.get("temperature", 0.0) top_p = request.get("top_p", 0.0) if temperature == 0.0 and top_p == 0.0: LOGGER.warning("Both temperature and top_p are 0. Setting top_k to 1 to ensure greedy sampling.") request["top_k"] = 1.0 # Prepare inference inputs with proper parameter mapping inference_inputs = { "prompts": request.get("prompts", []), "max_length": request.get("max_tokens", 256), "temperature": request.get("temperature", 1.0), "top_k": request.get("top_k", 0), "top_p": request.get("top_p", 0.0), "compute_logprob": True if request.get("logprobs") == 1 else False, "apply_chat_template": False, } # Run tokenization and model inference in the thread pool results = ray.get(self.primary_worker.infer.remote(inference_inputs)) # Extract generated texts from results generated_texts = results.get("sentences", []) # Calculate token counts asynchronously prompt_tokens = sum(len(p.split()) for p in request.get("prompts", [])) completion_tokens = sum(len(r.split()) for r in generated_texts) total_tokens = prompt_tokens + completion_tokens # Convert numpy arrays to Python lists for JSON serialization log_probs_data = results.get("log_probs", None) if log_probs_data is not None and isinstance(log_probs_data, np.ndarray): log_probs_data = log_probs_data.tolist() output = { "id": f"cmpl-{int(time.time())}", "object": "text_completion", "created": int(time.time()), "model": self.model_id, "choices": [ { "text": " ".join(generated_texts), "index": 0, "logprobs": ( { "token_logprobs": log_probs_data, "top_logprobs": log_probs_data, } if log_probs_data is not None else None ), "finish_reason": ( "length" if generated_texts and len(generated_texts[0]) >= request.get("max_tokens", 256) else "stop" ), } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, }, } return output except Exception as e: LOGGER.error(f"Error during inference: {str(e)}") raise HTTPException(status_code=500, detail=f"Error during inference: {str(e)}") @app.post("/v1/chat/completions/") async def chat_completions(self, request: Dict[Any, Any]): """Handle chat completion requests.""" try: # Extract parameters from the request dictionary messages = request.get("messages", []) # Prepare inference parameters # For chat templates, we need to pass the entire messages list as a single prompt # so that apply_chat_template receives the full conversation context inference_inputs = { "prompts": [messages], # Wrap messages in a list so apply_chat_template gets the full conversation "max_length": request.get("max_tokens", 256), "temperature": request.get("temperature", 1.0), "top_k": request.get("top_k", 0), "top_p": request.get("top_p", 0.0), "compute_logprob": True if request.get("logprobs") == 1 else False, "apply_chat_template": request.get("apply_chat_template", True), } # Run model inference in the thread pool results = ray.get(self.primary_worker.infer.remote(inference_inputs)) # Extract generated texts from results generated_texts = results["sentences"] # Calculate token counts prompt_tokens = sum(len(str(msg).split()) for msg in messages) completion_tokens = sum(len(r.split()) for r in generated_texts) total_tokens = prompt_tokens + completion_tokens # Convert numpy arrays to Python lists for JSON serialization log_probs_data = results.get("log_probs", None) if log_probs_data is not None and isinstance(log_probs_data, np.ndarray): log_probs_data = log_probs_data.tolist() output = { "id": f"chatcmpl-{int(time.time())}", "object": "chat.completion", "created": int(time.time()), "model": self.model_id, "choices": [ { "message": { "role": "assistant", "content": generated_texts[0] if generated_texts else "", }, "index": 0, "logprobs": ( { "token_logprobs": log_probs_data, "top_logprobs": log_probs_data, } if log_probs_data is not None else None ), "finish_reason": ( "length" if generated_texts and len(generated_texts[0]) >= inference_inputs["max_length"] else "stop" ), } ], "usage": { "prompt_tokens": prompt_tokens, "completion_tokens": completion_tokens, "total_tokens": total_tokens, }, } return output except Exception as e: LOGGER.error(f"Error during chat completion: {str(e)}") raise HTTPException(status_code=500, detail=f"Error during chat completion: {str(e)}") @app.get("/v1/models") async def list_models(self): """List available models.""" return { "data": [{"id": self.model_id, "object": "model", "created": int(time.time())}], "object": "list", } @app.get("/v1/health") async def health_check(self): """Health check endpoint.""" return {"status": "healthy"}