add openai_serving and openai_serving_models

This commit is contained in:
IlyaMescheryakov1402 2025-03-09 15:12:05 +03:00
parent 428be76642
commit cadd48f672
6 changed files with 99 additions and 67 deletions

View File

@ -4,11 +4,13 @@ import traceback
import gzip import gzip
import asyncio import asyncio
from fastapi import FastAPI, Request, Response, APIRouter, HTTPException from fastapi import FastAPI, Request, Response, APIRouter, HTTPException, Depends
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from grpc.aio import AioRpcError from grpc.aio import AioRpcError
from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from typing import Optional, Dict, Any, Callable, Union from typing import Optional, Dict, Any, Callable, Union
@ -194,16 +196,27 @@ async def base_serve_model(
return return_value return return_value
@router.post("/openai/v1/{endpoint_type:path}") async def validate_json_request(raw_request: Request):
@router.post("/openai/v1/{endpoint_type:path}/") content_type = raw_request.headers.get("content-type", "").lower()
media_type = content_type.split(";", maxsplit=1)[0]
if media_type != "application/json":
raise HTTPException(
status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE,
detail="Unsupported Media Type: Only 'application/json' is allowed"
)
@router.post("/openai/v1/{endpoint_type:path}", dependencies=[Depends(validate_json_request)])
@router.post("/openai/v1/{endpoint_type:path}/", dependencies=[Depends(validate_json_request)])
async def openai_serve_model( async def openai_serve_model(
endpoint_type: str, endpoint_type: str,
request: Union[bytes, Dict[Any, Any]] = None request: Union[CompletionRequest, ChatCompletionRequest],
raw_request: Request
): ):
combined_request = {"request": request, "raw_request": raw_request}
return_value = await process_with_exceptions( return_value = await process_with_exceptions(
base_url=request.get("model", None), base_url=request.get("model", None),
version=None, version=None,
request_body=request, request_body=combined_request,
serve_type=endpoint_type serve_type=endpoint_type
) )
return return_value return return_value

View File

@ -1234,12 +1234,12 @@ class ModelRequestProcessor(object):
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
processed = await processor.completions(preprocessed, state, stats_collect_fn) \ processed = await processor.completions(preprocessed, state, stats_collect_fn) \
if processor.is_process_async \ if processor.is_process_async \
else processor.completion(preprocessed, state, stats_collect_fn) else processor.completions(preprocessed, state, stats_collect_fn)
elif serve_type == "chat/completions": elif serve_type == "chat/completions":
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
processed = await processor.chat_completions(preprocessed, state, stats_collect_fn) \ processed = await processor.chat_completions(preprocessed, state, stats_collect_fn) \
if processor.is_process_async \ if processor.is_process_async \
else processor.chat_completion(preprocessed, state, stats_collect_fn) else processor.chat_completions(preprocessed, state, stats_collect_fn)
else: else:
raise ValueError(f"wrong url_type: expected 'process', 'completions' or 'chat/completions', got {serve_type}") raise ValueError(f"wrong url_type: expected 'process', 'completions' or 'chat/completions', got {serve_type}")
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences

View File

@ -613,6 +613,9 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
@BasePreprocessRequest.register_engine("vllm", modules=["vllm", "fastapi"]) @BasePreprocessRequest.register_engine("vllm", modules=["vllm", "fastapi"])
class VllmPreprocessRequest(BasePreprocessRequest): class VllmPreprocessRequest(BasePreprocessRequest):
is_preprocess_async = True
is_process_async = True
is_postprocess_async = True
asyncio_to_thread = None asyncio_to_thread = None
_vllm = None _vllm = None
_fastapi = None _fastapi = None
@ -729,25 +732,18 @@ class VllmPreprocessRequest(BasePreprocessRequest):
The actual processing function. The actual processing function.
We run the process in this context We run the process in this context
""" """
request, raw_request = data["request"], data["raw_request"]
raw_request = CustomRequest( handler = self._model["openai_serving_completion"]
headers = { if handler is None:
"traceparent": None, return self._model["openai_serving"].create_error_response(message="The model does not support Completions API")
"tracestate": None # request = self._vllm["completion_request"](**data)
} # self.logger.info(f"Received chat completion request: {request}")
) generator = await handler.create_completion(request=request, raw_request=raw_request)
request = self._vllm["completion_request"](**data)
self.logger.info(f"Received chat completion request: {request}")
generator = await self._model["openai_serving_completion"].create_completion(
request=request,
raw_request=raw_request
)
if isinstance(generator, self._vllm["error_response"]): if isinstance(generator, self._vllm["error_response"]):
return self._fastapi["json_response"](content=generator.model_dump(), status_code=generator.code) return self._fastapi["json_response"](content=generator.model_dump(), status_code=generator.code)
if request.stream: elif isinstance(generator, self._vllm["chat_completion_response"]):
return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream")
else:
return self._fastapi["json_response"](content=generator.model_dump()) return self._fastapi["json_response"](content=generator.model_dump())
return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream")
async def chat_completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: async def chat_completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
@ -755,19 +751,19 @@ class VllmPreprocessRequest(BasePreprocessRequest):
The actual processing function. The actual processing function.
We run the process in this context We run the process in this context
""" """
request, raw_request = data["request"], data["raw_request"]
request = self._vllm["chat_completion_request"](**data) handler = self._model["openai_serving_chat"] # analog of chat(raw_request) in https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/api_server.py#L405
self.logger.info(f"Received chat completion request: {request}") if handler is None:
generator = await self._model["openai_serving_chat"].create_chat_completion( return self._model["openai_serving"].create_error_response(message="The model does not support Chat Completions API")
request=request, raw_request=None # request = self._vllm["chat_completion_request"](**data)
) # self.logger.info(f"Received chat completion request: {request}")
generator = await handler.create_chat_completion(request=request, raw_request=raw_request)
if isinstance(generator, self._vllm["error_response"]): if isinstance(generator, self._vllm["error_response"]):
return self._fastapi["json_response"](content=generator.model_dump(), status_code=generator.code) return self._fastapi["json_response"](content=generator.model_dump(), status_code=generator.code)
if request.stream: elif isinstance(generator, self._vllm["chat_completion_response"]):
return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream")
else:
assert isinstance(generator, self._vllm["chat_completion_response"])
return self._fastapi["json_response"](content=generator.model_dump()) return self._fastapi["json_response"](content=generator.model_dump())
return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream")
@staticmethod @staticmethod
async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]: async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:

