From 1aaa2e8219b5213725e137c65424f6cacab89b6b Mon Sep 17 00:00:00 2001 From: "Timothy J. Baek" Date: Fri, 26 Jul 2024 12:22:13 +0100 Subject: [PATCH] fix: ollama rag issue workaround --- backend/main.py | 24 ++++++++++++++++++------ backend/utils/misc.py | 15 +++++++++++++++ 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/backend/main.py b/backend/main.py index 16d7b2b40..27b962b93 100644 --- a/backend/main.py +++ b/backend/main.py @@ -79,6 +79,7 @@ from utils.task import ( from utils.misc import ( get_last_user_message, add_or_update_system_message, + prepend_to_first_user_message_content, parse_duration, ) @@ -686,12 +687,23 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware): if len(contexts) > 0: context_string = "/n".join(contexts).strip() prompt = get_last_user_message(body["messages"]) - body["messages"] = add_or_update_system_message( - rag_template( - rag_app.state.config.RAG_TEMPLATE, context_string, prompt - ), - body["messages"], - ) + + # Workaround for Ollama 2.0+ system prompt issue + # TODO: replace with add_or_update_system_message + if model["owned_by"] == "ollama": + body["messages"] = prepend_to_first_user_message_content( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) + else: + body["messages"] = add_or_update_system_message( + rag_template( + rag_app.state.config.RAG_TEMPLATE, context_string, prompt + ), + body["messages"], + ) # If there are citations, add them to the data_items if len(citations) > 0: diff --git a/backend/utils/misc.py b/backend/utils/misc.py index 5a05f167d..f44a7ce7a 100644 --- a/backend/utils/misc.py +++ b/backend/utils/misc.py @@ -53,6 +53,21 @@ def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]: return get_system_message(messages), remove_system_message(messages) +def prepend_to_first_user_message_content( + content: str, messages: List[dict] +) -> List[dict]: + for message in messages: + if message["role"] == "user": + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"] == "text": + item["text"] = f"{content}\n{item['text']}" + else: + message["content"] = f"{content}\n{message['content']}" + break + return messages + + def add_or_update_system_message(content: str, messages: List[dict]): """ Adds a new system message at the beginning of the messages list