mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:main' into feature-external-db-reconnect
This commit is contained in:
@@ -37,6 +37,10 @@ from config import (
|
||||
ENABLE_IMAGE_GENERATION,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
COMFYUI_BASE_URL,
|
||||
COMFYUI_CFG_SCALE,
|
||||
COMFYUI_SAMPLER,
|
||||
COMFYUI_SCHEDULER,
|
||||
COMFYUI_SD3,
|
||||
IMAGES_OPENAI_API_BASE_URL,
|
||||
IMAGES_OPENAI_API_KEY,
|
||||
IMAGE_GENERATION_MODEL,
|
||||
@@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
|
||||
app.state.config.IMAGE_SIZE = IMAGE_SIZE
|
||||
app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
||||
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
|
||||
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
|
||||
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
|
||||
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
@@ -457,6 +465,18 @@ def generate_image(
|
||||
if form_data.negative_prompt is not None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
|
||||
if app.state.config.COMFYUI_CFG_SCALE:
|
||||
data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
|
||||
|
||||
if app.state.config.COMFYUI_SAMPLER is not None:
|
||||
data["sampler"] = app.state.config.COMFYUI_SAMPLER
|
||||
|
||||
if app.state.config.COMFYUI_SCHEDULER is not None:
|
||||
data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
|
||||
|
||||
if app.state.config.COMFYUI_SD3 is not None:
|
||||
data["sd3"] = app.state.config.COMFYUI_SD3
|
||||
|
||||
data = ImageGenerationPayload(**data)
|
||||
|
||||
res = comfyui_generate_image(
|
||||
|
||||
@@ -190,6 +190,10 @@ class ImageGenerationPayload(BaseModel):
|
||||
width: int
|
||||
height: int
|
||||
n: int = 1
|
||||
cfg_scale: Optional[float] = None
|
||||
sampler: Optional[str] = None
|
||||
scheduler: Optional[str] = None
|
||||
sd3: Optional[bool] = None
|
||||
|
||||
|
||||
def comfyui_generate_image(
|
||||
@@ -199,6 +203,18 @@ def comfyui_generate_image(
|
||||
|
||||
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
|
||||
|
||||
if payload.cfg_scale:
|
||||
comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
|
||||
|
||||
if payload.sampler:
|
||||
comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
|
||||
|
||||
if payload.scheduler:
|
||||
comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
|
||||
|
||||
if payload.sd3:
|
||||
comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
|
||||
|
||||
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
|
||||
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
|
||||
comfyui_prompt["5"]["inputs"]["width"] = payload.width
|
||||
|
||||
@@ -46,6 +46,7 @@ from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
OLLAMA_BASE_URLS,
|
||||
ENABLE_OLLAMA_API,
|
||||
AIOHTTP_CLIENT_TIMEOUT,
|
||||
ENABLE_MODEL_FILTER,
|
||||
MODEL_FILTER_LIST,
|
||||
UPLOAD_DIR,
|
||||
@@ -154,7 +155,9 @@ async def cleanup_response(
|
||||
async def post_streaming_url(url: str, payload: str):
|
||||
r = None
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
session = aiohttp.ClientSession(
|
||||
trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
|
||||
)
|
||||
r = await session.post(url, data=payload)
|
||||
r.raise_for_status()
|
||||
|
||||
@@ -751,6 +754,14 @@ async def generate_chat_completion(
|
||||
if model_info.params.get("num_ctx", None):
|
||||
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||
|
||||
if model_info.params.get("num_batch", None):
|
||||
payload["options"]["num_batch"] = model_info.params.get(
|
||||
"num_batch", None
|
||||
)
|
||||
|
||||
if model_info.params.get("num_keep", None):
|
||||
payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
|
||||
|
||||
if model_info.params.get("repeat_last_n", None):
|
||||
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||
"repeat_last_n", None
|
||||
@@ -839,8 +850,7 @@ async def generate_chat_completion(
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
print(payload)
|
||||
log.debug(payload)
|
||||
|
||||
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
|
||||
|
||||
|
||||
@@ -430,13 +430,11 @@ async def generate_chat_completion(
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
print(payload)
|
||||
log.debug(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
print(payload)
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
@@ -73,6 +73,7 @@ from apps.rag.search.serper import search_serper
|
||||
from apps.rag.search.serpstack import search_serpstack
|
||||
from apps.rag.search.serply import search_serply
|
||||
from apps.rag.search.duckduckgo import search_duckduckgo
|
||||
from apps.rag.search.tavily import search_tavily
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
@@ -119,6 +120,7 @@ from config import (
|
||||
SERPSTACK_HTTPS,
|
||||
SERPER_API_KEY,
|
||||
SERPLY_API_KEY,
|
||||
TAVILY_API_KEY,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
@@ -172,6 +174,7 @@ app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
||||
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
||||
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
||||
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
||||
app.state.config.TAVILY_API_KEY = TAVILY_API_KEY
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
|
||||
@@ -400,6 +403,7 @@ async def get_rag_config(user=Depends(get_admin_user)):
|
||||
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": app.state.config.SERPER_API_KEY,
|
||||
"serply_api_key": app.state.config.SERPLY_API_KEY,
|
||||
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
||||
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
},
|
||||
@@ -428,6 +432,7 @@ class WebSearchConfig(BaseModel):
|
||||
serpstack_https: Optional[bool] = None
|
||||
serper_api_key: Optional[str] = None
|
||||
serply_api_key: Optional[str] = None
|
||||
tavily_api_key: Optional[str] = None
|
||||
result_count: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
|
||||
@@ -479,6 +484,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
||||
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
|
||||
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
|
||||
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
|
||||
app.state.config.TAVILY_API_KEY = form_data.web.search.tavily_api_key
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||
form_data.web.search.concurrent_requests
|
||||
@@ -508,6 +514,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
||||
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": app.state.config.SERPER_API_KEY,
|
||||
"serply_api_key": app.state.config.SERPLY_API_KEY,
|
||||
"tavily_api_key": app.state.config.TAVILY_API_KEY,
|
||||
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
},
|
||||
@@ -756,7 +763,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
- SERPSTACK_API_KEY
|
||||
- SERPER_API_KEY
|
||||
- SERPLY_API_KEY
|
||||
|
||||
- TAVILY_API_KEY
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
"""
|
||||
@@ -825,6 +832,15 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
raise Exception("No SERPLY_API_KEY found in environment variables")
|
||||
elif engine == "duckduckgo":
|
||||
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
|
||||
elif engine == "tavily":
|
||||
if app.state.config.TAVILY_API_KEY:
|
||||
return search_tavily(
|
||||
app.state.config.TAVILY_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
)
|
||||
else:
|
||||
raise Exception("No TAVILY_API_KEY found in environment variables")
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
|
||||
39
backend/apps/rag/search/tavily.py
Normal file
39
backend/apps/rag/search/tavily.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_tavily(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
"""Search using Tavily's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A Tavily Search API key
|
||||
query (str): The query to search for
|
||||
|
||||
Returns:
|
||||
List[SearchResult]: A list of search results
|
||||
"""
|
||||
url = "https://api.tavily.com/search"
|
||||
data = {"query": query, "api_key": api_key}
|
||||
|
||||
response = requests.post(url, json=data)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
|
||||
raw_search_results = json_response.get("results", [])
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"],
|
||||
title=result.get("title", ""),
|
||||
snippet=result.get("content"),
|
||||
)
|
||||
for result in raw_search_results[:count]
|
||||
]
|
||||
48
backend/apps/webui/internal/migrations/013_add_user_info.py
Normal file
48
backend/apps/webui/internal/migrations/013_add_user_info.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields info to the 'user' table
|
||||
migrator.add_fields("user", info=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "info")
|
||||
@@ -25,6 +25,7 @@ from config import (
|
||||
USER_PERMISSIONS,
|
||||
WEBHOOK_URL,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
JWT_EXPIRES_IN,
|
||||
WEBUI_BANNERS,
|
||||
ENABLE_COMMUNITY_SHARING,
|
||||
@@ -40,6 +41,7 @@ app.state.config = AppConfig()
|
||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||
|
||||
|
||||
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
|
||||
|
||||
@@ -65,6 +65,20 @@ class MemoriesTable:
|
||||
else:
|
||||
return None
|
||||
|
||||
def update_memory_by_id(
|
||||
self,
|
||||
id: str,
|
||||
content: str,
|
||||
) -> Optional[MemoryModel]:
|
||||
try:
|
||||
memory = Memory.get(Memory.id == id)
|
||||
memory.content = content
|
||||
memory.updated_at = int(time.time())
|
||||
memory.save()
|
||||
return MemoryModel(**model_to_dict(memory))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_memories(self) -> List[MemoryModel]:
|
||||
try:
|
||||
memories = Memory.select()
|
||||
|
||||
@@ -26,6 +26,7 @@ class User(Model):
|
||||
|
||||
api_key = CharField(null=True, unique=True)
|
||||
settings = JSONField(null=True)
|
||||
info = JSONField(null=True)
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
@@ -50,6 +51,7 @@ class UserModel(BaseModel):
|
||||
|
||||
api_key: Optional[str] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
info: Optional[dict] = None
|
||||
|
||||
|
||||
####################
|
||||
|
||||
@@ -33,7 +33,11 @@ from utils.utils import (
|
||||
from utils.misc import parse_duration, validate_email_format
|
||||
from utils.webhook import post_webhook
|
||||
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from config import WEBUI_AUTH, WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
from config import (
|
||||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -110,11 +114,16 @@ async def signin(request: Request, form_data: SigninForm):
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
||||
|
||||
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
trusted_name = trusted_email
|
||||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||
trusted_name = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
|
||||
)
|
||||
if not Users.get_user_by_email(trusted_email.lower()):
|
||||
await signup(
|
||||
request,
|
||||
SignupForm(
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
|
||||
),
|
||||
)
|
||||
user = Auths.authenticate_user_by_trusted_header(trusted_email)
|
||||
|
||||
@@ -44,6 +44,10 @@ class AddMemoryForm(BaseModel):
|
||||
content: str
|
||||
|
||||
|
||||
class MemoryUpdateModel(BaseModel):
|
||||
content: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/add", response_model=Optional[MemoryModel])
|
||||
async def add_memory(
|
||||
request: Request, form_data: AddMemoryForm, user=Depends(get_verified_user)
|
||||
@@ -62,6 +66,34 @@ async def add_memory(
|
||||
return memory
|
||||
|
||||
|
||||
@router.post("/{memory_id}/update", response_model=Optional[MemoryModel])
|
||||
async def update_memory_by_id(
|
||||
memory_id: str,
|
||||
request: Request,
|
||||
form_data: MemoryUpdateModel,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
memory = Memories.update_memory_by_id(memory_id, form_data.content)
|
||||
if memory is None:
|
||||
raise HTTPException(status_code=404, detail="Memory not found")
|
||||
|
||||
if form_data.content is not None:
|
||||
memory_embedding = request.app.state.EMBEDDING_FUNCTION(form_data.content)
|
||||
collection = CHROMA_CLIENT.get_or_create_collection(
|
||||
name=f"user-memory-{user.id}"
|
||||
)
|
||||
collection.upsert(
|
||||
documents=[form_data.content],
|
||||
ids=[memory.id],
|
||||
embeddings=[memory_embedding],
|
||||
metadatas=[
|
||||
{"created_at": memory.created_at, "updated_at": memory.updated_at}
|
||||
],
|
||||
)
|
||||
|
||||
return memory
|
||||
|
||||
|
||||
############################
|
||||
# QueryMemory
|
||||
############################
|
||||
|
||||
@@ -115,6 +115,52 @@ async def update_user_settings_by_session_user(
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserInfoBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/user/info", response_model=Optional[dict])
|
||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserInfoBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/user/info/update", response_model=Optional[dict])
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
if user.info is None:
|
||||
user.info = {}
|
||||
|
||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserById
|
||||
############################
|
||||
|
||||
Reference in New Issue
Block a user