revert some old changes

This commit is contained in:
IlyaMescheryakov1402 2025-02-27 23:13:47 +03:00
parent 5b73bdf085
commit f51bf2e081
2 changed files with 334 additions and 27 deletions

View File

@ -1241,7 +1241,7 @@ class ModelRequestProcessor(object):
if processor.is_process_async \
else processor.chat_completion(preprocessed, state, stats_collect_fn)
else:
raise ValueError(f"wrong url_type: expected 'completions' and 'chat/completions', got {url_type}")
raise ValueError(f"wrong url_type: expected 'completions' and 'chat/completions', got {serve_type}")
# noinspection PyUnresolvedReferences
return_value = await processor.postprocess(processed, state, stats_collect_fn) \
if processor.is_postprocess_async \

View File

@ -19,7 +19,7 @@ class BasePreprocessRequest(object):
__preprocessing_lookup = {}
__preprocessing_modules = set()
_grpc_env_conf_prefix = "CLEARML_GRPC_"
_default_serving_base_url = "http://127.0.0.1:8080/clearml/"
_default_serving_base_url = "http://127.0.0.1:8080/serve/"
_server_config = {} # externally configured by the serving inference service
_timeout = None # timeout in seconds for the entire request, set in __init__
is_preprocess_async = False
@ -292,7 +292,7 @@ class TritonPreprocessRequest(BasePreprocessRequest):
self._grpc_stub = {}
async def chat_completion(
async def process(
self,
data: Any,
state: dict,
@ -428,28 +428,74 @@ class TritonPreprocessRequest(BasePreprocessRequest):
return output_results[0] if index == 1 else output_results
@BasePreprocessRequest.register_engine("sklearn", modules=["joblib", "sklearn"])
class SKLearnPreprocessRequest(BasePreprocessRequest):
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(SKLearnPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
if self._model is None:
# get model
import joblib # noqa
self._model = joblib.load(filename=self._get_local_model_file())
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the model in this context
"""
return self._model.predict(data)
@BasePreprocessRequest.register_engine("xgboost", modules=["xgboost"])
class XGBoostPreprocessRequest(BasePreprocessRequest):
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(XGBoostPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
if self._model is None:
# get model
import xgboost # noqa
self._model = xgboost.Booster()
self._model.load_model(self._get_local_model_file())
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the model in this context
"""
return self._model.predict(data)
@BasePreprocessRequest.register_engine("lightgbm", modules=["lightgbm"])
class LightGBMPreprocessRequest(BasePreprocessRequest):
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(LightGBMPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
if self._model is None:
# get model
import lightgbm # noqa
self._model = lightgbm.Booster(model_file=self._get_local_model_file())
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the model in this context
"""
return self._model.predict(data)
@BasePreprocessRequest.register_engine("custom")
class CustomPreprocessRequest(BasePreprocessRequest):
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(CustomPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
def completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
if self._preprocess is not None and hasattr(self._preprocess, 'completion'):
return self._preprocess.completion(data, state, collect_custom_statistics_fn)
return None
def chat_completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'):
return self._preprocess.chat_completion(data, state, collect_custom_statistics_fn)
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return self._preprocess.process(data, state, collect_custom_statistics_fn)
return None
@ -530,22 +576,13 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
return data
async def completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
async def process(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
if self._preprocess is not None and hasattr(self._preprocess, 'completion'):
return await self._preprocess.completion(data, state, collect_custom_statistics_fn)
return None
async def chat_completion(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'):
return await self._preprocess.chat_completion(data, state, collect_custom_statistics_fn)
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return await self._preprocess.process(data, state, collect_custom_statistics_fn)
return None
@staticmethod
@ -559,3 +596,273 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
if not return_value.ok:
return None
return return_value.json()
@BasePreprocessRequest.register_engine("vllm")
class VllmPreprocessRequest(BasePreprocessRequest):
import prometheus_client
from typing import Any, Union, Optional, Callable
from fastapi.responses import JSONResponse, StreamingResponse
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
ErrorResponse
)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.entrypoints.openai.serving_engine import LoRAModulePath, PromptAdapterPath
logger = init_logger(__name__)
REMOVE_WEB_ADDITIONAL_PROMPTS = True
if VllmPreprocessRequest.asyncio_to_thread is None:
from asyncio import to_thread as asyncio_to_thread
VllmPreprocessRequest.asyncio_to_thread = asyncio_to_thread
def remove_extra_system_prompts(messages: list) -> list:
"""
Removes all 'system' prompts except the last one.
:param messages: List of message dicts with 'role' and 'content'.
:return: Modified list of messages with only the last 'system' prompt preserved.
"""
# Фильтруем только системные сообщения
system_messages_indices = []
for i, msg in enumerate(messages):
if msg["role"] == "system":
system_messages_indices.append(i)
else:
break
# Если есть больше одного системного сообщения, удалим все, кроме последнего
if len(system_messages_indices) > 1:
last_system_index = system_messages_indices[-1]
# Удаляем все системные сообщения, кроме последнего
messages = [msg for i, msg in enumerate(messages) if msg["role"] != "system" or i == last_system_index]
return messages
class CustomRequest:
def __init__(self, headers: Optional[dict] = None):
self.headers = headers
async def is_disconnected(self):
return False
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(VllmPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
def is_port_in_use(port: int) -> bool:
import socket
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
return s.connect_ex(('localhost', port)) == 0
if not is_port_in_use(8000):
prometheus_client.start_http_server(8000)
vllm_engine_config = {
"model":f"{local_file_name}/model",
"tokenizer":f"{local_file_name}/tokenizer",
"disable_log_requests": True,
"disable_log_stats": False,
"gpu_memory_utilization": 0.9,
"quantization": None,
"enforce_eager": True,
"served_model_name": "ai_operator_hyp22v4"
}
vllm_model_config = {
"lora_modules": None, # [LoRAModulePath(name=a, path=b)]
"prompt_adapters": None, # [PromptAdapterPath(name=a, path=b)]
"response_role": "assistant",
"chat_template": None,
"return_tokens_as_token_ids": False,
"max_log_len": None
}
self.engine_args = AsyncEngineArgs(**vllm_engine_config)
self.async_engine_client = AsyncLLMEngine.from_engine_args(self.engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
model_config = self.async_engine_client.engine.get_model_config()
request_logger = RequestLogger(max_log_len=vllm_model_config["max_log_len"])
self.openai_serving_chat = OpenAIServingChat(
self.async_engine_client,
model_config,
served_model_names=[vllm_engine_config["served_model_name"]],
response_role=vllm_model_config["response_role"],
lora_modules=vllm_model_config["lora_modules"],
prompt_adapters=vllm_model_config["prompt_adapters"],
request_logger=request_logger,
chat_template=vllm_model_config["chat_template"],
return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"]
)
self.openai_serving_completion = OpenAIServingCompletion(
self.async_engine_client,
model_config,
served_model_names=[vllm_engine_config["served_model_name"]],
lora_modules=vllm_model_config["lora_modules"],
prompt_adapters=vllm_model_config["prompt_adapters"],
request_logger=request_logger,
return_tokens_as_token_ids=vllm_model_config["return_tokens_as_token_ids"]
)
self.openai_serving_embedding = OpenAIServingEmbedding(
self.async_engine_client,
model_config,
served_model_names=[vllm_engine_config["served_model_name"]],
request_logger=request_logger
)
self.openai_serving_tokenization = OpenAIServingTokenization(
self.async_engine_client,
model_config,
served_model_names=[vllm_engine_config["served_model_name"]],
lora_modules=vllm_model_config["lora_modules"],
request_logger=request_logger,
chat_template=vllm_model_config["chat_template"]
)
# override `send_request` method with the async version
self._preprocess.__class__.send_request = VllmPreprocessRequest._preprocess_send_request
async def preprocess(
self,
request: dict,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None,
) -> Optional[Any]:
"""
Raise exception to report an error
Return value will be passed to serving engine
:param request: dictionary as recieved from the RestAPI
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd
Usage example:
>>> print(request)
{"x0": 1, "x1": 2}
>>> collect_custom_statistics_fn({"x0": 1, "x1": 2})
:return: Object to be passed directly to the model inference
"""
if self._preprocess is not None and hasattr(self._preprocess, 'preprocess'):
return await self._preprocess.preprocess(request, state, collect_custom_statistics_fn)
return request
async def postprocess(
self,
data: Any,
state: dict,
collect_custom_statistics_fn: Callable[[dict], None] = None
) -> Optional[dict]:
"""
Raise exception to report an error
Return value will be passed to serving engine
:param data: object as recieved from the inference model function
:param state: Use state dict to store data passed to the post-processing function call.
Usage example:
>>> def preprocess(..., state):
state['preprocess_aux_data'] = [1,2,3]
>>> def postprocess(..., state):
print(state['preprocess_aux_data'])
:param collect_custom_statistics_fn: Optional, allows to send a custom set of key/values
to the statictics collector servicd
Usage example:
>>> collect_custom_statistics_fn({"y": 1})
:return: Dictionary passed directly as the returned result of the RestAPI
"""
if self._preprocess is not None and hasattr(self._preprocess, 'postprocess'):
return await self._preprocess.postprocess(data, state, collect_custom_statistics_fn)
return data
async def completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
if REMOVE_WEB_ADDITIONAL_PROMPTS:
if "messages" in body:
body["messages"] = remove_extra_system_prompts(body["messages"])
raw_request = CustomRequest(
headers = {
"traceparent": None,
"tracestate": None
}
)
request = CompletionRequest(**body)
logger.info(f"Received chat completion request: {request}")
generator = await self.openai_serving_completion.create_completion(
request=request,
raw_request=raw_request
)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator, media_type="text/event-stream")
else:
return JSONResponse(content=generator.model_dump())
async def chat_completions(self, data: Any, state: dict, collect_custom_statistics_fn: Callable[[dict], None] = None) -> Any:
"""
The actual processing function.
We run the process in this context
"""
# if self._preprocess is not None and hasattr(self._preprocess, 'chat_completion'):
# return await self._preprocess.chat_completion(data, state, collect_custom_statistics_fn)
# return None
if REMOVE_WEB_ADDITIONAL_PROMPTS:
if "messages" in body:
body["messages"] = remove_extra_system_prompts(body["messages"])
request = ChatCompletionRequest(**body)
logger.info(f"Received chat completion request: {request}")
generator = await self.openai_serving_chat.create_chat_completion(
request=request, raw_request=None
)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(), status_code=generator.code)
if request.stream:
return StreamingResponse(content=generator, media_type="text/event-stream")
else:
assert isinstance(generator, ChatCompletionResponse)
return JSONResponse(content=generator.model_dump())
@staticmethod
async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:
endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/")
base_url = BasePreprocessRequest.get_server_config().get("base_serving_url")
base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/")
url = "{}/{}".format(base_url, endpoint.strip("/"))
return_value = await CustomAsyncPreprocessRequest.asyncio_to_thread(
request_post, url, json=data, timeout=BasePreprocessRequest._timeout)
if not return_value.ok:
return None
return return_value.json()