View File

@ -19,4 +19,4 @@ requests>=2.31.0
kafka-python>=2.0.2,<2.1 kafka-python>=2.0.2,<2.1
lz4>=4.0.0,<5 lz4>=4.0.0,<5
prometheus_client==0.21.1 prometheus_client==0.21.1
vllm==0.5.4 vllm==0.7.3

View File

@ -26,4 +26,4 @@ scrape_configs:
scrape_interval: 5s scrape_interval: 5s
static_configs: static_configs:
- targets: ['clearml-serving-inference'] - targets: ['clearml-serving-inference:8000']

View File

@ -1,8 +1,10 @@
"""Hugginface preprocessing module for ClearML Serving.""" """Hugginface preprocessing module for ClearML Serving."""
from typing import Any, Optional from typing import Any, Optional, List, Callable, Union
from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@ -19,7 +21,13 @@ class Preprocess:
self.model_endpoint = None self.model_endpoint = None
def load(self, local_file_name: str) -> Optional[Any]: # noqa def load(self, local_file_name: str) -> Optional[Any]: # noqa
vllm_engine_config = {
@dataclass
class BaseModelPath:
name: str
model_path: str
self.vllm_engine_config = {
"model": local_file_name, "model": local_file_name,
"tokenizer": local_file_name, "tokenizer": local_file_name,
"disable_log_requests": True, "disable_log_requests": True,
@ -28,9 +36,10 @@ class Preprocess:
"quantization": None, "quantization": None,
"enforce_eager": True, "enforce_eager": True,
"served_model_name": "test_vllm", "served_model_name": "test_vllm",
"dtype": "float16" "dtype": "float16",
"max_model_len": 8192
} }
vllm_model_config = { self.vllm_model_config = {
"lora_modules": None, # [LoRAModulePath(name=a, path=b)] "lora_modules": None, # [LoRAModulePath(name=a, path=b)]
"prompt_adapters": None, # [PromptAdapterPath(name=a, path=b)] "prompt_adapters": None, # [PromptAdapterPath(name=a, path=b)]
"response_role": "assistant", "response_role": "assistant",
@ -39,43 +48,57 @@ class Preprocess:
"max_log_len": None "max_log_len": None
} }
self._model = {} self._model = {}
engine_args = AsyncEngineArgs(**vllm_engine_config) self.engine_args = AsyncEngineArgs(**self.vllm_engine_config)
async_engine_client = AsyncLLMEngine.from_engine_args(engine_args, usage_context=UsageContext.OPENAI_API_SERVER) async_engine_client = AsyncLLMEngine.from_engine_args(self.engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
model_config = async_engine_client.engine.get_model_config() self.model_config = async_engine_client.engine.get_model_config()
request_logger = RequestLogger(max_log_len=vllm_model_config["max_log_len"]) request_logger = RequestLogger(max_log_len=self.vllm_model_config["max_log_len"])
self._model["openai_serving_models"] = OpenAIServingModels(
async_engine_client,
self.model_config,
[BaseModelPath(name=self.vllm_engine_config["served_model_name"], model_path=self.vllm_engine_config["model"])],
lora_modules=self.vllm_model_config["lora_modules"],
prompt_adapters=self.vllm_model_config["prompt_adapters"],
)
self._model["openai_serving"] = OpenAIServing(
async_engine_client,
self.model_config,
self._model["openai_serving_models"],
request_logger=request_logger,
return_tokens_as_token_ids=self.vllm_model_config["return_tokens_as_token_ids"]
)
self._model["openai_serving_chat"] = OpenAIServingChat( self._model["openai_serving_chat"] = OpenAIServingChat(
async_engine_client, async_engine_client,
model_config, self.model_config,
served_model_names=[vllm_engine_config["served_model_name"]], served_model_names=[self.vllm_engine_config["served_model_name"]],
response_role=vllm_model_config["response_role"], response_role=self.vllm_model_config["response_role"],
lora_modules=vllm_model_config["lora_modules"], lora_modules=self.vllm_model_config["lora_modules"],
prompt_adapters=vllm_model_config["prompt_adapters"], prompt_adapters=self.vllm_model_config["prompt_adapters"],
request_logger=request_logger, request_logger=request_logger,
chat_template=vllm_model_config["chat_template"], chat_template=self.vllm_model_config["chat_template"],
return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"] return_tokens_as_token_ids=self.vllm_model_config["return_tokens_as_token_ids"]
) ) if self.model_config.runner_type == "generate" else None
self._model["openai_serving_completion"] = OpenAIServingCompletion( self._model["openai_serving_completion"] = OpenAIServingCompletion(
async_engine_client, async_engine_client,
model_config, self.model_config,
served_model_names=[vllm_engine_config["served_model_name"]], served_model_names=[self.vllm_engine_config["served_model_name"]],
lora_modules=vllm_model_config["lora_modules"], lora_modules=self.vllm_model_config["lora_modules"],
prompt_adapters=vllm_model_config["prompt_adapters"], prompt_adapters=self.vllm_model_config["prompt_adapters"],
request_logger=request_logger, request_logger=request_logger,
return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"] return_tokens_as_token_ids=self.vllm_model_config["return_tokens_as_token_ids"]
) ) if self.model_config.runner_type == "generate" else None
self._model["openai_serving_embedding"] = OpenAIServingEmbedding( self._model["openai_serving_embedding"] = OpenAIServingEmbedding(
async_engine_client, async_engine_client,
model_config, self.model_config,
served_model_names=[vllm_engine_config["served_model_name"]], served_model_names=[self.vllm_engine_config["served_model_name"]],
request_logger=request_logger request_logger=request_logger
) ) if self.model_config.task == "embed" else None
self._model["openai_serving_tokenization"] = OpenAIServingTokenization( self._model["openai_serving_tokenization"] = OpenAIServingTokenization(
async_engine_client, async_engine_client,
model_config, self.model_config,
served_model_names=[vllm_engine_config["served_model_name"]], served_model_names=[self.vllm_engine_config["served_model_name"]],
lora_modules=vllm_model_config["lora_modules"], lora_modules=self.vllm_model_config["lora_modules"],
request_logger=request_logger, request_logger=request_logger,
chat_template=vllm_model_config["chat_template"] chat_template=self.vllm_model_config["chat_template"]
) )
return self._model return self._model
@ -91,12 +114,12 @@ class Preprocess:
messages = [msg for i, msg in enumerate(messages) if msg["role"] != "system" or i == last_system_index] messages = [msg for i, msg in enumerate(messages) if msg["role"] != "system" or i == last_system_index]
return messages return messages
def preprocess( async def preprocess(
self, self,
body: Union[bytes, dict], body: Union[bytes, dict],
state: dict, state: dict,
collect_custom_statistics_fn: Optional[Callable[[dict], None]], collect_custom_statistics_fn: Optional[Callable[[dict], None]],
) -> Any: # noqa ) -> Any: # noqa
if "messages" in body: if "messages" in body["request"]:
body["messages"] = self.remove_extra_system_prompts(body["messages"]) body["request"]["messages"] = self.remove_extra_system_prompts(body["request"]["messages"])
return body return body