minor refac

This commit is contained in:
Michael Poluektov 2024-08-14 20:40:10 +01:00
parent fdc89cbcee
commit 4042219b3e

View File

@ -378,9 +378,8 @@ async def chat_completion_inlets_handler(body, model, extra_params):
print(f"Error: {e}")
raise e
if skip_files:
if "files" in body:
del body["files"]
if skip_files and "files" in body:
del body["files"]
return body, {}
@ -431,12 +430,17 @@ def get_configured_tools(
)
for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
name = spec["name"]
callable = getattr(module, name)
# convert to function that takes only model params and inserts custom params
custom_callable = get_tool_with_custom_params(callable, extra_params)
# TODO: This needs to be a pydantic model
tool_dict = {
"spec": spec,
"citation": has_citation,
@ -444,6 +448,7 @@ def get_configured_tools(
"toolkit_id": tool_id,
"callable": custom_callable,
}
# TODO: if collision, prepend toolkit name
if name in tools:
log.warning(f"Tool {name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.")
@ -533,9 +538,9 @@ async def chat_completion_tools_handler(
return body, {"contexts": contexts, "citations": citations}
async def chat_completion_files_handler(body):
async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts = []
citations = None
citations = []
if files := body.pop("files", None):
contexts, citations = get_rag_context(
@ -550,10 +555,7 @@ async def chat_completion_files_handler(body):
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
return body, {"contexts": contexts, "citations": citations}
def is_chat_completion_request(request):
@ -618,16 +620,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
log.exception(e)
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
log.exception(e)
# If context is not empty, insert it into the messages
if len(contexts) > 0: