mirror of
https://github.com/open-webui/open-webui
synced 2025-05-31 11:00:49 +00:00
refac
This commit is contained in:
parent
08cc20cb93
commit
448ca9d836
212
backend/main.py
212
backend/main.py
@ -170,6 +170,13 @@ app.state.MODELS = {}
|
|||||||
origins = ["*"]
|
origins = ["*"]
|
||||||
|
|
||||||
|
|
||||||
|
##################################
|
||||||
|
#
|
||||||
|
# ChatCompletion Middleware
|
||||||
|
#
|
||||||
|
##################################
|
||||||
|
|
||||||
|
|
||||||
async def get_function_call_response(
|
async def get_function_call_response(
|
||||||
messages, files, tool_id, template, task_model_id, user
|
messages, files, tool_id, template, task_model_id, user
|
||||||
):
|
):
|
||||||
@ -469,6 +476,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
app.add_middleware(ChatCompletionMiddleware)
|
app.add_middleware(ChatCompletionMiddleware)
|
||||||
|
|
||||||
|
##################################
|
||||||
|
#
|
||||||
|
# Pipeline Middleware
|
||||||
|
#
|
||||||
|
##################################
|
||||||
|
|
||||||
|
|
||||||
def filter_pipeline(payload, user):
|
def filter_pipeline(payload, user):
|
||||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
@ -628,7 +641,6 @@ async def update_embedding_function(request: Request, call_next):
|
|||||||
|
|
||||||
app.mount("/ws", socket_app)
|
app.mount("/ws", socket_app)
|
||||||
|
|
||||||
|
|
||||||
app.mount("/ollama", ollama_app)
|
app.mount("/ollama", ollama_app)
|
||||||
app.mount("/openai", openai_app)
|
app.mount("/openai", openai_app)
|
||||||
|
|
||||||
@ -730,6 +742,104 @@ async def get_models(user=Depends(get_verified_user)):
|
|||||||
return {"data": models}
|
return {"data": models}
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/chat/completions")
|
||||||
|
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||||
|
model_id = form_data["model"]
|
||||||
|
if model_id not in app.state.MODELS:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="Model not found",
|
||||||
|
)
|
||||||
|
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
print(model)
|
||||||
|
|
||||||
|
if model["owned_by"] == "ollama":
|
||||||
|
return await generate_ollama_chat_completion(form_data, user=user)
|
||||||
|
else:
|
||||||
|
return await generate_openai_chat_completion(form_data, user=user)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/api/chat/completed")
|
||||||
|
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||||
|
data = form_data
|
||||||
|
model_id = data["model"]
|
||||||
|
|
||||||
|
filters = [
|
||||||
|
model
|
||||||
|
for model in app.state.MODELS.values()
|
||||||
|
if "pipeline" in model
|
||||||
|
and "type" in model["pipeline"]
|
||||||
|
and model["pipeline"]["type"] == "filter"
|
||||||
|
and (
|
||||||
|
model["pipeline"]["pipelines"] == ["*"]
|
||||||
|
or any(
|
||||||
|
model_id == target_model_id
|
||||||
|
for target_model_id in model["pipeline"]["pipelines"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||||
|
|
||||||
|
print(model_id)
|
||||||
|
|
||||||
|
if model_id in app.state.MODELS:
|
||||||
|
model = app.state.MODELS[model_id]
|
||||||
|
if "pipeline" in model:
|
||||||
|
sorted_filters = [model] + sorted_filters
|
||||||
|
|
||||||
|
for filter in sorted_filters:
|
||||||
|
r = None
|
||||||
|
try:
|
||||||
|
urlIdx = filter["urlIdx"]
|
||||||
|
|
||||||
|
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
|
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
|
if key != "":
|
||||||
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/{filter['id']}/filter/outlet",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"user": {"id": user.id, "name": user.name, "role": user.role},
|
||||||
|
"body": data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
except Exception as e:
|
||||||
|
# Handle connection error here
|
||||||
|
print(f"Connection error: {e}")
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
try:
|
||||||
|
res = r.json()
|
||||||
|
if "detail" in res:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=r.status_code,
|
||||||
|
content=res,
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
##################################
|
||||||
|
#
|
||||||
|
# Task Endpoints
|
||||||
|
#
|
||||||
|
##################################
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: Refactor task API endpoints below into a separate file
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/task/config")
|
@app.get("/api/task/config")
|
||||||
async def get_task_config(user=Depends(get_verified_user)):
|
async def get_task_config(user=Depends(get_verified_user)):
|
||||||
return {
|
return {
|
||||||
@ -1015,92 +1125,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completions")
|
##################################
|
||||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
#
|
||||||
model_id = form_data["model"]
|
# Pipelines Endpoints
|
||||||
if model_id not in app.state.MODELS:
|
#
|
||||||
raise HTTPException(
|
##################################
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="Model not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
model = app.state.MODELS[model_id]
|
|
||||||
print(model)
|
|
||||||
|
|
||||||
if model["owned_by"] == "ollama":
|
|
||||||
return await generate_ollama_chat_completion(form_data, user=user)
|
|
||||||
else:
|
|
||||||
return await generate_openai_chat_completion(form_data, user=user)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completed")
|
# TODO: Refactor pipelines API endpoints below into a separate file
|
||||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
||||||
data = form_data
|
|
||||||
model_id = data["model"]
|
|
||||||
|
|
||||||
filters = [
|
|
||||||
model
|
|
||||||
for model in app.state.MODELS.values()
|
|
||||||
if "pipeline" in model
|
|
||||||
and "type" in model["pipeline"]
|
|
||||||
and model["pipeline"]["type"] == "filter"
|
|
||||||
and (
|
|
||||||
model["pipeline"]["pipelines"] == ["*"]
|
|
||||||
or any(
|
|
||||||
model_id == target_model_id
|
|
||||||
for target_model_id in model["pipeline"]["pipelines"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
||||||
|
|
||||||
print(model_id)
|
|
||||||
|
|
||||||
if model_id in app.state.MODELS:
|
|
||||||
model = app.state.MODELS[model_id]
|
|
||||||
if "pipeline" in model:
|
|
||||||
sorted_filters = [model] + sorted_filters
|
|
||||||
|
|
||||||
for filter in sorted_filters:
|
|
||||||
r = None
|
|
||||||
try:
|
|
||||||
urlIdx = filter["urlIdx"]
|
|
||||||
|
|
||||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
||||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
||||||
|
|
||||||
if key != "":
|
|
||||||
headers = {"Authorization": f"Bearer {key}"}
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/{filter['id']}/filter/outlet",
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"user": {"id": user.id, "name": user.name, "role": user.role},
|
|
||||||
"body": data,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
except Exception as e:
|
|
||||||
# Handle connection error here
|
|
||||||
print(f"Connection error: {e}")
|
|
||||||
|
|
||||||
if r is not None:
|
|
||||||
try:
|
|
||||||
res = r.json()
|
|
||||||
if "detail" in res:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=r.status_code,
|
|
||||||
content=res,
|
|
||||||
)
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/pipelines/list")
|
@app.get("/api/pipelines/list")
|
||||||
@ -1423,6 +1455,13 @@ async def update_pipeline_valves(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
##################################
|
||||||
|
#
|
||||||
|
# Config Endpoints
|
||||||
|
#
|
||||||
|
##################################
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/config")
|
@app.get("/api/config")
|
||||||
async def get_app_config():
|
async def get_app_config():
|
||||||
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
|
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
|
||||||
@ -1486,6 +1525,9 @@ async def update_model_filter_config(
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: webhook endpoint should be under config endpoints
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/webhook")
|
@app.get("/api/webhook")
|
||||||
async def get_webhook_url(user=Depends(get_admin_user)):
|
async def get_webhook_url(user=Depends(get_admin_user)):
|
||||||
return {
|
return {
|
||||||
|
@ -30,9 +30,10 @@
|
|||||||
let boilerplate = `from pydantic import BaseModel
|
let boilerplate = `from pydantic import BaseModel
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
class Valves(BaseModel):
|
class Valves(BaseModel):
|
||||||
max_turns: int
|
max_turns: int = 4
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@ -42,7 +43,7 @@ class Filter:
|
|||||||
|
|
||||||
# Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
|
# Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
|
||||||
# which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
|
# which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
|
||||||
self.valves = self.Valves(**{"max_turns": 10})
|
self.valves = self.Valves(**{"max_turns": 2})
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
|
def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
|
||||||
|
Loading…
Reference in New Issue
Block a user