mirror of
https://github.com/open-webui/open-webui
synced 2024-11-21 23:57:51 +00:00
refac
This commit is contained in:
parent
931e03bd9e
commit
932de8f1e2
@ -68,25 +68,12 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
||||
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
|
||||
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
||||
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
||||
# least connections, or least response time for better resource utilization and performance optimization.
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def check_url(request: Request, call_next):
|
||||
if len(app.state.MODELS) == 0:
|
||||
await get_all_models()
|
||||
else:
|
||||
pass
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.head("/")
|
||||
@app.get("/")
|
||||
async def get_status():
|
||||
@ -321,8 +308,6 @@ async def get_all_models():
|
||||
else:
|
||||
models = {"models": []}
|
||||
|
||||
app.state.MODELS = {model["model"]: model for model in models["models"]}
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@ -470,8 +455,11 @@ async def push_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
if form_data.name in app.state.MODELS:
|
||||
url_idx = app.state.MODELS[form_data.name]["urls"][0]
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
if form_data.name in models:
|
||||
url_idx = models[form_data.name]["urls"][0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -520,8 +508,11 @@ async def copy_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
if form_data.source in app.state.MODELS:
|
||||
url_idx = app.state.MODELS[form_data.source]["urls"][0]
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
if form_data.source in models:
|
||||
url_idx = models[form_data.source]["urls"][0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -576,8 +567,11 @@ async def delete_model(
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
if form_data.name in app.state.MODELS:
|
||||
url_idx = app.state.MODELS[form_data.name]["urls"][0]
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
if form_data.name in models:
|
||||
url_idx = models[form_data.name]["urls"][0]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -625,13 +619,16 @@ async def delete_model(
|
||||
|
||||
@app.post("/api/show")
|
||||
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
|
||||
if form_data.name not in app.state.MODELS:
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
if form_data.name not in models:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
|
||||
)
|
||||
|
||||
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
|
||||
url_idx = random.choice(models[form_data.name]["urls"])
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
@ -701,23 +698,26 @@ async def generate_embeddings(
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
||||
return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
|
||||
|
||||
|
||||
def generate_ollama_embeddings(
|
||||
async def generate_ollama_embeddings(
|
||||
form_data: GenerateEmbeddingsForm,
|
||||
url_idx: Optional[int] = None,
|
||||
):
|
||||
log.info(f"generate_ollama_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
if model in models:
|
||||
url_idx = random.choice(models[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -768,20 +768,23 @@ def generate_ollama_embeddings(
|
||||
)
|
||||
|
||||
|
||||
def generate_ollama_batch_embeddings(
|
||||
async def generate_ollama_batch_embeddings(
|
||||
form_data: GenerateEmbedForm,
|
||||
url_idx: Optional[int] = None,
|
||||
):
|
||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
if model in models:
|
||||
url_idx = random.choice(models[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -851,13 +854,16 @@ async def generate_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
if url_idx is None:
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
model = form_data.model
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
if model in models:
|
||||
url_idx = random.choice(models[model]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -892,14 +898,17 @@ class GenerateChatCompletionForm(BaseModel):
|
||||
keep_alive: Optional[Union[int, str]] = None
|
||||
|
||||
|
||||
def get_ollama_url(url_idx: Optional[int], model: str):
|
||||
async def get_ollama_url(url_idx: Optional[int], model: str):
|
||||
if url_idx is None:
|
||||
if model not in app.state.MODELS:
|
||||
model_list = await get_all_models()
|
||||
models = {model["model"]: model for model in model_list["models"]}
|
||||
|
||||
if model not in models:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
|
||||
)
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
url_idx = random.choice(models[model]["urls"])
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
return url
|
||||
|
||||
@ -948,7 +957,7 @@ async def generate_chat_completion(
|
||||
if ":" not in payload["model"]:
|
||||
payload["model"] = f"{payload['model']}:latest"
|
||||
|
||||
url = get_ollama_url(url_idx, payload["model"])
|
||||
url = await get_ollama_url(url_idx, payload["model"])
|
||||
log.info(f"url: {url}")
|
||||
log.debug(f"generate_chat_completion() - 2.payload = {payload}")
|
||||
|
||||
@ -1030,7 +1039,7 @@ async def generate_openai_chat_completion(
|
||||
if ":" not in payload["model"]:
|
||||
payload["model"] = f"{payload['model']}:latest"
|
||||
|
||||
url = get_ollama_url(url_idx, payload["model"])
|
||||
url = await get_ollama_url(url_idx, payload["model"])
|
||||
log.info(f"url: {url}")
|
||||
|
||||
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
|
||||
|
@ -36,7 +36,7 @@ from open_webui.utils.payload import (
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user
|
||||
from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OPENAI"])
|
||||
@ -64,17 +64,6 @@ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||
app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def check_url(request: Request, call_next):
|
||||
if len(app.state.MODELS) == 0:
|
||||
await get_all_models()
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_config(user=Depends(get_admin_user)):
|
||||
@ -259,7 +248,7 @@ def merge_models_lists(model_lists):
|
||||
return merged_list
|
||||
|
||||
|
||||
async def get_all_models_raw() -> list:
|
||||
async def get_all_models_responses() -> list:
|
||||
if not app.state.config.ENABLE_OPENAI_API:
|
||||
return []
|
||||
|
||||
@ -330,22 +319,13 @@ async def get_all_models_raw() -> list:
|
||||
return responses
|
||||
|
||||
|
||||
@overload
|
||||
async def get_all_models(raw: Literal[True]) -> list: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
|
||||
|
||||
|
||||
async def get_all_models(raw=False) -> dict[str, list] | list:
|
||||
async def get_all_models() -> dict[str, list]:
|
||||
log.info("get_all_models()")
|
||||
if not app.state.config.ENABLE_OPENAI_API:
|
||||
return [] if raw else {"data": []}
|
||||
|
||||
responses = await get_all_models_raw()
|
||||
if raw:
|
||||
return responses
|
||||
if not app.state.config.ENABLE_OPENAI_API:
|
||||
return {"data": []}
|
||||
|
||||
responses = await get_all_models_responses()
|
||||
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
@ -355,9 +335,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
|
||||
return None
|
||||
|
||||
models = {"data": merge_models_lists(map(extract_data, responses))}
|
||||
|
||||
log.debug(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
|
||||
return models
|
||||
|
||||
@ -365,21 +343,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
|
||||
@app.get("/models")
|
||||
@app.get("/models/{url_idx}")
|
||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
||||
models = {
|
||||
"data": [],
|
||||
}
|
||||
|
||||
if url_idx is None:
|
||||
models = await get_all_models()
|
||||
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user":
|
||||
# models["data"] = list(
|
||||
# filter(
|
||||
# lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
# models["data"],
|
||||
# )
|
||||
# )
|
||||
# return models
|
||||
|
||||
return models
|
||||
else:
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[url_idx]
|
||||
@ -387,6 +356,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
if ENABLE_FORWARD_USER_INFO_HEADERS:
|
||||
headers["X-OpenWebUI-User-Name"] = user.name
|
||||
headers["X-OpenWebUI-User-Id"] = user.id
|
||||
@ -428,8 +398,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
|
||||
)
|
||||
]
|
||||
|
||||
return response_data
|
||||
|
||||
models = response_data
|
||||
except aiohttp.ClientError as e:
|
||||
# ClientError covers all aiohttp requests issues
|
||||
log.exception(f"Client error: {str(e)}")
|
||||
@ -443,6 +412,22 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
|
||||
error_detail = f"Unexpected error: {str(e)}"
|
||||
raise HTTPException(status_code=500, detail=error_detail)
|
||||
|
||||
if user.role == "user":
|
||||
# Filter models based on user access control
|
||||
filtered_models = []
|
||||
for model in models.get("data", []):
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
):
|
||||
filtered_models.append(model)
|
||||
else:
|
||||
filtered_models.append(model)
|
||||
models["data"] = filtered_models
|
||||
|
||||
return models
|
||||
|
||||
|
||||
class ConnectionVerificationForm(BaseModel):
|
||||
url: str
|
||||
@ -502,18 +487,9 @@ async def generate_chat_completion(
|
||||
del payload["metadata"]
|
||||
|
||||
model_id = form_data.get("model")
|
||||
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if not bypass_filter:
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
# raise HTTPException(
|
||||
# status_code=403,
|
||||
# detail="Model not found",
|
||||
# )
|
||||
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
# Check model info and override the payload
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
@ -522,12 +498,33 @@ async def generate_chat_completion(
|
||||
payload = apply_model_params_to_body_openai(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
|
||||
try:
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
idx = model["urlIdx"]
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=404, detail="Model not found")
|
||||
# Check if user has access to the model
|
||||
if user.role == "user" and not has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Attemp to get urlIdx from the model
|
||||
models = await get_all_models()
|
||||
|
||||
# Find the model from the list
|
||||
model = next(
|
||||
(model for model in models["data"] if model["id"] == payload.get("model")),
|
||||
None,
|
||||
)
|
||||
|
||||
if model:
|
||||
idx = model["urlIdx"]
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Get the API config for the model
|
||||
api_config = app.state.config.OPENAI_API_CONFIGS.get(
|
||||
app.state.config.OPENAI_API_BASE_URLS[idx], {}
|
||||
)
|
||||
@ -536,6 +533,7 @@ async def generate_chat_completion(
|
||||
if prefix_id:
|
||||
payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
|
||||
|
||||
# Add user info to the payload if the model is a pipeline
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {
|
||||
"name": user.name,
|
||||
@ -546,8 +544,9 @@ async def generate_chat_completion(
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
is_o1 = payload["model"].lower().startswith("o1-")
|
||||
|
||||
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
|
||||
is_o1 = payload["model"].lower().startswith("o1-")
|
||||
# Change max_completion_tokens to max_tokens (Backward compatible)
|
||||
if "api.openai.com" not in url and not is_o1:
|
||||
if "max_completion_tokens" in payload:
|
||||
|
@ -3,6 +3,7 @@ import os
|
||||
import uuid
|
||||
from typing import Optional, Union
|
||||
|
||||
import asyncio
|
||||
import requests
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
@ -291,7 +292,13 @@ def get_embedding_function(
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
elif embedding_engine in ["ollama", "openai"]:
|
||||
func = lambda query: generate_embeddings(
|
||||
|
||||
# Wrapper to run the async generate_embeddings synchronously.
|
||||
def sync_generate_embeddings(*args, **kwargs):
|
||||
return asyncio.run(generate_embeddings(*args, **kwargs))
|
||||
|
||||
# Semantic expectation from the original version (using sync wrapper).
|
||||
func = lambda query: sync_generate_embeddings(
|
||||
engine=embedding_engine,
|
||||
model=embedding_model,
|
||||
text=query,
|
||||
@ -469,7 +476,7 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
return model
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
async def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
@ -492,14 +499,16 @@ def generate_openai_batch_embeddings(
|
||||
return None
|
||||
|
||||
|
||||
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
|
||||
async def generate_embeddings(
|
||||
engine: str, model: str, text: Union[str, list[str]], **kwargs
|
||||
):
|
||||
if engine == "ollama":
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
embeddings = await generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": text})
|
||||
)
|
||||
else:
|
||||
embeddings = generate_ollama_batch_embeddings(
|
||||
embeddings = await generate_ollama_batch_embeddings(
|
||||
GenerateEmbedForm(**{"model": model, "input": [text]})
|
||||
)
|
||||
return (
|
||||
@ -512,9 +521,9 @@ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **
|
||||
url = kwargs.get("url", "https://api.openai.com/v1")
|
||||
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
embeddings = await generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
@ -142,7 +142,6 @@ app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
|
||||
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
|
||||
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
|
||||
|
||||
app.state.MODELS = {}
|
||||
app.state.TOOLS = {}
|
||||
app.state.FUNCTIONS = {}
|
||||
|
||||
@ -369,7 +368,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):
|
||||
return params
|
||||
|
||||
|
||||
async def generate_function_chat_completion(form_data, user):
|
||||
async def generate_function_chat_completion(form_data, user, models: dict = {}):
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
@ -412,7 +411,7 @@ async def generate_function_chat_completion(form_data, user):
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": app.state.MODELS[form_data["model"]],
|
||||
"__model__": models.get(form_data["model"], None),
|
||||
"__messages__": form_data["messages"],
|
||||
"__files__": files,
|
||||
},
|
||||
|
@ -11,6 +11,7 @@ import random
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional
|
||||
|
||||
from aiocache import cached
|
||||
import aiohttp
|
||||
import requests
|
||||
from fastapi import (
|
||||
@ -45,6 +46,7 @@ from open_webui.apps.openai.main import (
|
||||
app as openai_app,
|
||||
generate_chat_completion as generate_openai_chat_completion,
|
||||
get_all_models as get_openai_models,
|
||||
get_all_models_responses as get_openai_models_responses,
|
||||
)
|
||||
from open_webui.apps.retrieval.main import app as retrieval_app
|
||||
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
||||
@ -132,6 +134,7 @@ from open_webui.utils.utils import (
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
get_verified_user,
|
||||
has_access,
|
||||
)
|
||||
|
||||
if SAFE_MODE:
|
||||
@ -196,20 +199,22 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
||||
app.state.config.TASK_MODEL = TASK_MODEL
|
||||
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
||||
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
||||
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
|
||||
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
|
||||
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# ChatCompletion Middleware
|
||||
@ -217,26 +222,6 @@ app.state.MODELS = {}
|
||||
##################################
|
||||
|
||||
|
||||
def get_task_model_id(default_model_id):
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||
if (
|
||||
app.state.config.TASK_MODEL
|
||||
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
else:
|
||||
if (
|
||||
app.state.config.TASK_MODEL_EXTERNAL
|
||||
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
|
||||
return task_model_id
|
||||
|
||||
|
||||
def get_filter_function_ids(model):
|
||||
def get_priority(function_id):
|
||||
function = Functions.get_function_by_id(function_id)
|
||||
@ -366,8 +351,24 @@ async def get_content_from_response(response) -> Optional[str]:
|
||||
return content
|
||||
|
||||
|
||||
def get_task_model_id(
|
||||
default_model_id: str, task_model: str, task_model_external: str, models
|
||||
) -> str:
|
||||
# Set the task model
|
||||
task_model_id = default_model_id
|
||||
# Check if the user has a custom task model and use that model
|
||||
if models[task_model_id]["owned_by"] == "ollama":
|
||||
if task_model and task_model in models:
|
||||
task_model_id = task_model
|
||||
else:
|
||||
if task_model_external and task_model_external in models:
|
||||
task_model_id = task_model_external
|
||||
|
||||
return task_model_id
|
||||
|
||||
|
||||
async def chat_completion_tools_handler(
|
||||
body: dict, user: UserModel, extra_params: dict
|
||||
body: dict, user: UserModel, models, extra_params: dict
|
||||
) -> tuple[dict, dict]:
|
||||
# If tool_ids field is present, call the functions
|
||||
metadata = body.get("metadata", {})
|
||||
@ -381,14 +382,19 @@ async def chat_completion_tools_handler(
|
||||
contexts = []
|
||||
citations = []
|
||||
|
||||
task_model_id = get_task_model_id(body["model"])
|
||||
task_model_id = get_task_model_id(
|
||||
body["model"],
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
tools = get_tools(
|
||||
webui_app,
|
||||
tool_ids,
|
||||
user,
|
||||
{
|
||||
**extra_params,
|
||||
"__model__": app.state.MODELS[task_model_id],
|
||||
"__model__": models[task_model_id],
|
||||
"__messages__": body["messages"],
|
||||
"__files__": metadata.get("files", []),
|
||||
},
|
||||
@ -412,7 +418,7 @@ async def chat_completion_tools_handler(
|
||||
)
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
@ -513,16 +519,16 @@ def is_chat_completion_request(request):
|
||||
)
|
||||
|
||||
|
||||
async def get_body_and_model_and_user(request):
|
||||
async def get_body_and_model_and_user(request, models):
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
body_str = body.decode("utf-8")
|
||||
body = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = body["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
model = app.state.MODELS[model_id]
|
||||
model = models[model_id]
|
||||
|
||||
user = get_current_user(
|
||||
request,
|
||||
@ -538,14 +544,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
return await call_next(request)
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
try:
|
||||
body, model, user = await get_body_and_model_and_user(request)
|
||||
body, model, user = await get_body_and_model_and_user(request, models)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if user.role == "user":
|
||||
if model_info and not has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={"detail": "User does not have access to the model"},
|
||||
)
|
||||
|
||||
metadata = {
|
||||
"chat_id": body.pop("chat_id", None),
|
||||
"message_id": body.pop("id", None),
|
||||
@ -582,15 +601,20 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
tool_ids = body.pop("tool_ids", None)
|
||||
files = body.pop("files", None)
|
||||
|
||||
metadata = {
|
||||
**metadata,
|
||||
"tool_ids": body.pop("tool_ids", None),
|
||||
"files": body.pop("files", None),
|
||||
"tool_ids": tool_ids,
|
||||
"files": files,
|
||||
}
|
||||
body["metadata"] = metadata
|
||||
|
||||
try:
|
||||
body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
||||
body, flags = await chat_completion_tools_handler(
|
||||
body, user, models, extra_params
|
||||
)
|
||||
contexts.extend(flags.get("contexts", []))
|
||||
citations.extend(flags.get("citations", []))
|
||||
except Exception as e:
|
||||
@ -687,10 +711,10 @@ app.add_middleware(ChatCompletionMiddleware)
|
||||
##################################
|
||||
|
||||
|
||||
def get_sorted_filters(model_id):
|
||||
def get_sorted_filters(model_id, models):
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
for model in models.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
@ -706,12 +730,12 @@ def get_sorted_filters(model_id):
|
||||
return sorted_filters
|
||||
|
||||
|
||||
def filter_pipeline(payload, user):
|
||||
def filter_pipeline(payload, user, models):
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
sorted_filters = get_sorted_filters(model_id)
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
model = models[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
@ -782,8 +806,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
content={"detail": "Not authenticated"},
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
try:
|
||||
data = filter_pipeline(data, user)
|
||||
data = filter_pipeline(data, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
@ -862,16 +889,10 @@ async def commit_session_after_request(request: Request, call_next):
|
||||
|
||||
@app.middleware("http")
|
||||
async def check_url(request: Request, call_next):
|
||||
if len(app.state.MODELS) == 0:
|
||||
await get_all_models()
|
||||
else:
|
||||
pass
|
||||
|
||||
start_time = int(time.time())
|
||||
response = await call_next(request)
|
||||
process_time = int(time.time()) - start_time
|
||||
response.headers["X-Process-Time"] = str(process_time)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@ -911,10 +932,10 @@ app.mount("/retrieval/api/v1", retrieval_app)
|
||||
|
||||
app.mount("/api/v1", webui_app)
|
||||
|
||||
|
||||
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
|
||||
|
||||
|
||||
@cached(ttl=1)
|
||||
async def get_all_base_models():
|
||||
open_webui_models = []
|
||||
openai_models = []
|
||||
@ -944,6 +965,7 @@ async def get_all_base_models():
|
||||
return models
|
||||
|
||||
|
||||
@cached(ttl=1)
|
||||
async def get_all_models():
|
||||
models = await get_all_base_models()
|
||||
|
||||
@ -1065,9 +1087,6 @@ async def get_all_models():
|
||||
|
||||
function_module = get_function_module_by_id(action_id)
|
||||
model["actions"].extend(get_action_items_from_module(function_module))
|
||||
|
||||
app.state.MODELS = {model["id"]: model for model in models}
|
||||
webui_app.state.MODELS = app.state.MODELS
|
||||
return models
|
||||
|
||||
|
||||
@ -1082,16 +1101,19 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
|
||||
]
|
||||
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user":
|
||||
# models = list(
|
||||
# filter(
|
||||
# lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
||||
# models,
|
||||
# )
|
||||
# )
|
||||
# return {"data": models}
|
||||
# Filter out models that the user does not have access to
|
||||
if user.role == "user":
|
||||
filtered_models = []
|
||||
for model in models:
|
||||
model_info = Models.get_model_by_id(model["id"])
|
||||
if model_info:
|
||||
if has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
):
|
||||
filtered_models.append(model)
|
||||
else:
|
||||
filtered_models.append(model)
|
||||
models = filtered_models
|
||||
|
||||
return {"data": models}
|
||||
|
||||
@ -1106,24 +1128,27 @@ async def get_base_models(user=Depends(get_admin_user)):
|
||||
async def generate_chat_completions(
|
||||
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
|
||||
):
|
||||
model_id = form_data["model"]
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
if model_id not in app.state.MODELS:
|
||||
model_id = form_data["model"]
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# TODO: Check User Group and Filter Models
|
||||
# if not bypass_filter:
|
||||
# if app.state.config.ENABLE_MODEL_FILTER:
|
||||
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_403_FORBIDDEN,
|
||||
# detail="Model not found",
|
||||
# )
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
model = models[model_id]
|
||||
# Check if user has access to the model
|
||||
if user.role == "user":
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
if not has_access(
|
||||
user.id, type="read", access_control=model_info.access_control
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
if model["owned_by"] == "arena":
|
||||
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
||||
@ -1174,7 +1199,9 @@ async def generate_chat_completions(
|
||||
|
||||
if model.get("pipe"):
|
||||
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
||||
return await generate_function_chat_completion(form_data, user=user)
|
||||
return await generate_function_chat_completion(
|
||||
form_data, user=user, models=models
|
||||
)
|
||||
if model["owned_by"] == "ollama":
|
||||
# Using /ollama/api/chat endpoint
|
||||
form_data = convert_payload_openai_to_ollama(form_data)
|
||||
@ -1198,16 +1225,20 @@ async def generate_chat_completions(
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
sorted_filters = get_sorted_filters(model_id)
|
||||
model = models[model_id]
|
||||
sorted_filters = get_sorted_filters(model_id, models)
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
@ -1382,14 +1413,18 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
|
||||
detail="Action not found",
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
model = models[model_id]
|
||||
|
||||
__event_emitter__ = get_event_emitter(
|
||||
{
|
||||
@ -1543,8 +1578,11 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
||||
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_title")
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
@ -1552,10 +1590,16 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
model = models[task_model_id]
|
||||
|
||||
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
@ -1589,7 +1633,7 @@ Artificial Intelligence in Healthcare
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 50}
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 50,
|
||||
}
|
||||
@ -1601,7 +1645,7 @@ Artificial Intelligence in Healthcare
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
@ -1628,8 +1672,11 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
||||
content={"detail": "Tags generation is disabled"},
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
@ -1637,7 +1684,12 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
print(task_model_id)
|
||||
|
||||
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
||||
@ -1675,7 +1727,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
@ -1702,8 +1754,11 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
detail=f"Search query generation is disabled",
|
||||
)
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
@ -1711,10 +1766,15 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
model = models[task_model_id]
|
||||
|
||||
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
@ -1741,7 +1801,7 @@ Search Query:"""
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 30}
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 30,
|
||||
}
|
||||
@ -1752,7 +1812,7 @@ Search Query:"""
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
@ -1774,8 +1834,11 @@ Search Query:"""
|
||||
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_emoji")
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
@ -1783,10 +1846,15 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
model = models[task_model_id]
|
||||
|
||||
template = '''
|
||||
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
||||
@ -1808,7 +1876,7 @@ Message: """{{prompt}}"""
|
||||
"stream": False,
|
||||
**(
|
||||
{"max_tokens": 4}
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
|
||||
if models[task_model_id]["owned_by"] == "ollama"
|
||||
else {
|
||||
"max_completion_tokens": 4,
|
||||
}
|
||||
@ -1820,7 +1888,7 @@ Message: """{{prompt}}"""
|
||||
|
||||
# Handle pipeline filters
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
@ -1842,8 +1910,11 @@ Message: """{{prompt}}"""
|
||||
async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_moa_response")
|
||||
|
||||
model_list = await get_all_models()
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
if model_id not in models:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
@ -1851,10 +1922,15 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
task_model_id = get_task_model_id(model_id)
|
||||
task_model_id = get_task_model_id(
|
||||
model_id,
|
||||
app.state.config.TASK_MODEL,
|
||||
app.state.config.TASK_MODEL_EXTERNAL,
|
||||
models,
|
||||
)
|
||||
print(task_model_id)
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
model = models[task_model_id]
|
||||
|
||||
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
||||
|
||||
@ -1881,7 +1957,7 @@ Responses from models: {{responses}}"""
|
||||
log.debug(payload)
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
payload = filter_pipeline(payload, user, models)
|
||||
except Exception as e:
|
||||
if len(e.args) > 1:
|
||||
return JSONResponse(
|
||||
@ -1911,7 +1987,7 @@ Responses from models: {{responses}}"""
|
||||
|
||||
@app.get("/api/pipelines/list")
|
||||
async def get_pipelines_list(user=Depends(get_admin_user)):
|
||||
responses = await get_openai_models(raw=True)
|
||||
responses = await get_openai_models_responses()
|
||||
|
||||
print(responses)
|
||||
urlIdxs = [
|
||||
|
@ -192,15 +192,16 @@ def has_permission(
|
||||
|
||||
def has_access(
|
||||
user_id: str,
|
||||
action: str = "write",
|
||||
type: str = "write",
|
||||
access_control: Optional[dict] = None,
|
||||
) -> bool:
|
||||
print("user_id", user_id, "type", type, "access_control", access_control)
|
||||
if access_control is None:
|
||||
return action == "read"
|
||||
return type == "read"
|
||||
|
||||
user_groups = Groups.get_groups_by_member_id(user_id)
|
||||
user_group_ids = [group.id for group in user_groups]
|
||||
permission_access = access_control.get(action, {})
|
||||
permission_access = access_control.get(type, {})
|
||||
permitted_group_ids = permission_access.get("group_ids", [])
|
||||
permitted_user_ids = permission_access.get("user_ids", [])
|
||||
|
||||
|
@ -13,6 +13,7 @@ passlib[bcrypt]==1.7.4
|
||||
requests==2.32.3
|
||||
aiohttp==3.10.8
|
||||
async-timeout
|
||||
aiocache
|
||||
|
||||
sqlalchemy==2.0.32
|
||||
alembic==1.13.2
|
||||
|
@ -21,6 +21,7 @@ dependencies = [
|
||||
"requests==2.32.3",
|
||||
"aiohttp==3.10.8",
|
||||
"async-timeout",
|
||||
"aiocache",
|
||||
|
||||
"sqlalchemy==2.0.32",
|
||||
"alembic==1.13.2",
|
||||
|
@ -71,7 +71,7 @@
|
||||
const upsertModelHandler = async (model) => {
|
||||
model.base_model_id = null;
|
||||
|
||||
if (models.find((m) => m.id === model.id)) {
|
||||
if (workspaceModels.find((m) => m.id === model.id)) {
|
||||
await updateModelById(localStorage.token, model.id, model).catch((error) => {
|
||||
return null;
|
||||
});
|
||||
|
Loading…
Reference in New Issue
Block a user