diff --git a/clearml_serving/serving/main.py b/clearml_serving/serving/main.py index f158181..e7a94d8 100644 --- a/clearml_serving/serving/main.py +++ b/clearml_serving/serving/main.py @@ -4,11 +4,13 @@ import traceback import gzip import asyncio -from fastapi import FastAPI, Request, Response, APIRouter, HTTPException +from fastapi import FastAPI, Request, Response, APIRouter, HTTPException, Depends from fastapi.routing import APIRoute from fastapi.responses import PlainTextResponse from grpc.aio import AioRpcError +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest + from starlette.background import BackgroundTask from typing import Optional, Dict, Any, Callable, Union @@ -194,16 +196,27 @@ async def base_serve_model( return return_value -@router.post("/openai/v1/{endpoint_type:path}") -@router.post("/openai/v1/{endpoint_type:path}/") +async def validate_json_request(raw_request: Request): + content_type = raw_request.headers.get("content-type", "").lower() + media_type = content_type.split(";", maxsplit=1)[0] + if media_type != "application/json": + raise HTTPException( + status_code=HTTPStatus.UNSUPPORTED_MEDIA_TYPE, + detail="Unsupported Media Type: Only 'application/json' is allowed" + ) + +@router.post("/openai/v1/{endpoint_type:path}", dependencies=[Depends(validate_json_request)]) +@router.post("/openai/v1/{endpoint_type:path}/", dependencies=[Depends(validate_json_request)]) async def openai_serve_model( endpoint_type: str, - request: Union[bytes, Dict[Any, Any]] = None + request: Union[CompletionRequest, ChatCompletionRequest], + raw_request: Request ): + combined_request = {"request": request, "raw_request": raw_request} return_value = await process_with_exceptions( base_url=request.get("model", None), version=None, - request_body=request, + request_body=combined_request, serve_type=endpoint_type ) return return_value diff --git a/clearml_serving/serving/model_request_processor.py b/clearml_serving/serving/model_request_processor.py index c952ddc..eaa4b49 100644 --- a/clearml_serving/serving/model_request_processor.py +++ b/clearml_serving/serving/model_request_processor.py @@ -1234,12 +1234,12 @@ class ModelRequestProcessor(object): # noinspection PyUnresolvedReferences processed = await processor.completions(preprocessed, state, stats_collect_fn) \ if processor.is_process_async \ - else processor.completion(preprocessed, state, stats_collect_fn) + else processor.completions(preprocessed, state, stats_collect_fn) elif serve_type == "chat/completions": # noinspection PyUnresolvedReferences processed = await processor.chat_completions(preprocessed, state, stats_collect_fn) \ if processor.is_process_async \ - else processor.chat_completion(preprocessed, state, stats_collect_fn) + else processor.chat_completions(preprocessed, state, stats_collect_fn) else: raise ValueError(f"wrong url_type: expected 'process', 'completions' or 'chat/completions', got {serve_type}") # noinspection PyUnresolvedReferences diff --git a/clearml_serving/serving/preprocess_service.py b/clearml_serving/serving/preprocess_service.py index 45946c7..bbf8a1f 100644 --- a/clearml_serving/serving/preprocess_service.py +++ b/clearml_serving/serving/preprocess_service.py @@ -613,6 +613,9 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest): @BasePreprocessRequest.register_engine("vllm", modules=["vllm", "fastapi"]) class VllmPreprocessRequest(BasePreprocessRequest): + is_preprocess_async = True + is_process_async = True + is_postprocess_async = True asyncio_to_thread = None _vllm = None _fastapi = None @@ -729,25 +732,18 @@ class VllmPreprocessRequest(BasePreprocessRequest): The actual processing function. We run the process in this context """ - - raw_request = CustomRequest( - headers = { - "traceparent": None, - "tracestate": None - } - ) - request = self._vllm["completion_request"](**data) - self.logger.info(f"Received chat completion request: {request}") - generator = await self._model["openai_serving_completion"].create_completion( - request=request, - raw_request=raw_request - ) + request, raw_request = data["request"], data["raw_request"] + handler = self._model["openai_serving_completion"] + if handler is None: + return self._model["openai_serving"].create_error_response(message="The model does not support Completions API") + # request = self._vllm["completion_request"](**data) + # self.logger.info(f"Received chat completion request: {request}") + generator = await handler.create_completion(request=request, raw_request=raw_request) if isinstance(generator, self._vllm["error_response"]): return self._fastapi["json_response"](content=generator.model_dump(), status_code=generator.code) - if request.stream: - return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream") - else: + elif isinstance(generator, self._vllm["chat_completion_response"]): return self._fastapi["json_response"](content=generator.model_dump()) + return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream") async def chat_completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: @@ -755,19 +751,19 @@ class VllmPreprocessRequest(BasePreprocessRequest): The actual processing function. We run the process in this context """ - - request = self._vllm["chat_completion_request"](**data) - self.logger.info(f"Received chat completion request: {request}") - generator = await self._model["openai_serving_chat"].create_chat_completion( - request=request, raw_request=None - ) + request, raw_request = data["request"], data["raw_request"] + handler = self._model["openai_serving_chat"] # analog of chat(raw_request) in https://github.com/vllm-project/vllm/blob/v0.7.3/vllm/entrypoints/openai/api_server.py#L405 + if handler is None: + return self._model["openai_serving"].create_error_response(message="The model does not support Chat Completions API") + # request = self._vllm["chat_completion_request"](**data) + # self.logger.info(f"Received chat completion request: {request}") + generator = await handler.create_chat_completion(request=request, raw_request=raw_request) if isinstance(generator, self._vllm["error_response"]): return self._fastapi["json_response"](content=generator.model_dump(), status_code=generator.code) - if request.stream: - return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream") - else: - assert isinstance(generator, self._vllm["chat_completion_response"]) + elif isinstance(generator, self._vllm["chat_completion_response"]): return self._fastapi["json_response"](content=generator.model_dump()) + return self._fastapi["streaming_response"](content=generator, media_type="text/event-stream") + @staticmethod async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]: diff --git a/clearml_serving/serving/requirements.txt b/clearml_serving/serving/requirements.txt index 366b19c..922f8e3 100644 --- a/clearml_serving/serving/requirements.txt +++ b/clearml_serving/serving/requirements.txt @@ -19,4 +19,4 @@ requests>=2.31.0 kafka-python>=2.0.2,<2.1 lz4>=4.0.0,<5 prometheus_client==0.21.1 -vllm==0.5.4 +vllm==0.7.3 diff --git a/docker/prometheus.yml b/docker/prometheus.yml index b7aa51e..da47e83 100644 --- a/docker/prometheus.yml +++ b/docker/prometheus.yml @@ -26,4 +26,4 @@ scrape_configs: scrape_interval: 5s static_configs: - - targets: ['clearml-serving-inference'] \ No newline at end of file + - targets: ['clearml-serving-inference:8000'] \ No newline at end of file diff --git a/examples/vllm/preprocess.py b/examples/vllm/preprocess.py index a4191d6..fc8b3aa 100644 --- a/examples/vllm/preprocess.py +++ b/examples/vllm/preprocess.py @@ -1,8 +1,10 @@ """Hugginface preprocessing module for ClearML Serving.""" -from typing import Any, Optional +from typing import Any, Optional, List, Callable, Union from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding @@ -19,7 +21,13 @@ class Preprocess: self.model_endpoint = None def load(self, local_file_name: str) -> Optional[Any]: # noqa - vllm_engine_config = { + + @dataclass + class BaseModelPath: + name: str + model_path: str + + self.vllm_engine_config = { "model": local_file_name, "tokenizer": local_file_name, "disable_log_requests": True, @@ -28,9 +36,10 @@ class Preprocess: "quantization": None, "enforce_eager": True, "served_model_name": "test_vllm", - "dtype": "float16" + "dtype": "float16", + "max_model_len": 8192 } - vllm_model_config = { + self.vllm_model_config = { "lora_modules": None, # [LoRAModulePath(name=a, path=b)] "prompt_adapters": None, # [PromptAdapterPath(name=a, path=b)] "response_role": "assistant", @@ -39,43 +48,57 @@ class Preprocess: "max_log_len": None } self._model = {} - engine_args = AsyncEngineArgs(**vllm_engine_config) - async_engine_client = AsyncLLMEngine.from_engine_args(engine_args, usage_context=UsageContext.OPENAI_API_SERVER) - model_config = async_engine_client.engine.get_model_config() - request_logger = RequestLogger(max_log_len=vllm_model_config["max_log_len"]) + self.engine_args = AsyncEngineArgs(**self.vllm_engine_config) + async_engine_client = AsyncLLMEngine.from_engine_args(self.engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + self.model_config = async_engine_client.engine.get_model_config() + request_logger = RequestLogger(max_log_len=self.vllm_model_config["max_log_len"]) + self._model["openai_serving_models"] = OpenAIServingModels( + async_engine_client, + self.model_config, + [BaseModelPath(name=self.vllm_engine_config["served_model_name"], model_path=self.vllm_engine_config["model"])], + lora_modules=self.vllm_model_config["lora_modules"], + prompt_adapters=self.vllm_model_config["prompt_adapters"], + ) + self._model["openai_serving"] = OpenAIServing( + async_engine_client, + self.model_config, + self._model["openai_serving_models"], + request_logger=request_logger, + return_tokens_as_token_ids=self.vllm_model_config["return_tokens_as_token_ids"] + ) self._model["openai_serving_chat"] = OpenAIServingChat( async_engine_client, - model_config, - served_model_names=[vllm_engine_config["served_model_name"]], - response_role=vllm_model_config["response_role"], - lora_modules=vllm_model_config["lora_modules"], - prompt_adapters=vllm_model_config["prompt_adapters"], + self.model_config, + served_model_names=[self.vllm_engine_config["served_model_name"]], + response_role=self.vllm_model_config["response_role"], + lora_modules=self.vllm_model_config["lora_modules"], + prompt_adapters=self.vllm_model_config["prompt_adapters"], request_logger=request_logger, - chat_template=vllm_model_config["chat_template"], - return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"] - ) + chat_template=self.vllm_model_config["chat_template"], + return_tokens_as_token_ids=self.vllm_model_config["return_tokens_as_token_ids"] + ) if self.model_config.runner_type == "generate" else None self._model["openai_serving_completion"] = OpenAIServingCompletion( async_engine_client, - model_config, - served_model_names=[vllm_engine_config["served_model_name"]], - lora_modules=vllm_model_config["lora_modules"], - prompt_adapters=vllm_model_config["prompt_adapters"], + self.model_config, + served_model_names=[self.vllm_engine_config["served_model_name"]], + lora_modules=self.vllm_model_config["lora_modules"], + prompt_adapters=self.vllm_model_config["prompt_adapters"], request_logger=request_logger, - return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"] - ) + return_tokens_as_token_ids=self.vllm_model_config["return_tokens_as_token_ids"] + ) if self.model_config.runner_type == "generate" else None self._model["openai_serving_embedding"] = OpenAIServingEmbedding( async_engine_client, - model_config, - served_model_names=[vllm_engine_config["served_model_name"]], + self.model_config, + served_model_names=[self.vllm_engine_config["served_model_name"]], request_logger=request_logger - ) + ) if self.model_config.task == "embed" else None self._model["openai_serving_tokenization"] = OpenAIServingTokenization( async_engine_client, - model_config, - served_model_names=[vllm_engine_config["served_model_name"]], - lora_modules=vllm_model_config["lora_modules"], + self.model_config, + served_model_names=[self.vllm_engine_config["served_model_name"]], + lora_modules=self.vllm_model_config["lora_modules"], request_logger=request_logger, - chat_template=vllm_model_config["chat_template"] + chat_template=self.vllm_model_config["chat_template"] ) return self._model @@ -91,12 +114,12 @@ class Preprocess: messages = [msg for i, msg in enumerate(messages) if msg["role"] != "system" or i == last_system_index] return messages - def preprocess( + async def preprocess( self, body: Union[bytes, dict], state: dict, collect_custom_statistics_fn: Optional[Callable[[dict], None]], ) -> Any: # noqa - if "messages" in body: - body["messages"] = self.remove_extra_system_prompts(body["messages"]) + if "messages" in body["request"]: + body["request"]["messages"] = self.remove_extra_system_prompts(body["request"]["messages"]) return body