This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 19:52:46 -08:00
parent 772f5ccd60
commit fe5519e0a2
5 changed files with 236 additions and 212 deletions

View File

@ -75,6 +75,11 @@ from open_webui.routers.retrieval import (
get_ef, get_ef,
get_rf, get_rf,
) )
from open_webui.routers.pipelines import (
process_pipeline_inlet_filter,
process_pipeline_outlet_filter,
)
from open_webui.retrieval.utils import get_sources_from_files from open_webui.retrieval.utils import get_sources_from_files
@ -290,6 +295,7 @@ from open_webui.utils.response import (
) )
from open_webui.utils.task import ( from open_webui.utils.task import (
get_task_model_id,
rag_template, rag_template,
tools_function_calling_generation_template, tools_function_calling_generation_template,
) )
@ -662,6 +668,9 @@ app.state.MODELS = {}
################################## ##################################
async def chat_completion_filter_functions_handler(body, model, extra_params):
skip_files = None
def get_filter_function_ids(model): def get_filter_function_ids(model):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
@ -670,7 +679,9 @@ def get_filter_function_ids(model):
return (function.valves if function.valves else {}).get("priority", 0) return (function.valves if function.valves else {}).get("priority", 0)
return 0 return 0
filter_ids = [function.id for function in Functions.get_global_filter_functions()] filter_ids = [
function.id for function in Functions.get_global_filter_functions()
]
if "info" in model and "meta" in model["info"]: if "info" in model and "meta" in model["info"]:
filter_ids.extend(model["info"]["meta"].get("filterIds", [])) filter_ids.extend(model["info"]["meta"].get("filterIds", []))
filter_ids = list(set(filter_ids)) filter_ids = list(set(filter_ids))
@ -687,10 +698,6 @@ def get_filter_function_ids(model):
filter_ids.sort(key=get_priority) filter_ids.sort(key=get_priority)
return filter_ids return filter_ids
async def chat_completion_filter_functions_handler(body, model, extra_params):
skip_files = None
filter_ids = get_filter_function_ids(model) filter_ids = get_filter_function_ids(model)
for filter_id in filter_ids: for filter_id in filter_ids:
filter = Functions.get_function_by_id(filter_id) filter = Functions.get_function_by_id(filter_id)
@ -791,22 +798,6 @@ async def get_content_from_response(response) -> Optional[str]:
return content return content
def get_task_model_id(
default_model_id: str, task_model: str, task_model_external: str, models
) -> str:
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id]["owned_by"] == "ollama":
if task_model and task_model in models:
task_model_id = task_model
else:
if task_model_external and task_model_external in models:
task_model_id = task_model_external
return task_model_id
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, models, extra_params: dict body: dict, user: UserModel, models, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
@ -857,7 +848,7 @@ async def chat_completion_tools_handler(
) )
try: try:
payload = filter_pipeline(payload, user, models) payload = process_pipeline_inlet_filter(request, payload, user, models)
except Exception as e: except Exception as e:
raise e raise e
@ -1153,7 +1144,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
if prompt is None: if prompt is None:
raise Exception("No user message found") raise Exception("No user message found")
if ( if (
retrieval_app.state.config.RELEVANCE_THRESHOLD == 0 app.state.config.RELEVANCE_THRESHOLD == 0
and context_string.strip() == "" and context_string.strip() == ""
): ):
log.debug( log.debug(
@ -1164,16 +1155,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
# TODO: replace with add_or_update_system_message # TODO: replace with add_or_update_system_message
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
body["messages"] = prepend_to_first_user_message_content( body["messages"] = prepend_to_first_user_message_content(
rag_template( rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"], body["messages"],
) )
else: else:
body["messages"] = add_or_update_system_message( body["messages"] = add_or_update_system_message(
rag_template( rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
),
body["messages"], body["messages"],
) )
@ -1225,77 +1212,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
app.add_middleware(ChatCompletionMiddleware) app.add_middleware(ChatCompletionMiddleware)
##################################
#
# Pipeline Middleware
#
##################################
def get_sorted_filters(model_id, models):
filters = [
model
for model in models.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters
def filter_pipeline(payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "":
continue
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
if "detail" in res:
raise Exception(r.status_code, res["detail"])
return payload
class PipelineMiddleware(BaseHTTPMiddleware): class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if not request.method == "POST" and any( if not request.method == "POST" and any(
@ -1335,11 +1251,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
content={"detail": e.detail}, content={"detail": e.detail},
) )
model_list = await get_all_models() await get_all_models()
models = {model["id"]: model for model in model_list} models = app.state.MODELS
try: try:
data = filter_pipeline(data, user, models) data = process_pipeline_inlet_filter(request, data, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -1447,8 +1363,8 @@ app.include_router(ollama.router, prefix="/ollama", tags=["ollama"])
app.include_router(openai.router, prefix="/openai", tags=["openai"]) app.include_router(openai.router, prefix="/openai", tags=["openai"])
app.include_router(pipelines.router, prefix="/pipelines", tags=["pipelines"]) app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"])
app.include_router(tasks.router, prefix="/tasks", tags=["tasks"]) app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"])
app.include_router(images.router, prefix="/api/v1/images", tags=["images"]) app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
@ -2105,7 +2021,6 @@ async def generate_chat_completions(
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint # Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data) form_data = convert_payload_openai_to_ollama(form_data)
form_data = GenerateChatCompletionForm(**form_data)
response = await generate_ollama_chat_completion( response = await generate_ollama_chat_completion(
form_data=form_data, user=user, bypass_filter=bypass_filter form_data=form_data, user=user, bypass_filter=bypass_filter
) )
@ -2124,7 +2039,9 @@ async def generate_chat_completions(
@app.post("/api/chat/completed") @app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)): async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user)
):
model_list = await get_all_models() model_list = await get_all_models()
models = {model["id"]: model for model in model_list} models = {model["id"]: model for model in model_list}
@ -2137,53 +2054,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
) )
model = models[model_id] model = models[model_id]
sorted_filters = get_sorted_filters(model_id, models)
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try: try:
urlIdx = filter["urlIdx"] data = process_pipeline_outlet_filter(request, data, user, models)
url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers=headers,
json={
"user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e: except Exception as e:
# Handle connection error here return HTTPException(
print(f"Connection error: {e}") status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
) )
except Exception:
pass
else:
pass
__event_emitter__ = get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
@ -2455,8 +2333,8 @@ async def get_app_config(request: Request):
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM, "enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
**( **(
{ {
"enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH, "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED, "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING, "enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING, "enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
"enable_admin_export": ENABLE_ADMIN_EXPORT, "enable_admin_export": ENABLE_ADMIN_EXPORT,
@ -2472,17 +2350,17 @@ async def get_app_config(request: Request):
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS, "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
"audio": { "audio": {
"tts": { "tts": {
"engine": audio_app.state.config.TTS_ENGINE, "engine": app.state.config.TTS_ENGINE,
"voice": audio_app.state.config.TTS_VOICE, "voice": app.state.config.TTS_VOICE,
"split_on": audio_app.state.config.TTS_SPLIT_ON, "split_on": app.state.config.TTS_SPLIT_ON,
}, },
"stt": { "stt": {
"engine": audio_app.state.config.STT_ENGINE, "engine": app.state.config.STT_ENGINE,
}, },
}, },
"file": { "file": {
"max_size": retrieval_app.state.config.FILE_MAX_SIZE, "max_size": app.state.config.FILE_MAX_SIZE,
"max_count": retrieval_app.state.config.FILE_MAX_COUNT, "max_count": app.state.config.FILE_MAX_COUNT,
}, },
"permissions": {**app.state.config.USER_PERMISSIONS}, "permissions": {**app.state.config.USER_PERMISSIONS},
} }

View File

@ -941,7 +941,7 @@ async def get_ollama_url(request: Request, model: str, url_idx: Optional[int] =
@router.post("/api/chat/{url_idx}") @router.post("/api/chat/{url_idx}")
async def generate_chat_completion( async def generate_chat_completion(
request: Request, request: Request,
form_data: GenerateChatCompletionForm, form_data: dict,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
bypass_filter: Optional[bool] = False, bypass_filter: Optional[bool] = False,
@ -949,6 +949,15 @@ async def generate_chat_completion(
if BYPASS_MODEL_ACCESS_CONTROL: if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True bypass_filter = True
try:
form_data = GenerateChatCompletionForm(**form_data)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=400,
detail=str(e),
)
payload = {**form_data.model_dump(exclude_none=True)} payload = {**form_data.model_dump(exclude_none=True)}
if "metadata" in payload: if "metadata" in payload:
del payload["metadata"] del payload["metadata"]

View File

@ -30,6 +30,130 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MAIN"]) log.setLevel(SRC_LOG_LEVELS["MAIN"])
##################################
#
# Pipeline Middleware
#
##################################
def get_sorted_filters(model_id, models):
filters = [
model
for model in models.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
return sorted_filters
def process_pipeline_inlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key == "":
continue
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
res = r.json()
if "detail" in res:
raise Exception(r.status_code, res["detail"])
return payload
def process_pipeline_outlet_filter(request, payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = request.app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = request.app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
r = requests.post(
f"{url}/{filter['id']}/filter/outlet",
headers={"Authorization": f"Bearer {key}"},
json={
"user": {
"id": user.id,
"name": user.name,
"email": user.email,
"role": user.role,
},
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return Exception(r.status_code, res)
except Exception:
pass
else:
pass
return payload
################################## ##################################
# #
# Pipelines Endpoints # Pipelines Endpoints
@ -39,7 +163,7 @@ log.setLevel(SRC_LOG_LEVELS["MAIN"])
router = APIRouter() router = APIRouter()
@router.get("/api/pipelines/list") @router.get("/list")
async def get_pipelines_list(request: Request, user=Depends(get_admin_user)): async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
responses = await get_all_models_responses(request) responses = await get_all_models_responses(request)
log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}") log.debug(f"get_pipelines_list: get_openai_models_responses returned {responses}")
@ -61,7 +185,7 @@ async def get_pipelines_list(request: Request, user=Depends(get_admin_user)):
} }
@router.post("/api/pipelines/upload") @router.post("/upload")
async def upload_pipeline( async def upload_pipeline(
request: Request, request: Request,
urlIdx: int = Form(...), urlIdx: int = Form(...),
@ -131,7 +255,7 @@ class AddPipelineForm(BaseModel):
urlIdx: int urlIdx: int
@router.post("/api/pipelines/add") @router.post("/add")
async def add_pipeline( async def add_pipeline(
request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user) request: Request, form_data: AddPipelineForm, user=Depends(get_admin_user)
): ):
@ -176,7 +300,7 @@ class DeletePipelineForm(BaseModel):
urlIdx: int urlIdx: int
@router.delete("/api/pipelines/delete") @router.delete("/delete")
async def delete_pipeline( async def delete_pipeline(
request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user) request: Request, form_data: DeletePipelineForm, user=Depends(get_admin_user)
): ):
@ -216,7 +340,7 @@ async def delete_pipeline(
) )
@router.get("/api/pipelines") @router.get("/")
async def get_pipelines( async def get_pipelines(
request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user) request: Request, urlIdx: Optional[int] = None, user=Depends(get_admin_user)
): ):
@ -250,7 +374,7 @@ async def get_pipelines(
) )
@router.get("/api/pipelines/{pipeline_id}/valves") @router.get("/{pipeline_id}/valves")
async def get_pipeline_valves( async def get_pipeline_valves(
request: Request, request: Request,
urlIdx: Optional[int], urlIdx: Optional[int],
@ -289,7 +413,7 @@ async def get_pipeline_valves(
) )
@router.get("/api/pipelines/{pipeline_id}/valves/spec") @router.get("/{pipeline_id}/valves/spec")
async def get_pipeline_valves_spec( async def get_pipeline_valves_spec(
request: Request, request: Request,
urlIdx: Optional[int], urlIdx: Optional[int],
@ -329,7 +453,7 @@ async def get_pipeline_valves_spec(
) )
@router.post("/api/pipelines/{pipeline_id}/valves/update") @router.post("/{pipeline_id}/valves/update")
async def update_pipeline_valves( async def update_pipeline_valves(
request: Request, request: Request,
urlIdx: Optional[int], urlIdx: Optional[int],

View File

@ -1,6 +1,7 @@
from fastapi import APIRouter, Depends, HTTPException, Response, status, Request from fastapi import APIRouter, Depends, HTTPException, Response, status, Request
from fastapi.responses import JSONResponse, RedirectResponse
from pydantic import BaseModel from pydantic import BaseModel
from starlette.responses import FileResponse
from typing import Optional from typing import Optional
import logging import logging
@ -16,6 +17,9 @@ from open_webui.utils.task import (
from open_webui.utils.auth import get_admin_user, get_verified_user from open_webui.utils.auth import get_admin_user, get_verified_user
from open_webui.constants import TASKS from open_webui.constants import TASKS
from open_webui.routers.pipelines import process_pipeline_inlet_filter
from open_webui.utils.task import get_task_model_id
from open_webui.config import ( from open_webui.config import (
DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE, DEFAULT_QUERY_GENERATION_PROMPT_TEMPLATE,
DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE, DEFAULT_AUTOCOMPLETE_GENERATION_PROMPT_TEMPLATE,
@ -121,9 +125,7 @@ async def update_task_config(
async def generate_title( async def generate_title(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
models = request.app.state.MODELS
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -191,7 +193,7 @@ Artificial Intelligence in Healthcare
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user, models) payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -220,8 +222,7 @@ async def generate_chat_tags(
content={"detail": "Tags generation is disabled"}, content={"detail": "Tags generation is disabled"},
) )
model_list = await get_all_models() models = request.app.state.MODELS
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -281,7 +282,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user, models) payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -318,8 +319,7 @@ async def generate_queries(
detail=f"Query generation is disabled", detail=f"Query generation is disabled",
) )
model_list = await get_all_models() models = request.app.state.MODELS
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -363,7 +363,7 @@ async def generate_queries(
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user, models) payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -405,8 +405,7 @@ async def generate_autocompletion(
detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}", detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
) )
model_list = await get_all_models() models = request.app.state.MODELS
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -450,7 +449,7 @@ async def generate_autocompletion(
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user, models) payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -473,8 +472,7 @@ async def generate_emoji(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
model_list = await get_all_models() models = request.app.state.MODELS
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -525,7 +523,7 @@ Message: """{{prompt}}"""
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user, models) payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -548,10 +546,9 @@ async def generate_moa_response(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
model_list = await get_all_models() models = request.app.state.MODELS
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
@ -593,7 +590,7 @@ Responses from models: {{responses}}"""
} }
try: try:
payload = filter_pipeline(payload, user, models) payload = process_pipeline_inlet_filter(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(

View File

@ -16,6 +16,22 @@ log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"]) log.setLevel(SRC_LOG_LEVELS["RAG"])
def get_task_model_id(
default_model_id: str, task_model: str, task_model_external: str, models
) -> str:
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id]["owned_by"] == "ollama":
if task_model and task_model in models:
task_model_id = task_model
else:
if task_model_external and task_model_external in models:
task_model_id = task_model_external
return task_model_id
def prompt_template( def prompt_template(
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
) -> str: ) -> str: