This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 20:39:55 -08:00
parent d9ffcea764
commit 866c3dff11
3 changed files with 36 additions and 22 deletions

View File

@ -70,6 +70,15 @@ from open_webui.routers import (
users, users,
utils, utils,
) )
from open_webui.routers.openai import (
generate_chat_completion as generate_openai_chat_completion,
)
from open_webui.routers.ollama import (
generate_chat_completion as generate_ollama_chat_completion,
)
from open_webui.routers.retrieval import ( from open_webui.routers.retrieval import (
get_embedding_function, get_embedding_function,
get_ef, get_ef,
@ -1019,8 +1028,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
return await call_next(request) return await call_next(request)
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
model_list = await get_all_models() await get_all_models(request)
models = {model["id"]: model for model in model_list} models = app.state.MODELS
try: try:
body, model, user = await get_body_and_model_and_user(request, models) body, model, user = await get_body_and_model_and_user(request, models)
@ -1257,7 +1266,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
content={"detail": e.detail}, content={"detail": e.detail},
) )
await get_all_models() await get_all_models(request)
models = app.state.MODELS models = app.state.MODELS
try: try:
@ -1924,6 +1933,7 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
@app.post("/api/chat/completions") @app.post("/api/chat/completions")
async def generate_chat_completions( async def generate_chat_completions(
request: Request,
form_data: dict, form_data: dict,
user=Depends(get_verified_user), user=Depends(get_verified_user),
bypass_filter: bool = False, bypass_filter: bool = False,
@ -1931,8 +1941,7 @@ async def generate_chat_completions(
if BYPASS_MODEL_ACCESS_CONTROL: if BYPASS_MODEL_ACCESS_CONTROL:
bypass_filter = True bypass_filter = True
model_list = app.state.MODELS models = app.state.MODELS
models = {model["id"]: model for model in model_list}
model_id = form_data["model"] model_id = form_data["model"]
if model_id not in models: if model_id not in models:
@ -1981,7 +1990,7 @@ async def generate_chat_completions(
if model_ids and filter_mode == "exclude": if model_ids and filter_mode == "exclude":
model_ids = [ model_ids = [
model["id"] model["id"]
for model in await get_all_models() for model in await get_all_models(request)
if model.get("owned_by") != "arena" and model["id"] not in model_ids if model.get("owned_by") != "arena" and model["id"] not in model_ids
] ]
@ -1991,7 +2000,7 @@ async def generate_chat_completions(
else: else:
model_ids = [ model_ids = [
model["id"] model["id"]
for model in await get_all_models() for model in await get_all_models(request)
if model.get("owned_by") != "arena" if model.get("owned_by") != "arena"
] ]
selected_model_id = random.choice(model_ids) selected_model_id = random.choice(model_ids)
@ -2028,6 +2037,7 @@ async def generate_chat_completions(
# Using /ollama/api/chat endpoint # Using /ollama/api/chat endpoint
form_data = convert_payload_openai_to_ollama(form_data) form_data = convert_payload_openai_to_ollama(form_data)
response = await generate_ollama_chat_completion( response = await generate_ollama_chat_completion(
request=request,
form_data=form_data, user=user, bypass_filter=bypass_filter form_data=form_data, user=user, bypass_filter=bypass_filter
) )
if form_data.stream: if form_data.stream:
@ -2040,6 +2050,8 @@ async def generate_chat_completions(
return convert_response_ollama_to_openai(response) return convert_response_ollama_to_openai(response)
else: else:
return await generate_openai_chat_completion( return await generate_openai_chat_completion(
request=request,
form_data, user=user, bypass_filter=bypass_filter form_data, user=user, bypass_filter=bypass_filter
) )
@ -2048,8 +2060,8 @@ async def generate_chat_completions(
async def chat_completed( async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
model_list = await get_all_models(request) await get_all_models(request)
models = {model["id"]: model for model in model_list} models = app.state.MODELS
data = form_data data = form_data
model_id = data["model"] model_id = data["model"]
@ -2183,7 +2195,9 @@ async def chat_completed(
@app.post("/api/chat/actions/{action_id}") @app.post("/api/chat/actions/{action_id}")
async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)): async def chat_action(
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
):
if "." in action_id: if "." in action_id:
action_id, sub_action_id = action_id.split(".") action_id, sub_action_id = action_id.split(".")
else: else:
@ -2196,8 +2210,8 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
detail="Action not found", detail="Action not found",
) )
model_list = await get_all_models() await get_all_models(request)
models = {model["id"]: model for model in model_list} models = app.state.MODELS
data = form_data data = form_data
model_id = data["model"] model_id = data["model"]

View File

@ -344,7 +344,7 @@ async def get_ollama_tags(
models = [] models = []
if url_idx is None: if url_idx is None:
models = await get_all_models() models = await get_all_models(request)
else: else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx] url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS) key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
@ -565,7 +565,7 @@ async def copy_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
await get_all_models() await get_all_models(request)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.source in models: if form_data.source in models:
@ -620,7 +620,7 @@ async def delete_model(
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx is None: if url_idx is None:
await get_all_models() await get_all_models(request)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name in models: if form_data.name in models:
@ -670,7 +670,7 @@ async def delete_model(
async def show_model_info( async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user) request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
): ):
await get_all_models() await get_all_models(request)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
if form_data.name not in models: if form_data.name not in models:
@ -734,7 +734,7 @@ async def embed(
log.info(f"generate_ollama_batch_embeddings {form_data}") log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None: if url_idx is None:
await get_all_models() await get_all_models(request)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
model = form_data.model model = form_data.model
@ -803,7 +803,7 @@ async def embeddings(
log.info(f"generate_ollama_embeddings {form_data}") log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None: if url_idx is None:
await get_all_models() await get_all_models(request)
models = request.app.state.OLLAMA_MODELS models = request.app.state.OLLAMA_MODELS
model = form_data.model model = form_data.model
@ -878,8 +878,8 @@ async def generate_completion(
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx is None: if url_idx is None:
model_list = await get_all_models() await get_all_models(request)
models = {model["model"]: model for model in model_list["models"]} models = request.app.state.OLLAMA_MODELS
model = form_data.model model = form_data.model
@ -1200,7 +1200,7 @@ async def get_openai_models(
models = [] models = []
if url_idx is None: if url_idx is None:
model_list = await get_all_models() model_list = await get_all_models(request)
models = [ models = [
{ {
"id": model["model"], "id": model["model"],

View File

@ -404,7 +404,7 @@ async def get_models(
} }
if url_idx is None: if url_idx is None:
models = await get_all_models() models = await get_all_models(request)
else: else:
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx] url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx] key = request.app.state.config.OPENAI_API_KEYS[url_idx]