refac: chat completion middleware

This commit is contained in:
Timothy J. Baek 2024-07-01 19:33:58 -07:00
parent b62d2a9b28
commit c7a9b5ccfa
3 changed files with 304 additions and 223 deletions

View File

@ -294,15 +294,17 @@ def get_rag_context(
extracted_collections.extend(collection_names) extracted_collections.extend(collection_names)
context_string = "" contexts = []
citations = [] citations = []
for context in relevant_contexts: for context in relevant_contexts:
try: try:
if "documents" in context: if "documents" in context:
context_string += "\n\n".join( contexts.append(
"\n\n".join(
[text for text in context["documents"][0] if text is not None] [text for text in context["documents"][0] if text is not None]
) )
)
if "metadatas" in context: if "metadatas" in context:
citations.append( citations.append(
@ -315,9 +317,7 @@ def get_rag_context(
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
context_string = context_string.strip() return contexts, citations
return context_string, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):

View File

@ -213,7 +213,7 @@ origins = ["*"]
async def get_function_call_response( async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user messages, files, tool_id, template, task_model_id, user, model
): ):
tool = Tools.get_tool_by_id(tool_id) tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2) tools_specs = json.dumps(tool.specs, indent=2)
@ -373,68 +373,55 @@ async def get_function_call_response(
return None, None, False return None, None, False
class ChatCompletionMiddleware(BaseHTTPMiddleware): def get_task_model_id(default_model_id):
async def dispatch(self, request: Request, call_next): # Set the task model
data_items = [] task_model_id = default_model_id
# Check if the user has a custom task model and use that model
show_citations = False if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
citations = [] if (
app.state.config.TASK_MODEL
if request.method == "POST" and any( and app.state.config.TASK_MODEL in app.state.MODELS
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
): ):
log.debug(f"request.url.path: {request.url.path}") task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
# Read the original request body return task_model_id
body = await request.body()
body_str = body.decode("utf-8")
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
if data.get("citations"):
show_citations = True
del data["citations"]
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
def get_filter_function_ids(model):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"): if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get( return (function.valves if function.valves else {}).get("priority", 0)
"priority", 0
)
return 0 return 0
filter_ids = [ filter_ids = [function.id for function in Functions.get_global_filter_functions()]
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]: if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids)) filter_ids = list(set(filter_ids))
enabled_filter_ids = [ enabled_filter_ids = [
function.id function.id
for function in Functions.get_functions_by_type( for function in Functions.get_functions_by_type("filter", active_only=True)
"filter", active_only=True
)
] ]
filter_ids = [ filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
] ]
filter_ids.sort(key=get_priority) filter_ids.sort(key=get_priority)
return filter_ids
async def chat_completion_functions_handler(body, model, user):
skip_files = None
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
if filter: if filter:
@ -464,7 +451,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Get the signature of the function # Get the signature of the function
sig = inspect.signature(inlet) sig = inspect.signature(inlet)
params = {"body": data} params = {"body": body}
if "__user__" in sig.parameters: if "__user__" in sig.parameters:
__user__ = { __user__ = {
@ -499,107 +486,195 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
} }
if inspect.iscoroutinefunction(inlet): if inspect.iscoroutinefunction(inlet):
data = await inlet(**params) body = await inlet(**params)
else: else:
data = inlet(**params) body = inlet(**params)
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
return JSONResponse( raise e
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model if skip_files:
task_model_id = data["model"] if "files" in body:
# Check if the user has a custom task model and use that model del body["files"]
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"]) return body, {}
context = ""
async def chat_completion_tools_handler(body, model, user):
skip_files = None
contexts = []
citations = None
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
if "tool_ids" in data: if "tool_ids" in body:
print(data["tool_ids"]) print(body["tool_ids"])
for tool_id in data["tool_ids"]: for tool_id in body["tool_ids"]:
print(tool_id) print(tool_id)
try: try:
response, citation, file_handler = ( response, citation, file_handler = await get_function_call_response(
await get_function_call_response( messages=body["messages"],
messages=data["messages"], files=body.get("files", []),
files=data.get("files", []),
tool_id=tool_id, tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id, task_model_id=task_model_id,
user=user, user=user,
) model=model,
) )
print(file_handler) print(file_handler)
if isinstance(response, str): if isinstance(response, str):
context += ("\n" if context != "" else "") + response contexts.append(response)
if citation: if citation:
if citations is None:
citations = [citation]
else:
citations.append(citation) citations.append(citation)
show_citations = True
if file_handler: if file_handler:
skip_files = True skip_files = True
except Exception as e: except Exception as e:
print(f"Error: {e}") print(f"Error: {e}")
del data["tool_ids"] del body["tool_ids"]
print(f"tool_contexts: {contexts}")
print(f"tool_context: {context}") if skip_files:
if "files" in body:
del body["files"]
# If files field is present, generate RAG completions return body, {
# If skip_files is True, skip the RAG completions **({"contexts": contexts} if contexts is not None else {}),
if "files" in data: **({"citations": citations} if citations is not None else {}),
if not skip_files: }
data = {**data}
rag_context, rag_citations = get_rag_context(
files=data["files"], async def chat_completion_files_handler(body):
messages=data["messages"], contexts = []
citations = None
if "files" in body:
files = body["files"]
del body["files"]
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION, embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K, k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf, reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD, r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
) )
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}") log.debug(f"rag_contexts: {contexts}, citations: {citations}")
if rag_citations: return body, {
citations.extend(rag_citations) **({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
del data["files"]
if show_citations and len(citations) > 0: async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise "Model not found"
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
try:
body, model, user = await get_body_and_model_and_user(request)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Extract chat_id and message_id from the request body
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
del body["chat_id"]
message_id = None
if "id" in body:
message_id = body["id"]
del body["id"]
# Initialize data_items to store additional data to be sent to the client
data_items = []
# Initialize context, and citations
contexts = []
citations = []
print(body)
try:
body, flags = await chat_completion_functions_handler(body, model, user)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
try:
body, flags = await chat_completion_tools_handler(body, model, user)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
try:
body, flags = await chat_completion_files_handler(body)
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations}) data_items.append({"citations": citations})
if context != "": modified_body_bytes = json.dumps(body).encode("utf-8")
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
system_prompt, data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one
request._body = modified_body_bytes request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length # Set custom header to ensure content-length matches new body length
@ -721,9 +796,6 @@ def filter_pipeline(payload, user):
pass pass
if "pipeline" not in app.state.MODELS[model_id]: if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload: if "title" in payload:
del payload["title"] del payload["title"]
@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)
@ -1349,6 +1427,9 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]}, content={"detail": e.args[1]},
) )
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user) return await generate_chat_completions(form_data=payload, user=user)

View File

@ -665,6 +665,7 @@
await tick(); await tick();
const [res, controller] = await generateChatCompletion(localStorage.token, { const [res, controller] = await generateChatCompletion(localStorage.token, {
stream: true,
model: model.id, model: model.id,
messages: messagesBody, messages: messagesBody,
options: { options: {
@ -682,8 +683,8 @@
keep_alive: $settings.keepAlive ?? undefined, keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined, chat_id: $chatId,
chat_id: $chatId id: responseMessageId
}); });
if (res && res.ok) { if (res && res.ok) {
@ -912,8 +913,8 @@
const [res, controller] = await generateOpenAIChatCompletion( const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token, localStorage.token,
{ {
model: model.id,
stream: true, stream: true,
model: model.id,
stream_options: stream_options:
model.info?.meta?.capabilities?.usage ?? false model.info?.meta?.capabilities?.usage ?? false
? { ? {
@ -983,9 +984,8 @@
max_tokens: $settings?.params?.max_tokens ?? undefined, max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined, files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined, chat_id: $chatId,
id: responseMessageId
chat_id: $chatId
}, },
`${WEBUI_BASE_URL}/api` `${WEBUI_BASE_URL}/api`
); );