refac: tools & rag

This commit is contained in:
Timothy J. Baek 2024-06-11 01:10:24 -07:00
parent fc46532955
commit bd5a8567ef
2 changed files with 27 additions and 35 deletions

View File

@ -236,10 +236,9 @@ def get_embedding_function(
return lambda query: generate_multiple(query, func)
def rag_messages(
def get_rag_context(
docs,
messages,
template,
embedding_function,
k,
reranking_function,
@ -318,16 +317,7 @@ def rag_messages(
context_string = context_string.strip()
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
messages = add_or_update_system_message(ra_content, messages)
return messages, citations
return context_string, citations
def get_model_path(model: str, update_model: bool = False):

View File

@ -64,7 +64,7 @@ from utils.task import (
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import rag_messages, rag_template
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
@ -248,6 +248,10 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
@ -276,13 +280,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if "tool_ids" in data:
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
prompt = get_last_user_message(data["messages"])
context = ""
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
prompt = get_last_user_message(data["messages"])
for tool_id in data["tool_ids"]:
print(tool_id)
response = await get_function_call_response(
@ -295,37 +297,37 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if response:
context += ("\n" if context != "" else "") + response
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
del data["tool_ids"]
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
data["messages"], citations = rag_messages(
rag_context, citations = get_rag_context(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.config.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
del data["docs"]
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")