refac: tools & rag

This commit is contained in:
Timothy J. Baek 2024-06-11 01:10:24 -07:00
parent fc46532955
commit bd5a8567ef
2 changed files with 27 additions and 35 deletions

View File

@ -236,10 +236,9 @@ def get_embedding_function(
return lambda query: generate_multiple(query, func) return lambda query: generate_multiple(query, func)
def rag_messages( def get_rag_context(
docs, docs,
messages, messages,
template,
embedding_function, embedding_function,
k, k,
reranking_function, reranking_function,
@ -318,16 +317,7 @@ def rag_messages(
context_string = context_string.strip() context_string = context_string.strip()
ra_content = rag_template( return context_string, citations
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
messages = add_or_update_system_message(ra_content, messages)
return messages, citations
def get_model_path(model: str, update_model: bool = False): def get_model_path(model: str, update_model: bool = False):

View File

@ -64,7 +64,7 @@ from utils.task import (
) )
from utils.misc import get_last_user_message, add_or_update_system_message from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import rag_messages, rag_template from apps.rag.utils import get_rag_context, rag_template
from config import ( from config import (
CONFIG_DATA, CONFIG_DATA,
@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Parse string to JSON # Parse string to JSON
data = json.loads(body_str) if body_str else {} data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
# Remove the citations from the body # Remove the citations from the body
return_citations = data.get("citations", False) return_citations = data.get("citations", False)
if "citations" in data: if "citations" in data:
@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
): ):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if "tool_ids" in data:
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
prompt = get_last_user_message(data["messages"])
context = "" context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
prompt = get_last_user_message(data["messages"])
for tool_id in data["tool_ids"]: for tool_id in data["tool_ids"]:
print(tool_id) print(tool_id)
response = await get_function_call_response( response = await get_function_call_response(
@ -295,6 +297,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if response: if response:
context += ("\n" if context != "" else "") + response context += ("\n" if context != "" else "") + response
del data["tool_ids"]
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
rag_context, citations = get_rag_context(
docs=data["docs"],
messages=data["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
del data["docs"]
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "": if context != "":
system_prompt = rag_template( system_prompt = rag_template(
@ -307,27 +330,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
f"\n{system_prompt}", data["messages"] f"\n{system_prompt}", data["messages"]
) )
del data["tool_ids"]
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
data["messages"], citations = rag_messages(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.config.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
del data["docs"]
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
)
modified_body_bytes = json.dumps(data).encode("utf-8") modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one # Replace the request body with the modified one