mirror of
https://github.com/open-webui/open-webui
synced 2025-04-16 21:42:50 +00:00
fix
This commit is contained in:
parent
d9ffcea764
commit
866c3dff11
@ -70,6 +70,15 @@ from open_webui.routers import (
|
|||||||
users,
|
users,
|
||||||
utils,
|
utils,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from open_webui.routers.openai import (
|
||||||
|
generate_chat_completion as generate_openai_chat_completion,
|
||||||
|
)
|
||||||
|
|
||||||
|
from open_webui.routers.ollama import (
|
||||||
|
generate_chat_completion as generate_ollama_chat_completion,
|
||||||
|
)
|
||||||
|
|
||||||
from open_webui.routers.retrieval import (
|
from open_webui.routers.retrieval import (
|
||||||
get_embedding_function,
|
get_embedding_function,
|
||||||
get_ef,
|
get_ef,
|
||||||
@ -1019,8 +1028,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
log.debug(f"request.url.path: {request.url.path}")
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
|
|
||||||
model_list = await get_all_models()
|
await get_all_models(request)
|
||||||
models = {model["id"]: model for model in model_list}
|
models = app.state.MODELS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
body, model, user = await get_body_and_model_and_user(request, models)
|
body, model, user = await get_body_and_model_and_user(request, models)
|
||||||
@ -1257,7 +1266,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|||||||
content={"detail": e.detail},
|
content={"detail": e.detail},
|
||||||
)
|
)
|
||||||
|
|
||||||
await get_all_models()
|
await get_all_models(request)
|
||||||
models = app.state.MODELS
|
models = app.state.MODELS
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -1924,6 +1933,7 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
|
|||||||
|
|
||||||
@app.post("/api/chat/completions")
|
@app.post("/api/chat/completions")
|
||||||
async def generate_chat_completions(
|
async def generate_chat_completions(
|
||||||
|
request: Request,
|
||||||
form_data: dict,
|
form_data: dict,
|
||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
bypass_filter: bool = False,
|
bypass_filter: bool = False,
|
||||||
@ -1931,8 +1941,7 @@ async def generate_chat_completions(
|
|||||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||||
bypass_filter = True
|
bypass_filter = True
|
||||||
|
|
||||||
model_list = app.state.MODELS
|
models = app.state.MODELS
|
||||||
models = {model["id"]: model for model in model_list}
|
|
||||||
|
|
||||||
model_id = form_data["model"]
|
model_id = form_data["model"]
|
||||||
if model_id not in models:
|
if model_id not in models:
|
||||||
@ -1981,7 +1990,7 @@ async def generate_chat_completions(
|
|||||||
if model_ids and filter_mode == "exclude":
|
if model_ids and filter_mode == "exclude":
|
||||||
model_ids = [
|
model_ids = [
|
||||||
model["id"]
|
model["id"]
|
||||||
for model in await get_all_models()
|
for model in await get_all_models(request)
|
||||||
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1991,7 +2000,7 @@ async def generate_chat_completions(
|
|||||||
else:
|
else:
|
||||||
model_ids = [
|
model_ids = [
|
||||||
model["id"]
|
model["id"]
|
||||||
for model in await get_all_models()
|
for model in await get_all_models(request)
|
||||||
if model.get("owned_by") != "arena"
|
if model.get("owned_by") != "arena"
|
||||||
]
|
]
|
||||||
selected_model_id = random.choice(model_ids)
|
selected_model_id = random.choice(model_ids)
|
||||||
@ -2028,6 +2037,7 @@ async def generate_chat_completions(
|
|||||||
# Using /ollama/api/chat endpoint
|
# Using /ollama/api/chat endpoint
|
||||||
form_data = convert_payload_openai_to_ollama(form_data)
|
form_data = convert_payload_openai_to_ollama(form_data)
|
||||||
response = await generate_ollama_chat_completion(
|
response = await generate_ollama_chat_completion(
|
||||||
|
request=request,
|
||||||
form_data=form_data, user=user, bypass_filter=bypass_filter
|
form_data=form_data, user=user, bypass_filter=bypass_filter
|
||||||
)
|
)
|
||||||
if form_data.stream:
|
if form_data.stream:
|
||||||
@ -2040,6 +2050,8 @@ async def generate_chat_completions(
|
|||||||
return convert_response_ollama_to_openai(response)
|
return convert_response_ollama_to_openai(response)
|
||||||
else:
|
else:
|
||||||
return await generate_openai_chat_completion(
|
return await generate_openai_chat_completion(
|
||||||
|
request=request,
|
||||||
|
|
||||||
form_data, user=user, bypass_filter=bypass_filter
|
form_data, user=user, bypass_filter=bypass_filter
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -2048,8 +2060,8 @@ async def generate_chat_completions(
|
|||||||
async def chat_completed(
|
async def chat_completed(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
model_list = await get_all_models(request)
|
await get_all_models(request)
|
||||||
models = {model["id"]: model for model in model_list}
|
models = app.state.MODELS
|
||||||
|
|
||||||
data = form_data
|
data = form_data
|
||||||
model_id = data["model"]
|
model_id = data["model"]
|
||||||
@ -2183,7 +2195,9 @@ async def chat_completed(
|
|||||||
|
|
||||||
|
|
||||||
@app.post("/api/chat/actions/{action_id}")
|
@app.post("/api/chat/actions/{action_id}")
|
||||||
async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)):
|
async def chat_action(
|
||||||
|
request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
|
||||||
|
):
|
||||||
if "." in action_id:
|
if "." in action_id:
|
||||||
action_id, sub_action_id = action_id.split(".")
|
action_id, sub_action_id = action_id.split(".")
|
||||||
else:
|
else:
|
||||||
@ -2196,8 +2210,8 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
|
|||||||
detail="Action not found",
|
detail="Action not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
model_list = await get_all_models()
|
await get_all_models(request)
|
||||||
models = {model["id"]: model for model in model_list}
|
models = app.state.MODELS
|
||||||
|
|
||||||
data = form_data
|
data = form_data
|
||||||
model_id = data["model"]
|
model_id = data["model"]
|
||||||
|
@ -344,7 +344,7 @@ async def get_ollama_tags(
|
|||||||
models = []
|
models = []
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models()
|
models = await get_all_models(request)
|
||||||
else:
|
else:
|
||||||
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||||
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
|
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
|
||||||
@ -565,7 +565,7 @@ async def copy_model(
|
|||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models()
|
await get_all_models(request)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.source in models:
|
if form_data.source in models:
|
||||||
@ -620,7 +620,7 @@ async def delete_model(
|
|||||||
user=Depends(get_admin_user),
|
user=Depends(get_admin_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models()
|
await get_all_models(request)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name in models:
|
if form_data.name in models:
|
||||||
@ -670,7 +670,7 @@ async def delete_model(
|
|||||||
async def show_model_info(
|
async def show_model_info(
|
||||||
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
await get_all_models()
|
await get_all_models(request)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
if form_data.name not in models:
|
if form_data.name not in models:
|
||||||
@ -734,7 +734,7 @@ async def embed(
|
|||||||
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
log.info(f"generate_ollama_batch_embeddings {form_data}")
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models()
|
await get_all_models(request)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
@ -803,7 +803,7 @@ async def embeddings(
|
|||||||
log.info(f"generate_ollama_embeddings {form_data}")
|
log.info(f"generate_ollama_embeddings {form_data}")
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
await get_all_models()
|
await get_all_models(request)
|
||||||
models = request.app.state.OLLAMA_MODELS
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
@ -878,8 +878,8 @@ async def generate_completion(
|
|||||||
user=Depends(get_verified_user),
|
user=Depends(get_verified_user),
|
||||||
):
|
):
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
model_list = await get_all_models()
|
await get_all_models(request)
|
||||||
models = {model["model"]: model for model in model_list["models"]}
|
models = request.app.state.OLLAMA_MODELS
|
||||||
|
|
||||||
model = form_data.model
|
model = form_data.model
|
||||||
|
|
||||||
@ -1200,7 +1200,7 @@ async def get_openai_models(
|
|||||||
|
|
||||||
models = []
|
models = []
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
model_list = await get_all_models()
|
model_list = await get_all_models(request)
|
||||||
models = [
|
models = [
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
|
@ -404,7 +404,7 @@ async def get_models(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if url_idx is None:
|
if url_idx is None:
|
||||||
models = await get_all_models()
|
models = await get_all_models(request)
|
||||||
else:
|
else:
|
||||||
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
|
||||||
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
key = request.app.state.config.OPENAI_API_KEYS[url_idx]
|
||||||
|
Loading…
Reference in New Issue
Block a user