This commit is contained in:
Timothy J. Baek 2024-08-17 16:41:34 +02:00
parent e71f55e58f
commit 15f3ebba93

View File

@ -218,25 +218,6 @@ origins = ["*"]
##################################
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise Exception("Model not found")
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
@ -283,26 +264,6 @@ def get_filter_function_ids(model):
return filter_ids
def get_tools_function_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
)
prompt = f"History:\n{history}\nQuery: {user_message}"
return {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
async def chat_completion_filter_functions_handler(body, model, extra_params):
skip_files = None
@ -369,12 +330,32 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
return body, {}
def get_tools_function_calling_payload(messages, task_model_id, content):
user_message = get_last_user_message(messages)
history = "\n".join(
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
)
prompt = f"History:\n{history}\nQuery: {user_message}"
return {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"metadata": {"task": str(TASKS.FUNCTION_CALLING)},
}
def apply_extra_params_to_tool_function(
function: Callable, custom_params: dict
function: Callable, extra_params: dict
) -> Callable[..., Awaitable]:
sig = inspect.signature(function)
extra_params = {
key: value for key, value in custom_params.items() if key in sig.parameters
key: value for key, value in extra_params.items() if key in sig.parameters
}
is_coroutine = inspect.iscoroutinefunction(function)
@ -511,27 +492,27 @@ async def chat_completion_tools_handler(
return body, {}
result = json.loads(content)
tool_name = result.get("name", None)
if tool_name not in tools:
tool_function_name = result.get("name", None)
if tool_function_name not in tools:
return body, {}
tool_params = result.get("parameters", {})
toolkit_id = tools[tool_name]["toolkit_id"]
tool_function_params = result.get("parameters", {})
try:
tool_output = await tools[tool_name]["callable"](**tool_params)
tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
except Exception as e:
tool_output = str(e)
if tools[tool_name]["citation"]:
if tools[tool_function_name]["citation"]:
citations.append(
{
"source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
"source": {"name": f"TOOL:{tools[tool_function_name]["toolkit_id"]}/{tool_function_name}"},
"document": [tool_output],
"metadata": [{"source": tool_name}],
"metadata": [{"source": tool_function_name}],
}
)
if tools[tool_name]["file_handler"]:
if tools[tool_function_name]["file_handler"]:
skip_files = True
if isinstance(tool_output, str):
@ -576,6 +557,25 @@ def is_chat_completion_request(request):
)
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise Exception("Model not found")
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if not is_chat_completion_request(request):