#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary

"""Standalone Model Context Protocol server for NVIDIA Nemotron Speech NIMs.

This server is intended to be downloaded from the NVIDIA Speech documentation
and hosted by a customer near their self-hosted Speech NIM containers. It uses a
YAML config file to discover which ASR, TTS, and NMT NIM endpoints are available
and exposes a small MCP control plane for agents.

Realtime audio is not transported through MCP. Realtime tools return the native
Riva WebSocket URL and protocol instructions so the agent can stream media
directly to the Speech NIM realtime endpoint.
"""

from __future__ import annotations

import argparse
import base64
import hmac
import json
import logging
import os
import uuid
import wave
from dataclasses import dataclass, field
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse

try:
    import yaml
except ImportError as exc:  # pragma: no cover - exercised by users without deps
    raise SystemExit(
        "PyYAML is required. Install dependencies with: "
        "python3 -m pip install -r requirements.txt"
    ) from exc

LOGGER = logging.getLogger("riva_speech_mcp")
JSONRPC_VERSION = "2.0"
MCP_PROTOCOL_VERSION = "2025-03-26"
DEFAULT_MAX_AUDIO_BYTES = 100 * 1024 * 1024
DEFAULT_MAX_REQUEST_BYTES = 150 * 1024 * 1024


class JsonRpcError(Exception):
    def __init__(self, code: int, message: str, data: Optional[Any] = None):
        super().__init__(message)
        self.code = code
        self.message = message
        self.data = data


@dataclass
class BackendConfig:
    id: str
    modalities: List[str]
    grpc_uri: str
    description: str = ""
    realtime_url: Optional[str] = None
    public_realtime_url: Optional[str] = None
    use_ssl: bool = False
    ssl_root_cert: Optional[str] = None
    ssl_client_cert: Optional[str] = None
    ssl_client_key: Optional[str] = None
    metadata: Dict[str, str] = field(default_factory=dict)
    default_language_code: Optional[str] = None
    default_model: Optional[str] = None

    def supports(self, modality: str) -> bool:
        return modality.lower() in {item.lower() for item in self.modalities}

    def stream_url_base(self) -> Optional[str]:
        return (self.public_realtime_url or self.realtime_url or "").rstrip("/") or None


@dataclass
class ServerConfig:
    name: str = "nemotron-speech-mcp"
    title: str = "Nemotron Speech MCP"
    bearer_token_env: str = "RIVA_MCP_BEARER_TOKEN"
    allow_origins: List[str] = field(default_factory=list)
    allow_local_files: bool = False
    max_audio_bytes: int = DEFAULT_MAX_AUDIO_BYTES
    max_request_bytes: int = DEFAULT_MAX_REQUEST_BYTES
    streaming_chunk_bytes: int = 64 * 1024
    grpc_timeout: float = 30.0


@dataclass
class AppConfig:
    server: ServerConfig
    backends: List[BackendConfig]


def _json_dumps(data: Any, pretty: bool = False) -> str:
    if pretty:
        return json.dumps(data, indent=2, sort_keys=True)
    return json.dumps(data, separators=(",", ":"), sort_keys=True)


def _text_content(text: str) -> Dict[str, str]:
    return {"type": "text", "text": text}


def _json_tool_result(data: Any, text: Optional[str] = None, is_error: bool = False) -> Dict[str, Any]:
    return {
        "content": [_text_content(text if text is not None else _json_dumps(data, pretty=True))],
        "structuredContent": data,
        "isError": is_error,
    }


def _json_rpc_response(request_id: Any, result: Any) -> Dict[str, Any]:
    return {"jsonrpc": JSONRPC_VERSION, "id": request_id, "result": result}


def _json_rpc_error(
    request_id: Any, code: int, message: str, data: Optional[Any] = None
) -> Dict[str, Any]:
    error: Dict[str, Any] = {"code": code, "message": message}
    if data is not None:
        error["data"] = data
    return {"jsonrpc": JSONRPC_VERSION, "id": request_id, "error": error}


def _expand_env(value: Any) -> Any:
    if isinstance(value, str):
        return os.path.expandvars(value)
    if isinstance(value, list):
        return [_expand_env(item) for item in value]
    if isinstance(value, dict):
        return {key: _expand_env(item) for key, item in value.items()}
    return value


def _pcm_to_wav(pcm: bytes, sample_rate: int, channels: int = 1, sample_width: int = 2) -> bytes:
    buf = BytesIO()
    with wave.open(buf, "wb") as wf:
        wf.setnchannels(channels)
        wf.setsampwidth(sample_width)
        wf.setframerate(sample_rate)
        wf.writeframes(pcm)
    return buf.getvalue()


def _strip_wav_header(audio_bytes: bytes) -> Tuple[bytes, Optional[int]]:
    """Return (pcm_bytes, sample_rate) if audio_bytes is a WAV file, else (audio_bytes, None)."""
    if len(audio_bytes) < 12 or audio_bytes[:4] != b"RIFF" or audio_bytes[8:12] != b"WAVE":
        return audio_bytes, None
    try:
        buf = BytesIO(audio_bytes)
        with wave.open(buf, "rb") as wf:
            sample_rate = wf.getframerate()
            pcm = wf.readframes(wf.getnframes())
        return pcm, sample_rate
    except Exception:
        return audio_bytes, None


