Fix python < 3.10 support

Fix custom_async engine
Suppress warning
This commit is contained in:
allegroai 2024-02-27 09:45:32 +02:00
parent bca162810b
commit 8df521b949

View File

@ -2,6 +2,7 @@ import os
import sys import sys
import threading import threading
import traceback import traceback
import warnings
from pathlib import Path from pathlib import Path
from typing import Optional, Any, Callable, List from typing import Optional, Any, Callable, List
@ -247,6 +248,8 @@ class BasePreprocessRequest(object):
@BasePreprocessRequest.register_engine("triton", modules=["grpc", "tritonclient"]) @BasePreprocessRequest.register_engine("triton", modules=["grpc", "tritonclient"])
class TritonPreprocessRequest(BasePreprocessRequest): class TritonPreprocessRequest(BasePreprocessRequest):
with warnings.catch_warnings():
warnings.simplefilter(action='ignore', category=FutureWarning)
_content_lookup = { _content_lookup = {
getattr(np, 'int', int): 'int_contents', getattr(np, 'int', int): 'int_contents',
np.uint8: 'uint_contents', np.uint8: 'uint_contents',
@ -501,10 +504,17 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
is_preprocess_async = True is_preprocess_async = True
is_process_async = True is_process_async = True
is_postprocess_async = True is_postprocess_async = True
asyncio_to_thread = None
def __init__(self, model_endpoint: ModelEndpoint, task: Task = None): def __init__(self, model_endpoint: ModelEndpoint, task: Task = None):
super(CustomAsyncPreprocessRequest, self).__init__( super(CustomAsyncPreprocessRequest, self).__init__(
model_endpoint=model_endpoint, task=task) model_endpoint=model_endpoint, task=task)
# load asyncio only when needed, basically python < 3.10 does not supported to_thread
if CustomAsyncPreprocessRequest.asyncio_to_thread is None:
from asyncio import to_thread as asyncio_to_thread
CustomAsyncPreprocessRequest.asyncio_to_thread = asyncio_to_thread
# override `send_request` method with the async version
self._preprocess.__class__.send_request = CustomAsyncPreprocessRequest._preprocess_send_request
async def preprocess( async def preprocess(
self, self,
@ -574,3 +584,15 @@ class CustomAsyncPreprocessRequest(BasePreprocessRequest):
if self._preprocess is not None and hasattr(self._preprocess, 'process'): if self._preprocess is not None and hasattr(self._preprocess, 'process'):
return await self._preprocess.process(data, state, collect_custom_statistics_fn) return await self._preprocess.process(data, state, collect_custom_statistics_fn)
return None return None
@staticmethod
async def _preprocess_send_request(_, endpoint: str, version: str = None, data: dict = None) -> Optional[dict]:
endpoint = "{}/{}".format(endpoint.strip("/"), version.strip("/")) if version else endpoint.strip("/")
base_url = BasePreprocessRequest.get_server_config().get("base_serving_url")
base_url = (base_url or BasePreprocessRequest._default_serving_base_url).strip("/")
url = "{}/{}".format(base_url, endpoint.strip("/"))
return_value = await CustomAsyncPreprocessRequest.asyncio_to_thread(
request_post, url, json=data, timeout=BasePreprocessRequest._timeout)
if not return_value.ok:
return None
return return_value.json()