Merge pull request #3621 from open-webui/dev

dev
This commit is contained in:
Timothy Jaeryang Baek
2024-07-03 20:59:14 -07:00
committed by GitHub
7 changed files with 988 additions and 965 deletions

View File

@@ -39,6 +39,8 @@ from config import (
WEBUI_BANNERS,
ENABLE_COMMUNITY_SHARING,
AppConfig,
OAUTH_USERNAME_CLAIM,
OAUTH_PICTURE_CLAIM,
)
import inspect
@@ -74,6 +76,9 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.config.OAUTH_USERNAME_CLAIM = OAUTH_USERNAME_CLAIM
app.state.config.OAUTH_PICTURE_CLAIM = OAUTH_PICTURE_CLAIM
app.state.MODELS = {}
app.state.TOOLS = {}
app.state.FUNCTIONS = {}

View File

@@ -393,6 +393,18 @@ OAUTH_PROVIDER_NAME = PersistentConfig(
os.environ.get("OAUTH_PROVIDER_NAME", "SSO"),
)
OAUTH_USERNAME_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.username_claim",
os.environ.get("OAUTH_USERNAME_CLAIM", "name"),
)
OAUTH_PICTURE_CLAIM = PersistentConfig(
"OAUTH_USERNAME_CLAIM",
"oauth.oidc.avatar_claim",
os.environ.get("OAUTH_PICTURE_CLAIM", "picture"),
)
def load_oauth_providers():
OAUTH_PROVIDERS.clear()

View File

@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)
class TASKS(str, Enum):
def __str__(self) -> str:
return super().__str__()
DEFAULT = lambda task="": f"{task if task else 'default'}"
TITLE_GENERATION = "Title Generation"
EMOJI_GENERATION = "Emoji Generation"
QUERY_GENERATION = "Query Generation"
FUNCTION_CALLING = "Function Calling"

View File

@@ -131,7 +131,7 @@ from config import (
BACKEND_DIR,
DATABASE_URL,
)
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
from utils.webhook import post_webhook
if SAFE_MODE:
@@ -327,6 +327,7 @@ async def get_function_call_response(
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
"task": TASKS.FUNCTION_CALLING,
}
try:
@@ -339,7 +340,6 @@ async def get_function_call_response(
response = None
try:
response = await generate_chat_completions(form_data=payload, user=user)
content = None
if hasattr(response, "body_iterator"):
@@ -849,9 +849,6 @@ def filter_pipeline(payload, user):
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "title" in payload:
del payload["title"]
if "task" in payload:
del payload["task"]
@@ -1362,7 +1359,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
"task": TASKS.TITLE_GENERATION,
}
log.debug(payload)
@@ -1425,7 +1422,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
"task": True,
"task": TASKS.QUERY_GENERATION,
}
print(payload)
@@ -1492,7 +1489,7 @@ Message: """{{prompt}}"""
"stream": False,
"max_tokens": 4,
"chat_id": form_data.get("chat_id", None),
"task": True,
"task": TASKS.EMOJI_GENERATION,
}
log.debug(payload)
@@ -2095,7 +2092,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
if existing_user:
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
picture_url = user_data.get("picture", "")
picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
picture_url = user_data.get(picture_claim, "")
if picture_url:
# Download the profile image into a base64 string
try:
@@ -2115,6 +2113,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
picture_url = ""
if not picture_url:
picture_url = "/user.png"
username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
role = (
"admin"
if Users.get_num_users() == 0
@@ -2125,7 +2124,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
password=get_password_hash(
str(uuid.uuid4())
), # Random password, not used
name=user_data.get("name", "User"),
name=user_data.get(username_claim, "User"),
profile_image_url=picture_url,
role=role,
oauth_sub=provider_sub,