diff --git a/clearml_serving/serving/model_request_processor.py b/clearml_serving/serving/model_request_processor.py index 44eaf9e..0f6bfa8 100644 --- a/clearml_serving/serving/model_request_processor.py +++ b/clearml_serving/serving/model_request_processor.py @@ -1241,7 +1241,7 @@ class ModelRequestProcessor(object): if processor.is_process_async \ else processor.chat_completion(preprocessed, state, stats_collect_fn) else: - raise ValueError(f"wrong url_type: expected 'completions' and 'chat/completions', got {url_type}") + raise ValueError(f"wrong url_type: expected 'completions' and 'chat/completions', got {serve_type}") # noinspection PyUnresolvedReferences return_value = await processor.postprocess(processed, state, stats_collect_fn) \ if processor.is_postprocess_async \ diff --git a/clearml_serving/serving/preprocess_service.py b/clearml_serving/serving/preprocess_service.py index e065144..d29f5f8 100644 --- a/clearml_serving/serving/preprocess_service.py +++ b/clearml_serving/serving/preprocess_service.py @@ -19,7 +19,7 @@ class BasePreprocessRequest(object): __preprocessing_lookup = {} __preprocessing_modules = set() _grpc_env_conf_prefix = "CLEARML_GRPC_" - _default_serving_base_url = "http://127.0.0.1:8080/clearml/" + _default_serving_base_url = "http://127.0.0.1:8080/serve/" _server_config = {} # externally configured by the serving inference service _timeout = None # timeout in seconds for the entire request, set in __init__ is_preprocess_async = False @@ -292,7 +292,7 @@ class TritonPreprocessRequest(BasePreprocessRequest): self._grpc_stub = {} - async def chat_completion( + async def process( self, data: Any, state: dict, @@ -428,28 +428,74 @@ class TritonPreprocessRequest(BasePreprocessRequest): return output_results[0] if index == 1 else output_results +@BasePreprocessRequest.register_engine("sklearn", modules=["joblib", "sklearn"]) +class SKLearnPreprocessRequest(BasePreprocessRequest): + def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): + super(SKLearnPreprocessRequest, self).__init__( + model_endpoint=model_endpoint, task=task) + if self._model is None: + # get model + import joblib # noqa + self._model = joblib.load(filename=self._get_local_model_file()) + + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + """ + The actual processing function. + We run the model in this context + """ + return self._model.predict(data) + + +@BasePreprocessRequest.register_engine("xgboost", modules=["xgboost"]) +class XGBoostPreprocessRequest(BasePreprocessRequest): + def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): + super(XGBoostPreprocessRequest, self).__init__( + model_endpoint=model_endpoint, task=task) + if self._model is None: + # get model + import xgboost # noqa + self._model = xgboost.Booster() + self._model.load_model(self._get_local_model_file()) + + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + """ + The actual processing function. + We run the model in this context + """ + return self._model.predict(data) + + +@BasePreprocessRequest.register_engine("lightgbm", modules=["lightgbm"]) +class LightGBMPreprocessRequest(BasePreprocessRequest): + def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): + super(LightGBMPreprocessRequest, self).__init__( + model_endpoint=model_endpoint, task=task) + if self._model is None: + # get model + import lightgbm # noqa + self._model = lightgbm.Booster(model_file=self._get_local_model_file()) + + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + """ + The actual processing function. + We run the model in this context + """ + return self._model.predict(data) + + @BasePreprocessRequest.register_engine("custom") class CustomPreprocessRequest(BasePreprocessRequest): def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): super(CustomPreprocessRequest, self).__init__( model_endpoint=model_endpoint, task=task) - def completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: """ The actual processing function. We run the process in this context """ - if self._preprocess is not None and hasattr(self._preprocess, 'completion'): - return self._preprocess.completion(data, state, collect_custom_statistics_fn) - return None - - def chat_completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: - """ - The actual processing function. - We run the process in this context - """ - if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'): - return self._preprocess.chat_completion(data, state, collect_custom_statistics_fn) + if self._preprocess is not None and hasattr(self._preprocess, 'process'): + return self._preprocess.process(data, state, collect_custom_statistics_fn) return None @@ -530,22 +576,13 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest): return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn) return data - async def completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + async def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: """ The actual processing function. We run the process in this context """ - if self._preprocess is not None and hasattr(self._preprocess, 'completion'): - return await self._preprocess.completion(data, state, collect_custom_statistics_fn) - return None - - async def chat_completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: - """ - The actual processing function. - We run the process in this context - """ - if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'): - return await self._preprocess.chat_completion(data, state, collect_custom_statistics_fn) + if self._preprocess is not None and hasattr(self._preprocess, 'process'): + return await self._preprocess.process(data, state, collect_custom_statistics_fn) return None @staticmethod @@ -559,3 +596,273 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest): if not return_value.ok: return None return return_value.json() + + +@BasePreprocessRequest.register_engine("vllm") +class VllmPreprocessRequest(BasePreprocessRequest): + import prometheus_client + + from typing import Any, Union, Optional, Callable + + from fastapi.responses import JSONResponse, StreamingResponse + + from vllm.engine.arg_utils import AsyncEngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.entrypoints.logger import RequestLogger + # yapf conflicts with isort for this block + # yapf: disable + from vllm.entrypoints.openai.protocol import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + ErrorResponse + ) + + # yapf: enable + from vllm.entrypoints.openai.serving_chat import OpenAIServingChat + from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion + from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding + from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization + from vllm.logger import init_logger + from vllm.usage.usage_lib import UsageContext + from vllm.entrypoints.openai.serving_engine import LoRAModulePath, PromptAdapterPath + + logger = init_logger(__name__) + + REMOVE_WEB_ADDITIONAL_PROMPTS = True + + if VllmPreprocessRequest.asyncio_to_thread is None: + from asyncio import to_thread as asyncio_to_thread + VllmPreprocessRequest.asyncio_to_thread = asyncio_to_thread + + def remove_extra_system_prompts(messages: list) -> list: + """ + Removes all 'system' prompts except the last one. + + :param messages: List of message dicts with 'role' and 'content'. + :return: Modified list of messages with only the last 'system' prompt preserved. + """ + # Фильтруем только системные сообщения + system_messages_indices = [] + for i, msg in enumerate(messages): + if msg["role"] == "system": + system_messages_indices.append(i) + else: + break + + # Если есть больше одного системного сообщения, удалим все, кроме последнего + if len(system_messages_indices) > 1: + last_system_index = system_messages_indices[-1] + # Удаляем все системные сообщения, кроме последнего + messages = [msg for i, msg in enumerate(messages) if msg["role"] != "system" or i == last_system_index] + + return messages + + class CustomRequest: + def __init__(self, headers: Optional[dict] = None): + self.headers = headers + + async def is_disconnected(self): + return False + + def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): + super(VllmPreprocessRequest, self).__init__( + model_endpoint=model_endpoint, task=task) + + def is_port_in_use(port: int) -> bool: + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return s.connect_ex(('localhost', port)) == 0 + if not is_port_in_use(8000): + prometheus_client.start_http_server(8000) + + vllm_engine_config = { + "model":f"{local_file_name}/model", + "tokenizer":f"{local_file_name}/tokenizer", + "disable_log_requests": True, + "disable_log_stats": False, + "gpu_memory_utilization": 0.9, + "quantization": None, + "enforce_eager": True, + "served_model_name": "ai_operator_hyp22v4" + } + vllm_model_config = { + "lora_modules": None, # [LoRAModulePath(name=a, path=b)] + "prompt_adapters": None, # [PromptAdapterPath(name=a, path=b)] + "response_role": "assistant", + "chat_template": None, + "return_tokens_as_token_ids": False, + "max_log_len": None + } + + self.engine_args = AsyncEngineArgs(**vllm_engine_config) + self.async_engine_client = AsyncLLMEngine.from_engine_args(self.engine_args, usage_context=UsageContext.OPENAI_API_SERVER) + + + model_config = self.async_engine_client.engine.get_model_config() + + request_logger = RequestLogger(max_log_len=vllm_model_config["max_log_len"]) + + self.openai_serving_chat = OpenAIServingChat( + self.async_engine_client, + model_config, + served_model_names=[vllm_engine_config["served_model_name"]], + response_role=vllm_model_config["response_role"], + lora_modules=vllm_model_config["lora_modules"], + prompt_adapters=vllm_model_config["prompt_adapters"], + request_logger=request_logger, + chat_template=vllm_model_config["chat_template"], + return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"] + ) + self.openai_serving_completion = OpenAIServingCompletion( + self.async_engine_client, + model_config, + served_model_names=[vllm_engine_config["served_model_name"]], + lora_modules=vllm_model_config["lora_modules"], + prompt_adapters=vllm_model_config["prompt_adapters"], + request_logger=request_logger, + return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"] + ) + self.openai_serving_embedding = OpenAIServingEmbedding( + self.async_engine_client, + model_config, + served_model_names=[vllm_engine_config["served_model_name"]], + request_logger=request_logger + ) + self.openai_serving_tokenization = OpenAIServingTokenization( + self.async_engine_client, + model_config, + served_model_names=[vllm_engine_config["served_model_name"]], + lora_modules=vllm_model_config["lora_modules"], + request_logger=request_logger, + chat_template=vllm_model_config["chat_template"] + ) + # override `send_request` method with the async version + self._preprocess.__class__.send_request = VllmPreprocessRequest._preprocess_send_request + + async def preprocess( + self, + request: dict, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None, + ) -> Optional[Any]: + """ + Raise exception to report an error + Return value will be passed to serving engine + + :param request: dictionary as recieved from the RestAPI + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) + :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values + to the statictics collector servicd + + Usage example: + >>> print(request) + {"x0": 1, "x1": 2} + >>> collect_custom_statistics_fn({"x0": 1, "x1": 2}) + + :return: Object to be passed directly to the model inference + """ + if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'): + return await self._preprocess.preprocess(request, state, collect_custom_statistics_fn) + return request + + async def postprocess( + self, + data: Any, + state: dict, + collect_custom_statistics_fn: Callable[[dict], None] = None + ) -> Optional[dict]: + """ + Raise exception to report an error + Return value will be passed to serving engine + + :param data: object as recieved from the inference model function + :param state: Use state dict to store data passed to the post-processing function call. + Usage example: + >>> def preprocess(..., state): + state['preprocess_aux_data'] = [1,2,3] + >>> def postprocess(..., state): + print(state['preprocess_aux_data']) + :param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values + to the statictics collector servicd + + Usage example: + >>> collect_custom_statistics_fn({"y": 1}) + + :return: Dictionary passed directly as the returned result of the RestAPI + """ + if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'): + return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn) + return data + + + async def completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + """ + The actual processing function. + We run the process in this context + """ + if REMOVE_WEB_ADDITIONAL_PROMPTS: + if "messages" in body: + body["messages"] = remove_extra_system_prompts(body["messages"]) + + raw_request = CustomRequest( + headers = { + "traceparent": None, + "tracestate": None + } + ) + request = CompletionRequest(**body) + logger.info(f"Received chat completion request: {request}") + generator = await self.openai_serving_completion.create_completion( + request=request, + raw_request=raw_request + ) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") + else: + return JSONResponse(content=generator.model_dump()) + + + async def chat_completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any: + """ + The actual processing function. + We run the process in this context + """ + # if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'): + # return await self._preprocess.chat_completion(data, state, collect_custom_statistics_fn) + # return None + if REMOVE_WEB_ADDITIONAL_PROMPTS: + if "messages" in body: + body["messages"] = remove_extra_system_prompts(body["messages"]) + + request = ChatCompletionRequest(**body) + logger.info(f"Received chat completion request: {request}") + generator = await self.openai_serving_chat.create_chat_completion( + request=request, raw_request=None + ) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), status_code=generator.code) + if request.stream: + return StreamingResponse(content=generator, media_type="text/event-stream") + else: + assert isinstance(generator, ChatCompletionResponse) + return JSONResponse(content=generator.model_dump()) + + @staticmethod + async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]: + endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/") + base_url = BasePreprocessRequest.get_server_config().get("base_serving_url") + base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/") + url = "{}/{}".format(base_url, endpoint.strip("/")) + return_value = await CustomAsyncPreprocessRequest.asyncio_to_thread( + request_post, url, json=data, timeout=BasePreprocessRequest._timeout) + if not return_value.ok: + return None + return return_value.json()