refac: apps/openai/main.py and utils

This commit is contained in:
Michael Poluektov
2024-08-03 14:24:26 +01:00
parent 774defd184
commit 12c21fac22
8 changed files with 148 additions and 230 deletions

View File

@@ -6,6 +6,8 @@ from typing import Optional, List, Tuple
import uuid
import time
from utils.task import prompt_template
def get_last_user_message_item(messages: List[dict]) -> Optional[dict]:
for message in reversed(messages):
@@ -111,6 +113,47 @@ def openai_chat_completion_message_template(model: str, message: str):
template["choices"][0]["finish_reason"] = "stop"
# inplace function: form_data is modified
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
system = params.get("system", None)
if not system:
return form_data
if user:
template_params = {
"user_name": user.name,
"user_location": user.info.get("location") if user.info else None,
}
else:
template_params = {}
system = prompt_template(system, **template_params)
form_data["messages"] = add_or_update_system_message(
system, form_data.get("messages", [])
)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
if not params:
return form_data
mappings = {
"temperature": float,
"top_p": int,
"max_tokens": int,
"frequency_penalty": int,
"seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
}
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
def get_gravatar_url(email):
# Trim leading and trailing whitespace from
# an email address and force all characters

View File

@@ -6,7 +6,7 @@ from typing import Optional
def prompt_template(
template: str, user_name: str = None, user_location: str = None
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
) -> str:
# Get the current date
current_date = datetime.now()
@@ -83,7 +83,6 @@ def title_generation_template(
def search_query_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)

View File

@@ -1,15 +1,12 @@
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from fastapi import HTTPException, status, Depends, Request
from sqlalchemy.orm import Session
from apps.webui.models.users import Users
from pydantic import BaseModel
from typing import Union, Optional
from constants import ERROR_MESSAGES
from passlib.context import CryptContext
from datetime import datetime, timedelta
import requests
import jwt
import uuid
import logging
@@ -54,7 +51,7 @@ def decode_token(token: str) -> Optional[dict]:
try:
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
return decoded
except Exception as e:
except Exception:
return None
@@ -71,7 +68,7 @@ def get_http_authorization_cred(auth_header: str):
try:
scheme, credentials = auth_header.split(" ")
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
except:
except Exception:
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
@@ -96,7 +93,7 @@ def get_current_user(
# auth by jwt token
data = decode_token(token)
if data != None and "id" in data:
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user is None:
raise HTTPException(