mirror of
https://github.com/clearml/clearml-serving
synced 2025-06-26 18:16:00 +00:00
revert some old changes
This commit is contained in:
parent
5b73bdf085
commit
f51bf2e081
@ -1241,7 +1241,7 @@ class ModelRequestProcessor(object):
|
|||||||
if processor.is_process_async \
|
if processor.is_process_async \
|
||||||
else processor.chat_completion(preprocessed, state, stats_collect_fn)
|
else processor.chat_completion(preprocessed, state, stats_collect_fn)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"wrong url_type: expected 'completions' and 'chat/completions', got {url_type}")
|
raise ValueError(f"wrong url_type: expected 'completions' and 'chat/completions', got {serve_type}")
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
|
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
|
||||||
if processor.is_postprocess_async \
|
if processor.is_postprocess_async \
|
||||||
|
@ -19,7 +19,7 @@ class BasePreprocessRequest(object):
|
|||||||
__preprocessing_lookup = {}
|
__preprocessing_lookup = {}
|
||||||
__preprocessing_modules = set()
|
__preprocessing_modules = set()
|
||||||
_grpc_env_conf_prefix = "CLEARML_GRPC_"
|
_grpc_env_conf_prefix = "CLEARML_GRPC_"
|
||||||
_default_serving_base_url = "http://127.0.0.1:8080/clearml/"
|
_default_serving_base_url = "http://127.0.0.1:8080/serve/"
|
||||||
_server_config = {} # externally configured by the serving inference service
|
_server_config = {} # externally configured by the serving inference service
|
||||||
_timeout = None # timeout in seconds for the entire request, set in __init__
|
_timeout = None # timeout in seconds for the entire request, set in __init__
|
||||||
is_preprocess_async = False
|
is_preprocess_async = False
|
||||||
@ -292,7 +292,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
|||||||
|
|
||||||
self._grpc_stub = {}
|
self._grpc_stub = {}
|
||||||
|
|
||||||
async def chat_completion(
|
async def process(
|
||||||
self,
|
self,
|
||||||
data: Any,
|
data: Any,
|
||||||
state: dict,
|
state: dict,
|
||||||
@ -428,28 +428,74 @@ class TritonPreprocessRequest(BasePreprocessRequest):
|
|||||||
return output_results[0] if index == 1 else output_results
|
return output_results[0] if index == 1 else output_results
|
||||||
|
|
||||||
|
|
||||||
|
@BasePreprocessRequest.register_engine("sklearn", modules=["joblib", "sklearn"])
|
||||||
|
class SKLearnPreprocessRequest(BasePreprocessRequest):
|
||||||
|
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
||||||
|
super(SKLearnPreprocessRequest, self).__init__(
|
||||||
|
model_endpoint=model_endpoint, task=task)
|
||||||
|
if self._model is None:
|
||||||
|
# get model
|
||||||
|
import joblib # noqa
|
||||||
|
self._model = joblib.load(filename=self._get_local_model_file())
|
||||||
|
|
||||||
|
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||||
|
"""
|
||||||
|
The actual processing function.
|
||||||
|
We run the model in this context
|
||||||
|
"""
|
||||||
|
return self._model.predict(data)
|
||||||
|
|
||||||
|
|
||||||
|
@BasePreprocessRequest.register_engine("xgboost", modules=["xgboost"])
|
||||||
|
class XGBoostPreprocessRequest(BasePreprocessRequest):
|
||||||
|
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
||||||
|
super(XGBoostPreprocessRequest, self).__init__(
|
||||||
|
model_endpoint=model_endpoint, task=task)
|
||||||
|
if self._model is None:
|
||||||
|
# get model
|
||||||
|
import xgboost # noqa
|
||||||
|
self._model = xgboost.Booster()
|
||||||
|
self._model.load_model(self._get_local_model_file())
|
||||||
|
|
||||||
|
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||||
|
"""
|
||||||
|
The actual processing function.
|
||||||
|
We run the model in this context
|
||||||
|
"""
|
||||||
|
return self._model.predict(data)
|
||||||
|
|
||||||
|
|
||||||
|
@BasePreprocessRequest.register_engine("lightgbm", modules=["lightgbm"])
|
||||||
|
class LightGBMPreprocessRequest(BasePreprocessRequest):
|
||||||
|
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
||||||
|
super(LightGBMPreprocessRequest, self).__init__(
|
||||||
|
model_endpoint=model_endpoint, task=task)
|
||||||
|
if self._model is None:
|
||||||
|
# get model
|
||||||
|
import lightgbm # noqa
|
||||||
|
self._model = lightgbm.Booster(model_file=self._get_local_model_file())
|
||||||
|
|
||||||
|
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||||
|
"""
|
||||||
|
The actual processing function.
|
||||||
|
We run the model in this context
|
||||||
|
"""
|
||||||
|
return self._model.predict(data)
|
||||||
|
|
||||||
|
|
||||||
@BasePreprocessRequest.register_engine("custom")
|
@BasePreprocessRequest.register_engine("custom")
|
||||||
class CustomPreprocessRequest(BasePreprocessRequest):
|
class CustomPreprocessRequest(BasePreprocessRequest):
|
||||||
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
||||||
super(CustomPreprocessRequest, self).__init__(
|
super(CustomPreprocessRequest, self).__init__(
|
||||||
model_endpoint=model_endpoint, task=task)
|
model_endpoint=model_endpoint, task=task)
|
||||||
|
|
||||||
def completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
The actual processing function.
|
The actual processing function.
|
||||||
We run the process in this context
|
We run the process in this context
|
||||||
"""
|
"""
|
||||||
if self._preprocess is not None and hasattr(self._preprocess, 'completion'):
|
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
|
||||||
return self._preprocess.completion(data, state, collect_custom_statistics_fn)
|
return self._preprocess.process(data, state, collect_custom_statistics_fn)
|
||||||
return None
|
|
||||||
|
|
||||||
def chat_completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
|
||||||
"""
|
|
||||||
The actual processing function.
|
|
||||||
We run the process in this context
|
|
||||||
"""
|
|
||||||
if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'):
|
|
||||||
return self._preprocess.chat_completion(data, state, collect_custom_statistics_fn)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -530,22 +576,13 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
|
|||||||
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
|
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
async def completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
async def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||||
"""
|
"""
|
||||||
The actual processing function.
|
The actual processing function.
|
||||||
We run the process in this context
|
We run the process in this context
|
||||||
"""
|
"""
|
||||||
if self._preprocess is not None and hasattr(self._preprocess, 'completion'):
|
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
|
||||||
return await self._preprocess.completion(data, state, collect_custom_statistics_fn)
|
return await self._preprocess.process(data, state, collect_custom_statistics_fn)
|
||||||
return None
|
|
||||||
|
|
||||||
async def chat_completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
|
||||||
"""
|
|
||||||
The actual processing function.
|
|
||||||
We run the process in this context
|
|
||||||
"""
|
|
||||||
if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'):
|
|
||||||
return await self._preprocess.chat_completion(data, state, collect_custom_statistics_fn)
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -559,3 +596,273 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
|
|||||||
if not return_value.ok:
|
if not return_value.ok:
|
||||||
return None
|
return None
|
||||||
return return_value.json()
|
return return_value.json()
|
||||||
|
|
||||||
|
|
||||||
|
@BasePreprocessRequest.register_engine("vllm")
|
||||||
|
class VllmPreprocessRequest(BasePreprocessRequest):
|
||||||
|
import prometheus_client
|
||||||
|
|
||||||
|
from typing import Any, Union, Optional, Callable
|
||||||
|
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
|
||||||
|
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||||
|
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||||
|
from vllm.entrypoints.logger import RequestLogger
|
||||||
|
# yapf conflicts with isort for this block
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.entrypoints.openai.protocol import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
ErrorResponse
|
||||||
|
)
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
|
||||||
|
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
|
||||||
|
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
|
||||||
|
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.usage.usage_lib import UsageContext
|
||||||
|
from vllm.entrypoints.openai.serving_engine import LoRAModulePath, PromptAdapterPath
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
REMOVE_WEB_ADDITIONAL_PROMPTS = True
|
||||||
|
|
||||||
|
if VllmPreprocessRequest.asyncio_to_thread is None:
|
||||||
|
from asyncio import to_thread as asyncio_to_thread
|
||||||
|
VllmPreprocessRequest.asyncio_to_thread = asyncio_to_thread
|
||||||
|
|
||||||
|
def remove_extra_system_prompts(messages: list) -> list:
|
||||||
|
"""
|
||||||
|
Removes all 'system' prompts except the last one.
|
||||||
|
|
||||||
|
:param messages: List of message dicts with 'role' and 'content'.
|
||||||
|
:return: Modified list of messages with only the last 'system' prompt preserved.
|
||||||
|
"""
|
||||||
|
# Фильтруем только системные сообщения
|
||||||
|
system_messages_indices = []
|
||||||
|
for i, msg in enumerate(messages):
|
||||||
|
if msg["role"] == "system":
|
||||||
|
system_messages_indices.append(i)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Если есть больше одного системного сообщения, удалим все, кроме последнего
|
||||||
|
if len(system_messages_indices) > 1:
|
||||||
|
last_system_index = system_messages_indices[-1]
|
||||||
|
# Удаляем все системные сообщения, кроме последнего
|
||||||
|
messages = [msg for i, msg in enumerate(messages) if msg["role"] != "system" or i == last_system_index]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
class CustomRequest:
|
||||||
|
def __init__(self, headers: Optional[dict] = None):
|
||||||
|
self.headers = headers
|
||||||
|
|
||||||
|
async def is_disconnected(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
|
||||||
|
super(VllmPreprocessRequest, self).__init__(
|
||||||
|
model_endpoint=model_endpoint, task=task)
|
||||||
|
|
||||||
|
def is_port_in_use(port: int) -> bool:
|
||||||
|
import socket
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
return s.connect_ex(('localhost', port)) == 0
|
||||||
|
if not is_port_in_use(8000):
|
||||||
|
prometheus_client.start_http_server(8000)
|
||||||
|
|
||||||
|
vllm_engine_config = {
|
||||||
|
"model":f"{local_file_name}/model",
|
||||||
|
"tokenizer":f"{local_file_name}/tokenizer",
|
||||||
|
"disable_log_requests": True,
|
||||||
|
"disable_log_stats": False,
|
||||||
|
"gpu_memory_utilization": 0.9,
|
||||||
|
"quantization": None,
|
||||||
|
"enforce_eager": True,
|
||||||
|
"served_model_name": "ai_operator_hyp22v4"
|
||||||
|
}
|
||||||
|
vllm_model_config = {
|
||||||
|
"lora_modules": None, # [LoRAModulePath(name=a, path=b)]
|
||||||
|
"prompt_adapters": None, # [PromptAdapterPath(name=a, path=b)]
|
||||||
|
"response_role": "assistant",
|
||||||
|
"chat_template": None,
|
||||||
|
"return_tokens_as_token_ids": False,
|
||||||
|
"max_log_len": None
|
||||||
|
}
|
||||||
|
|
||||||
|
self.engine_args = AsyncEngineArgs(**vllm_engine_config)
|
||||||
|
self.async_engine_client = AsyncLLMEngine.from_engine_args(self.engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||||
|
|
||||||
|
|
||||||
|
model_config = self.async_engine_client.engine.get_model_config()
|
||||||
|
|
||||||
|
request_logger = RequestLogger(max_log_len=vllm_model_config["max_log_len"])
|
||||||
|
|
||||||
|
self.openai_serving_chat = OpenAIServingChat(
|
||||||
|
self.async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names=[vllm_engine_config["served_model_name"]],
|
||||||
|
response_role=vllm_model_config["response_role"],
|
||||||
|
lora_modules=vllm_model_config["lora_modules"],
|
||||||
|
prompt_adapters=vllm_model_config["prompt_adapters"],
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=vllm_model_config["chat_template"],
|
||||||
|
return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"]
|
||||||
|
)
|
||||||
|
self.openai_serving_completion = OpenAIServingCompletion(
|
||||||
|
self.async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names=[vllm_engine_config["served_model_name"]],
|
||||||
|
lora_modules=vllm_model_config["lora_modules"],
|
||||||
|
prompt_adapters=vllm_model_config["prompt_adapters"],
|
||||||
|
request_logger=request_logger,
|
||||||
|
return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"]
|
||||||
|
)
|
||||||
|
self.openai_serving_embedding = OpenAIServingEmbedding(
|
||||||
|
self.async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names=[vllm_engine_config["served_model_name"]],
|
||||||
|
request_logger=request_logger
|
||||||
|
)
|
||||||
|
self.openai_serving_tokenization = OpenAIServingTokenization(
|
||||||
|
self.async_engine_client,
|
||||||
|
model_config,
|
||||||
|
served_model_names=[vllm_engine_config["served_model_name"]],
|
||||||
|
lora_modules=vllm_model_config["lora_modules"],
|
||||||
|
request_logger=request_logger,
|
||||||
|
chat_template=vllm_model_config["chat_template"]
|
||||||
|
)
|
||||||
|
# override `send_request` method with the async version
|
||||||
|
self._preprocess.__class__.send_request = VllmPreprocessRequest._preprocess_send_request
|
||||||
|
|
||||||
|
async def preprocess(
|
||||||
|
self,
|
||||||
|
request: dict,
|
||||||
|
state: dict,
|
||||||
|
collect_custom_statistics_fn: Callable[[dict], None] = None,
|
||||||
|
) -> Optional[Any]:
|
||||||
|
"""
|
||||||
|
Raise exception to report an error
|
||||||
|
Return value will be passed to serving engine
|
||||||
|
|
||||||
|
:param request: dictionary as recieved from the RestAPI
|
||||||
|
:param state: Use state dict to store data passed to the post-processing function call.
|
||||||
|
Usage example:
|
||||||
|
>>> def preprocess(..., state):
|
||||||
|
state['preprocess_aux_data'] = [1,2,3]
|
||||||
|
>>> def postprocess(..., state):
|
||||||
|
print(state['preprocess_aux_data'])
|
||||||
|
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
|
||||||
|
to the statictics collector servicd
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
>>> print(request)
|
||||||
|
{"x0": 1, "x1": 2}
|
||||||
|
>>> collect_custom_statistics_fn({"x0": 1, "x1": 2})
|
||||||
|
|
||||||
|
:return: Object to be passed directly to the model inference
|
||||||
|
"""
|
||||||
|
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
|
||||||
|
return await self._preprocess.preprocess(request, state, collect_custom_statistics_fn)
|
||||||
|
return request
|
||||||
|
|
||||||
|
async def postprocess(
|
||||||
|
self,
|
||||||
|
data: Any,
|
||||||
|
state: dict,
|
||||||
|
collect_custom_statistics_fn: Callable[[dict], None] = None
|
||||||
|
) -> Optional[dict]:
|
||||||
|
"""
|
||||||
|
Raise exception to report an error
|
||||||
|
Return value will be passed to serving engine
|
||||||
|
|
||||||
|
:param data: object as recieved from the inference model function
|
||||||
|
:param state: Use state dict to store data passed to the post-processing function call.
|
||||||
|
Usage example:
|
||||||
|
>>> def preprocess(..., state):
|
||||||
|
state['preprocess_aux_data'] = [1,2,3]
|
||||||
|
>>> def postprocess(..., state):
|
||||||
|
print(state['preprocess_aux_data'])
|
||||||
|
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
|
||||||
|
to the statictics collector servicd
|
||||||
|
|
||||||
|
Usage example:
|
||||||
|
>>> collect_custom_statistics_fn({"y": 1})
|
||||||
|
|
||||||
|
:return: Dictionary passed directly as the returned result of the RestAPI
|
||||||
|
"""
|
||||||
|
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
|
||||||
|
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
async def completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||||
|
"""
|
||||||
|
The actual processing function.
|
||||||
|
We run the process in this context
|
||||||
|
"""
|
||||||
|
if REMOVE_WEB_ADDITIONAL_PROMPTS:
|
||||||
|
if "messages" in body:
|
||||||
|
body["messages"] = remove_extra_system_prompts(body["messages"])
|
||||||
|
|
||||||
|
raw_request = CustomRequest(
|
||||||
|
headers = {
|
||||||
|
"traceparent": None,
|
||||||
|
"tracestate": None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
request = CompletionRequest(**body)
|
||||||
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
generator = await self.openai_serving_completion.create_completion(
|
||||||
|
request=request,
|
||||||
|
raw_request=raw_request
|
||||||
|
)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
|
||||||
|
"""
|
||||||
|
The actual processing function.
|
||||||
|
We run the process in this context
|
||||||
|
"""
|
||||||
|
# if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'):
|
||||||
|
# return await self._preprocess.chat_completion(data, state, collect_custom_statistics_fn)
|
||||||
|
# return None
|
||||||
|
if REMOVE_WEB_ADDITIONAL_PROMPTS:
|
||||||
|
if "messages" in body:
|
||||||
|
body["messages"] = remove_extra_system_prompts(body["messages"])
|
||||||
|
|
||||||
|
request = ChatCompletionRequest(**body)
|
||||||
|
logger.info(f"Received chat completion request: {request}")
|
||||||
|
generator = await self.openai_serving_chat.create_chat_completion(
|
||||||
|
request=request, raw_request=None
|
||||||
|
)
|
||||||
|
if isinstance(generator, ErrorResponse):
|
||||||
|
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
|
||||||
|
if request.stream:
|
||||||
|
return StreamingResponse(content=generator, media_type="text/event-stream")
|
||||||
|
else:
|
||||||
|
assert isinstance(generator, ChatCompletionResponse)
|
||||||
|
return JSONResponse(content=generator.model_dump())
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:
|
||||||
|
endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/")
|
||||||
|
base_url = BasePreprocessRequest.get_server_config().get("base_serving_url")
|
||||||
|
base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/")
|
||||||
|
url = "{}/{}".format(base_url, endpoint.strip("/"))
|
||||||
|
return_value = await CustomAsyncPreprocessRequest.asyncio_to_thread(
|
||||||
|
request_post, url, json=data, timeout=BasePreprocessRequest._timeout)
|
||||||
|
if not return_value.ok:
|
||||||
|
return None
|
||||||
|
return return_value.json()
|
||||||
|
Loading…
Reference in New Issue
Block a user