refac: chat completion middleware

This commit is contained in:
Timothy J. Baek 2024-07-01 19:33:58 -07:00
parent b62d2a9b28
commit c7a9b5ccfa
3 changed files with 304 additions and 223 deletions

View File

@ -294,14 +294,16 @@ def get_rag_context(
extracted_collections.extend(collection_names)
context_string = ""
contexts = []
citations = []
for context in relevant_contexts:
try:
if "documents" in context:
context_string += "\n\n".join(
[text for text in context["documents"][0] if text is not None]
contexts.append(
"\n\n".join(
[text for text in context["documents"][0] if text is not None]
)
)
if "metadatas" in context:
@ -315,9 +317,7 @@ def get_rag_context(
except Exception as e:
log.exception(e)
context_string = context_string.strip()
return context_string, citations
return contexts, citations
def get_model_path(model: str, update_model: bool = False):

View File

@ -213,7 +213,7 @@ origins = ["*"]
async def get_function_call_response(
messages, files, tool_id, template, task_model_id, user
messages, files, tool_id, template, task_model_id, user, model
):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
@ -373,233 +373,308 @@ async def get_function_call_response(
return None, None, False
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model):
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get("priority", 0)
return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type("filter", active_only=True)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
return filter_ids
async def chat_completion_functions_handler(body, model, user):
skip_files = None
filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": body}
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if inspect.iscoroutinefunction(inlet):
body = await inlet(**params)
else:
body = inlet(**params)
except Exception as e:
print(f"Error: {e}")
raise e
if skip_files:
if "files" in body:
del body["files"]
return body, {}
async def chat_completion_tools_handler(body, model, user):
skip_files = None
contexts = []
citations = None
task_model_id = get_task_model_id(body["model"])
# If tool_ids field is present, call the functions
if "tool_ids" in body:
print(body["tool_ids"])
for tool_id in body["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = await get_function_call_response(
messages=body["messages"],
files=body.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
model=model,
)
print(file_handler)
if isinstance(response, str):
contexts.append(response)
if citation:
if citations is None:
citations = [citation]
else:
citations.append(citation)
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del body["tool_ids"]
print(f"tool_contexts: {contexts}")
if skip_files:
if "files" in body:
del body["files"]
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
async def chat_completion_files_handler(body):
contexts = []
citations = None
if "files" in body:
files = body["files"]
del body["files"]
contexts, citations = get_rag_context(
files=files,
messages=body["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
return body, {
**({"contexts": contexts} if contexts is not None else {}),
**({"citations": citations} if citations is not None else {}),
}
async def get_body_and_model_and_user(request):
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {}
model_id = body["model"]
if model_id not in app.state.MODELS:
raise "Model not found"
model = app.state.MODELS[model_id]
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
return body, model, user
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
data_items = []
show_citations = False
citations = []
if request.method == "POST" and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
):
log.debug(f"request.url.path: {request.url.path}")
# Read the original request body
body = await request.body()
body_str = body.decode("utf-8")
data = json.loads(body_str) if body_str else {}
user = get_current_user(
request,
get_http_authorization_cred(request.headers.get("Authorization")),
)
# Flag to skip RAG completions if file_handler is present in tools/functions
skip_files = False
if data.get("citations"):
show_citations = True
del data["citations"]
model_id = data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
try:
body, model, user = await get_body_and_model_and_user(request)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
model = app.state.MODELS[model_id]
def get_priority(function_id):
function = Functions.get_function_by_id(function_id)
if function is not None and hasattr(function, "valves"):
return (function.valves if function.valves else {}).get(
"priority", 0
)
return 0
# Extract chat_id and message_id from the request body
chat_id = None
if "chat_id" in body:
chat_id = body["chat_id"]
del body["chat_id"]
message_id = None
if "id" in body:
message_id = body["id"]
del body["id"]
filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids))
# Initialize data_items to store additional data to be sent to the client
data_items = []
enabled_filter_ids = [
function.id
for function in Functions.get_functions_by_type(
"filter", active_only=True
# Initialize context, and citations
contexts = []
citations = []
print(body)
try:
body, flags = await chat_completion_functions_handler(body, model, user)
except Exception as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
]
filter_ids = [
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
]
filter_ids.sort(key=get_priority)
for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id)
if filter:
if filter_id in webui_app.state.FUNCTIONS:
function_module = webui_app.state.FUNCTIONS[filter_id]
else:
function_module, function_type, frontmatter = (
load_function_module_by_id(filter_id)
)
webui_app.state.FUNCTIONS[filter_id] = function_module
try:
body, flags = await chat_completion_tools_handler(body, model, user)
# Check if the function has a file_handler variable
if hasattr(function_module, "file_handler"):
skip_files = function_module.file_handler
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
if hasattr(function_module, "valves") and hasattr(
function_module, "Valves"
):
valves = Functions.get_function_valves_by_id(filter_id)
function_module.valves = function_module.Valves(
**(valves if valves else {})
)
try:
body, flags = await chat_completion_files_handler(body)
try:
if hasattr(function_module, "inlet"):
inlet = function_module.inlet
contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", []))
except Exception as e:
print(e)
pass
# Get the signature of the function
sig = inspect.signature(inlet)
params = {"body": data}
# If context is not empty, insert it into the messages
if len(contexts) > 0:
context_string = "/n".join(contexts).strip()
prompt = get_last_user_message(body["messages"])
body["messages"] = add_or_update_system_message(
rag_template(
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"],
)
if "__user__" in sig.parameters:
__user__ = {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
}
try:
if hasattr(function_module, "UserValves"):
__user__["valves"] = function_module.UserValves(
**Functions.get_user_valves_by_id_and_user_id(
filter_id, user.id
)
)
except Exception as e:
print(e)
params = {**params, "__user__": __user__}
if "__id__" in sig.parameters:
params = {
**params,
"__id__": filter_id,
}
if "__model__" in sig.parameters:
params = {
**params,
"__model__": model,
}
if inspect.iscoroutinefunction(inlet):
data = await inlet(**params)
else:
data = inlet(**params)
except Exception as e:
print(f"Error: {e}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)},
)
# Set the task model
task_model_id = data["model"]
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response, citation, file_handler = (
await get_function_call_response(
messages=data["messages"],
files=data.get("files", []),
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
)
print(file_handler)
if isinstance(response, str):
context += ("\n" if context != "" else "") + response
if citation:
citations.append(citation)
show_citations = True
if file_handler:
skip_files = True
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# If files field is present, generate RAG completions
# If skip_files is True, skip the RAG completions
if "files" in data:
if not skip_files:
data = {**data}
rag_context, rag_citations = get_rag_context(
files=data["files"],
messages=data["messages"],
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if rag_citations:
citations.extend(rag_citations)
del data["files"]
if show_citations and len(citations) > 0:
# If there are citations, add them to the data_items
if len(citations) > 0:
data_items.append({"citations": citations})
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
system_prompt, data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
modified_body_bytes = json.dumps(body).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
@ -721,9 +796,6 @@ def filter_pipeline(payload, user):
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)
@ -1349,6 +1427,9 @@ Message: """{{prompt}}"""
content={"detail": e.args[1]},
)
if "chat_id" in payload:
del payload["chat_id"]
return await generate_chat_completions(form_data=payload, user=user)

View File

@ -665,6 +665,7 @@
await tick();
const [res, controller] = await generateChatCompletion(localStorage.token, {
stream: true,
model: model.id,
messages: messagesBody,
options: {
@ -682,8 +683,8 @@
keep_alive: $settings.keepAlive ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
chat_id: $chatId,
id: responseMessageId
});
if (res && res.ok) {
@ -912,8 +913,8 @@
const [res, controller] = await generateOpenAIChatCompletion(
localStorage.token,
{
model: model.id,
stream: true,
model: model.id,
stream_options:
model.info?.meta?.capabilities?.usage ?? false
? {
@ -983,9 +984,8 @@
max_tokens: $settings?.params?.max_tokens ?? undefined,
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
files: files.length > 0 ? files : undefined,
citations: files.length > 0 ? true : undefined,
chat_id: $chatId
chat_id: $chatId,
id: responseMessageId
},
`${WEBUI_BASE_URL}/api`
);