minor refac

This commit is contained in:
Michael Poluektov 2024-08-14 20:40:10 +01:00
parent fdc89cbcee
commit 4042219b3e

View File

@ -378,9 +378,8 @@ async def chat_completion_inlets_handler(body, model, extra_params):
print(f"Error: {e}") print(f"Error: {e}")
raise e raise e
if skip_files: if skip_files and "files" in body:
if "files" in body: del body["files"]
del body["files"]
return body, {} return body, {}
@ -431,12 +430,17 @@ def get_configured_tools(
) )
for spec in toolkit.specs: for spec in toolkit.specs:
# TODO: Fix hack for OpenAI API
for val in spec.get("parameters", {}).get("properties", {}).values():
if val["type"] == "str":
val["type"] = "string"
name = spec["name"] name = spec["name"]
callable = getattr(module, name) callable = getattr(module, name)
# convert to function that takes only model params and inserts custom params # convert to function that takes only model params and inserts custom params
custom_callable = get_tool_with_custom_params(callable, extra_params) custom_callable = get_tool_with_custom_params(callable, extra_params)
# TODO: This needs to be a pydantic model
tool_dict = { tool_dict = {
"spec": spec, "spec": spec,
"citation": has_citation, "citation": has_citation,
@ -444,6 +448,7 @@ def get_configured_tools(
"toolkit_id": tool_id, "toolkit_id": tool_id,
"callable": custom_callable, "callable": custom_callable,
} }
# TODO: if collision, prepend toolkit name
if name in tools: if name in tools:
log.warning(f"Tool {name} already exists in another toolkit!") log.warning(f"Tool {name} already exists in another toolkit!")
log.warning(f"Collision between {toolkit} and {tool_id}.") log.warning(f"Collision between {toolkit} and {tool_id}.")
@ -533,9 +538,9 @@ async def chat_completion_tools_handler(
return body, {"contexts": contexts, "citations": citations} return body, {"contexts": contexts, "citations": citations}
async def chat_completion_files_handler(body): async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
contexts = [] contexts = []
citations = None citations = []
if files := body.pop("files", None): if files := body.pop("files", None):
contexts, citations = get_rag_context( contexts, citations = get_rag_context(
@ -550,10 +555,7 @@ async def chat_completion_files_handler(body):
log.debug(f"rag_contexts: {contexts}, citations: {citations}") log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, { return body, {"contexts": contexts, "citations": citations}
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
def is_chat_completion_request(request): def is_chat_completion_request(request):
@ -618,16 +620,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", [])) citations.extend(flags.get("citations", []))
except Exception as e: except Exception as e:
print(e) log.exception(e)
pass
try: try:
body, flags = await chat_completion_files_handler(body) body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", [])) citations.extend(flags.get("citations", []))
except Exception as e: except Exception as e:
print(e) log.exception(e)
pass
# If context is not empty, insert it into the messages # If context is not empty, insert it into the messages
if len(contexts) > 0: if len(contexts) > 0: