mirror of
https://github.com/open-webui/open-webui
synced 2025-03-22 04:58:27 +00:00
refac: chat completion middleware
This commit is contained in:
parent
b62d2a9b28
commit
c7a9b5ccfa
@ -294,14 +294,16 @@ def get_rag_context(
|
|||||||
|
|
||||||
extracted_collections.extend(collection_names)
|
extracted_collections.extend(collection_names)
|
||||||
|
|
||||||
context_string = ""
|
contexts = []
|
||||||
|
|
||||||
citations = []
|
citations = []
|
||||||
|
|
||||||
for context in relevant_contexts:
|
for context in relevant_contexts:
|
||||||
try:
|
try:
|
||||||
if "documents" in context:
|
if "documents" in context:
|
||||||
context_string += "\n\n".join(
|
contexts.append(
|
||||||
[text for text in context["documents"][0] if text is not None]
|
"\n\n".join(
|
||||||
|
[text for text in context["documents"][0] if text is not None]
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if "metadatas" in context:
|
if "metadatas" in context:
|
||||||
@ -315,9 +317,7 @@ def get_rag_context(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception(e)
|
log.exception(e)
|
||||||
|
|
||||||
context_string = context_string.strip()
|
return contexts, citations
|
||||||
|
|
||||||
return context_string, citations
|
|
||||||
|
|
||||||
|
|
||||||
def get_model_path(model: str, update_model: bool = False):
|
def get_model_path(model: str, update_model: bool = False):
|
||||||
|
501
backend/main.py
501
backend/main.py
@ -213,7 +213,7 @@ origins = ["*"]
|
|||||||
|
|
||||||
|
|
||||||
async def get_function_call_response(
|
async def get_function_call_response(
|
||||||
messages, files, tool_id, template, task_model_id, user
|
messages, files, tool_id, template, task_model_id, user, model
|
||||||
):
|
):
|
||||||
tool = Tools.get_tool_by_id(tool_id)
|
tool = Tools.get_tool_by_id(tool_id)
|
||||||
tools_specs = json.dumps(tool.specs, indent=2)
|
tools_specs = json.dumps(tool.specs, indent=2)
|
||||||
@ -373,233 +373,308 @@ async def get_function_call_response(
|
|||||||
return None, None, False
|
return None, None, False
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_model_id(default_model_id):
|
||||||
|
# Set the task model
|
||||||
|
task_model_id = default_model_id
|
||||||
|
# Check if the user has a custom task model and use that model
|
||||||
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||||
|
if (
|
||||||
|
app.state.config.TASK_MODEL
|
||||||
|
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||||
|
):
|
||||||
|
task_model_id = app.state.config.TASK_MODEL
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||||
|
):
|
||||||
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||||
|
|
||||||
|
return task_model_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_filter_function_ids(model):
|
||||||
|
def get_priority(function_id):
|
||||||
|
function = Functions.get_function_by_id(function_id)
|
||||||
|
if function is not None and hasattr(function, "valves"):
|
||||||
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
|
return 0
|
||||||
|
|
||||||
|
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
||||||
|
if "info" in model and "meta" in model["info"]:
|
||||||
|
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||||
|
filter_ids = list(set(filter_ids))
|
||||||
|
|
||||||
|
enabled_filter_ids = [
|
||||||
|
function.id
|
||||||
|
for function in Functions.get_functions_by_type("filter", active_only=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
filter_ids = [
|
||||||
|
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
filter_ids.sort(key=get_priority)
|
||||||
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_functions_handler(body, model, user):
|
||||||
|
skip_files = None
|
||||||
|
|
||||||
|
filter_ids = get_filter_function_ids(model)
|
||||||
|
for filter_id in filter_ids:
|
||||||
|
filter = Functions.get_function_by_id(filter_id)
|
||||||
|
if filter:
|
||||||
|
if filter_id in webui_app.state.FUNCTIONS:
|
||||||
|
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||||
|
else:
|
||||||
|
function_module, function_type, frontmatter = (
|
||||||
|
load_function_module_by_id(filter_id)
|
||||||
|
)
|
||||||
|
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||||
|
|
||||||
|
# Check if the function has a file_handler variable
|
||||||
|
if hasattr(function_module, "file_handler"):
|
||||||
|
skip_files = function_module.file_handler
|
||||||
|
|
||||||
|
if hasattr(function_module, "valves") and hasattr(
|
||||||
|
function_module, "Valves"
|
||||||
|
):
|
||||||
|
valves = Functions.get_function_valves_by_id(filter_id)
|
||||||
|
function_module.valves = function_module.Valves(
|
||||||
|
**(valves if valves else {})
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(function_module, "inlet"):
|
||||||
|
inlet = function_module.inlet
|
||||||
|
|
||||||
|
# Get the signature of the function
|
||||||
|
sig = inspect.signature(inlet)
|
||||||
|
params = {"body": body}
|
||||||
|
|
||||||
|
if "__user__" in sig.parameters:
|
||||||
|
__user__ = {
|
||||||
|
"id": user.id,
|
||||||
|
"email": user.email,
|
||||||
|
"name": user.name,
|
||||||
|
"role": user.role,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(function_module, "UserValves"):
|
||||||
|
__user__["valves"] = function_module.UserValves(
|
||||||
|
**Functions.get_user_valves_by_id_and_user_id(
|
||||||
|
filter_id, user.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
params = {**params, "__user__": __user__}
|
||||||
|
|
||||||
|
if "__id__" in sig.parameters:
|
||||||
|
params = {
|
||||||
|
**params,
|
||||||
|
"__id__": filter_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if "__model__" in sig.parameters:
|
||||||
|
params = {
|
||||||
|
**params,
|
||||||
|
"__model__": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
if inspect.iscoroutinefunction(inlet):
|
||||||
|
body = await inlet(**params)
|
||||||
|
else:
|
||||||
|
body = inlet(**params)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
if skip_files:
|
||||||
|
if "files" in body:
|
||||||
|
del body["files"]
|
||||||
|
|
||||||
|
return body, {}
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_tools_handler(body, model, user):
|
||||||
|
skip_files = None
|
||||||
|
|
||||||
|
contexts = []
|
||||||
|
citations = None
|
||||||
|
|
||||||
|
task_model_id = get_task_model_id(body["model"])
|
||||||
|
|
||||||
|
# If tool_ids field is present, call the functions
|
||||||
|
if "tool_ids" in body:
|
||||||
|
print(body["tool_ids"])
|
||||||
|
for tool_id in body["tool_ids"]:
|
||||||
|
print(tool_id)
|
||||||
|
try:
|
||||||
|
response, citation, file_handler = await get_function_call_response(
|
||||||
|
messages=body["messages"],
|
||||||
|
files=body.get("files", []),
|
||||||
|
tool_id=tool_id,
|
||||||
|
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||||
|
task_model_id=task_model_id,
|
||||||
|
user=user,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(file_handler)
|
||||||
|
if isinstance(response, str):
|
||||||
|
contexts.append(response)
|
||||||
|
|
||||||
|
if citation:
|
||||||
|
if citations is None:
|
||||||
|
citations = [citation]
|
||||||
|
else:
|
||||||
|
citations.append(citation)
|
||||||
|
|
||||||
|
if file_handler:
|
||||||
|
skip_files = True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
del body["tool_ids"]
|
||||||
|
print(f"tool_contexts: {contexts}")
|
||||||
|
|
||||||
|
if skip_files:
|
||||||
|
if "files" in body:
|
||||||
|
del body["files"]
|
||||||
|
|
||||||
|
return body, {
|
||||||
|
**({"contexts": contexts} if contexts is not None else {}),
|
||||||
|
**({"citations": citations} if citations is not None else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_files_handler(body):
|
||||||
|
contexts = []
|
||||||
|
citations = None
|
||||||
|
|
||||||
|
if "files" in body:
|
||||||
|
files = body["files"]
|
||||||
|
del body["files"]
|
||||||
|
|
||||||
|
contexts, citations = get_rag_context(
|
||||||
|
files=files,
|
||||||
|
messages=body["messages"],
|
||||||
|
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||||
|
k=rag_app.state.config.TOP_K,
|
||||||
|
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||||
|
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||||
|
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
||||||
|
|
||||||
|
return body, {
|
||||||
|
**({"contexts": contexts} if contexts is not None else {}),
|
||||||
|
**({"citations": citations} if citations is not None else {}),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def get_body_and_model_and_user(request):
|
||||||
|
# Read the original request body
|
||||||
|
body = await request.body()
|
||||||
|
body_str = body.decode("utf-8")
|
||||||
|
body = json.loads(body_str) if body_str else {}
|
||||||
|
|
||||||
|
model_id = body["model"]
|
||||||
|
if model_id not in app.state.MODELS:
|
||||||
|
raise "Model not found"
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
|
||||||
|
user = get_current_user(
|
||||||
|
request,
|
||||||
|
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||||
|
)
|
||||||
|
|
||||||
|
return body, model, user
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
data_items = []
|
|
||||||
|
|
||||||
show_citations = False
|
|
||||||
citations = []
|
|
||||||
|
|
||||||
if request.method == "POST" and any(
|
if request.method == "POST" and any(
|
||||||
endpoint in request.url.path
|
endpoint in request.url.path
|
||||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||||
):
|
):
|
||||||
log.debug(f"request.url.path: {request.url.path}")
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
|
|
||||||
# Read the original request body
|
try:
|
||||||
body = await request.body()
|
body, model, user = await get_body_and_model_and_user(request)
|
||||||
body_str = body.decode("utf-8")
|
except Exception as e:
|
||||||
data = json.loads(body_str) if body_str else {}
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
user = get_current_user(
|
content={"detail": str(e)},
|
||||||
request,
|
|
||||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
|
||||||
)
|
|
||||||
# Flag to skip RAG completions if file_handler is present in tools/functions
|
|
||||||
skip_files = False
|
|
||||||
if data.get("citations"):
|
|
||||||
show_citations = True
|
|
||||||
del data["citations"]
|
|
||||||
|
|
||||||
model_id = data["model"]
|
|
||||||
if model_id not in app.state.MODELS:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Model not found",
|
|
||||||
)
|
)
|
||||||
model = app.state.MODELS[model_id]
|
|
||||||
|
|
||||||
def get_priority(function_id):
|
# Extract chat_id and message_id from the request body
|
||||||
function = Functions.get_function_by_id(function_id)
|
chat_id = None
|
||||||
if function is not None and hasattr(function, "valves"):
|
if "chat_id" in body:
|
||||||
return (function.valves if function.valves else {}).get(
|
chat_id = body["chat_id"]
|
||||||
"priority", 0
|
del body["chat_id"]
|
||||||
)
|
message_id = None
|
||||||
return 0
|
if "id" in body:
|
||||||
|
message_id = body["id"]
|
||||||
|
del body["id"]
|
||||||
|
|
||||||
filter_ids = [
|
# Initialize data_items to store additional data to be sent to the client
|
||||||
function.id for function in Functions.get_global_filter_functions()
|
data_items = []
|
||||||
]
|
|
||||||
if "info" in model and "meta" in model["info"]:
|
|
||||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
||||||
filter_ids = list(set(filter_ids))
|
|
||||||
|
|
||||||
enabled_filter_ids = [
|
# Initialize context, and citations
|
||||||
function.id
|
contexts = []
|
||||||
for function in Functions.get_functions_by_type(
|
citations = []
|
||||||
"filter", active_only=True
|
|
||||||
|
print(body)
|
||||||
|
|
||||||
|
try:
|
||||||
|
body, flags = await chat_completion_functions_handler(body, model, user)
|
||||||
|
except Exception as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={"detail": str(e)},
|
||||||
)
|
)
|
||||||
]
|
|
||||||
filter_ids = [
|
|
||||||
filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
||||||
]
|
|
||||||
|
|
||||||
filter_ids.sort(key=get_priority)
|
try:
|
||||||
for filter_id in filter_ids:
|
body, flags = await chat_completion_tools_handler(body, model, user)
|
||||||
filter = Functions.get_function_by_id(filter_id)
|
|
||||||
if filter:
|
|
||||||
if filter_id in webui_app.state.FUNCTIONS:
|
|
||||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
||||||
else:
|
|
||||||
function_module, function_type, frontmatter = (
|
|
||||||
load_function_module_by_id(filter_id)
|
|
||||||
)
|
|
||||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
||||||
|
|
||||||
# Check if the function has a file_handler variable
|
contexts.extend(flags.get("contexts", []))
|
||||||
if hasattr(function_module, "file_handler"):
|
citations.extend(flags.get("citations", []))
|
||||||
skip_files = function_module.file_handler
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
if hasattr(function_module, "valves") and hasattr(
|
try:
|
||||||
function_module, "Valves"
|
body, flags = await chat_completion_files_handler(body)
|
||||||
):
|
|
||||||
valves = Functions.get_function_valves_by_id(filter_id)
|
|
||||||
function_module.valves = function_module.Valves(
|
|
||||||
**(valves if valves else {})
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
contexts.extend(flags.get("contexts", []))
|
||||||
if hasattr(function_module, "inlet"):
|
citations.extend(flags.get("citations", []))
|
||||||
inlet = function_module.inlet
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
pass
|
||||||
|
|
||||||
# Get the signature of the function
|
# If context is not empty, insert it into the messages
|
||||||
sig = inspect.signature(inlet)
|
if len(contexts) > 0:
|
||||||
params = {"body": data}
|
context_string = "/n".join(contexts).strip()
|
||||||
|
prompt = get_last_user_message(body["messages"])
|
||||||
|
body["messages"] = add_or_update_system_message(
|
||||||
|
rag_template(
|
||||||
|
rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
||||||
|
),
|
||||||
|
body["messages"],
|
||||||
|
)
|
||||||
|
|
||||||
if "__user__" in sig.parameters:
|
# If there are citations, add them to the data_items
|
||||||
__user__ = {
|
if len(citations) > 0:
|
||||||
"id": user.id,
|
|
||||||
"email": user.email,
|
|
||||||
"name": user.name,
|
|
||||||
"role": user.role,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(function_module, "UserValves"):
|
|
||||||
__user__["valves"] = function_module.UserValves(
|
|
||||||
**Functions.get_user_valves_by_id_and_user_id(
|
|
||||||
filter_id, user.id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
params = {**params, "__user__": __user__}
|
|
||||||
|
|
||||||
if "__id__" in sig.parameters:
|
|
||||||
params = {
|
|
||||||
**params,
|
|
||||||
"__id__": filter_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
if "__model__" in sig.parameters:
|
|
||||||
params = {
|
|
||||||
**params,
|
|
||||||
"__model__": model,
|
|
||||||
}
|
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(inlet):
|
|
||||||
data = await inlet(**params)
|
|
||||||
else:
|
|
||||||
data = inlet(**params)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
content={"detail": str(e)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set the task model
|
|
||||||
task_model_id = data["model"]
|
|
||||||
# Check if the user has a custom task model and use that model
|
|
||||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
||||||
if (
|
|
||||||
app.state.config.TASK_MODEL
|
|
||||||
and app.state.config.TASK_MODEL in app.state.MODELS
|
|
||||||
):
|
|
||||||
task_model_id = app.state.config.TASK_MODEL
|
|
||||||
else:
|
|
||||||
if (
|
|
||||||
app.state.config.TASK_MODEL_EXTERNAL
|
|
||||||
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
||||||
):
|
|
||||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
||||||
|
|
||||||
prompt = get_last_user_message(data["messages"])
|
|
||||||
context = ""
|
|
||||||
|
|
||||||
# If tool_ids field is present, call the functions
|
|
||||||
if "tool_ids" in data:
|
|
||||||
print(data["tool_ids"])
|
|
||||||
for tool_id in data["tool_ids"]:
|
|
||||||
print(tool_id)
|
|
||||||
try:
|
|
||||||
response, citation, file_handler = (
|
|
||||||
await get_function_call_response(
|
|
||||||
messages=data["messages"],
|
|
||||||
files=data.get("files", []),
|
|
||||||
tool_id=tool_id,
|
|
||||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
||||||
task_model_id=task_model_id,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print(file_handler)
|
|
||||||
if isinstance(response, str):
|
|
||||||
context += ("\n" if context != "" else "") + response
|
|
||||||
|
|
||||||
if citation:
|
|
||||||
citations.append(citation)
|
|
||||||
show_citations = True
|
|
||||||
|
|
||||||
if file_handler:
|
|
||||||
skip_files = True
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error: {e}")
|
|
||||||
del data["tool_ids"]
|
|
||||||
|
|
||||||
print(f"tool_context: {context}")
|
|
||||||
|
|
||||||
# If files field is present, generate RAG completions
|
|
||||||
# If skip_files is True, skip the RAG completions
|
|
||||||
if "files" in data:
|
|
||||||
if not skip_files:
|
|
||||||
data = {**data}
|
|
||||||
rag_context, rag_citations = get_rag_context(
|
|
||||||
files=data["files"],
|
|
||||||
messages=data["messages"],
|
|
||||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
||||||
k=rag_app.state.config.TOP_K,
|
|
||||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
|
||||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
||||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
||||||
)
|
|
||||||
if rag_context:
|
|
||||||
context += ("\n" if context != "" else "") + rag_context
|
|
||||||
|
|
||||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
||||||
|
|
||||||
if rag_citations:
|
|
||||||
citations.extend(rag_citations)
|
|
||||||
|
|
||||||
del data["files"]
|
|
||||||
|
|
||||||
if show_citations and len(citations) > 0:
|
|
||||||
data_items.append({"citations": citations})
|
data_items.append({"citations": citations})
|
||||||
|
|
||||||
if context != "":
|
modified_body_bytes = json.dumps(body).encode("utf-8")
|
||||||
system_prompt = rag_template(
|
|
||||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
||||||
)
|
|
||||||
print(system_prompt)
|
|
||||||
data["messages"] = add_or_update_system_message(
|
|
||||||
system_prompt, data["messages"]
|
|
||||||
)
|
|
||||||
|
|
||||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
||||||
# Replace the request body with the modified one
|
# Replace the request body with the modified one
|
||||||
request._body = modified_body_bytes
|
request._body = modified_body_bytes
|
||||||
# Set custom header to ensure content-length matches new body length
|
# Set custom header to ensure content-length matches new body length
|
||||||
@ -721,9 +796,6 @@ def filter_pipeline(payload, user):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
if "pipeline" not in app.state.MODELS[model_id]:
|
if "pipeline" not in app.state.MODELS[model_id]:
|
||||||
if "chat_id" in payload:
|
|
||||||
del payload["chat_id"]
|
|
||||||
|
|
||||||
if "title" in payload:
|
if "title" in payload:
|
||||||
del payload["title"]
|
del payload["title"]
|
||||||
|
|
||||||
@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
content={"detail": e.args[1]},
|
content={"detail": e.args[1]},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "chat_id" in payload:
|
||||||
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|||||||
content={"detail": e.args[1]},
|
content={"detail": e.args[1]},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "chat_id" in payload:
|
||||||
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
@ -1349,6 +1427,9 @@ Message: """{{prompt}}"""
|
|||||||
content={"detail": e.args[1]},
|
content={"detail": e.args[1]},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if "chat_id" in payload:
|
||||||
|
del payload["chat_id"]
|
||||||
|
|
||||||
return await generate_chat_completions(form_data=payload, user=user)
|
return await generate_chat_completions(form_data=payload, user=user)
|
||||||
|
|
||||||
|
|
||||||
|
@ -665,6 +665,7 @@
|
|||||||
await tick();
|
await tick();
|
||||||
|
|
||||||
const [res, controller] = await generateChatCompletion(localStorage.token, {
|
const [res, controller] = await generateChatCompletion(localStorage.token, {
|
||||||
|
stream: true,
|
||||||
model: model.id,
|
model: model.id,
|
||||||
messages: messagesBody,
|
messages: messagesBody,
|
||||||
options: {
|
options: {
|
||||||
@ -682,8 +683,8 @@
|
|||||||
keep_alive: $settings.keepAlive ?? undefined,
|
keep_alive: $settings.keepAlive ?? undefined,
|
||||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||||
files: files.length > 0 ? files : undefined,
|
files: files.length > 0 ? files : undefined,
|
||||||
citations: files.length > 0 ? true : undefined,
|
chat_id: $chatId,
|
||||||
chat_id: $chatId
|
id: responseMessageId
|
||||||
});
|
});
|
||||||
|
|
||||||
if (res && res.ok) {
|
if (res && res.ok) {
|
||||||
@ -912,8 +913,8 @@
|
|||||||
const [res, controller] = await generateOpenAIChatCompletion(
|
const [res, controller] = await generateOpenAIChatCompletion(
|
||||||
localStorage.token,
|
localStorage.token,
|
||||||
{
|
{
|
||||||
model: model.id,
|
|
||||||
stream: true,
|
stream: true,
|
||||||
|
model: model.id,
|
||||||
stream_options:
|
stream_options:
|
||||||
model.info?.meta?.capabilities?.usage ?? false
|
model.info?.meta?.capabilities?.usage ?? false
|
||||||
? {
|
? {
|
||||||
@ -983,9 +984,8 @@
|
|||||||
max_tokens: $settings?.params?.max_tokens ?? undefined,
|
max_tokens: $settings?.params?.max_tokens ?? undefined,
|
||||||
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
|
||||||
files: files.length > 0 ? files : undefined,
|
files: files.length > 0 ? files : undefined,
|
||||||
citations: files.length > 0 ? true : undefined,
|
chat_id: $chatId,
|
||||||
|
id: responseMessageId
|
||||||
chat_id: $chatId
|
|
||||||
},
|
},
|
||||||
`${WEBUI_BASE_URL}/api`
|
`${WEBUI_BASE_URL}/api`
|
||||||
);
|
);
|
||||||
|
Loading…
Reference in New Issue
Block a user