From c7a9b5ccfab963803916e63df68058a3938d511b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Mon, 1 Jul 2024 19:33:58 -0700 Subject: [PATCH] refac: chat completion middleware --- backend/apps/rag/utils.py | 14 +- backend/main.py | 501 ++++++++++++++++------------ src/lib/components/chat/Chat.svelte | 12 +- 3 files changed, 304 insertions(+), 223 deletions(-) diff --git a/backend/apps/rag/utils.py b/backend/apps/rag/utils.py index 3a3dad4a2..fde89b069 100644 --- a/backend/apps/rag/utils.py +++ b/backend/apps/rag/utils.py @@ -294,14 +294,16 @@ def get_rag_context( extracted_collections.extend(collection_names) - context_string = "" - + contexts = [] citations = [] + for context in relevant_contexts: try: if "documents" in context: - context_string += "\n\n".join( - [text for text in context["documents"][0] if text is not None] + contexts.append( + "\n\n".join( + [text for text in context["documents"][0] if text is not None] + ) ) if "metadatas" in context: @@ -315,9 +317,7 @@ def get_rag_context( except Exception as e: log.exception(e) - context_string = context_string.strip() - - return context_string, citations + return contexts, citations def get_model_path(model: str, update_model: bool = False): diff --git a/backend/main.py b/backend/main.py index e1172f026..00e08676b 100644 --- a/backend/main.py +++ b/backend/main.py @@ -213,7 +213,7 @@ origins = ["*"] async def get_function_call_response( - messages, files, tool_id, template, task_model_id, user + messages, files, tool_id, template, task_model_id, user, model ): tool = Tools.get_tool_by_id(tool_id) tools_specs = json.dumps(tool.specs, indent=2) @@ -373,233 +373,308 @@ async def get_function_call_response( return None, None, False +def get_task_model_id(default_model_id): + # Set the task model + task_model_id = default_model_id + # Check if the user has a custom task model and use that model + if app.state.MODELS[task_model_id]["owned_by"] == "ollama": + if ( + app.state.config.TASK_MODEL + and app.state.config.TASK_MODEL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL + else: + if ( + app.state.config.TASK_MODEL_EXTERNAL + and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS + ): + task_model_id = app.state.config.TASK_MODEL_EXTERNAL + + return task_model_id + + +def get_filter_function_ids(model): + def get_priority(function_id): + function = Functions.get_function_by_id(function_id) + if function is not None and hasattr(function, "valves"): + return (function.valves if function.valves else {}).get("priority", 0) + return 0 + + filter_ids = [function.id for function in Functions.get_global_filter_functions()] + if "info" in model and "meta" in model["info"]: + filter_ids.extend(model["info"]["meta"].get("filterIds", [])) + filter_ids = list(set(filter_ids)) + + enabled_filter_ids = [ + function.id + for function in Functions.get_functions_by_type("filter", active_only=True) + ] + + filter_ids = [ + filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids + ] + + filter_ids.sort(key=get_priority) + return filter_ids + + +async def chat_completion_functions_handler(body, model, user): + skip_files = None + + filter_ids = get_filter_function_ids(model) + for filter_id in filter_ids: + filter = Functions.get_function_by_id(filter_id) + if filter: + if filter_id in webui_app.state.FUNCTIONS: + function_module = webui_app.state.FUNCTIONS[filter_id] + else: + function_module, function_type, frontmatter = ( + load_function_module_by_id(filter_id) + ) + webui_app.state.FUNCTIONS[filter_id] = function_module + + # Check if the function has a file_handler variable + if hasattr(function_module, "file_handler"): + skip_files = function_module.file_handler + + if hasattr(function_module, "valves") and hasattr( + function_module, "Valves" + ): + valves = Functions.get_function_valves_by_id(filter_id) + function_module.valves = function_module.Valves( + **(valves if valves else {}) + ) + + try: + if hasattr(function_module, "inlet"): + inlet = function_module.inlet + + # Get the signature of the function + sig = inspect.signature(inlet) + params = {"body": body} + + if "__user__" in sig.parameters: + __user__ = { + "id": user.id, + "email": user.email, + "name": user.name, + "role": user.role, + } + + try: + if hasattr(function_module, "UserValves"): + __user__["valves"] = function_module.UserValves( + **Functions.get_user_valves_by_id_and_user_id( + filter_id, user.id + ) + ) + except Exception as e: + print(e) + + params = {**params, "__user__": __user__} + + if "__id__" in sig.parameters: + params = { + **params, + "__id__": filter_id, + } + + if "__model__" in sig.parameters: + params = { + **params, + "__model__": model, + } + + if inspect.iscoroutinefunction(inlet): + body = await inlet(**params) + else: + body = inlet(**params) + + except Exception as e: + print(f"Error: {e}") + raise e + + if skip_files: + if "files" in body: + del body["files"] + + return body, {} + + +async def chat_completion_tools_handler(body, model, user): + skip_files = None + + contexts = [] + citations = None + + task_model_id = get_task_model_id(body["model"]) + + # If tool_ids field is present, call the functions + if "tool_ids" in body: + print(body["tool_ids"]) + for tool_id in body["tool_ids"]: + print(tool_id) + try: + response, citation, file_handler = await get_function_call_response( + messages=body["messages"], + files=body.get("files", []), + tool_id=tool_id, + template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, + task_model_id=task_model_id, + user=user, + model=model, + ) + + print(file_handler) + if isinstance(response, str): + contexts.append(response) + + if citation: + if citations is None: + citations = [citation] + else: + citations.append(citation) + + if file_handler: + skip_files = True + + except Exception as e: + print(f"Error: {e}") + del body["tool_ids"] + print(f"tool_contexts: {contexts}") + + if skip_files: + if "files" in body: + del body["files"] + + return body, { + **({"contexts": contexts} if contexts is not None else {}), + **({"citations": citations} if citations is not None else {}), + } + + +async def chat_completion_files_handler(body): + contexts = [] + citations = None + + if "files" in body: + files = body["files"] + del body["files"] + + contexts, citations = get_rag_context( + files=files, + messages=body["messages"], + embedding_function=rag_app.state.EMBEDDING_FUNCTION, + k=rag_app.state.config.TOP_K, + reranking_function=rag_app.state.sentence_transformer_rf, + r=rag_app.state.config.RELEVANCE_THRESHOLD, + hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, + ) + + log.debug(f"rag_contexts: {contexts}, citations: {citations}") + + return body, { + **({"contexts": contexts} if contexts is not None else {}), + **({"citations": citations} if citations is not None else {}), + } + + +async def get_body_and_model_and_user(request): + # Read the original request body + body = await request.body() + body_str = body.decode("utf-8") + body = json.loads(body_str) if body_str else {} + + model_id = body["model"] + if model_id not in app.state.MODELS: + raise "Model not found" + model = app.state.MODELS[model_id] + + user = get_current_user( + request, + get_http_authorization_cred(request.headers.get("Authorization")), + ) + + return body, model, user + + class ChatCompletionMiddleware(BaseHTTPMiddleware): async def dispatch(self, request: Request, call_next): - data_items = [] - - show_citations = False - citations = [] - if request.method == "POST" and any( endpoint in request.url.path for endpoint in ["/ollama/api/chat", "/chat/completions"] ): log.debug(f"request.url.path: {request.url.path}") - # Read the original request body - body = await request.body() - body_str = body.decode("utf-8") - data = json.loads(body_str) if body_str else {} - - user = get_current_user( - request, - get_http_authorization_cred(request.headers.get("Authorization")), - ) - # Flag to skip RAG completions if file_handler is present in tools/functions - skip_files = False - if data.get("citations"): - show_citations = True - del data["citations"] - - model_id = data["model"] - if model_id not in app.state.MODELS: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail="Model not found", + try: + body, model, user = await get_body_and_model_and_user(request) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, ) - model = app.state.MODELS[model_id] - def get_priority(function_id): - function = Functions.get_function_by_id(function_id) - if function is not None and hasattr(function, "valves"): - return (function.valves if function.valves else {}).get( - "priority", 0 - ) - return 0 + # Extract chat_id and message_id from the request body + chat_id = None + if "chat_id" in body: + chat_id = body["chat_id"] + del body["chat_id"] + message_id = None + if "id" in body: + message_id = body["id"] + del body["id"] - filter_ids = [ - function.id for function in Functions.get_global_filter_functions() - ] - if "info" in model and "meta" in model["info"]: - filter_ids.extend(model["info"]["meta"].get("filterIds", [])) - filter_ids = list(set(filter_ids)) + # Initialize data_items to store additional data to be sent to the client + data_items = [] - enabled_filter_ids = [ - function.id - for function in Functions.get_functions_by_type( - "filter", active_only=True + # Initialize context, and citations + contexts = [] + citations = [] + + print(body) + + try: + body, flags = await chat_completion_functions_handler(body, model, user) + except Exception as e: + return JSONResponse( + status_code=status.HTTP_400_BAD_REQUEST, + content={"detail": str(e)}, ) - ] - filter_ids = [ - filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids - ] - filter_ids.sort(key=get_priority) - for filter_id in filter_ids: - filter = Functions.get_function_by_id(filter_id) - if filter: - if filter_id in webui_app.state.FUNCTIONS: - function_module = webui_app.state.FUNCTIONS[filter_id] - else: - function_module, function_type, frontmatter = ( - load_function_module_by_id(filter_id) - ) - webui_app.state.FUNCTIONS[filter_id] = function_module + try: + body, flags = await chat_completion_tools_handler(body, model, user) - # Check if the function has a file_handler variable - if hasattr(function_module, "file_handler"): - skip_files = function_module.file_handler + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + print(e) + pass - if hasattr(function_module, "valves") and hasattr( - function_module, "Valves" - ): - valves = Functions.get_function_valves_by_id(filter_id) - function_module.valves = function_module.Valves( - **(valves if valves else {}) - ) + try: + body, flags = await chat_completion_files_handler(body) - try: - if hasattr(function_module, "inlet"): - inlet = function_module.inlet + contexts.extend(flags.get("contexts", [])) + citations.extend(flags.get("citations", [])) + except Exception as e: + print(e) + pass - # Get the signature of the function - sig = inspect.signature(inlet) - params = {"body": data} + # If context is not empty, insert it into the messages + if len(contexts) > 0: + context_string = "/n".join(contexts).strip() + prompt = get_last_user_message(body["messages"]) + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) - if "__user__" in sig.parameters: - __user__ = { - "id": user.id, - "email": user.email, - "name": user.name, - "role": user.role, - } - - try: - if hasattr(function_module, "UserValves"): - __user__["valves"] = function_module.UserValves( - **Functions.get_user_valves_by_id_and_user_id( - filter_id, user.id - ) - ) - except Exception as e: - print(e) - - params = {**params, "__user__": __user__} - - if "__id__" in sig.parameters: - params = { - **params, - "__id__": filter_id, - } - - if "__model__" in sig.parameters: - params = { - **params, - "__model__": model, - } - - if inspect.iscoroutinefunction(inlet): - data = await inlet(**params) - else: - data = inlet(**params) - - except Exception as e: - print(f"Error: {e}") - return JSONResponse( - status_code=status.HTTP_400_BAD_REQUEST, - content={"detail": str(e)}, - ) - - # Set the task model - task_model_id = data["model"] - # Check if the user has a custom task model and use that model - if app.state.MODELS[task_model_id]["owned_by"] == "ollama": - if ( - app.state.config.TASK_MODEL - and app.state.config.TASK_MODEL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL - else: - if ( - app.state.config.TASK_MODEL_EXTERNAL - and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS - ): - task_model_id = app.state.config.TASK_MODEL_EXTERNAL - - prompt = get_last_user_message(data["messages"]) - context = "" - - # If tool_ids field is present, call the functions - if "tool_ids" in data: - print(data["tool_ids"]) - for tool_id in data["tool_ids"]: - print(tool_id) - try: - response, citation, file_handler = ( - await get_function_call_response( - messages=data["messages"], - files=data.get("files", []), - tool_id=tool_id, - template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE, - task_model_id=task_model_id, - user=user, - ) - ) - - print(file_handler) - if isinstance(response, str): - context += ("\n" if context != "" else "") + response - - if citation: - citations.append(citation) - show_citations = True - - if file_handler: - skip_files = True - - except Exception as e: - print(f"Error: {e}") - del data["tool_ids"] - - print(f"tool_context: {context}") - - # If files field is present, generate RAG completions - # If skip_files is True, skip the RAG completions - if "files" in data: - if not skip_files: - data = {**data} - rag_context, rag_citations = get_rag_context( - files=data["files"], - messages=data["messages"], - embedding_function=rag_app.state.EMBEDDING_FUNCTION, - k=rag_app.state.config.TOP_K, - reranking_function=rag_app.state.sentence_transformer_rf, - r=rag_app.state.config.RELEVANCE_THRESHOLD, - hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH, - ) - if rag_context: - context += ("\n" if context != "" else "") + rag_context - - log.debug(f"rag_context: {rag_context}, citations: {citations}") - - if rag_citations: - citations.extend(rag_citations) - - del data["files"] - - if show_citations and len(citations) > 0: + # If there are citations, add them to the data_items + if len(citations) > 0: data_items.append({"citations": citations}) - if context != "": - system_prompt = rag_template( - rag_app.state.config.RAG_TEMPLATE, context, prompt - ) - print(system_prompt) - data["messages"] = add_or_update_system_message( - system_prompt, data["messages"] - ) - - modified_body_bytes = json.dumps(data).encode("utf-8") + modified_body_bytes = json.dumps(body).encode("utf-8") # Replace the request body with the modified one request._body = modified_body_bytes # Set custom header to ensure content-length matches new body length @@ -721,9 +796,6 @@ def filter_pipeline(payload, user): pass if "pipeline" not in app.state.MODELS[model_id]: - if "chat_id" in payload: - del payload["chat_id"] - if "title" in payload: del payload["title"] @@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)): content={"detail": e.args[1]}, ) + if "chat_id" in payload: + del payload["chat_id"] + return await generate_chat_completions(form_data=payload, user=user) @@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user) content={"detail": e.args[1]}, ) + if "chat_id" in payload: + del payload["chat_id"] + return await generate_chat_completions(form_data=payload, user=user) @@ -1349,6 +1427,9 @@ Message: """{{prompt}}""" content={"detail": e.args[1]}, ) + if "chat_id" in payload: + del payload["chat_id"] + return await generate_chat_completions(form_data=payload, user=user) diff --git a/src/lib/components/chat/Chat.svelte b/src/lib/components/chat/Chat.svelte index 1f8fd9827..056432d42 100644 --- a/src/lib/components/chat/Chat.svelte +++ b/src/lib/components/chat/Chat.svelte @@ -665,6 +665,7 @@ await tick(); const [res, controller] = await generateChatCompletion(localStorage.token, { + stream: true, model: model.id, messages: messagesBody, options: { @@ -682,8 +683,8 @@ keep_alive: $settings.keepAlive ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0 ? true : undefined, - chat_id: $chatId + chat_id: $chatId, + id: responseMessageId }); if (res && res.ok) { @@ -912,8 +913,8 @@ const [res, controller] = await generateOpenAIChatCompletion( localStorage.token, { - model: model.id, stream: true, + model: model.id, stream_options: model.info?.meta?.capabilities?.usage ?? false ? { @@ -983,9 +984,8 @@ max_tokens: $settings?.params?.max_tokens ?? undefined, tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined, files: files.length > 0 ? files : undefined, - citations: files.length > 0 ? true : undefined, - - chat_id: $chatId + chat_id: $chatId, + id: responseMessageId }, `${WEBUI_BASE_URL}/api` );