This commit is contained in:
Timothy J. Baek 2024-06-22 14:08:23 -07:00
parent df71d7c63b
commit de367e488d

View File

@ -399,7 +399,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(inlet) sig = inspect.signature(inlet)
param = {"body": data} params = {"body": data}
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
__user__ = { __user__ = {
@ -424,15 +424,15 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
params = {**params, "__user__": __user__} params = {**params, "__user__": __user__}
if "__id__" in sig.parameters: if "__id__" in sig.parameters:
param = { params = {
**param, **params,
"__id__": filter_id, "__id__": filter_id,
} }
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
data = await inlet(**param) data = await inlet(**params)
else: else:
data = inlet(**param) data = inlet(**params)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
@ -962,17 +962,17 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
try: try:
if inspect.iscoroutinefunction(pipe): if inspect.iscoroutinefunction(pipe):
res = await pipe(**param) res = await pipe(**params)
else: else:
res = pipe(**param) res = pipe(**params)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return {"error": {"detail": str(e)}} return {"error": {"detail": str(e)}}
if inspect.iscoroutinefunction(pipe): if inspect.iscoroutinefunction(pipe):
res = await pipe(**param) res = await pipe(**params)
else: else:
res = pipe(**param) res = pipe(**params)
if isinstance(res, dict): if isinstance(res, dict):
return res return res
@ -1104,7 +1104,7 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(outlet) sig = inspect.signature(outlet)
param = {"body": data} params = {"body": data}
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
__user__ = { __user__ = {
@ -1127,15 +1127,15 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
params = {**params, "__user__": __user__} params = {**params, "__user__": __user__}
if "__id__" in sig.parameters: if "__id__" in sig.parameters:
param = { params = {
**param, **params,
"__id__": filter_id, "__id__": filter_id,
} }
if inspect.iscoroutinefunction(outlet): if inspect.iscoroutinefunction(outlet):
data = await outlet(**param) data = await outlet(**params)
else: else:
data = outlet(**param) data = outlet(**params)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")