This commit is contained in:
Timothy Jaeryang Baek 2024-11-16 04:41:07 -08:00
parent 931e03bd9e
commit 932de8f1e2
9 changed files with 309 additions and 214 deletions

View File

@ -68,25 +68,12 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
app.state.MODELS = {}
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances. # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin, # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization. # least connections, or least response time for better resource utilization and performance optimization.
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
response = await call_next(request)
return response
@app.head("/") @app.head("/")
@app.get("/") @app.get("/")
async def get_status(): async def get_status():
@ -321,8 +308,6 @@ async def get_all_models():
else: else:
models = {"models": []} models = {"models": []}
app.state.MODELS = {model["model"]: model for model in models["models"]}
return models return models
@ -470,8 +455,11 @@ async def push_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
if form_data.name in app.state.MODELS: model_list = await get_all_models()
url_idx = app.state.MODELS[form_data.name]["urls"][0] models = {model["model"]: model for model in model_list["models"]}
if form_data.name in models:
url_idx = models[form_data.name]["urls"][0]
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -520,8 +508,11 @@ async def copy_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
if form_data.source in app.state.MODELS: model_list = await get_all_models()
url_idx = app.state.MODELS[form_data.source]["urls"][0] models = {model["model"]: model for model in model_list["models"]}
if form_data.source in models:
url_idx = models[form_data.source]["urls"][0]
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -576,8 +567,11 @@ async def delete_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
if form_data.name in app.state.MODELS: model_list = await get_all_models()
url_idx = app.state.MODELS[form_data.name]["urls"][0] models = {model["model"]: model for model in model_list["models"]}
if form_data.name in models:
url_idx = models[form_data.name]["urls"][0]
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -625,13 +619,16 @@ async def delete_model(
@app.post("/api/show") @app.post("/api/show")
async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)): async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
if form_data.name not in app.state.MODELS: model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
if form_data.name not in models:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
) )
url_idx = random.choice(app.state.MODELS[form_data.name]["urls"]) url_idx = random.choice(models[form_data.name]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
@ -701,23 +698,26 @@ async def generate_embeddings(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx) return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
def generate_ollama_embeddings( async def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
): ):
log.info(f"generate_ollama_embeddings {form_data}") log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None: if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
model = f"{model}:latest" model = f"{model}:latest"
if model in app.state.MODELS: if model in models:
url_idx = random.choice(app.state.MODELS[model]["urls"]) url_idx = random.choice(models[model]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -768,20 +768,23 @@ def generate_ollama_embeddings(
) )
def generate_ollama_batch_embeddings( async def generate_ollama_batch_embeddings(
form_data: GenerateEmbedForm, form_data: GenerateEmbedForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
): ):
log.info(f"generate_ollama_batch_embeddings {form_data}") log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None: if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
model = f"{model}:latest" model = f"{model}:latest"
if model in app.state.MODELS: if model in models:
url_idx = random.choice(app.state.MODELS[model]["urls"]) url_idx = random.choice(models[model]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -851,13 +854,16 @@ async def generate_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx is None: if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
model = f"{model}:latest" model = f"{model}:latest"
if model in app.state.MODELS: if model in models:
url_idx = random.choice(app.state.MODELS[model]["urls"]) url_idx = random.choice(models[model]["urls"])
else: else:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
@ -892,14 +898,17 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
def get_ollama_url(url_idx: Optional[int], model: str): async def get_ollama_url(url_idx: Optional[int], model: str):
if url_idx is None: if url_idx is None:
if model not in app.state.MODELS: model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
if model not in models:
raise HTTPException( raise HTTPException(
status_code=400, status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model), detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
) )
url_idx = random.choice(app.state.MODELS[model]["urls"]) url_idx = random.choice(models[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
return url return url
@ -948,7 +957,7 @@ async def generate_chat_completion(
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(f"generate_chat_completion() - 2.payload = {payload}") log.debug(f"generate_chat_completion() - 2.payload = {payload}")
@ -1030,7 +1039,7 @@ async def generate_openai_chat_completion(
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
url = get_ollama_url(url_idx, payload["model"]) url = await get_ollama_url(url_idx, payload["model"])
log.info(f"url: {url}") log.info(f"url: {url}")
api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {}) api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})

View File

@ -36,7 +36,7 @@ from open_webui.utils.payload import (
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
from open_webui.utils.utils import get_admin_user, get_verified_user from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OPENAI"]) log.setLevel(SRC_LOG_LEVELS["OPENAI"])
@ -64,17 +64,6 @@ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
app.state.MODELS = {}
@app.middleware("http")
async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
response = await call_next(request)
return response
@app.get("/config") @app.get("/config")
async def get_config(user=Depends(get_admin_user)): async def get_config(user=Depends(get_admin_user)):
@ -259,7 +248,7 @@ def merge_models_lists(model_lists):
return merged_list return merged_list
async def get_all_models_raw() -> list: async def get_all_models_responses() -> list:
if not app.state.config.ENABLE_OPENAI_API: if not app.state.config.ENABLE_OPENAI_API:
return [] return []
@ -330,22 +319,13 @@ async def get_all_models_raw() -> list:
return responses return responses
@overload async def get_all_models() -> dict[str, list]:
async def get_all_models(raw: Literal[True]) -> list: ...
@overload
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
async def get_all_models(raw=False) -> dict[str, list] | list:
log.info("get_all_models()") log.info("get_all_models()")
if not app.state.config.ENABLE_OPENAI_API:
return [] if raw else {"data": []}
responses = await get_all_models_raw() if not app.state.config.ENABLE_OPENAI_API:
if raw: return {"data": []}
return responses
responses = await get_all_models_responses()
def extract_data(response): def extract_data(response):
if response and "data" in response: if response and "data" in response:
@ -355,9 +335,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
return None return None
models = {"data": merge_models_lists(map(extract_data, responses))} models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}") log.debug(f"models: {models}")
app.state.MODELS = {model["id"]: model for model in models["data"]}
return models return models
@ -365,21 +343,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
@app.get("/models") @app.get("/models")
@app.get("/models/{url_idx}") @app.get("/models/{url_idx}")
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)): async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
models = {
"data": [],
}
if url_idx is None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
# TODO: Check User Group and Filter Models
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user":
# models["data"] = list(
# filter(
# lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
# models["data"],
# )
# )
# return models
return models
else: else:
url = app.state.config.OPENAI_API_BASE_URLS[url_idx] url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = app.state.config.OPENAI_API_KEYS[url_idx] key = app.state.config.OPENAI_API_KEYS[url_idx]
@ -387,6 +356,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
headers = {} headers = {}
headers["Authorization"] = f"Bearer {key}" headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json" headers["Content-Type"] = "application/json"
if ENABLE_FORWARD_USER_INFO_HEADERS: if ENABLE_FORWARD_USER_INFO_HEADERS:
headers["X-OpenWebUI-User-Name"] = user.name headers["X-OpenWebUI-User-Name"] = user.name
headers["X-OpenWebUI-User-Id"] = user.id headers["X-OpenWebUI-User-Id"] = user.id
@ -428,8 +398,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
) )
] ]
return response_data models = response_data
except aiohttp.ClientError as e: except aiohttp.ClientError as e:
# ClientError covers all aiohttp requests issues # ClientError covers all aiohttp requests issues
log.exception(f"Client error: {str(e)}") log.exception(f"Client error: {str(e)}")
@ -443,6 +412,22 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
error_detail = f"Unexpected error: {str(e)}" error_detail = f"Unexpected error: {str(e)}"
raise HTTPException(status_code=500, detail=error_detail) raise HTTPException(status_code=500, detail=error_detail)
if user.role == "user":
# Filter models based on user access control
filtered_models = []
for model in models.get("data", []):
model_info = Models.get_model_by_id(model["id"])
if model_info:
if has_access(
user.id, type="read", access_control=model_info.access_control
):
filtered_models.append(model)
else:
filtered_models.append(model)
models["data"] = filtered_models
return models
class ConnectionVerificationForm(BaseModel): class ConnectionVerificationForm(BaseModel):
url: str url: str
@ -502,18 +487,9 @@ async def generate_chat_completion(
del payload["metadata"] del payload["metadata"]
model_id = form_data.get("model") model_id = form_data.get("model")
# TODO: Check User Group and Filter Models
# if not bypass_filter:
# if app.state.config.ENABLE_MODEL_FILTER:
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
# raise HTTPException(
# status_code=403,
# detail="Model not found",
# )
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
# Check model info and override the payload
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
@ -522,12 +498,33 @@ async def generate_chat_completion(
payload = apply_model_params_to_body_openai(params, payload) payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
try: # Check if user has access to the model
model = app.state.MODELS[payload.get("model")] if user.role == "user" and not has_access(
idx = model["urlIdx"] user.id, type="read", access_control=model_info.access_control
except Exception as e: ):
raise HTTPException(status_code=404, detail="Model not found") raise HTTPException(
status_code=403,
detail="Model not found",
)
# Attemp to get urlIdx from the model
models = await get_all_models()
# Find the model from the list
model = next(
(model for model in models["data"] if model["id"] == payload.get("model")),
None,
)
if model:
idx = model["urlIdx"]
else:
raise HTTPException(
status_code=404,
detail="Model not found",
)
# Get the API config for the model
api_config = app.state.config.OPENAI_API_CONFIGS.get( api_config = app.state.config.OPENAI_API_CONFIGS.get(
app.state.config.OPENAI_API_BASE_URLS[idx], {} app.state.config.OPENAI_API_BASE_URLS[idx], {}
) )
@ -536,6 +533,7 @@ async def generate_chat_completion(
if prefix_id: if prefix_id:
payload["model"] = payload["model"].replace(f"{prefix_id}.", "") payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
# Add user info to the payload if the model is a pipeline
if "pipeline" in model and model.get("pipeline"): if "pipeline" in model and model.get("pipeline"):
payload["user"] = { payload["user"] = {
"name": user.name, "name": user.name,
@ -546,8 +544,9 @@ async def generate_chat_completion(
url = app.state.config.OPENAI_API_BASE_URLS[idx] url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx] key = app.state.config.OPENAI_API_KEYS[idx]
is_o1 = payload["model"].lower().startswith("o1-")
# Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
is_o1 = payload["model"].lower().startswith("o1-")
# Change max_completion_tokens to max_tokens (Backward compatible) # Change max_completion_tokens to max_tokens (Backward compatible)
if "api.openai.com" not in url and not is_o1: if "api.openai.com" not in url and not is_o1:
if "max_completion_tokens" in payload: if "max_completion_tokens" in payload:

View File

@ -3,6 +3,7 @@ import os
import uuid import uuid
from typing import Optional, Union from typing import Optional, Union
import asyncio
import requests import requests
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
@ -291,7 +292,13 @@ def get_embedding_function(
if embedding_engine == "": if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist() return lambda query: embedding_function.encode(query).tolist()
elif embedding_engine in ["ollama", "openai"]: elif embedding_engine in ["ollama", "openai"]:
func = lambda query: generate_embeddings(
# Wrapper to run the async generate_embeddings synchronously.
def sync_generate_embeddings(*args, **kwargs):
return asyncio.run(generate_embeddings(*args, **kwargs))
# Semantic expectation from the original version (using sync wrapper).
func = lambda query: sync_generate_embeddings(
engine=embedding_engine, engine=embedding_engine,
model=embedding_model, model=embedding_model,
text=query, text=query,
@ -469,7 +476,7 @@ def get_model_path(model: str, update_model: bool = False):
return model return model
def generate_openai_batch_embeddings( async def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1" model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]: ) -> Optional[list[list[float]]]:
try: try:
@ -492,14 +499,16 @@ def generate_openai_batch_embeddings(
return None return None
def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs): async def generate_embeddings(
engine: str, model: str, text: Union[str, list[str]], **kwargs
):
if engine == "ollama": if engine == "ollama":
if isinstance(text, list): if isinstance(text, list):
embeddings = generate_ollama_batch_embeddings( embeddings = await generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": text}) GenerateEmbedForm(**{"model": model, "input": text})
) )
else: else:
embeddings = generate_ollama_batch_embeddings( embeddings = await generate_ollama_batch_embeddings(
GenerateEmbedForm(**{"model": model, "input": [text]}) GenerateEmbedForm(**{"model": model, "input": [text]})
) )
return ( return (
@ -512,9 +521,9 @@ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **
url = kwargs.get("url", "https://api.openai.com/v1") url = kwargs.get("url", "https://api.openai.com/v1")
if isinstance(text, list): if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url) embeddings = await generate_openai_batch_embeddings(model, text, key, url)
else: else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url) embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings return embeddings[0] if isinstance(text, str) else embeddings

View File

@ -142,7 +142,6 @@ app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
app.state.config.LDAP_CIPHERS = LDAP_CIPHERS app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
app.state.MODELS = {}
app.state.TOOLS = {} app.state.TOOLS = {}
app.state.FUNCTIONS = {} app.state.FUNCTIONS = {}
@ -369,7 +368,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):
return params return params
async def generate_function_chat_completion(form_data, user): async def generate_function_chat_completion(form_data, user, models: dict = {}):
model_id = form_data.get("model") model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
@ -412,7 +411,7 @@ async def generate_function_chat_completion(form_data, user):
user, user,
{ {
**extra_params, **extra_params,
"__model__": app.state.MODELS[form_data["model"]], "__model__": models.get(form_data["model"], None),
"__messages__": form_data["messages"], "__messages__": form_data["messages"],
"__files__": files, "__files__": files,
}, },

View File

@ -11,6 +11,7 @@ import random
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Optional from typing import Optional
from aiocache import cached
import aiohttp import aiohttp
import requests import requests
from fastapi import ( from fastapi import (
@ -45,6 +46,7 @@ from open_webui.apps.openai.main import (
app as openai_app, app as openai_app,
generate_chat_completion as generate_openai_chat_completion, generate_chat_completion as generate_openai_chat_completion,
get_all_models as get_openai_models, get_all_models as get_openai_models,
get_all_models_responses as get_openai_models_responses,
) )
from open_webui.apps.retrieval.main import app as retrieval_app from open_webui.apps.retrieval.main import app as retrieval_app
from open_webui.apps.retrieval.utils import get_rag_context, rag_template from open_webui.apps.retrieval.utils import get_rag_context, rag_template
@ -132,6 +134,7 @@ from open_webui.utils.utils import (
get_current_user, get_current_user,
get_http_authorization_cred, get_http_authorization_cred,
get_verified_user, get_verified_user,
has_access,
) )
if SAFE_MODE: if SAFE_MODE:
@ -196,20 +199,22 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = ( app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
) )
app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = ( app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
) )
app.state.MODELS = {}
################################## ##################################
# #
# ChatCompletion Middleware # ChatCompletion Middleware
@ -217,26 +222,6 @@ app.state.MODELS = {}
################################## ##################################
def get_task_model_id(default_model_id):
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
return task_model_id
def get_filter_function_ids(model): def get_filter_function_ids(model):
def get_priority(function_id): def get_priority(function_id):
function = Functions.get_function_by_id(function_id) function = Functions.get_function_by_id(function_id)
@ -366,8 +351,24 @@ async def get_content_from_response(response) -> Optional[str]:
return content return content
def get_task_model_id(
default_model_id: str, task_model: str, task_model_external: str, models
) -> str:
# Set the task model
task_model_id = default_model_id
# Check if the user has a custom task model and use that model
if models[task_model_id]["owned_by"] == "ollama":
if task_model and task_model in models:
task_model_id = task_model
else:
if task_model_external and task_model_external in models:
task_model_id = task_model_external
return task_model_id
async def chat_completion_tools_handler( async def chat_completion_tools_handler(
body: dict, user: UserModel, extra_params: dict body: dict, user: UserModel, models, extra_params: dict
) -> tuple[dict, dict]: ) -> tuple[dict, dict]:
# If tool_ids field is present, call the functions # If tool_ids field is present, call the functions
metadata = body.get("metadata", {}) metadata = body.get("metadata", {})
@ -381,14 +382,19 @@ async def chat_completion_tools_handler(
contexts = [] contexts = []
citations = [] citations = []
task_model_id = get_task_model_id(body["model"]) task_model_id = get_task_model_id(
body["model"],
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
tools = get_tools( tools = get_tools(
webui_app, webui_app,
tool_ids, tool_ids,
user, user,
{ {
**extra_params, **extra_params,
"__model__": app.state.MODELS[task_model_id], "__model__": models[task_model_id],
"__messages__": body["messages"], "__messages__": body["messages"],
"__files__": metadata.get("files", []), "__files__": metadata.get("files", []),
}, },
@ -412,7 +418,7 @@ async def chat_completion_tools_handler(
) )
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user, models)
except Exception as e: except Exception as e:
raise e raise e
@ -513,16 +519,16 @@ def is_chat_completion_request(request):
) )
async def get_body_and_model_and_user(request): async def get_body_and_model_and_user(request, models):
# Read the original request body # Read the original request body
body = await request.body() body = await request.body()
body_str = body.decode("utf-8") body_str = body.decode("utf-8")
body = json.loads(body_str) if body_str else {} body = json.loads(body_str) if body_str else {}
model_id = body["model"] model_id = body["model"]
if model_id not in app.state.MODELS: if model_id not in models:
raise Exception("Model not found") raise Exception("Model not found")
model = app.state.MODELS[model_id] model = models[model_id]
user = get_current_user( user = get_current_user(
request, request,
@ -538,14 +544,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
return await call_next(request) return await call_next(request)
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
try: try:
body, model, user = await get_body_and_model_and_user(request) body, model, user = await get_body_and_model_and_user(request, models)
except Exception as e: except Exception as e:
return JSONResponse( return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
content={"detail": str(e)}, content={"detail": str(e)},
) )
model_info = Models.get_model_by_id(model["id"])
if user.role == "user":
if model_info and not has_access(
user.id, type="read", access_control=model_info.access_control
):
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "User does not have access to the model"},
)
metadata = { metadata = {
"chat_id": body.pop("chat_id", None), "chat_id": body.pop("chat_id", None),
"message_id": body.pop("id", None), "message_id": body.pop("id", None),
@ -582,15 +601,20 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
content={"detail": str(e)}, content={"detail": str(e)},
) )
tool_ids = body.pop("tool_ids", None)
files = body.pop("files", None)
metadata = { metadata = {
**metadata, **metadata,
"tool_ids": body.pop("tool_ids", None), "tool_ids": tool_ids,
"files": body.pop("files", None), "files": files,
} }
body["metadata"] = metadata body["metadata"] = metadata
try: try:
body, flags = await chat_completion_tools_handler(body, user, extra_params) body, flags = await chat_completion_tools_handler(
body, user, models, extra_params
)
contexts.extend(flags.get("contexts", [])) contexts.extend(flags.get("contexts", []))
citations.extend(flags.get("citations", [])) citations.extend(flags.get("citations", []))
except Exception as e: except Exception as e:
@ -687,10 +711,10 @@ app.add_middleware(ChatCompletionMiddleware)
################################## ##################################
def get_sorted_filters(model_id): def get_sorted_filters(model_id, models):
filters = [ filters = [
model model
for model in app.state.MODELS.values() for model in models.values()
if "pipeline" in model if "pipeline" in model
and "type" in model["pipeline"] and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter" and model["pipeline"]["type"] == "filter"
@ -706,12 +730,12 @@ def get_sorted_filters(model_id):
return sorted_filters return sorted_filters
def filter_pipeline(payload, user): def filter_pipeline(payload, user, models):
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role} user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
model_id = payload["model"] model_id = payload["model"]
sorted_filters = get_sorted_filters(model_id)
model = app.state.MODELS[model_id] sorted_filters = get_sorted_filters(model_id, models)
model = models[model_id]
if "pipeline" in model: if "pipeline" in model:
sorted_filters.append(model) sorted_filters.append(model)
@ -782,8 +806,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
content={"detail": "Not authenticated"}, content={"detail": "Not authenticated"},
) )
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
try: try:
data = filter_pipeline(data, user) data = filter_pipeline(data, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -862,16 +889,10 @@ async def commit_session_after_request(request: Request, call_next):
@app.middleware("http") @app.middleware("http")
async def check_url(request: Request, call_next): async def check_url(request: Request, call_next):
if len(app.state.MODELS) == 0:
await get_all_models()
else:
pass
start_time = int(time.time()) start_time = int(time.time())
response = await call_next(request) response = await call_next(request)
process_time = int(time.time()) - start_time process_time = int(time.time()) - start_time
response.headers["X-Process-Time"] = str(process_time) response.headers["X-Process-Time"] = str(process_time)
return response return response
@ -911,10 +932,10 @@ app.mount("/retrieval/api/v1", retrieval_app)
app.mount("/api/v1", webui_app) app.mount("/api/v1", webui_app)
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
@cached(ttl=1)
async def get_all_base_models(): async def get_all_base_models():
open_webui_models = [] open_webui_models = []
openai_models = [] openai_models = []
@ -944,6 +965,7 @@ async def get_all_base_models():
return models return models
@cached(ttl=1)
async def get_all_models(): async def get_all_models():
models = await get_all_base_models() models = await get_all_base_models()
@ -1065,9 +1087,6 @@ async def get_all_models():
function_module = get_function_module_by_id(action_id) function_module = get_function_module_by_id(action_id)
model["actions"].extend(get_action_items_from_module(function_module)) model["actions"].extend(get_action_items_from_module(function_module))
app.state.MODELS = {model["id"]: model for model in models}
webui_app.state.MODELS = app.state.MODELS
return models return models
@ -1082,16 +1101,19 @@ async def get_models(user=Depends(get_verified_user)):
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter" if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
] ]
# TODO: Check User Group and Filter Models # Filter out models that the user does not have access to
# if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user":
# if user.role == "user": filtered_models = []
# models = list( for model in models:
# filter( model_info = Models.get_model_by_id(model["id"])
# lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST, if model_info:
# models, if has_access(
# ) user.id, type="read", access_control=model_info.access_control
# ) ):
# return {"data": models} filtered_models.append(model)
else:
filtered_models.append(model)
models = filtered_models
return {"data": models} return {"data": models}
@ -1106,24 +1128,27 @@ async def get_base_models(user=Depends(get_admin_user)):
async def generate_chat_completions( async def generate_chat_completions(
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
): ):
model_id = form_data["model"] model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
if model_id not in app.state.MODELS: model_id = form_data["model"]
if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
# TODO: Check User Group and Filter Models model = models[model_id]
# if not bypass_filter: # Check if user has access to the model
# if app.state.config.ENABLE_MODEL_FILTER: if user.role == "user":
# if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST: model_info = Models.get_model_by_id(model_id)
# raise HTTPException( if not has_access(
# status_code=status.HTTP_403_FORBIDDEN, user.id, type="read", access_control=model_info.access_control
# detail="Model not found", ):
# ) raise HTTPException(
status_code=403,
model = app.state.MODELS[model_id] detail="Model not found",
)
if model["owned_by"] == "arena": if model["owned_by"] == "arena":
model_ids = model.get("info", {}).get("meta", {}).get("model_ids") model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
@ -1174,7 +1199,9 @@ async def generate_chat_completions(
if model.get("pipe"): if model.get("pipe"):
# Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
return await generate_function_chat_completion(form_data, user=user) return await generate_function_chat_completion(
form_data, user=user, models=models
)
if model["owned_by"] == "ollama": if model["owned_by"] == "ollama":
# Using /ollama/api/chat endpoint # Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data) form_data = convert_payload_openai_to_ollama(form_data)
@ -1198,16 +1225,20 @@ async def generate_chat_completions(
@app.post("/api/chat/completed") @app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)): async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
data = form_data data = form_data
model_id = data["model"] model_id = data["model"]
if model_id not in app.state.MODELS: if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id]
sorted_filters = get_sorted_filters(model_id) model = models[model_id]
sorted_filters = get_sorted_filters(model_id, models)
if "pipeline" in model: if "pipeline" in model:
sorted_filters = [model] + sorted_filters sorted_filters = [model] + sorted_filters
@ -1382,14 +1413,18 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
detail="Action not found", detail="Action not found",
) )
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
data = form_data data = form_data
model_id = data["model"] model_id = data["model"]
if model_id not in app.state.MODELS:
if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
) )
model = app.state.MODELS[model_id] model = models[model_id]
__event_emitter__ = get_event_emitter( __event_emitter__ = get_event_emitter(
{ {
@ -1543,8 +1578,11 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
async def generate_title(form_data: dict, user=Depends(get_verified_user)): async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print("generate_title") print("generate_title")
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
@ -1552,10 +1590,16 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id) task_model_id = get_task_model_id(
model_id,
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
print(task_model_id) print(task_model_id)
model = app.state.MODELS[task_model_id] model = models[task_model_id]
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "": if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@ -1589,7 +1633,7 @@ Artificial Intelligence in Healthcare
"stream": False, "stream": False,
**( **(
{"max_tokens": 50} {"max_tokens": 50}
if app.state.MODELS[task_model_id]["owned_by"] == "ollama" if models[task_model_id]["owned_by"] == "ollama"
else { else {
"max_completion_tokens": 50, "max_completion_tokens": 50,
} }
@ -1601,7 +1645,7 @@ Artificial Intelligence in Healthcare
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -1628,8 +1672,11 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
content={"detail": "Tags generation is disabled"}, content={"detail": "Tags generation is disabled"},
) )
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
@ -1637,7 +1684,12 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id) task_model_id = get_task_model_id(
model_id,
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
print(task_model_id) print(task_model_id)
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "": if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
@ -1675,7 +1727,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -1702,8 +1754,11 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
detail=f"Search query generation is disabled", detail=f"Search query generation is disabled",
) )
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
@ -1711,10 +1766,15 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id) task_model_id = get_task_model_id(
model_id,
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
print(task_model_id) print(task_model_id)
model = app.state.MODELS[task_model_id] model = models[task_model_id]
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "": if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@ -1741,7 +1801,7 @@ Search Query:"""
"stream": False, "stream": False,
**( **(
{"max_tokens": 30} {"max_tokens": 30}
if app.state.MODELS[task_model_id]["owned_by"] == "ollama" if models[task_model_id]["owned_by"] == "ollama"
else { else {
"max_completion_tokens": 30, "max_completion_tokens": 30,
} }
@ -1752,7 +1812,7 @@ Search Query:"""
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -1774,8 +1834,11 @@ Search Query:"""
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)): async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
print("generate_emoji") print("generate_emoji")
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
@ -1783,10 +1846,15 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id) task_model_id = get_task_model_id(
model_id,
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
print(task_model_id) print(task_model_id)
model = app.state.MODELS[task_model_id] model = models[task_model_id]
template = ''' template = '''
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱). Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@ -1808,7 +1876,7 @@ Message: """{{prompt}}"""
"stream": False, "stream": False,
**( **(
{"max_tokens": 4} {"max_tokens": 4}
if app.state.MODELS[task_model_id]["owned_by"] == "ollama" if models[task_model_id]["owned_by"] == "ollama"
else { else {
"max_completion_tokens": 4, "max_completion_tokens": 4,
} }
@ -1820,7 +1888,7 @@ Message: """{{prompt}}"""
# Handle pipeline filters # Handle pipeline filters
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -1842,8 +1910,11 @@ Message: """{{prompt}}"""
async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)): async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
print("generate_moa_response") print("generate_moa_response")
model_list = await get_all_models()
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in app.state.MODELS: if model_id not in models:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found", detail="Model not found",
@ -1851,10 +1922,15 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)
# Check if the user has a custom task model # Check if the user has a custom task model
# If the user has a custom task model, use that model # If the user has a custom task model, use that model
task_model_id = get_task_model_id(model_id) task_model_id = get_task_model_id(
model_id,
app.state.config.TASK_MODEL,
app.state.config.TASK_MODEL_EXTERNAL,
models,
)
print(task_model_id) print(task_model_id)
model = app.state.MODELS[task_model_id] model = models[task_model_id]
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}" template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
@ -1881,7 +1957,7 @@ Responses from models: {{responses}}"""
log.debug(payload) log.debug(payload)
try: try:
payload = filter_pipeline(payload, user) payload = filter_pipeline(payload, user, models)
except Exception as e: except Exception as e:
if len(e.args) > 1: if len(e.args) > 1:
return JSONResponse( return JSONResponse(
@ -1911,7 +1987,7 @@ Responses from models: {{responses}}"""
@app.get("/api/pipelines/list") @app.get("/api/pipelines/list")
async def get_pipelines_list(user=Depends(get_admin_user)): async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True) responses = await get_openai_models_responses()
print(responses) print(responses)
urlIdxs = [ urlIdxs = [

View File

@ -192,15 +192,16 @@ def has_permission(
def has_access( def has_access(
user_id: str, user_id: str,
action: str = "write", type: str = "write",
access_control: Optional[dict] = None, access_control: Optional[dict] = None,
) -> bool: ) -> bool:
print("user_id", user_id, "type", type, "access_control", access_control)
if access_control is None: if access_control is None:
return action == "read" return type == "read"
user_groups = Groups.get_groups_by_member_id(user_id) user_groups = Groups.get_groups_by_member_id(user_id)
user_group_ids = [group.id for group in user_groups] user_group_ids = [group.id for group in user_groups]
permission_access = access_control.get(action, {}) permission_access = access_control.get(type, {})
permitted_group_ids = permission_access.get("group_ids", []) permitted_group_ids = permission_access.get("group_ids", [])
permitted_user_ids = permission_access.get("user_ids", []) permitted_user_ids = permission_access.get("user_ids", [])

View File

@ -13,6 +13,7 @@ passlib[bcrypt]==1.7.4
requests==2.32.3 requests==2.32.3
aiohttp==3.10.8 aiohttp==3.10.8
async-timeout async-timeout
aiocache
sqlalchemy==2.0.32 sqlalchemy==2.0.32
alembic==1.13.2 alembic==1.13.2

View File

@ -21,6 +21,7 @@ dependencies = [
"requests==2.32.3", "requests==2.32.3",
"aiohttp==3.10.8", "aiohttp==3.10.8",
"async-timeout", "async-timeout",
"aiocache",
"sqlalchemy==2.0.32", "sqlalchemy==2.0.32",
"alembic==1.13.2", "alembic==1.13.2",

View File

@ -71,7 +71,7 @@
const upsertModelHandler = async (model) => { const upsertModelHandler = async (model) => {
model.base_model_id = null; model.base_model_id = null;
if (models.find((m) => m.id === model.id)) { if (workspaceModels.find((m) => m.id === model.id)) {
await updateModelById(localStorage.token, model.id, model).catch((error) => { await updateModelById(localStorage.token, model.id, model).catch((error) => {
return null; return null;
}); });