def load_config(path: str) -> AppConfig:
    raw = yaml.safe_load(Path(path).read_text(encoding="utf-8")) or {}
    raw = _expand_env(raw)
    server_raw = raw.get("server") or {}
    backends_raw = raw.get("backends") or []
    if not backends_raw:
        raise ValueError("config must define at least one backend")

    server = ServerConfig(
        name=server_raw.get("name", "nemotron-speech-mcp"),
        title=server_raw.get("title", "Nemotron Speech MCP"),
        bearer_token_env=server_raw.get("bearer_token_env", "RIVA_MCP_BEARER_TOKEN"),
        allow_origins=list(server_raw.get("allow_origins") or []),
        allow_local_files=bool(server_raw.get("allow_local_files", False)),
        max_audio_bytes=int(server_raw.get("max_audio_bytes", DEFAULT_MAX_AUDIO_BYTES)),
        max_request_bytes=int(server_raw.get("max_request_bytes", DEFAULT_MAX_REQUEST_BYTES)),
        streaming_chunk_bytes=int(server_raw.get("streaming_chunk_bytes", 64 * 1024)),
        grpc_timeout=float(server_raw.get("grpc_timeout", 30.0)),
    )

    backends: List[BackendConfig] = []
    for item in backends_raw:
        backend_id = str(item.get("id") or "").strip()
        if not backend_id:
            raise ValueError("each backend requires an id")
        modalities = [str(modality).lower() for modality in item.get("modalities", [])]
        if not modalities:
            raise ValueError("backend {} requires at least one modality".format(backend_id))
        grpc_uri = str(item.get("grpc_uri") or "").strip()
        if not grpc_uri:
            raise ValueError("backend {} requires grpc_uri".format(backend_id))
        backends.append(
            BackendConfig(
                id=backend_id,
                description=str(item.get("description") or ""),
                modalities=modalities,
                grpc_uri=grpc_uri,
                realtime_url=item.get("realtime_url"),
                public_realtime_url=item.get("public_realtime_url"),
                use_ssl=bool(item.get("use_ssl", False)),
                ssl_root_cert=item.get("ssl_root_cert"),
                ssl_client_cert=item.get("ssl_client_cert"),
                ssl_client_key=item.get("ssl_client_key"),
                metadata=dict(item.get("metadata") or {}),
                default_language_code=item.get("default_language_code"),
                default_model=item.get("default_model"),
            )
        )
    return AppConfig(server=server, backends=backends)


class BackendRegistry:
    def __init__(self, config: AppConfig):
        self.config = config
        self.backends = {backend.id: backend for backend in config.backends}

    def list_backends(self) -> List[Dict[str, Any]]:
        return [self._describe_backend(backend) for backend in self.config.backends]

    def get_backend(self, backend_id: str) -> BackendConfig:
        backend = self.backends.get(backend_id)
        if backend is None:
            raise ValueError("unknown backend_id: {}".format(backend_id))
        return backend

    def _describe_backend(self, backend: BackendConfig) -> Dict[str, Any]:
        return {
            "id": backend.id,
            "description": backend.description,
            "modalities": backend.modalities,
            "grpc_uri": backend.grpc_uri,
            "realtime_url": backend.stream_url_base(),
            "default_language_code": backend.default_language_code,
            "default_model": backend.default_model,
        }

    def select_backend(self, modality: str, backend_id: Optional[str] = None) -> BackendConfig:
        modality = modality.lower()
        if backend_id:
            backend = self.get_backend(backend_id)
            if not backend.supports(modality):
                raise ValueError("backend {} does not support {}".format(backend_id, modality))
            return backend

        candidates = [backend for backend in self.config.backends if backend.supports(modality)]
        if not candidates:
            raise ValueError("no configured backend supports {}".format(modality))
        if len(candidates) > 1:
            raise ValueError(
                "multiple {} backends are configured; pass backend_id. Valid values: {}".format(
                    modality, ", ".join(backend.id for backend in candidates)
                )
            )
        return candidates[0]


