Fix python < 3.10 support

Fix custom_async engine
Suppress warning
This commit is contained in:
allegroai 2024-02-27 09:45:32 +02:00
parent bca162810b
commit 8df521b949

View File

@ -2,6 +2,7 @@ import os
import sys
import threading
import traceback
import warnings
from pathlib import Path
from typing import Optional, Any, Callable, List
@ -247,18 +248,20 @@ class BasePreprocessRequest(object):
@BasePreprocessRequest.register_engine("triton", modules=["grpc", "tritonclient"])
class TritonPreprocessRequest(BasePreprocessRequest):
_content_lookup = {
getattr(np, 'int', int): 'int_contents',
np.uint8: 'uint_contents',
np.int8: 'int_contents',
np.int64: 'int64_contents',
np.uint64: 'uint64_contents',
np.int32: 'int_contents',
np.uint: 'uint_contents',
getattr(np, 'bool', bool): 'bool_contents',
np.float32: 'fp32_contents',
np.float64: 'fp64_contents',
}
with warnings.catch_warnings():
warnings.simplefilter(action='ignore', category=FutureWarning)
_content_lookup = {
getattr(np, 'int', int): 'int_contents',
np.uint8: 'uint_contents',
np.int8: 'int_contents',
np.int64: 'int64_contents',
np.uint64: 'uint64_contents',
np.int32: 'int_contents',
np.uint: 'uint_contents',
getattr(np, 'bool', bool): 'bool_contents',
np.float32: 'fp32_contents',
np.float64: 'fp64_contents',
}
_default_grpc_address = "127.0.0.1:8001"
_default_grpc_compression = False
_ext_grpc = None
@ -501,10 +504,17 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
is_preprocess_async = True
is_process_async = True
is_postprocess_async = True
asyncio_to_thread = None
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(CustomAsyncPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task)
# load asyncio only when needed, basically python < 3.10 does not supported to_thread
if CustomAsyncPreprocessRequest.asyncio_to_thread is None:
from asyncio import to_thread as asyncio_to_thread
CustomAsyncPreprocessRequest.asyncio_to_thread = asyncio_to_thread
# override `send_request` method with the async version
self._preprocess.__class__.send_request = CustomAsyncPreprocessRequest._preprocess_send_request
async def preprocess(
self,
@ -574,3 +584,15 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return await self._preprocess.process(data, state, collect_custom_statistics_fn)
return None
@staticmethod
async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:
endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/")
base_url = BasePreprocessRequest.get_server_config().get("base_serving_url")
base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/")
url = "{}/{}".format(base_url, endpoint.strip("/"))
return_value = await CustomAsyncPreprocessRequest.asyncio_to_thread(
request_post, url, json=data, timeout=BasePreprocessRequest._timeout)
if not return_value.ok:
return None
return return_value.json()