mirror of
https://github.com/open-webui/open-webui
synced 2025-06-23 02:16:52 +00:00
wip
This commit is contained in:
parent
772f5ccd60
commit
fe5519e0a2
@ -75,6 +75,11 @@ from open_webui.routers.retrieval import (
|
|||||||
get_ef,
|
get_ef,
|
||||||
get_rf,
|
get_rf,
|
||||||
)
|
)
|
||||||
|
from open_webui.routers.pipelines import (
|
||||||
|
process_pipeline_inlet_filter,
|
||||||
|
process_pipeline_outlet_filter,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.retrieval.utils import get_sources_from_files
|
from open_webui.retrieval.utils import get_sources_from_files
|
||||||
|
|
||||||
|
|
||||||
@ -290,6 +295,7 @@ from open_webui.utils.response import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from open_webui.utils.task import (
|
from open_webui.utils.task import (
|
||||||
|
get_task_model_id,
|
||||||
rag_template,
|
rag_template,
|
||||||
tools_function_calling_generation_template,
|
tools_function_calling_generation_template,
|
||||||
)
|
)
|
||||||
@ -662,7 +668,10 @@ app.state.MODELS = {}
|
|||||||
##################################
|
##################################
|
||||||
|
|
||||||
|
|
||||||
def get_filter_function_ids(model):
|
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
||||||
|
skip_files = None
|
||||||
|
|
||||||
|
def get_filter_function_ids(model):
|
||||||
def get_priority(function_id):
|
def get_priority(function_id):
|
||||||
function = Functions.get_function_by_id(function_id)
|
function = Functions.get_function_by_id(function_id)
|
||||||
if function is not None and hasattr(function, "valves"):
|
if function is not None and hasattr(function, "valves"):
|
||||||
@ -670,7 +679,9 @@ def get_filter_function_ids(model):
|
|||||||
return (function.valves if function.valves else {}).get("priority", 0)
|
return (function.valves if function.valves else {}).get("priority", 0)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
filter_ids = [
|
||||||
|
function.id for function in Functions.get_global_filter_functions()
|
||||||
|
]
|
||||||
if "info" in model and "meta" in model["info"]:
|
if "info" in model and "meta" in model["info"]:
|
||||||
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
||||||
filter_ids = list(set(filter_ids))
|
filter_ids = list(set(filter_ids))
|
||||||
@ -687,10 +698,6 @@ def get_filter_function_ids(model):
|
|||||||
filter_ids.sort(key=get_priority)
|
filter_ids.sort(key=get_priority)
|
||||||
return filter_ids
|
return filter_ids
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
||||||
skip_files = None
|
|
||||||
|
|
||||||
filter_ids = get_filter_function_ids(model)
|
filter_ids = get_filter_function_ids(model)
|
||||||
for filter_id in filter_ids:
|
for filter_id in filter_ids:
|
||||||
filter = Functions.get_function_by_id(filter_id)
|
filter = Functions.get_function_by_id(filter_id)
|
||||||
@ -791,22 +798,6 @@ async def get_content_from_response(response) -> Optional[str]:
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
def get_task_model_id(
|
|
||||||
default_model_id: str, task_model: str, task_model_external: str, models
|
|
||||||
) -> str:
|
|
||||||
# Set the task model
|
|
||||||
task_model_id = default_model_id
|
|
||||||
# Check if the user has a custom task model and use that model
|
|
||||||
if models[task_model_id]["owned_by"] == "ollama":
|
|
||||||
if task_model and task_model in models:
|
|
||||||
task_model_id = task_model
|
|
||||||
else:
|
|
||||||
if task_model_external and task_model_external in models:
|
|
||||||
task_model_id = task_model_external
|
|
||||||
|
|
||||||
return task_model_id
|
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion_tools_handler(
|
async def chat_completion_tools_handler(
|
||||||
body: dict, user: UserModel, models, extra_params: dict
|
body: dict, user: UserModel, models, extra_params: dict
|
||||||
) -> tuple[dict, dict]:
|
) -> tuple[dict, dict]:
|
||||||
@ -857,7 +848,7 @@ async def chat_completion_tools_handler(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = process_pipeline_inlet_filter(request, payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
@ -1153,7 +1144,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
if prompt is None:
|
if prompt is None:
|
||||||
raise Exception("No user message found")
|
raise Exception("No user message found")
|
||||||
if (
|
if (
|
||||||
retrieval_app.state.config.RELEVANCE_THRESHOLD == 0
|
app.state.config.RELEVANCE_THRESHOLD == 0
|
||||||
and context_string.strip() == ""
|
and context_string.strip() == ""
|
||||||
):
|
):
|
||||||
log.debug(
|
log.debug(
|
||||||
@ -1164,16 +1155,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
# TODO: replace with add_or_update_system_message
|
# TODO: replace with add_or_update_system_message
|
||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
body["messages"] = prepend_to_first_user_message_content(
|
body["messages"] = prepend_to_first_user_message_content(
|
||||||
rag_template(
|
rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
||||||
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
body["messages"],
|
body["messages"],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
body["messages"] = add_or_update_system_message(
|
body["messages"] = add_or_update_system_message(
|
||||||
rag_template(
|
rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
||||||
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
||||||
),
|
|
||||||
body["messages"],
|
body["messages"],
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1225,77 +1212,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
app.add_middleware(ChatCompletionMiddleware)
|
app.add_middleware(ChatCompletionMiddleware)
|
||||||
|
|
||||||
|
|
||||||
##################################
|
|
||||||
#
|
|
||||||
# Pipeline Middleware
|
|
||||||
#
|
|
||||||
##################################
|
|
||||||
|
|
||||||
|
|
||||||
def get_sorted_filters(model_id, models):
|
|
||||||
filters = [
|
|
||||||
model
|
|
||||||
for model in models.values()
|
|
||||||
if "pipeline" in model
|
|
||||||
and "type" in model["pipeline"]
|
|
||||||
and model["pipeline"]["type"] == "filter"
|
|
||||||
and (
|
|
||||||
model["pipeline"]["pipelines"] == ["*"]
|
|
||||||
or any(
|
|
||||||
model_id == target_model_id
|
|
||||||
for target_model_id in model["pipeline"]["pipelines"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
]
|
|
||||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
||||||
return sorted_filters
|
|
||||||
|
|
||||||
|
|
||||||
def filter_pipeline(payload, user, models):
|
|
||||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
|
||||||
model_id = payload["model"]
|
|
||||||
|
|
||||||
sorted_filters = get_sorted_filters(model_id, models)
|
|
||||||
model = models[model_id]
|
|
||||||
|
|
||||||
if "pipeline" in model:
|
|
||||||
sorted_filters.append(model)
|
|
||||||
|
|
||||||
for filter in sorted_filters:
|
|
||||||
r = None
|
|
||||||
try:
|
|
||||||
urlIdx = filter["urlIdx"]
|
|
||||||
|
|
||||||
url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
||||||
key = app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
||||||
|
|
||||||
if key == "":
|
|
||||||
continue
|
|
||||||
|
|
||||||
headers = {"Authorization": f"Bearer {key}"}
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/{filter['id']}/filter/inlet",
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"user": user,
|
|
||||||
"body": payload,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
payload = r.json()
|
|
||||||
except Exception as e:
|
|
||||||
# Handle connection error here
|
|
||||||
print(f"Connection error: {e}")
|
|
||||||
|
|
||||||
if r is not None:
|
|
||||||
res = r.json()
|
|
||||||
if "detail" in res:
|
|
||||||
raise Exception(r.status_code, res["detail"])
|
|
||||||
|
|
||||||
return payload
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if not request.method == "POST" and any(
|
if not request.method == "POST" and any(
|
||||||
@ -1335,11 +1251,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|||||||
content={"detail": e.detail},
|
content={"detail": e.detail},
|
||||||
)
|
)
|
||||||
|
|
||||||
model_list = await get_all_models()
|
await get_all_models()
|
||||||
models = {model["id"]: model for model in model_list}
|
models = app.state.MODELS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = filter_pipeline(data, user, models)
|
data = process_pipeline_inlet_filter(request, data, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -1447,8 +1363,8 @@ app.include_router(ollama.router, prefix="/ollama", tags=["ollama"])
|
|||||||
app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
||||||
|
|
||||||
|
|
||||||
app.include_router(pipelines.router, prefix="/pipelines", tags=["pipelines"])
|
app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"])
|
||||||
app.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"])
|
||||||
|
|
||||||
|
|
||||||
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
||||||
@ -2105,7 +2021,6 @@ async def generate_chat_completions(
|
|||||||
if model["owned_by"] == "ollama":
|
if model["owned_by"] == "ollama":
|
||||||
# Using /ollama/api/chat endpoint
|
# Using /ollama/api/chat endpoint
|
||||||
form_data = convert_payload_openai_to_ollama(form_data)
|
form_data = convert_payload_openai_to_ollama(form_data)
|
||||||
form_data = GenerateChatCompletionForm(**form_data)
|
|
||||||
response = await generate_ollama_chat_completion(
|
response = await generate_ollama_chat_completion(
|
||||||
form_data=form_data, user=user, bypass_filter=bypass_filter
|
form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||||
)
|
)
|
||||||
@ -2124,7 +2039,9 @@ async def generate_chat_completions(
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/completed")
|
@app.post("/api/chat/completed")
|
||||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
async def chat_completed(
|
||||||
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
model_list = await get_all_models()
|
model_list = await get_all_models()
|
||||||
models = {model["id"]: model for model in model_list}
|
models = {model["id"]: model for model in model_list}
|
||||||
|
|
||||||
@ -2137,53 +2054,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model = models[model_id]
|
model = models[model_id]
|
||||||
sorted_filters = get_sorted_filters(model_id, models)
|
|
||||||
if "pipeline" in model:
|
|
||||||
sorted_filters = [model] + sorted_filters
|
|
||||||
|
|
||||||
for filter in sorted_filters:
|
|
||||||
r = None
|
|
||||||
try:
|
try:
|
||||||
urlIdx = filter["urlIdx"]
|
data = process_pipeline_outlet_filter(request, data, user, models)
|
||||||
|
|
||||||
url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
||||||
key = app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
||||||
|
|
||||||
if key != "":
|
|
||||||
headers = {"Authorization": f"Bearer {key}"}
|
|
||||||
r = requests.post(
|
|
||||||
f"{url}/{filter['id']}/filter/outlet",
|
|
||||||
headers=headers,
|
|
||||||
json={
|
|
||||||
"user": {
|
|
||||||
"id": user.id,
|
|
||||||
"name": user.name,
|
|
||||||
"email": user.email,
|
|
||||||
"role": user.role,
|
|
||||||
},
|
|
||||||
"body": data,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
r.raise_for_status()
|
|
||||||
data = r.json()
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle connection error here
|
return HTTPException(
|
||||||
print(f"Connection error: {e}")
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=str(e),
|
||||||
if r is not None:
|
|
||||||
try:
|
|
||||||
res = r.json()
|
|
||||||
if "detail" in res:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=r.status_code,
|
|
||||||
content=res,
|
|
||||||
)
|
)
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
pass
|
|
||||||
|
|
||||||
__event_emitter__ = get_event_emitter(
|
__event_emitter__ = get_event_emitter(
|
||||||
{
|
{
|
||||||
@ -2455,8 +2333,8 @@ async def get_app_config(request: Request):
|
|||||||
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
||||||
**(
|
**(
|
||||||
{
|
{
|
||||||
"enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
|
"enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||||
"enable_image_generation": images_app.state.config.ENABLED,
|
"enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
||||||
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||||
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
||||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||||
@ -2472,17 +2350,17 @@ async def get_app_config(request: Request):
|
|||||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||||
"audio": {
|
"audio": {
|
||||||
"tts": {
|
"tts": {
|
||||||
"engine": audio_app.state.config.TTS_ENGINE,
|
"engine": app.state.config.TTS_ENGINE,
|
||||||
"voice": audio_app.state.config.TTS_VOICE,
|
"voice": app.state.config.TTS_VOICE,
|
||||||
"split_on": audio_app.state.config.TTS_SPLIT_ON,
|
"split_on": app.state.config.TTS_SPLIT_ON,
|
||||||
},
|
},
|
||||||
"stt": {
|
"stt": {
|
||||||
"engine": audio_app.state.config.STT_ENGINE,
|
"engine": app.state.config.STT_ENGINE,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"file": {
|
"file": {
|
||||||
"max_size": retrieval_app.state.config.FILE_MAX_SIZE,
|
"max_size": app.state.config.FILE_MAX_SIZE,
|
||||||
"max_count": retrieval_app.state.config.FILE_MAX_COUNT,
|
"max_count": app.state.config.FILE_MAX_COUNT,
|
||||||
},
|
},
|
||||||
"permissions": {**app.state.config.USER_PERMISSIONS},
|
"permissions": {**app.state.config.USER_PERMISSIONS},
|
||||||
}
|
}
|
||||||
|
@ -941,7 +941,7 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] =
|
|||||||
@router.post("/api/chat/{url_idx}")
|
@router.post("/api/chat/{url_idx}")
|
||||||
async def generate_chat_completion(
|
async def generate_chat_completion(
|
||||||
request: Request,
|
request: Request,
|
||||||
form_data: GenerateChatCompletionForm,
|
form_data: dict,
|
||||||
url_idx: Optional[int] = None,
|
url_idx: Optional[int] = None,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
bypass_filter: Optional[bool] = False,
|
bypass_filter: Optional[bool] = False,
|
||||||
@ -949,6 +949,15 @@ async def generate_chat_completion(
|
|||||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
bypass_filter = True
|
bypass_filter = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
form_data = GenerateChatCompletionForm(**form_data)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=str(e),
|
||||||
|
)
|
||||||
|
|
||||||
payload = {**form_data.model_dump(exclude_none=True)}
|
payload = {**form_data.model_dump(exclude_none=True)}
|
||||||
if "metadata" in payload:
|
if "metadata" in payload:
|
||||||
del payload["metadata"]
|
del payload["metadata"]
|
||||||
|
@ -30,6 +30,130 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||||
|
|
||||||
|
|
||||||
|
##################################
|
||||||
|
#
|
||||||
|
# Pipeline Middleware
|
||||||
|
#
|
||||||
|
##################################
|
||||||
|
|
||||||
|
|
||||||
|
def get_sorted_filters(model_id, models):
|
||||||
|
filters = [
|
||||||
|
model
|
||||||
|
for model in models.values()
|
||||||
|
if "pipeline" in model
|
||||||
|
and "type" in model["pipeline"]
|
||||||
|
and model["pipeline"]["type"] == "filter"
|
||||||
|
and (
|
||||||
|
model["pipeline"]["pipelines"] == ["*"]
|
||||||
|
or any(
|
||||||
|
model_id == target_model_id
|
||||||
|
for target_model_id in model["pipeline"]["pipelines"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||||
|
return sorted_filters
|
||||||
|
|
||||||
|
|
||||||
|
def process_pipeline_inlet_filter(request, payload, user, models):
|
||||||
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
|
model_id = payload["model"]
|
||||||
|
|
||||||
|
sorted_filters = get_sorted_filters(model_id, models)
|
||||||
|
model = models[model_id]
|
||||||
|
|
||||||
|
if "pipeline" in model:
|
||||||
|
sorted_filters.append(model)
|
||||||
|
|
||||||
|
for filter in sorted_filters:
|
||||||
|
r = None
|
||||||
|
try:
|
||||||
|
urlIdx = filter["urlIdx"]
|
||||||
|
|
||||||
|
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
|
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
|
if key == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/{filter['id']}/filter/inlet",
|
||||||
|
headers=headers,
|
||||||
|
json={
|
||||||
|
"user": user,
|
||||||
|
"body": payload,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
payload = r.json()
|
||||||
|
except Exception as e:
|
||||||
|
# Handle connection error here
|
||||||
|
print(f"Connection error: {e}")
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
res = r.json()
|
||||||
|
if "detail" in res:
|
||||||
|
raise Exception(r.status_code, res["detail"])
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
|
def process_pipeline_outlet_filter(request, payload, user, models):
|
||||||
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||||
|
model_id = payload["model"]
|
||||||
|
|
||||||
|
sorted_filters = get_sorted_filters(model_id, models)
|
||||||
|
model = models[model_id]
|
||||||
|
|
||||||
|
if "pipeline" in model:
|
||||||
|
sorted_filters = [model] + sorted_filters
|
||||||
|
|
||||||
|
for filter in sorted_filters:
|
||||||
|
r = None
|
||||||
|
try:
|
||||||
|
urlIdx = filter["urlIdx"]
|
||||||
|
|
||||||
|
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||||
|
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||||
|
|
||||||
|
if key != "":
|
||||||
|
r = requests.post(
|
||||||
|
f"{url}/{filter['id']}/filter/outlet",
|
||||||
|
headers={"Authorization": f"Bearer {key}"},
|
||||||
|
json={
|
||||||
|
"user": {
|
||||||
|
"id": user.id,
|
||||||
|
"name": user.name,
|
||||||
|
"email": user.email,
|
||||||
|
"role": user.role,
|
||||||
|
},
|
||||||
|
"body": data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
r.raise_for_status()
|
||||||
|
data = r.json()
|
||||||
|
except Exception as e:
|
||||||
|
# Handle connection error here
|
||||||
|
print(f"Connection error: {e}")
|
||||||
|
|
||||||
|
if r is not None:
|
||||||
|
try:
|
||||||
|
res = r.json()
|
||||||
|
if "detail" in res:
|
||||||
|
return Exception(r.status_code, res)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
##################################
|
##################################
|
||||||
#
|
#
|
||||||
# Pipelines Endpoints
|
# Pipelines Endpoints
|
||||||
@ -39,7 +163,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/pipelines/list")
|
@router.get("/list")
|
||||||
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
||||||
responses = await get_all_models_responses(request)
|
responses = await get_all_models_responses(request)
|
||||||
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
|
||||||
@ -61,7 +185,7 @@ async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/pipelines/upload")
|
@router.post("/upload")
|
||||||
async def upload_pipeline(
|
async def upload_pipeline(
|
||||||
request: Request,
|
request: Request,
|
||||||
urlIdx: int = Form(...),
|
urlIdx: int = Form(...),
|
||||||
@ -131,7 +255,7 @@ class AddPipelineForm(BaseModel):
|
|||||||
urlIdx: int
|
urlIdx: int
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/pipelines/add")
|
@router.post("/add")
|
||||||
async def add_pipeline(
|
async def add_pipeline(
|
||||||
request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
|
request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
@ -176,7 +300,7 @@ class DeletePipelineForm(BaseModel):
|
|||||||
urlIdx: int
|
urlIdx: int
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/api/pipelines/delete")
|
@router.delete("/delete")
|
||||||
async def delete_pipeline(
|
async def delete_pipeline(
|
||||||
request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
|
request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
@ -216,7 +340,7 @@ async def delete_pipeline(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/pipelines")
|
@router.get("/")
|
||||||
async def get_pipelines(
|
async def get_pipelines(
|
||||||
request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
|
request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
|
||||||
):
|
):
|
||||||
@ -250,7 +374,7 @@ async def get_pipelines(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/pipelines/{pipeline_id}/valves")
|
@router.get("/{pipeline_id}/valves")
|
||||||
async def get_pipeline_valves(
|
async def get_pipeline_valves(
|
||||||
request: Request,
|
request: Request,
|
||||||
urlIdx: Optional[int],
|
urlIdx: Optional[int],
|
||||||
@ -289,7 +413,7 @@ async def get_pipeline_valves(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/api/pipelines/{pipeline_id}/valves/spec")
|
@router.get("/{pipeline_id}/valves/spec")
|
||||||
async def get_pipeline_valves_spec(
|
async def get_pipeline_valves_spec(
|
||||||
request: Request,
|
request: Request,
|
||||||
urlIdx: Optional[int],
|
urlIdx: Optional[int],
|
||||||
@ -329,7 +453,7 @@ async def get_pipeline_valves_spec(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/api/pipelines/{pipeline_id}/valves/update")
|
@router.post("/{pipeline_id}/valves/update")
|
||||||
async def update_pipeline_valves(
|
async def update_pipeline_valves(
|
||||||
request: Request,
|
request: Request,
|
||||||
urlIdx: Optional[int],
|
urlIdx: Optional[int],
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
|
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
|
||||||
|
from fastapi.responses import JSONResponse, RedirectResponse
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from starlette.responses import FileResponse
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
@ -16,6 +17,9 @@ from open_webui.utils.task import (
|
|||||||
from open_webui.utils.auth import get_admin_user, get_verified_user
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||||
from open_webui.constants import TASKS
|
from open_webui.constants import TASKS
|
||||||
|
|
||||||
|
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||||
|
from open_webui.utils.task import get_task_model_id
|
||||||
|
|
||||||
from open_webui.config import (
|
from open_webui.config import (
|
||||||
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||||
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
|
||||||
@ -121,9 +125,7 @@ async def update_task_config(
|
|||||||
async def generate_title(
|
async def generate_title(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
models = request.app.state.MODELS
|
||||||
model_list = await get_all_models()
|
|
||||||
models = {model["id"]: model for model in model_list}
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
@ -191,7 +193,7 @@ Artificial Intelligence in Healthcare
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = process_pipeline_inlet_filter(payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -220,8 +222,7 @@ async def generate_chat_tags(
|
|||||||
content={"detail": "Tags generation is disabled"},
|
content={"detail": "Tags generation is disabled"},
|
||||||
)
|
)
|
||||||
|
|
||||||
model_list = await get_all_models()
|
models = request.app.state.MODELS
|
||||||
models = {model["id"]: model for model in model_list}
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
@ -281,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = process_pipeline_inlet_filter(payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -318,8 +319,7 @@ async def generate_queries(
|
|||||||
detail=f"Query generation is disabled",
|
detail=f"Query generation is disabled",
|
||||||
)
|
)
|
||||||
|
|
||||||
model_list = await get_all_models()
|
models = request.app.state.MODELS
|
||||||
models = {model["id"]: model for model in model_list}
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
@ -363,7 +363,7 @@ async def generate_queries(
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = process_pipeline_inlet_filter(payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -405,8 +405,7 @@ async def generate_autocompletion(
|
|||||||
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
|
||||||
)
|
)
|
||||||
|
|
||||||
model_list = await get_all_models()
|
models = request.app.state.MODELS
|
||||||
models = {model["id"]: model for model in model_list}
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
@ -450,7 +449,7 @@ async def generate_autocompletion(
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = process_pipeline_inlet_filter(payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -473,8 +472,7 @@ async def generate_emoji(
|
|||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
model_list = await get_all_models()
|
models = request.app.state.MODELS
|
||||||
models = {model["id"]: model for model in model_list}
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
@ -525,7 +523,7 @@ Message: """{{prompt}}"""
|
|||||||
|
|
||||||
# Handle pipeline filters
|
# Handle pipeline filters
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = process_pipeline_inlet_filter(payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
@ -548,10 +546,9 @@ async def generate_moa_response(
|
|||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
|
|
||||||
model_list = await get_all_models()
|
models = request.app.state.MODELS
|
||||||
models = {model["id"]: model for model in model_list}
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
|
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
@ -593,7 +590,7 @@ Responses from models: {{responses}}"""
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
payload = filter_pipeline(payload, user, models)
|
payload = process_pipeline_inlet_filter(payload, user, models)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if len(e.args) > 1:
|
if len(e.args) > 1:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
|
@ -16,6 +16,22 @@ log = logging.getLogger(__name__)
|
|||||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||||
|
|
||||||
|
|
||||||
|
def get_task_model_id(
|
||||||
|
default_model_id: str, task_model: str, task_model_external: str, models
|
||||||
|
) -> str:
|
||||||
|
# Set the task model
|
||||||
|
task_model_id = default_model_id
|
||||||
|
# Check if the user has a custom task model and use that model
|
||||||
|
if models[task_model_id]["owned_by"] == "ollama":
|
||||||
|
if task_model and task_model in models:
|
||||||
|
task_model_id = task_model
|
||||||
|
else:
|
||||||
|
if task_model_external and task_model_external in models:
|
||||||
|
task_model_id = task_model_external
|
||||||
|
|
||||||
|
return task_model_id
|
||||||
|
|
||||||
|
|
||||||
def prompt_template(
|
def prompt_template(
|
||||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
|
Loading…
Reference in New Issue
Block a user