Source code for nemo_deploy.service.rest_model_api
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import requests
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from pydantic_settings import BaseSettings
from nemo_deploy.nlp import NemoQueryLLM
try:
from nemo.utils import logging
except (ImportError, ModuleNotFoundError):
import logging
logging = logging.getLogger(__name__)
[docs]
class TritonSettings(BaseSettings):
_triton_service_port: int
_triton_service_ip: str
_triton_request_timeout: str
def __init__(self):
super(TritonSettings, self).__init__()
try:
self._triton_service_port = int(os.environ.get("TRITON_PORT", 8080))
self._triton_service_ip = os.environ.get("TRITON_HTTP_ADDRESS", "0.0.0.0")
self._triton_request_timeout = int(os.environ.get("TRITON_REQUEST_TIMEOUT", 60))
self._openai_format_response = os.environ.get("OPENAI_FORMAT_RESPONSE", "False").lower() == "true"
self._output_generation_logits = os.environ.get("OUTPUT_GENERATION_LOGITS", "False").lower() == "true"
except Exception as error:
logging.error(
"An exception occurred trying to retrieve set args in TritonSettings class. Error:",
error,
)
return
@property
def triton_service_port(self):
return self._triton_service_port
@property
def triton_service_ip(self):
return self._triton_service_ip
@property
def triton_request_timeout(self):
return self._triton_request_timeout
@property
def openai_format_response(self):
"""Retuns the response from Triton server in OpenAI compatible format if set to True."""
return self._openai_format_response
@property
def output_generation_logits(self):
"""Retuns the generation logits along with text in Triton server output if set to True."""
return self._output_generation_logits
app = FastAPI()
triton_settings = TritonSettings()
[docs]
class CompletionRequest(BaseModel):
model: str
prompt: str
max_tokens: int = 512
temperature: float = 1.0
top_p: float = 0.0
top_k: int = 1
stream: bool = False
stop: str | None = None
frequency_penalty: float = 1.0
[docs]
@app.get("/v1/health")
def health_check():
return {"status": "ok"}
[docs]
@app.get("/v1/triton_health")
async def check_triton_health():
"""check_triton_health.
This method exposes endpoint "/triton_health" which can be used to verify if Triton server is accessible while running the REST or FastAPI application.
Verify by running: curl http://service_http_address:service_port/v1/triton_health and the returned status should inform if the server is accessible.
"""
triton_url = (
f"http://{triton_settings.triton_service_ip}:{str(triton_settings.triton_service_port)}/v2/health/ready"
)
logging.info(f"Attempting to connect to Triton server at: {triton_url}")
try:
response = requests.get(triton_url, timeout=5)
if response.status_code == 200:
return {"status": "Triton server is reachable and ready"}
else:
raise HTTPException(status_code=503, detail="Triton server is not ready")
except requests.RequestException as e:
raise HTTPException(status_code=503, detail=f"Cannot reach Triton server: {str(e)}")
[docs]
@app.post("/v1/completions/")
def completions_v1(request: CompletionRequest):
try:
url = triton_settings.triton_service_ip + ":" + str(triton_settings.triton_service_port)
nq = NemoQueryLLM(url=url, model_name=request.model)
output = nq.query_llm(
prompts=[request.prompt],
max_output_len=request.max_tokens,
# when these below params are passed as None
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature,
init_timeout=triton_settings.triton_request_timeout,
openai_format_response=triton_settings.openai_format_response,
output_generation_logits=triton_settings.output_generation_logits,
)
if triton_settings.openai_format_response:
return output
else:
return {
"output": output[0][0],
}
except Exception as error:
logging.error(
"An exception occurred with the post request to /v1/completions/ endpoint:",
error,
)
return {"error": "An exception occurred"}