mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: apps/openai/main.py and utils
This commit is contained in:
@@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
|
||||
import uuid
|
||||
import time
|
||||
|
||||
from utils.task import prompt_template
|
||||
|
||||
|
||||
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
|
||||
for message in reversed(messages):
|
||||
@@ -111,6 +113,47 @@ def openai_chat_completion_message_template(model: str, message: str):
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
system = prompt_template(system, **template_params)
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
)
|
||||
return form_data
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
|
||||
if not params:
|
||||
return form_data
|
||||
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": int,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": int,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
|
||||
for key, cast_func in mappings.items():
|
||||
if (value := params.get(key)) is not None:
|
||||
form_data[key] = cast_func(value)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
# Trim leading and trailing whitespace from
|
||||
# an email address and force all characters
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import Optional
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: str = None, user_location: str = None
|
||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||
) -> str:
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
@@ -83,7 +83,6 @@ def title_generation_template(
|
||||
def search_query_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi import HTTPException, status, Depends, Request
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Union, Optional
|
||||
from constants import ERROR_MESSAGES
|
||||
from passlib.context import CryptContext
|
||||
from datetime import datetime, timedelta
|
||||
import requests
|
||||
import jwt
|
||||
import uuid
|
||||
import logging
|
||||
@@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
|
||||
return decoded
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
@@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str):
|
||||
try:
|
||||
scheme, credentials = auth_header.split(" ")
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
except:
|
||||
except Exception:
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
|
||||
|
||||
|
||||
@@ -96,7 +93,7 @@ def get_current_user(
|
||||
|
||||
# auth by jwt token
|
||||
data = decode_token(token)
|
||||
if data != None and "id" in data:
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
|
||||
Reference in New Issue
Block a user