mirror of
https://github.com/open-webui/open-webui
synced 2025-01-18 00:30:51 +00:00
Merge pull request #4295 from michaelpoluektov/refactor-tools
refactor: Refactor OpenAI API to use helper functions, silence LSP/linter warnings
This commit is contained in:
commit
91851114e4
@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI, Request, Response, HTTPException, Depends
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
|
||||
import requests
|
||||
import aiohttp
|
||||
@ -12,16 +12,12 @@ from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.users import Users
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.utils import (
|
||||
decode_token,
|
||||
get_verified_user,
|
||||
get_verified_user,
|
||||
get_admin_user,
|
||||
)
|
||||
from utils.task import prompt_template
|
||||
from utils.misc import add_or_update_system_message
|
||||
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
@ -34,7 +30,7 @@ from config import (
|
||||
MODEL_FILTER_LIST,
|
||||
AppConfig,
|
||||
)
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Literal, overload
|
||||
|
||||
|
||||
import hashlib
|
||||
@ -69,8 +65,6 @@ app.state.MODELS = {}
|
||||
async def check_url(request: Request, call_next):
|
||||
if len(app.state.MODELS) == 0:
|
||||
await get_all_models()
|
||||
else:
|
||||
pass
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
@ -175,7 +169,7 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
@ -234,64 +228,68 @@ def merge_models_lists(model_lists):
|
||||
return merged_list
|
||||
|
||||
|
||||
async def get_all_models(raw: bool = False):
|
||||
def is_openai_api_disabled():
|
||||
api_keys = app.state.config.OPENAI_API_KEYS
|
||||
no_keys = len(api_keys) == 1 and api_keys[0] == ""
|
||||
return no_keys or not app.state.config.ENABLE_OPENAI_API
|
||||
|
||||
|
||||
async def get_all_models_raw() -> list:
|
||||
if is_openai_api_disabled():
|
||||
return []
|
||||
|
||||
# Check if API KEYS length is same than API URLS length
|
||||
num_urls = len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
num_keys = len(app.state.config.OPENAI_API_KEYS)
|
||||
|
||||
if num_keys != num_urls:
|
||||
# if there are more keys than urls, remove the extra keys
|
||||
if num_keys > num_urls:
|
||||
new_keys = app.state.config.OPENAI_API_KEYS[:num_urls]
|
||||
app.state.config.OPENAI_API_KEYS = new_keys
|
||||
# if there are more urls than keys, add empty keys
|
||||
else:
|
||||
app.state.config.OPENAI_API_KEYS += [""] * (num_urls - num_keys)
|
||||
|
||||
tasks = [
|
||||
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
||||
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
|
||||
return responses
|
||||
|
||||
|
||||
@overload
|
||||
async def get_all_models(raw: Literal[True]) -> list: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
|
||||
|
||||
|
||||
async def get_all_models(raw=False) -> dict[str, list] | list:
|
||||
log.info("get_all_models()")
|
||||
if is_openai_api_disabled():
|
||||
return [] if raw else {"data": []}
|
||||
|
||||
if (
|
||||
len(app.state.config.OPENAI_API_KEYS) == 1
|
||||
and app.state.config.OPENAI_API_KEYS[0] == ""
|
||||
) or not app.state.config.ENABLE_OPENAI_API:
|
||||
models = {"data": []}
|
||||
else:
|
||||
# Check if API KEYS length is same than API URLS length
|
||||
if len(app.state.config.OPENAI_API_KEYS) != len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
# if there are more keys than urls, remove the extra keys
|
||||
if len(app.state.config.OPENAI_API_KEYS) > len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
|
||||
: len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
# if there are more urls than keys, add empty keys
|
||||
else:
|
||||
app.state.config.OPENAI_API_KEYS += [
|
||||
""
|
||||
for _ in range(
|
||||
len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
- len(app.state.config.OPENAI_API_KEYS)
|
||||
)
|
||||
]
|
||||
responses = await get_all_models_raw()
|
||||
if raw:
|
||||
return responses
|
||||
|
||||
tasks = [
|
||||
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
||||
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
return response["data"]
|
||||
if isinstance(response, list):
|
||||
return response
|
||||
return None
|
||||
|
||||
responses = await asyncio.gather(*tasks)
|
||||
log.debug(f"get_all_models:responses() {responses}")
|
||||
models = {"data": merge_models_lists(map(extract_data, responses))}
|
||||
|
||||
if raw:
|
||||
return responses
|
||||
|
||||
models = {
|
||||
"data": merge_models_lists(
|
||||
list(
|
||||
map(
|
||||
lambda response: (
|
||||
response["data"]
|
||||
if (response and "data" in response)
|
||||
else (response if isinstance(response, list) else None)
|
||||
),
|
||||
responses,
|
||||
)
|
||||
)
|
||||
)
|
||||
}
|
||||
|
||||
log.debug(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
log.debug(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
|
||||
return models
|
||||
|
||||
@ -299,7 +297,7 @@ async def get_all_models(raw: bool = False):
|
||||
@app.get("/models")
|
||||
@app.get("/models/{url_idx}")
|
||||
async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
|
||||
if url_idx == None:
|
||||
if url_idx is None:
|
||||
models = await get_all_models()
|
||||
if app.state.config.ENABLE_MODEL_FILTER:
|
||||
if user.role == "user":
|
||||
@ -340,7 +338,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
@ -358,8 +356,7 @@ async def generate_chat_completion(
|
||||
):
|
||||
idx = 0
|
||||
payload = {**form_data}
|
||||
if "metadata" in payload:
|
||||
del payload["metadata"]
|
||||
payload.pop("metadata")
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
@ -368,70 +365,9 @@ async def generate_chat_completion(
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
if (
|
||||
model_info.params.get("temperature", None) is not None
|
||||
and payload.get("temperature") is None
|
||||
):
|
||||
payload["temperature"] = float(model_info.params.get("temperature"))
|
||||
|
||||
if model_info.params.get("top_p", None) and payload.get("top_p") is None:
|
||||
payload["top_p"] = int(model_info.params.get("top_p", None))
|
||||
|
||||
if (
|
||||
model_info.params.get("max_tokens", None)
|
||||
and payload.get("max_tokens") is None
|
||||
):
|
||||
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
||||
|
||||
if (
|
||||
model_info.params.get("frequency_penalty", None)
|
||||
and payload.get("frequency_penalty") is None
|
||||
):
|
||||
payload["frequency_penalty"] = int(
|
||||
model_info.params.get("frequency_penalty", None)
|
||||
)
|
||||
|
||||
if (
|
||||
model_info.params.get("seed", None) is not None
|
||||
and payload.get("seed") is None
|
||||
):
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
if model_info.params.get("stop", None) and payload.get("stop") is None:
|
||||
payload["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
system = model_info.params.get("system", None)
|
||||
if system:
|
||||
system = prompt_template(
|
||||
system,
|
||||
**(
|
||||
{
|
||||
"user_name": user.name,
|
||||
"user_location": (
|
||||
user.info.get("location") if user.info else None
|
||||
),
|
||||
}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
if payload.get("messages"):
|
||||
payload["messages"] = add_or_update_system_message(
|
||||
system, payload["messages"]
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
params = model_info.params.model_dump()
|
||||
payload = apply_model_params_to_body(params, payload)
|
||||
payload = apply_model_system_prompt_to_body(params, payload, user)
|
||||
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
idx = model["urlIdx"]
|
||||
@ -444,13 +380,6 @@ async def generate_chat_completion(
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
if payload.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in payload:
|
||||
payload["max_tokens"] = 4000
|
||||
log.debug("Modified payload:", payload)
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
@ -506,7 +435,7 @@ async def generate_chat_completion(
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
@ -569,7 +498,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
except Exception:
|
||||
error_detail = f"External: {e}"
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
|
@ -44,23 +44,26 @@ async def user_join(sid, data):
|
||||
print("user-join", sid, data)
|
||||
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
if not auth or "token" not in auth:
|
||||
return
|
||||
|
||||
if auth and "token" in auth:
|
||||
data = decode_token(auth["token"])
|
||||
data = decode_token(auth["token"])
|
||||
if data is None or "id" not in data:
|
||||
return
|
||||
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if not user:
|
||||
return
|
||||
|
||||
if user:
|
||||
SESSION_POOL[sid] = user.id
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
SESSION_POOL[sid] = user.id
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
|
||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
||||
|
||||
|
||||
@sio.on("user-count")
|
||||
|
@ -22,9 +22,9 @@ from apps.webui.utils import load_function_module_by_id
|
||||
from utils.misc import (
|
||||
openai_chat_chunk_message_template,
|
||||
openai_chat_completion_message_template,
|
||||
add_or_update_system_message,
|
||||
apply_model_params_to_body,
|
||||
apply_model_system_prompt_to_body,
|
||||
)
|
||||
from utils.task import prompt_template
|
||||
|
||||
|
||||
from config import (
|
||||
@ -269,47 +269,6 @@ def get_function_params(function_module, form_data, user, extra_params={}):
|
||||
return params
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
|
||||
if not params:
|
||||
return form_data
|
||||
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": int,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": int,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
|
||||
for key, cast_func in mappings.items():
|
||||
if (value := params.get(key)) is not None:
|
||||
form_data[key] = cast_func(value)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
system = prompt_template(system, **template_params)
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
)
|
||||
return form_data
|
||||
|
||||
|
||||
async def generate_function_chat_completion(form_data, user):
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
@ -1,12 +1,8 @@
|
||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
from fastapi import Depends, HTTPException, status, Request
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
@ -14,7 +10,6 @@ from utils.utils import get_admin_user, get_verified_user
|
||||
from utils.tools import get_tools_specs
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from importlib import util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@ -69,7 +64,7 @@ async def create_new_toolkit(
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
toolkit = Tools.get_tool_by_id(form_data.id)
|
||||
if toolkit == None:
|
||||
if toolkit is None:
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
|
||||
try:
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
@ -98,7 +93,7 @@ async def create_new_toolkit(
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@ -170,7 +165,7 @@ async def update_toolkit_by_id(
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
|
||||
|
||||
@ -210,7 +205,7 @@ async def get_toolkit_valves_by_id(id: str, user=Depends(get_admin_user)):
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@ -233,7 +228,7 @@ async def get_toolkit_valves_spec_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "Valves"):
|
||||
@ -261,7 +256,7 @@ async def update_toolkit_valves_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "Valves"):
|
||||
@ -276,7 +271,7 @@ async def update_toolkit_valves_by_id(
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@ -306,7 +301,7 @@ async def get_toolkit_user_valves_by_id(id: str, user=Depends(get_verified_user)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
@ -324,7 +319,7 @@ async def get_toolkit_user_valves_spec_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
@ -348,7 +343,7 @@ async def update_toolkit_user_valves_by_id(
|
||||
if id in request.app.state.TOOLS:
|
||||
toolkit_module = request.app.state.TOOLS[id]
|
||||
else:
|
||||
toolkit_module, frontmatter = load_toolkit_module_by_id(id)
|
||||
toolkit_module, _ = load_toolkit_module_by_id(id)
|
||||
request.app.state.TOOLS[id] = toolkit_module
|
||||
|
||||
if hasattr(toolkit_module, "UserValves"):
|
||||
@ -365,7 +360,7 @@ async def update_toolkit_user_valves_by_id(
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
detail=ERROR_MESSAGES.DEFAULT(str(e)),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
@ -957,7 +957,7 @@ async def get_all_models():
|
||||
|
||||
custom_models = Models.get_all_models()
|
||||
for custom_model in custom_models:
|
||||
if custom_model.base_model_id == None:
|
||||
if custom_model.base_model_id is None:
|
||||
for model in models:
|
||||
if (
|
||||
custom_model.id == model["id"]
|
||||
@ -1656,13 +1656,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
||||
|
||||
@app.get("/api/pipelines/list")
|
||||
async def get_pipelines_list(user=Depends(get_admin_user)):
|
||||
responses = await get_openai_models(raw=True)
|
||||
responses = await get_openai_models(raw = True)
|
||||
|
||||
print(responses)
|
||||
urlIdxs = [
|
||||
idx
|
||||
for idx, response in enumerate(responses)
|
||||
if response != None and "pipelines" in response
|
||||
if response is not None and "pipelines" in response
|
||||
]
|
||||
|
||||
return {
|
||||
@ -1723,7 +1723,7 @@ async def upload_pipeline(
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
@ -1769,7 +1769,7 @@ async def add_pipeline(form_data: AddPipelineForm, user=Depends(get_admin_user))
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
@ -1811,7 +1811,7 @@ async def delete_pipeline(form_data: DeletePipelineForm, user=Depends(get_admin_
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
@ -1844,7 +1844,7 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
@ -1859,7 +1859,6 @@ async def get_pipeline_valves(
|
||||
pipeline_id: str,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
models = await get_all_models()
|
||||
r = None
|
||||
try:
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
@ -1898,8 +1897,6 @@ async def get_pipeline_valves_spec(
|
||||
pipeline_id: str,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
models = await get_all_models()
|
||||
|
||||
r = None
|
||||
try:
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
@ -1922,7 +1919,7 @@ async def get_pipeline_valves_spec(
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
@ -1938,8 +1935,6 @@ async def update_pipeline_valves(
|
||||
form_data: dict,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
models = await get_all_models()
|
||||
|
||||
r = None
|
||||
try:
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
@ -1967,7 +1962,7 @@ async def update_pipeline_valves(
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
@ -2068,7 +2063,7 @@ async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
||||
|
||||
|
||||
@app.get("/api/version")
|
||||
async def get_app_config():
|
||||
async def get_app_version():
|
||||
return {
|
||||
"version": VERSION,
|
||||
}
|
||||
@ -2091,7 +2086,7 @@ async def get_app_latest_release_version():
|
||||
latest_version = data["tag_name"]
|
||||
|
||||
return {"current": VERSION, "latest": latest_version[1:]}
|
||||
except aiohttp.ClientError as e:
|
||||
except aiohttp.ClientError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=ERROR_MESSAGES.RATE_LIMIT_EXCEEDED,
|
||||
|
@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
|
||||
import uuid
|
||||
import time
|
||||
|
||||
from utils.task import prompt_template
|
||||
|
||||
|
||||
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
|
||||
for message in reversed(messages):
|
||||
@ -112,6 +114,47 @@ def openai_chat_completion_message_template(model: str, message: str) -> dict:
|
||||
return template
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
system = prompt_template(system, **template_params)
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
)
|
||||
return form_data
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
|
||||
if not params:
|
||||
return form_data
|
||||
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": int,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": int,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
|
||||
for key, cast_func in mappings.items():
|
||||
if (value := params.get(key)) is not None:
|
||||
form_data[key] = cast_func(value)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
# Trim leading and trailing whitespace from
|
||||
# an email address and force all characters
|
||||
|
@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: str = None, user_location: str = None
|
||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||
) -> str:
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
@ -83,7 +83,6 @@ def title_generation_template(
|
||||
def search_query_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
|
@ -1,15 +1,12 @@
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi import HTTPException, status, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Optional
|
||||
from constants import ERROR_MESSAGES
|
||||
from passlib.context import CryptContext
|
||||
from datetime import datetime, timedelta
|
||||
import requests
|
||||
import jwt
|
||||
import uuid
|
||||
import logging
|
||||
@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
|
||||
return decoded
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str):
|
||||
try:
|
||||
scheme, credentials = auth_header.split(" ")
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
|
||||
|
||||
|
||||
@ -96,7 +93,7 @@ def get_current_user(
|
||||
|
||||
# auth by jwt token
|
||||
data = decode_token(token)
|
||||
if data != None and "id" in data:
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
|
Loading…
Reference in New Issue
Block a user