class RivaBackendClient:
    def __init__(self, registry: BackendRegistry):
        self.registry = registry

    def _auth(self, backend: BackendConfig):
        from riva.client import Auth

        return Auth(
            ssl_root_cert=backend.ssl_root_cert,
            ssl_client_cert=backend.ssl_client_cert,
            ssl_client_key=backend.ssl_client_key,
            use_ssl=backend.use_ssl,
            uri=backend.grpc_uri,
            metadata_args=backend.metadata or None,
        )

    def health(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        import riva.client.proto.health_pb2 as health_pb2
        import riva.client.proto.health_pb2_grpc as health_pb2_grpc

        timeout = self.registry.config.server.grpc_timeout
        results = []
        backend_id = arguments.get("backend_id")
        backends = [self.registry.get_backend(backend_id)] if backend_id else self.registry.config.backends
        for backend in backends:
            auth = self._auth(backend)
            stub = health_pb2_grpc.HealthStub(auth.channel)
            response = stub.Check(health_pb2.HealthCheckRequest(), timeout=timeout)
            results.append(
                {
                    "backend_id": backend.id,
                    "grpc_uri": backend.grpc_uri,
                    "serving": response.status == health_pb2.HealthCheckResponse.ServingStatus.SERVING,
                    "status": response.status,
                }
            )
        return {"backends": results}

    def list_models(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        modality = str(arguments.get("modality", "all")).lower()
        backend_id = arguments.get("backend_id")
        if modality not in {"all", "asr", "tts", "nmt"}:
            raise ValueError("modality must be one of: all, asr, tts, nmt")

        selected = [self.registry.get_backend(backend_id)] if backend_id else self.registry.config.backends
        result: Dict[str, Any] = {"backends": []}
        for backend in selected:
            entry: Dict[str, Any] = {"backend_id": backend.id, "modalities": {}}
            if modality in {"all", "asr"} and backend.supports("asr"):
                entry["modalities"]["asr"] = self._list_asr_models(backend)
            if modality in {"all", "tts"} and backend.supports("tts"):
                entry["modalities"]["tts"] = self._list_tts_models(backend)
            if modality in {"all", "nmt"} and backend.supports("nmt"):
                entry["modalities"]["nmt"] = self._list_nmt_models(backend)
            if entry["modalities"]:
                result["backends"].append(entry)
        return result

    def _list_asr_models(self, backend: BackendConfig) -> Dict[str, Any]:
        import riva.client.proto.riva_asr_pb2 as riva_asr_pb2
        import riva.client.proto.riva_asr_pb2_grpc as riva_asr_pb2_grpc

        timeout = self.registry.config.server.grpc_timeout
        auth = self._auth(backend)
        stub = riva_asr_pb2_grpc.RivaSpeechRecognitionStub(auth.channel)
        response = stub.GetRivaSpeechRecognitionConfig(riva_asr_pb2.RivaSpeechRecognitionConfigRequest(), timeout=timeout)
        models = []
        for model_config in response.model_config:
            parameters = dict(model_config.parameters)
            models.append(
                {
                    "name": model_config.model_name,
                    "type": parameters.get("type"),
                    "language_code": parameters.get("language_code"),
                    "streaming": parameters.get("streaming"),
                    "parameters": parameters,
                }
            )
        return {"models": models}

    def _list_tts_models(self, backend: BackendConfig) -> Dict[str, Any]:
        import riva.client.proto.riva_tts_pb2 as riva_tts_pb2
        import riva.client.proto.riva_tts_pb2_grpc as riva_tts_pb2_grpc

        timeout = self.registry.config.server.grpc_timeout
        auth = self._auth(backend)
        stub = riva_tts_pb2_grpc.RivaSpeechSynthesisStub(auth.channel)
        response = stub.GetRivaSynthesisConfig(riva_tts_pb2.RivaSynthesisConfigRequest(), timeout=timeout)
        models = []
        for model_config in response.model_config:
            parameters = dict(model_config.parameters)
            subvoices = parameters.get("subvoices", "")
            voices = [v.split(":")[0] for v in subvoices.split(",") if v.strip()]
            models.append(
                {
                    "name": model_config.model_name,
                    "language_code": parameters.get("language_code"),
                    "voices": voices,
                    "parameters": parameters,
                }
            )
        return {"models": models}

    def _list_nmt_models(self, backend: BackendConfig) -> Dict[str, Any]:
        import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2
        import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_pb2_grpc

        timeout = self.registry.config.server.grpc_timeout
        auth = self._auth(backend)
        stub = riva_nmt_pb2_grpc.RivaTranslationStub(auth.channel)
        response = stub.ListSupportedLanguagePairs(riva_nmt_pb2.AvailableLanguageRequest(), timeout=timeout)
        return {
            "models": [
                {
                    "name": model_name,
                    "source_languages": list(language_pair.src_lang),
                    "target_languages": list(language_pair.tgt_lang),
                }
                for model_name, language_pair in response.languages.items()
            ]
        }

    def transcribe_file(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        import riva.client.proto.riva_asr_pb2 as riva_asr_pb2
        import riva.client.proto.riva_asr_pb2_grpc as riva_asr_pb2_grpc
        from riva.client.proto.riva_audio_pb2 import AudioEncoding

        backend = self.registry.select_backend("asr", arguments.get("backend_id"))
        audio_bytes = self._load_audio(arguments)
        language_code = arguments.get("language_code") or backend.default_language_code or "en-US"
        encoding_name = arguments.get("encoding", "LINEAR_PCM")
        encoding = self._audio_encoding(AudioEncoding, encoding_name)

        # Strip WAV container when the caller passes a .wav file with LINEAR_PCM encoding.
        # Auto-detect sample rate from the WAV header if the caller did not supply one.
        sample_rate_hz = int(arguments.get("sample_rate_hz", 0))
        if encoding_name.upper() == "LINEAR_PCM":
            pcm, detected_rate = _strip_wav_header(audio_bytes)
            if detected_rate is not None:
                audio_bytes = pcm
                if not sample_rate_hz:
                    sample_rate_hz = detected_rate
        if not sample_rate_hz:
            sample_rate_hz = 16000

        model = arguments.get("model") or backend.default_model

        config = riva_asr_pb2.RecognitionConfig(
            encoding=encoding,
            sample_rate_hertz=sample_rate_hz,
            language_code=language_code,
            max_alternatives=int(arguments.get("max_alternatives", 1)),
            enable_automatic_punctuation=bool(arguments.get("enable_automatic_punctuation", True)),
            enable_word_time_offsets=bool(arguments.get("enable_word_time_offsets", False)),
            profanity_filter=bool(arguments.get("profanity_filter", False)),
        )
        if model:
            config.model = model
        for key, value in (arguments.get("custom_configuration") or {}).items():
            config.custom_configuration[str(key)] = str(value)

        timeout = self.registry.config.server.grpc_timeout
        auth = self._auth(backend)
        stub = riva_asr_pb2_grpc.RivaSpeechRecognitionStub(auth.channel)

        mode = str(arguments.get("mode", "auto")).lower()
        if mode == "auto":
            mode = "offline" if self._is_offline_model(backend, model) else "streaming"

        if mode == "offline":
            request = riva_asr_pb2.RecognizeRequest(config=config, audio=audio_bytes)
            response = stub.Recognize(request, metadata=auth.get_auth_metadata(), timeout=timeout)
            # Wrap in a list so _parse_asr_responses can iterate uniformly.
            parsed = self._parse_asr_responses([response])
        else:
            streaming_config = riva_asr_pb2.StreamingRecognitionConfig(
                config=config,
                interim_results=bool(arguments.get("interim_results", False)),
            )

            def requests() -> Iterable[Any]:
                yield riva_asr_pb2.StreamingRecognizeRequest(streaming_config=streaming_config)
                chunk_size = self.registry.config.server.streaming_chunk_bytes
                for offset in range(0, len(audio_bytes), chunk_size):
                    yield riva_asr_pb2.StreamingRecognizeRequest(audio_content=audio_bytes[offset : offset + chunk_size])

            responses = stub.StreamingRecognize(requests(), metadata=auth.get_auth_metadata(), timeout=timeout)
            parsed = self._parse_asr_responses(responses)

        parsed["backend_id"] = backend.id
        parsed["mode"] = mode
        return parsed

    def _is_offline_model(self, backend: BackendConfig, model: Optional[str]) -> bool:
        """Return True if the selected model is an offline (non-streaming) model."""
        try:
            info = self._list_asr_models(backend)
            for m in info.get("models", []):
                if model and m.get("name") != model:
                    continue
                if m.get("parameters", {}).get("type", "").lower() == "offline":
                    return True
        except Exception:
            pass
        # Fall back to inspecting the model name when metadata lookup fails.
        if model and "offline" in model.lower():
            return True
        return False

    def create_realtime_asr_session(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        backend = self.registry.select_backend("asr", arguments.get("backend_id"))
        stream_base = backend.stream_url_base()
        if not stream_base:
            raise ValueError("backend {} does not define realtime_url".format(backend.id))
        session_id = "asr_{}".format(uuid.uuid4().hex)
        stream_url = "{}?intent=transcription&session_id={}".format(stream_base, session_id)
        return {
            "backend_id": backend.id,
            "session_id": session_id,
            "protocol": "riva.realtime.websocket",
            "stream_url": stream_url,
            "input_audio_format": arguments.get("input_audio_format", "pcm16"),
            "sample_rate_hz": int(arguments.get("sample_rate_hz", 16000)),
            "chunk_duration_ms": int(arguments.get("chunk_duration_ms", 20)),
            "language_code": arguments.get("language_code") or backend.default_language_code,
            "model": arguments.get("model") or backend.default_model,
            "events": {
                "conversation_created": "conversation.created",
                "session_update": "transcription_session.update",
                "session_updated": "transcription_session.updated",
                "send_audio": "input_audio_buffer.append",
                "commit": "input_audio_buffer.commit",
                "done": "input_audio_buffer.done",
                "committed": "input_audio_buffer.committed",
                "partial_transcript": "conversation.item.input_audio_transcription.delta",
                "final_transcript": "conversation.item.input_audio_transcription.completed",
                "error": "error",
            },
            "message_templates": {
                "optional_session_update": {
                    "type": "transcription_session.update",
                    "session": {
                        "input_audio_format": "pcm16",
                        "input_audio_params": {"sample_rate_hz": int(arguments.get("sample_rate_hz", 16000)), "num_channels": 1},
                        "input_audio_transcription": {
                            "language": "<optional BCP-47 language code>",
                            "model": "<valid model name from riva_list_models>",
                        },
                    },
                },
                "append_audio_chunk": {"type": "input_audio_buffer.append", "event_id": "<uuid>", "audio": "<base64 raw PCM16 chunk>"},
                "commit": {"type": "input_audio_buffer.commit", "event_id": "<uuid>"},
                "done": {"type": "input_audio_buffer.done", "event_id": "<uuid>"},
            },
            "flow": [
                "Call this MCP tool to get stream_url.",
                "Open a WebSocket to stream_url and wait for conversation.created.",
                "Skip transcription_session.update unless you need to override a specific field — the server already configures the session with the language, model, and sample rate from this tool call.",
                "If you do send transcription_session.update, always include the model name from riva_list_models — omitting it resets the model to a server-side default that may not exist and will cause an internal error.",
                "Send input_audio_buffer.append with base64 raw PCM16 chunks, not WAV bytes.",
                "Send input_audio_buffer.commit every 20-30 seconds or at a turn boundary.",
                "Send one final commit at end of stream.",
                "Send input_audio_buffer.done, then continue reading final transcript events.",
                "Use delta events for partial text and completed events for finalized text.",
            ],
            "agent_notes": [
                "MCP is the control plane; realtime media goes directly to stream_url.",
                "For WAV files, decode the WAV and stream raw PCM16 frames.",
                "Do not send RIFF/WAV headers as audio chunks.",
                "Do not send transcription_session.update unless overriding a specific field. If you do, always include a valid model name — omitting it causes the server to fall back to an invalid default model and return StatusCode.INTERNAL.",
                "Do not invent a model name; omit the model or use one from riva_list_models.",
                "Completed events can be segment-level or aggregate-level; avoid duplicates.",
                "ASR models require 16000 Hz audio. TTS synthesizes at 22050 Hz by default — resample before streaming to ASR.",
            ],
        }

    def synthesize(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        import riva.client.proto.riva_tts_pb2 as riva_tts_pb2
        import riva.client.proto.riva_tts_pb2_grpc as riva_tts_pb2_grpc
        from riva.client.proto.riva_audio_pb2 import AudioEncoding

        backend = self.registry.select_backend("tts", arguments.get("backend_id"))
        text = arguments.get("text")
        if not text or not str(text).strip():
            raise ValueError("text must be non-empty")
        language_code = arguments.get("language_code") or backend.default_language_code or "en-US"
        encoding_name = arguments.get("encoding", "LINEAR_PCM")
        encoding = self._audio_encoding(AudioEncoding, encoding_name)
        request = riva_tts_pb2.SynthesizeSpeechRequest(
            text=text,
            language_code=language_code,
            sample_rate_hz=int(arguments.get("sample_rate_hz", 22050)),
            encoding=encoding,
        )
        voice_name = arguments.get("voice_name") or backend.default_model
        if voice_name:
            request.voice_name = voice_name
        if arguments.get("custom_dictionary"):
            request.custom_dictionary = arguments["custom_dictionary"]

        timeout = self.registry.config.server.grpc_timeout
        auth = self._auth(backend)
        stub = riva_tts_pb2_grpc.RivaSpeechSynthesisStub(auth.channel)
        response = stub.Synthesize(request, metadata=auth.get_auth_metadata(), timeout=timeout)
        sample_rate_hz = int(arguments.get("sample_rate_hz", 22050))
        if encoding_name.upper() in {"OGGOPUS", "OGG_OPUS"}:
            mime_type = "audio/opus"
            audio_out = response.audio
        else:
            mime_type = "audio/wav"
            audio_out = _pcm_to_wav(response.audio, sample_rate=sample_rate_hz)
        return {
            "backend_id": backend.id,
            "mime_type": mime_type,
            "encoding": encoding_name,
            "sample_rate_hz": sample_rate_hz,
            "audio_base64": base64.b64encode(audio_out).decode("ascii"),
            "audio_bytes": len(audio_out),
        }

    def create_realtime_tts_session(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        backend = self.registry.select_backend("tts", arguments.get("backend_id"))
        stream_base = backend.stream_url_base()
        if not stream_base:
            raise ValueError("backend {} does not define realtime_url".format(backend.id))
        session_id = "tts_{}".format(uuid.uuid4().hex)
        stream_url = "{}?intent=synthesize&session_id={}".format(stream_base, session_id)
        return {
            "backend_id": backend.id,
            "session_id": session_id,
            "protocol": "riva.realtime.websocket",
            "stream_url": stream_url,
            "language_code": arguments.get("language_code") or backend.default_language_code or "en-US",
            "voice_name": arguments.get("voice_name") or backend.default_model,
            "sample_rate_hz": int(arguments.get("sample_rate_hz", 22050)),
            "encoding": arguments.get("encoding", "LINEAR_PCM"),
            "events": {"send_text": "input_text.append", "commit": "input_text.commit", "audio_delta": "response.audio.delta", "completed": "response.audio.done"},
            "instructions": "Stream text directly to stream_url. Do not use MCP as the realtime audio transport.",
        }

    def translate(self, arguments: Dict[str, Any]) -> Dict[str, Any]:
        import riva.client.proto.riva_nmt_pb2 as riva_nmt_pb2
        import riva.client.proto.riva_nmt_pb2_grpc as riva_nmt_pb2_grpc

        backend = self.registry.select_backend("nmt", arguments.get("backend_id"))
        text = arguments.get("text")
        if not text or not str(text).strip():
            raise ValueError("text must be non-empty")
        source_language = arguments.get("source_language")
        target_language = arguments.get("target_language")
        if not source_language or not target_language:
            raise ValueError("source_language and target_language are required")
        request = riva_nmt_pb2.TranslateTextRequest(texts=[text], source_language=source_language, target_language=target_language)
        model = arguments.get("model") or backend.default_model
        if model:
            request.model = model
        dnt_phrases = arguments.get("dnt_phrases") or {}
        for key, value in dnt_phrases.items():
            request.dnt_phrases.append("{}##{}".format(key, value))
        if arguments.get("max_len_variation"):
            request.max_len_variation = arguments["max_len_variation"]

        timeout = self.registry.config.server.grpc_timeout
        auth = self._auth(backend)
        stub = riva_nmt_pb2_grpc.RivaTranslationStub(auth.channel)
        response = stub.TranslateText(request, metadata=auth.get_auth_metadata(), timeout=timeout)
        translations = [{"text": item.text, "language": item.language} for item in response.translations]
        return {"backend_id": backend.id, "translations": translations}

    def _load_audio(self, arguments: Dict[str, Any]) -> bytes:
        if arguments.get("audio_base64"):
            audio_bytes = base64.b64decode(arguments["audio_base64"], validate=True)
        else:
            path = arguments.get("audio_path") or self._file_uri_to_path(arguments.get("audio_uri"))
            if not path:
                raise ValueError("one of audio_base64, audio_path, or file:// audio_uri is required")
            if not self.registry.config.server.allow_local_files:
                raise ValueError("local file transcription is disabled by server config")
            audio_bytes = Path(path).read_bytes()
        if len(audio_bytes) > self.registry.config.server.max_audio_bytes:
            raise ValueError(
                "audio payload is {} bytes, which exceeds max_audio_bytes={}".format(
                    len(audio_bytes), self.registry.config.server.max_audio_bytes
                )
            )
        return audio_bytes

    def _file_uri_to_path(self, uri: Optional[str]) -> Optional[str]:
        if not uri:
            return None
        parsed = urlparse(uri)
        if parsed.scheme != "file":
            raise ValueError("only file:// audio_uri values are supported")
        return parsed.path

    def _audio_encoding(self, enum_cls: Any, encoding_name: str) -> int:
        normalized = str(encoding_name).upper()
        aliases = {"OGG_OPUS": "OGGOPUS"}
        normalized = aliases.get(normalized, normalized)
        if not hasattr(enum_cls, normalized):
            raise ValueError("unsupported audio encoding: {}".format(encoding_name))
        return getattr(enum_cls, normalized)

    def _parse_asr_responses(self, responses: Iterable[Any]) -> Dict[str, Any]:
        final_parts: List[str] = []
        segments: List[Dict[str, Any]] = []
        words: List[Dict[str, Any]] = []
        language_codes: List[str] = []
        for response in responses:
            for result in response.results:
                alternatives = result.alternatives
                if not alternatives:
                    continue
                alternative = alternatives[0]
                transcript = alternative.transcript
                # RecognizeResponse results are always final; StreamingRecognizeResponse results carry is_final explicitly.
                is_final = bool(getattr(result, "is_final", True))
                segment = {"text": transcript, "is_final": is_final, "confidence": getattr(alternative, "confidence", 0.0)}
                if getattr(result, "language_code", ""):
                    segment["language_code"] = result.language_code
                    if result.language_code not in language_codes:
                        language_codes.append(result.language_code)
                segments.append(segment)
                if is_final:
                    final_parts.append(transcript)
                    for word in alternative.words:
                        words.append(
                            {
                                "word": word.word,
                                "start_time": self._duration_to_seconds(word.start_time),
                                "end_time": self._duration_to_seconds(word.end_time),
                                "confidence": getattr(word, "confidence", 0.0),
                            }
                        )
        text = "".join(final_parts)
        if not text and segments:
            text = segments[-1]["text"]
        return {"text": text, "language_codes": language_codes, "segments": segments, "words": words}

    def _duration_to_seconds(self, duration: Any) -> float:
        if isinstance(duration, (int, float)):
            return float(duration)
        return float(duration.seconds) + float(duration.nanos) / 1_000_000_000.0


class RivaMCPApplication:
    def __init__(self, config: AppConfig):
        self.config = config
        self.registry = BackendRegistry(config)
        self.backend = RivaBackendClient(self.registry)

    def handle_json_rpc(self, message: Dict[str, Any]) -> Tuple[Optional[Dict[str, Any]], int]:
        if not isinstance(message, dict):
            return _json_rpc_error(None, -32600, "Invalid Request"), HTTPStatus.BAD_REQUEST
        if message.get("jsonrpc") != JSONRPC_VERSION or "method" not in message:
            return _json_rpc_error(message.get("id"), -32600, "Invalid Request"), HTTPStatus.BAD_REQUEST
        request_id = message.get("id")
        method = message["method"]
        params = message.get("params") or {}
        if "id" not in message:
            LOGGER.debug("Ignoring notification %s", method)
            return None, HTTPStatus.ACCEPTED
        try:
            result = self._dispatch(method, params)
            return _json_rpc_response(request_id, result), HTTPStatus.OK
        except JsonRpcError as exc:
            return _json_rpc_error(request_id, exc.code, exc.message, exc.data), HTTPStatus.OK
        except Exception as exc:
            LOGGER.exception("Unhandled MCP request failure")
            return _json_rpc_error(request_id, -32603, str(exc)), HTTPStatus.OK

    def _dispatch(self, method: str, params: Dict[str, Any]) -> Dict[str, Any]:
        if method == "initialize":
            return self._initialize(params)
        if method == "ping":
            return {}
        if method == "tools/list":
            return {"tools": self._tools()}
        if method == "tools/call":
            return self._call_tool(params)
        if method == "resources/list":
            return {"resources": self._resources()}
        if method == "resources/read":
            return self._read_resource(params)
        if method == "prompts/list":
            return {"prompts": self._prompts()}
        if method == "prompts/get":
            return self._get_prompt(params)
        raise JsonRpcError(-32601, "Method not found: {}".format(method))

    def _initialize(self, params: Dict[str, Any]) -> Dict[str, Any]:
        requested = params.get("protocolVersion", MCP_PROTOCOL_VERSION)
        version = requested if requested == MCP_PROTOCOL_VERSION else MCP_PROTOCOL_VERSION
        return {
            "protocolVersion": version,
            "capabilities": {"tools": {"listChanged": False}, "resources": {"listChanged": False}, "prompts": {"listChanged": False}},
            "serverInfo": {"name": self.config.server.name, "title": self.config.server.title, "version": "0.1.0"},
            "instructions": (
                "Use this MCP server as the agent control plane for configured Nemotron Speech NIM backends. "
                "For realtime ASR or TTS, call the create_realtime_session tool and stream media directly to the "
                "returned Riva WebSocket URL."
            ),
        }

    def _call_tool(self, params: Dict[str, Any]) -> Dict[str, Any]:
        name = params.get("name")
        arguments = params.get("arguments") or {}
        if not isinstance(arguments, dict):
            raise JsonRpcError(-32602, "tools/call arguments must be an object")
        handlers = {
            "riva_health": lambda: self.backend.health(arguments),
            "riva_list_backends": lambda: {"backends": self.registry.list_backends()},
            "riva_list_models": lambda: self.backend.list_models(arguments),
            "riva_asr_transcribe_file": lambda: self.backend.transcribe_file(arguments),
            "riva_asr_create_realtime_session": lambda: self.backend.create_realtime_asr_session(arguments),
            "riva_tts_synthesize": lambda: self.backend.synthesize(arguments),
            "riva_tts_create_realtime_session": lambda: self.backend.create_realtime_tts_session(arguments),
            "riva_nmt_translate": lambda: self.backend.translate(arguments),
        }
        if name not in handlers:
            raise JsonRpcError(-32602, "Unknown tool: {}".format(name))
        try:
            data = handlers[name]()
            result = _json_tool_result(data)
            if name == "riva_tts_synthesize":
                # Return audio as a proper MCP audio content item.
                # Remove audio_base64 from structuredContent to avoid sending the
                # full payload twice (it is already in content[0].data).
                structured = {k: v for k, v in data.items() if k != "audio_base64"}
                result["structuredContent"] = structured
                result["content"] = [
                    {"type": "audio", "data": data["audio_base64"], "mimeType": data["mime_type"]},
                    _text_content("Synthesized {} bytes of {} audio.".format(data["audio_bytes"], data["mime_type"])),
                ]
            return result
        except ValueError as exc:
            return _json_tool_result({"tool": name, "error": str(exc)}, str(exc), is_error=True)
        except Exception as exc:
            LOGGER.exception("Tool %s failed", name)
            return _json_tool_result({"tool": name, "error": str(exc)}, str(exc), is_error=True)

    def _tools(self) -> List[Dict[str, Any]]:
        backend_id = {"type": "string", "description": "Optional configured backend id. Required when multiple matching backends exist."}
        return [
            {"name": "riva_health", "title": "Riva Health", "description": "Check health for one configured backend or all configured backends.", "inputSchema": {"type": "object", "properties": {"backend_id": backend_id}, "additionalProperties": False}},
            {"name": "riva_list_backends", "title": "List Configured Speech NIM Backends", "description": "List the ASR, TTS, and NMT backends configured for this MCP gateway.", "inputSchema": {"type": "object", "properties": {}, "additionalProperties": False}},
            {"name": "riva_list_models", "title": "List Riva Speech Models", "description": "Discover live models exposed by configured Riva Speech NIM backends.", "inputSchema": {"type": "object", "properties": {"backend_id": backend_id, "modality": {"type": "string", "enum": ["all", "asr", "tts", "nmt"], "default": "all"}}, "additionalProperties": False}},
            {"name": "riva_asr_transcribe_file", "title": "Transcribe Audio", "description": "Transcribe a file or audio blob with a configured Riva ASR backend. Use riva_asr_create_realtime_session for live microphone streaming. ASR models require 16000 Hz audio; if you pass a WAV file the server will auto-detect the sample rate. TTS output is 22050 Hz by default and must be resampled before transcription. Use mode=auto (default) to let the server pick the correct gRPC path based on model type.", "inputSchema": {"type": "object", "properties": {"backend_id": backend_id, "audio_base64": {"type": "string", "description": "Base64-encoded raw PCM16 or WAV bytes. WAV headers are stripped automatically."}, "audio_path": {"type": "string"}, "audio_uri": {"type": "string"}, "language_code": {"type": "string", "default": "en-US"}, "model": {"type": "string"}, "mode": {"type": "string", "enum": ["auto", "offline", "streaming"], "default": "auto", "description": "auto selects the gRPC path based on model type metadata: offline models use Recognize (full audio in one request), streaming models use StreamingRecognize. Override only when auto-detection is wrong."}, "sample_rate_hz": {"type": "integer", "default": 16000, "description": "Sample rate in Hz. Inferred from WAV header when omitted and audio_base64 is a WAV file."}, "encoding": {"type": "string", "default": "LINEAR_PCM"}, "enable_automatic_punctuation": {"type": "boolean", "default": True}, "enable_word_time_offsets": {"type": "boolean", "default": False}, "max_alternatives": {"type": "integer", "default": 1}, "profanity_filter": {"type": "boolean", "default": False}, "interim_results": {"type": "boolean", "default": False}, "custom_configuration": {"type": "object", "additionalProperties": {"type": "string"}}}, "additionalProperties": False}},
            {"name": "riva_asr_create_realtime_session", "title": "Create Realtime ASR Session", "description": "Return the Riva realtime WebSocket URL and event flow. Audio chunks are sent directly to the returned WebSocket URL, not through MCP. ASR models require 16000 Hz PCM16 audio. TTS synthesizes at 22050 Hz by default — resample before streaming to ASR. Do NOT send transcription_session.update after connecting unless overriding a specific field; the session is already configured. If you do send it, always include a valid model name or the server will reset to an invalid default and return an internal error.", "inputSchema": {"type": "object", "properties": {"backend_id": backend_id, "language_code": {"type": "string"}, "model": {"type": "string"}, "input_audio_format": {"type": "string", "default": "pcm16"}, "sample_rate_hz": {"type": "integer", "default": 16000, "description": "Must match the actual audio sample rate. ASR models require 16000 Hz."}, "chunk_duration_ms": {"type": "integer", "default": 20}}, "additionalProperties": False}},
            {"name": "riva_tts_synthesize", "title": "Synthesize Speech", "description": "Synthesize short text with a configured Riva TTS backend. Returns a WAV file (audio/wav). TTS output is 22050 Hz by default; ASR requires 16000 Hz — resample if piping TTS output to ASR.", "inputSchema": {"type": "object", "required": ["text"], "properties": {"backend_id": backend_id, "text": {"type": "string"}, "language_code": {"type": "string", "default": "en-US"}, "voice_name": {"type": "string", "description": "Voice in LANG.Name format, e.g. EN-US.Aria or HI-IN.Leo. Call riva_list_models with modality=tts to see all available voices."}, "sample_rate_hz": {"type": "integer", "default": 22050}, "encoding": {"type": "string", "default": "LINEAR_PCM"}, "custom_dictionary": {"type": "string"}}, "additionalProperties": False}},
            {"name": "riva_tts_create_realtime_session", "title": "Create Realtime TTS Session", "description": "Return the Riva realtime WebSocket URL for continuous TTS.", "inputSchema": {"type": "object", "properties": {"backend_id": backend_id, "language_code": {"type": "string", "default": "en-US"}, "voice_name": {"type": "string"}, "sample_rate_hz": {"type": "integer", "default": 22050}, "encoding": {"type": "string", "default": "LINEAR_PCM"}}, "additionalProperties": False}},
            {"name": "riva_nmt_translate", "title": "Translate Text", "description": "Translate text with a configured Riva NMT backend.", "inputSchema": {"type": "object", "required": ["text", "source_language", "target_language"], "properties": {"backend_id": backend_id, "text": {"type": "string"}, "source_language": {"type": "string"}, "target_language": {"type": "string"}, "model": {"type": "string"}, "dnt_phrases": {"type": "object", "additionalProperties": {"type": "string"}}, "max_len_variation": {"type": "string"}}, "additionalProperties": False}},
        ]

    def _resources(self) -> List[Dict[str, Any]]:
        return [
            {"uri": "riva://speech/capabilities", "name": "Nemotron Speech MCP Capabilities", "mimeType": "application/json", "description": "Static description of the MCP control plane and native data planes."},
            {"uri": "riva://speech/backends", "name": "Configured Speech NIM Backends", "mimeType": "application/json", "description": "The ASR, TTS, and NMT NIM endpoints configured for this MCP server."},
            {"uri": "riva://speech/models", "name": "Live Speech NIM Models", "mimeType": "application/json", "description": "Live model inventory queried from configured Speech NIM backends."},
            {"uri": "riva://asr/realtime-protocol", "name": "Realtime ASR Protocol", "mimeType": "application/json", "description": "How agents should discover realtime ASR and stream media natively."},
            {"uri": "riva://tts/realtime-protocol", "name": "Realtime TTS Protocol", "mimeType": "application/json", "description": "How agents should discover realtime TTS and stream media natively."},
        ]

    def _read_resource(self, params: Dict[str, Any]) -> Dict[str, Any]:
        uri = params.get("uri")
        if uri == "riva://speech/capabilities":
            data = {"native_data_plane": ["grpc", "http", "realtime_websocket"], "mcp_role": "agent control plane", "tools": [tool["name"] for tool in self._tools()]}
        elif uri == "riva://speech/backends":
            data = {"backends": self.registry.list_backends()}
        elif uri == "riva://speech/models":
            data = self.backend.list_models({"modality": "all"})
        elif uri == "riva://asr/realtime-protocol":
            data = self.backend.create_realtime_asr_session({})
        elif uri == "riva://tts/realtime-protocol":
            data = self.backend.create_realtime_tts_session({})
        else:
            raise JsonRpcError(-32602, "Unknown resource URI: {}".format(uri))
        return {"contents": [{"uri": uri, "mimeType": "application/json", "text": _json_dumps(data, pretty=True)}]}

    def _prompts(self) -> List[Dict[str, Any]]:
        return [
            {"name": "riva_live_transcription_agent", "title": "Live Transcription Agent", "description": "Guide an agent to set up realtime ASR and stream audio natively.", "arguments": [{"name": "backend_id", "required": False}]},
            {"name": "riva_read_aloud_agent", "title": "Read Aloud Agent", "description": "Guide an agent to synthesize final responses with Riva TTS.", "arguments": [{"name": "backend_id", "required": False}]},
        ]

    def _get_prompt(self, params: Dict[str, Any]) -> Dict[str, Any]:
        name = params.get("name")
        if name == "riva_live_transcription_agent":
            text = "When the user requests live transcription, call riva_asr_create_realtime_session first. Use the returned stream_url, flow, events, and message_templates fields. Stream base64 raw PCM16 chunks over WebSocket, commit every 20-30 seconds or at turn boundaries, then send input_audio_buffer.done. Do not invent model names or send realtime audio through MCP tools."
        elif name == "riva_read_aloud_agent":
            text = "When the user asks to hear short text aloud, call riva_tts_synthesize. For continuous speech, call riva_tts_create_realtime_session and stream text through the returned Riva WebSocket URL."
        else:
            raise JsonRpcError(-32602, "Unknown prompt: {}".format(name))
        return {"description": text, "messages": [{"role": "user", "content": _text_content(text)}]}


class MCPHTTPServer(ThreadingHTTPServer):
    def __init__(self, address: Tuple[str, int], app: RivaMCPApplication, bearer_token: Optional[str], allowed_origins: List[str]):
        super().__init__(address, RivaMCPRequestHandler)
        self.app = app
        self.bearer_token = bearer_token
        self.allowed_origins = set(allowed_origins)


class RivaMCPRequestHandler(BaseHTTPRequestHandler):
    server_version = "NemotronSpeechMCP/0.1"

    def do_GET(self) -> None:
        parsed = urlparse(self.path)
        if parsed.path == "/health":
            self._send_json({"status": "ok", "service": "nemotron-speech-mcp"})
            return
        if parsed.path == "/mcp":
            self.send_response(HTTPStatus.METHOD_NOT_ALLOWED)
            self.send_header("Allow", "POST")
            self.end_headers()
            return
        self.send_error(HTTPStatus.NOT_FOUND)

    def do_POST(self) -> None:
        parsed = urlparse(self.path)
        if parsed.path != "/mcp":
            self.send_error(HTTPStatus.NOT_FOUND)
            return
        if not self._authorized():
            self.send_error(HTTPStatus.UNAUTHORIZED)
            return
        if not self._origin_allowed():
            self.send_error(HTTPStatus.FORBIDDEN)
            return
        ct = self.headers.get("Content-Type", "")
        if not ct.split(";")[0].strip().lower() == "application/json":
            self.send_error(HTTPStatus.UNSUPPORTED_MEDIA_TYPE, "Content-Type must be application/json")
            return
        try:
            content_length = int(self.headers.get("Content-Length", "0"))
        except ValueError:
            self.send_error(HTTPStatus.LENGTH_REQUIRED)
            return
        max_request_bytes = self.server.app.config.server.max_request_bytes
        if content_length > max_request_bytes:
            self.send_error(HTTPStatus.REQUEST_ENTITY_TOO_LARGE)
            return
        try:
            body = self.rfile.read(content_length)
            message = json.loads(body.decode("utf-8"))
        except Exception:
            self._send_json(_json_rpc_error(None, -32700, "Parse error"), HTTPStatus.BAD_REQUEST)
            return
        if isinstance(message, list):
            self._send_json(_json_rpc_error(None, -32600, "JSON-RPC batches are not supported"))
            return
        response, status = self.server.app.handle_json_rpc(message)
        if response is None:
            self.send_response(status)
            self.send_header("MCP-Protocol-Version", MCP_PROTOCOL_VERSION)
            self.end_headers()
            return
        self._send_json(response, status)

    def log_message(self, fmt: str, *args: Any) -> None:
        LOGGER.info("%s - %s", self.address_string(), fmt % args)

    def _authorized(self) -> bool:
        token = self.server.bearer_token
        if not token:
            return True
        expected = "Bearer {}".format(token)
        return hmac.compare_digest(self.headers.get("Authorization", ""), expected)

    def _origin_allowed(self) -> bool:
        allowed = self.server.allowed_origins
        if not allowed:
            return True
        origin = self.headers.get("Origin")
        return origin is None or origin in allowed

    def _send_json(self, data: Dict[str, Any], status: int = HTTPStatus.OK) -> None:
        payload = _json_dumps(data).encode("utf-8")
        self.send_response(status)
        self.send_header("Content-Type", "application/json")
        self.send_header("Content-Length", str(len(payload)))
        self.send_header("MCP-Protocol-Version", MCP_PROTOCOL_VERSION)
        self.end_headers()
        self.wfile.write(payload)


def serve(config: AppConfig, host: str, port: int) -> None:
    token = os.environ.get(config.server.bearer_token_env)
    server = MCPHTTPServer((host, port), app=RivaMCPApplication(config), bearer_token=token, allowed_origins=config.server.allow_origins)
    LOGGER.info("Starting Nemotron Speech MCP server on %s:%s", host, port)
    LOGGER.info("Configured backends: %s", ", ".join(backend.id for backend in config.backends))
    if token:
        LOGGER.info("Bearer token auth is enabled with env %s", config.server.bearer_token_env)
    server.serve_forever()


def main() -> None:
    parser = argparse.ArgumentParser(description="Standalone Nemotron Speech MCP server")
    parser.add_argument("--config", required=True, help="Path to YAML config file")
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=9900)
    parser.add_argument("--log-level", default="INFO")
    args = parser.parse_args()
    logging.basicConfig(level=getattr(logging, args.log_level.upper(), logging.INFO))
    config = load_config(args.config)
    serve(config, args.host, args.port)


if __name__ == "__main__":
    main()
