Merge remote-tracking branch 'upstream/main' into feature-external-db-reconnect

This commit is contained in:
perf3ct
2024-06-16 09:03:57 -07:00
206 changed files with 21988 additions and 7009 deletions

View File

@@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel
import uuid
import requests
import hashlib
from pathlib import Path
import json
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
@@ -41,10 +40,15 @@ from config import (
WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE,
DEVICE_TYPE,
AUDIO_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE,
AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY,
AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_VOICE,
AppConfig,
)
@@ -61,10 +65,17 @@ app.add_middleware(
)
app.state.config = AppConfig()
app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
@@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str
model: str
speaker: str
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
@app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)):
async def get_audio_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
}
@app.post("/config/update")
async def update_openai_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
async def update_audio_config(
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.config.OPENAI_API_KEY = form_data.key
app.state.config.OPENAI_API_MODEL = form_data.model
app.state.config.OPENAI_API_VOICE = form_data.speaker
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
}
@@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
body = body.decode("utf-8")
body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8")
except Exception as e:
pass
r = None
try:
r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=body,
headers=headers,
stream=True,
@@ -181,41 +260,110 @@ def transcribe(
)
try:
filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}"
ext = file.filename.split(".")[-1]
id = uuid.uuid4()
filename = f"{id}.{ext}"
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
print(filename)
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
if app.state.config.STT_ENGINE == "":
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs)
except:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
try:
model = WhisperModel(**whisper_kwargs)
except:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
return {"text": transcript.strip()}
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
print("is_mp4_audio")
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": "whisper-1"}
print(files, data)
r = None
try:
r = requests.post(
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r != None else 500,
detail=error_detail,
)
except Exception as e:
log.exception(e)

View File

@@ -29,6 +29,8 @@ import time
from urllib.parse import urlparse
from typing import Optional, List, Union
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
@@ -39,8 +41,6 @@ from utils.utils import (
get_admin_user,
)
from utils.models import get_model_id_from_custom_model_id
from config import (
SRC_LOG_LEVELS,
@@ -75,9 +75,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
REQUEST_POOL = []
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization.
@@ -132,20 +129,10 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url) as response:
return await response.json()
except Exception as e:
@@ -154,6 +141,45 @@ async def fetch_url(url):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def post_streaming_url(url: str, payload: str):
r = None
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.post(url, data=payload)
r.raise_for_status()
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(cleanup_response, response=r, session=session),
)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=error_detail,
)
def merge_models_lists(model_lists):
merged_models = {}
@@ -246,54 +272,57 @@ async def get_ollama_tags(
@app.get("/api/version")
@app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API:
if url_idx == None:
if url_idx == None:
# returns lowest version
tasks = [
fetch_url(f"{url}/api/version")
for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
# returns lowest version
tasks = [
fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
if len(responses) > 0:
lowest_version = min(
responses,
key=lambda x: tuple(
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
),
)
if len(responses) > 0:
lowest_version = min(
responses,
key=lambda x: tuple(
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
),
)
return {"version": lowest_version["version"]}
return {"version": lowest_version["version"]}
else:
raise HTTPException(
status_code=500,
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
else:
raise HTTPException(
status_code=500,
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return {"version": False}
class ModelNameForm(BaseModel):
@@ -313,65 +342,7 @@ async def pull_model(
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request():
nonlocal url
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
class PushModelForm(BaseModel):
@@ -399,50 +370,9 @@ async def push_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}")
r = None
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/push",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
)
class CreateModelForm(BaseModel):
@@ -461,53 +391,9 @@ async def create_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/create",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
log.debug(f"r: {r}")
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
)
class CopyModelForm(BaseModel):
@@ -797,66 +683,9 @@ async def generate_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal form_data
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if form_data.stream:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/generate",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
)
class ChatMessage(BaseModel):
@@ -897,7 +726,6 @@ async def generate_chat_completion(
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@@ -906,44 +734,77 @@ async def generate_chat_completion(
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if model_info.params.get("mirostat", None):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if model_info.params.get("mirostat_eta", None):
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("mirostat_tau", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("num_ctx", None):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("repeat_last_n", None):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("frequency_penalty", None):
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if model_info.params.get("temperature", None) is not None:
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
if model_info.params.get("seed", None):
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("tfs_z", None):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("max_tokens", None):
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
if model_info.params.get("top_k", None):
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("top_p", None):
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("use_mmap", None):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if model_info.params.get("use_mlock", None):
payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None
)
if model_info.params.get("num_thread", None):
payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
@@ -981,73 +842,18 @@ async def generate_chat_completion(
print(payload)
r = None
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
log.exception(e)
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
# TODO: we should update this part once Ollama supports other types
class OpenAIChatMessageContent(BaseModel):
type: str
model_config = ConfigDict(extra="allow")
class OpenAIChatMessage(BaseModel):
role: str
content: str
content: Union[str, OpenAIChatMessageContent]
model_config = ConfigDict(extra="allow")
@@ -1075,7 +881,6 @@ async def generate_openai_chat_completion(
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@@ -1132,68 +937,7 @@ async def generate_openai_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
@app.get("/v1/models")
@@ -1305,7 +1049,7 @@ async def download_file_stream(
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(file_url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
@@ -1522,7 +1266,7 @@ async def deprecated_proxy(
if path == "generate":
data = json.loads(body.decode("utf-8"))
if not ("stream" in data and data["stream"] == False):
if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat":

View File

@@ -9,6 +9,7 @@ import json
import logging
from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
@@ -185,7 +186,7 @@ async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=5)
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
@@ -228,6 +239,27 @@ async def get_all_models(raw: bool = False):
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [
""
for _ in range(
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
]
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
@@ -313,113 +345,155 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
)
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
idx = 0
payload = {**form_data}
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None) is not None:
payload["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None):
payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON
payload = json.dumps(payload)
print(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
print(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
session = None
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None
try:
if "chat/completions" in path:
body = body.decode("utf-8")
body = json.loads(body)
payload = {**body}
model_id = body.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None):
payload["temperature"] = int(
model_info.params.get("temperature")
)
if model_info.params.get("top_p", None):
payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
payload["max_tokens"] = int(
model_info.params.get("max_tokens", None)
)
if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None)
+ message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
payload["title"] = (
True
if payload["stream"] == False and payload["max_tokens"] == 50
else False
)
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON
payload = json.dumps(payload)
except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e)
print(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
@@ -431,40 +505,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Content-Type"] = "application/json"
r = None
session = None
streaming = False
try:
r = requests.request(
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=target_url,
data=payload if payload else body,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = r.json()
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()

View File

@@ -8,12 +8,15 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
import requests
import os, shutil, logging, re
from datetime import datetime
from pathlib import Path
from typing import List, Union, Sequence
from typing import List, Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document
from langchain_community.document_loaders import (
WebBaseLoader,
@@ -30,6 +33,7 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader,
UnstructuredPowerPointLoader,
YoutubeLoader,
OutlookMessageLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -59,9 +63,17 @@ from apps.rag.utils import (
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
search_web,
)
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
@@ -71,6 +83,7 @@ from utils.misc import (
from utils.utils import get_current_user, get_admin_user
from config import (
AppConfig,
ENV,
SRC_LOG_LEVELS,
UPLOAD_DIR,
@@ -96,8 +109,19 @@ from config import (
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
BRAVE_SEARCH_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
SERPER_API_KEY,
SERPLY_API_KEY,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
AppConfig,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
from constants import ERROR_MESSAGES
@@ -122,6 +146,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@@ -136,6 +161,21 @@ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
def update_embedding_model(
embedding_model: str,
update_model: bool = False,
@@ -181,6 +221,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
origins = ["*"]
@@ -217,6 +258,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
}
@@ -229,6 +271,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@@ -244,6 +287,7 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
@@ -264,9 +308,14 @@ async def update_embedding_config(
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
form_data.openai_config.batch_size
if form_data.openai_config.batch_size
else 1
)
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -276,6 +325,7 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
return {
@@ -285,6 +335,7 @@ async def update_embedding_config(
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@@ -332,11 +383,27 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
},
}
@@ -350,11 +417,31 @@ class YoutubeLoaderConfig(BaseModel):
translation: Optional[str] = None
class WebSearchConfig(BaseModel):
enabled: bool
engine: Optional[str] = None
searxng_query_url: Optional[str] = None
google_pse_api_key: Optional[str] = None
google_pse_engine_id: Optional[str] = None
brave_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
serply_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
class WebConfig(BaseModel):
search: WebSearchConfig
web_loader_ssl_verification: Optional[bool] = None
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
@app.post("/config/update")
@@ -365,35 +452,37 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
)
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk is not None
else app.state.config.CHUNK_OVERLAP
)
if form_data.youtube is not None:
app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
if form_data.web is not None:
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web.web_loader_ssl_verification
)
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube is not None
else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
app.state.config.GOOGLE_PSE_ENGINE_ID = (
form_data.web.search.google_pse_engine_id
)
app.state.config.BRAVE_SEARCH_API_KEY = (
form_data.web.search.brave_search_api_key
)
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
return {
"status": True,
@@ -402,11 +491,27 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
},
}
@@ -599,7 +704,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(
return SafeWebBaseLoader(
url,
verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@@ -642,17 +747,107 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
- SERPLY_API_KEY
Args:
query (str): The query to search for
"""
# TODO: add playwright to search the web
if engine == "searxng":
if app.state.config.SEARXNG_QUERY_URL:
return search_searxng(
app.state.config.SEARXNG_QUERY_URL,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
elif engine == "google_pse":
if (
app.state.config.GOOGLE_PSE_API_KEY
and app.state.config.GOOGLE_PSE_ENGINE_ID
):
return search_google_pse(
app.state.config.GOOGLE_PSE_API_KEY,
app.state.config.GOOGLE_PSE_ENGINE_ID,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception(
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
)
elif engine == "brave":
if app.state.config.BRAVE_SEARCH_API_KEY:
return search_brave(
app.state.config.BRAVE_SEARCH_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if app.state.config.SERPSTACK_API_KEY:
return search_serpstack(
app.state.config.SERPSTACK_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
https_enabled=app.state.config.SERPSTACK_HTTPS,
)
else:
raise Exception("No SERPSTACK_API_KEY found in environment variables")
elif engine == "serper":
if app.state.config.SERPER_API_KEY:
return search_serper(
app.state.config.SERPER_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
elif engine == "serply":
if app.state.config.SERPLY_API_KEY:
return search_serply(
app.state.config.SERPLY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else:
raise Exception("No search engine API key found in environment variables")
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try:
try:
web_results = search_web(form_data.query)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
)
logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
)
web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
)
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
)
try:
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
data = loader.load()
@@ -710,6 +905,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if isinstance(value, datetime):
metadata[key] = str(value)
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
@@ -725,6 +927,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@@ -795,6 +998,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"swift",
"vue",
"svelte",
"msg",
]
if file_ext == "pdf":
@@ -829,6 +1033,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
@@ -994,6 +1200,30 @@ def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
try:
# Check if the directory exists
if os.path.exists(folder):
# Iterate over all the files and directories in the specified directory
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"The directory {folder} does not exist")
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
return True
@app.get("/reset")
def reset(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
@@ -1015,6 +1245,33 @@ def reset(user=Depends(get_admin_user)) -> bool:
return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
try:
soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
if ENV == "dev":
@app.get("/ef")

View File

@@ -3,13 +3,13 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str) -> list[SearchResult]:
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
@@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
"Accept-Encoding": "gzip",
"X-Subscription-Token": api_key,
}
params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
@@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in results[:count]
]

View File

@@ -0,0 +1,46 @@
import logging
from apps.rag.search.main import SearchResult
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text(
query, safesearch="moderate", max_results=count, backend="api"
)
# Check if there are search results
if ddgs_gen:
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
# Create an empty list to store the SearchResult objects
results = []
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
)
print(results)
# Return the list of search results
return results

View File

@@ -4,14 +4,14 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str, search_engine_id: str, query: str
api_key: str, search_engine_id: str, query: str, count: int
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
@@ -27,7 +27,7 @@ def search_google_pse(
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": RAG_WEB_SEARCH_RESULT_COUNT,
"num": count,
}
response = requests.request("GET", url, headers=headers, params=params)

View File

@@ -1,28 +1,68 @@
import logging
import requests
from typing import List
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng(query_url: str, query: str) -> list[SearchResult]:
"""Search a SearXNG instance for a query and return the results as a list of SearchResult objects.
def search_searxng(
query_url: str, query: str, count: int, **kwargs
) -> List[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
The function allows passing additional parameters such as language or time_range to tailor the search result.
Args:
query_url (str): The URL of the SearXNG instance to search. Must contain "<query>" as a placeholder
query (str): The query to search for
"""
url = query_url.replace("<query>", query)
if "&format=json" not in url:
url += "&format=json"
log.debug(f"searching {url}")
query_url (str): The base URL of the SearXNG server.
query (str): The search term or question to find in the SearXNG database.
count (int): The maximum number of results to retrieve from the search.
r = requests.get(
url,
Keyword Args:
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
Returns:
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
"""
# Default values for optional parameters are provided as empty strings or None when not specified.
language = kwargs.get("language", "en-US")
safesearch = kwargs.get("safesearch", "1")
time_range = kwargs.get("time_range", "")
categories = "".join(kwargs.get("categories", []))
params = {
"q": query,
"format": "json",
"pageno": 1,
"safesearch": safesearch,
"language": language,
"time_range": time_range,
"categories": categories,
"theme": "simple",
"image_proxy": 0,
}
# Legacy query format
if "<query>" in query_url:
# Strip all query parameters from the URL
query_url = query_url.split("?")[0]
log.debug(f"searching {query_url}")
response = requests.get(
query_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
@@ -30,15 +70,17 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]:
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
params=params,
)
r.raise_for_status()
json_response = r.json()
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
)
for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in sorted_results[:count]
]

View File

@@ -4,13 +4,13 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str) -> list[SearchResult]:
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
@@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in results[:count]
]

View File

@@ -0,0 +1,68 @@
import json
import logging
import requests
from urllib.parse import urlencode
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serply(
api_key: str,
query: str,
count: int,
hl: str = "us",
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serply.io API key
query (str): The query to search for
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
limit (int): The maximum number of results to return [10-100, defaults to 10]
"""
log.info("Searching with Serply")
url = "https://api.serply.io/v1/search/"
query_payload = {
"q": query,
"language": "en",
"num": limit,
"gl": proxy_location.upper(),
"hl": hl.lower(),
}
url = f"{url}{urlencode(query_payload)}"
headers = {
"X-API-KEY": api_key,
"X-User-Agent": device_type,
"User-Agent": "open-webui",
"X-Proxy-Location": proxy_location,
}
response = requests.request("GET", url, headers=headers)
response.raise_for_status()
json_response = response.json()
log.info(f"results from serply search: {json_response}")
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]

View File

@@ -4,14 +4,14 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack(
api_key: str, query: str, https_enabled: bool = True
api_key: str, query: str, count: int, https_enabled: bool = True
) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
@@ -39,5 +39,5 @@ def search_serpstack(
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in results[:count]
]

View File

@@ -0,0 +1,206 @@
{
"ads": [],
"ads_count": 0,
"answers": [],
"results": [
{
"title": "Apple",
"link": "https://www.apple.com/",
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
"additional_links": [
{
"text": "AppleApplehttps://www.apple.com",
"href": "https://www.apple.com/"
}
],
"cite": {},
"subdomains": [
{
"title": "Support",
"link": "https://support.apple.com/",
"description": "SupportContact - iPhone Support - Billing and Subscriptions - Apple Repair"
},
{
"title": "Store",
"link": "https://www.apple.com/store",
"description": "StoreShop iPhone - Shop iPad - App Store - Shop Mac - ..."
},
{
"title": "Mac",
"link": "https://www.apple.com/mac/",
"description": "MacMacBook Air - MacBook Pro - iMac - Compare Mac models - Mac mini"
},
{
"title": "iPad",
"link": "https://www.apple.com/ipad/",
"description": "iPadShop iPad - iPad Pro - iPad Air - Compare iPad models - ..."
},
{
"title": "Watch",
"link": "https://www.apple.com/watch/",
"description": "WatchShop Apple Watch - Series 9 - SE - Ultra 2 - Nike - Hermès - ..."
}
],
"realPosition": 1
},
{
"title": "Apple",
"link": "https://www.apple.com/",
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
"additional_links": [
{
"text": "AppleApplehttps://www.apple.com",
"href": "https://www.apple.com/"
}
],
"cite": {},
"realPosition": 2
},
{
"title": "Apple Inc.",
"link": "https://en.wikipedia.org/wiki/Apple_Inc.",
"description": "Apple Inc. (formerly Apple Computer, Inc.) is an American multinational corporation and technology company headquartered in Cupertino, California, ...",
"additional_links": [
{
"text": "Apple Inc.Wikipediahttps://en.wikipedia.org wiki Apple_Inc",
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
},
{
"text": "",
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
},
{
"text": "History",
"href": "https://en.wikipedia.org/wiki/History_of_Apple_Inc."
},
{
"text": "List of Apple products",
"href": "https://en.wikipedia.org/wiki/List_of_Apple_products"
},
{
"text": "Litigation involving Apple Inc.",
"href": "https://en.wikipedia.org/wiki/Litigation_involving_Apple_Inc."
},
{
"text": "Apple Park",
"href": "https://en.wikipedia.org/wiki/Apple_Park"
}
],
"cite": {
"domain": "https://en.wikipedia.org wiki Apple_Inc",
"span": " wiki Apple_Inc"
},
"realPosition": 3
},
{
"title": "Apple Inc. (AAPL) Company Profile & Facts",
"link": "https://finance.yahoo.com/quote/AAPL/profile/",
"description": "Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables, and accessories worldwide. The company offers iPhone, a line ...",
"additional_links": [
{
"text": "Apple Inc. (AAPL) Company Profile & FactsYahoo Financehttps://finance.yahoo.com quote AAPL profile",
"href": "https://finance.yahoo.com/quote/AAPL/profile/"
}
],
"cite": {
"domain": "https://finance.yahoo.com quote AAPL profile",
"span": " quote AAPL profile"
},
"realPosition": 4
},
{
"title": "Apple Inc - Company Profile and News",
"link": "https://www.bloomberg.com/profile/company/AAPL:US",
"description": "Apple Inc. Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables and accessories, and sells a variety of related ...",
"additional_links": [
{
"text": "Apple Inc - Company Profile and NewsBloomberghttps://www.bloomberg.com company AAPL:US",
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
},
{
"text": "",
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
}
],
"cite": {
"domain": "https://www.bloomberg.com company AAPL:US",
"span": " company AAPL:US"
},
"realPosition": 5
},
{
"title": "Apple Inc. | History, Products, Headquarters, & Facts",
"link": "https://www.britannica.com/money/Apple-Inc",
"description": "May 22, 2024 — Apple Inc. is an American multinational technology company that revolutionized the technology sector through its innovation of computer ...",
"additional_links": [
{
"text": "Apple Inc. | History, Products, Headquarters, & FactsBritannicahttps://www.britannica.com money Apple-Inc",
"href": "https://www.britannica.com/money/Apple-Inc"
},
{
"text": "",
"href": "https://www.britannica.com/money/Apple-Inc"
}
],
"cite": {
"domain": "https://www.britannica.com money Apple-Inc",
"span": " money Apple-Inc"
},
"realPosition": 6
}
],
"shopping_ads": [],
"places": [
{
"title": "Apple Inc."
},
{
"title": "Apple Inc"
},
{
"title": "Apple Inc"
}
],
"related_searches": {
"images": [],
"text": [
{
"title": "apple inc full form",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+full+form&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhPEAE"
},
{
"title": "apple company history",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+company+history&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhOEAE"
},
{
"title": "apple store",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Store&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhQEAE"
},
{
"title": "apple id",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+id&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhSEAE"
},
{
"title": "apple inc industry",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+industry&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhREAE"
},
{
"title": "apple login",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+login&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhTEAE"
}
]
},
"image_results": [],
"carousel": [],
"total": 2450000000,
"knowledge_graph": "",
"related_questions": [
"What does the Apple Inc do?",
"Why did Apple change to Apple Inc?",
"Who owns Apple Inc.?",
"What is Apple Inc best known for?"
],
"carousel_count": 0,
"ts": 2.491065263748169,
"device_type": null
}

View File

@@ -2,7 +2,7 @@ import os
import logging
import requests
from typing import List
from typing import List, Union
from apps.ollama.main import (
generate_ollama_embeddings,
@@ -20,23 +20,8 @@ from langchain.retrievers import (
from typing import Optional
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from config import (
SRC_LOG_LEVELS,
CHROMA_CLIENT,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
BRAVE_SEARCH_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
SERPER_API_KEY,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -214,6 +199,7 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
@@ -237,17 +223,22 @@ def get_embedding_function(
def generate_multiple(query, f):
if isinstance(query, list):
return [f(q) for q in query]
if embedding_engine == "openai":
embeddings = []
for i in range(0, len(query), batch_size):
embeddings.extend(f(query[i : i + batch_size]))
return embeddings
else:
return [f(q) for q in query]
else:
return f(query)
return lambda query: generate_multiple(query, func)
def rag_messages(
def get_rag_context(
docs,
messages,
template,
embedding_function,
k,
reranking_function,
@@ -255,31 +246,7 @@ def rag_messages(
hybrid_search,
):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
query = get_last_user_message(messages)
extracted_collections = []
relevant_contexts = []
@@ -350,33 +317,7 @@ def rag_messages(
context_string = context_string.strip()
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages, citations
return context_string, citations
def get_model_path(model: str, update_model: bool = False):
@@ -418,8 +359,22 @@ def get_model_path(model: str, update_model: bool = False):
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/embeddings",
@@ -427,12 +382,12 @@ def generate_openai_embeddings(
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
return [elem["embedding"] for elem in data["data"]]
else:
raise "Something went wrong :/"
except Exception as e:
@@ -536,31 +491,3 @@ class RerankCompressor(BaseDocumentCompressor):
)
final_results.append(doc)
return final_results
def search_web(query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
Args:
query (str): The query to search for
"""
# TODO: add playwright to search the web
if SEARXNG_QUERY_URL:
return search_searxng(SEARXNG_QUERY_URL, query)
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
elif BRAVE_SEARCH_API_KEY:
return search_brave(BRAVE_SEARCH_API_KEY, query)
elif SERPSTACK_API_KEY:
return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
elif SERPER_API_KEY:
return search_serper(SERPER_API_KEY, query)
else:
raise Exception("No search engine API key found in environment variables")

139
backend/apps/socket/main.py Normal file
View File

@@ -0,0 +1,139 @@
import socketio
import asyncio
from apps.webui.models.users import Users
from utils.utils import decode_token
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
@sio.event
async def connect(sid, environ, auth):
user = None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()})
@sio.on("user-join")
async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(set(USER_POOL))})
def get_models_in_use():
# Aggregate all models in use
models_in_use = []
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
return models_in_use
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Cancel previous callback if there is one
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["callback"].cancel()
# Store the new usage data and task
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["sids"].append(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
else:
USAGE_POOL[model_id] = {"sids": [sid]}
# Schedule a task to remove the usage data after TIMEOUT_DURATION
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
remove_after_timeout(sid, model_id)
)
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
async def remove_after_timeout(sid, model_id):
try:
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError:
# Task was cancelled due to new 'usage' event
pass
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
user_id = SESSION_POOL[sid]
del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await sio.emit("user-count", {"count": len(USER_POOL)})
else:
print(f"Unknown session ID {sid} disconnected")

View File

@@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")

View File

@@ -6,6 +6,7 @@ from apps.webui.routers import (
users,
chats,
documents,
tools,
models,
prompts,
configs,
@@ -14,6 +15,8 @@ from apps.webui.routers import (
)
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
@@ -24,8 +27,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
AppConfig,
)
app = FastAPI()
@@ -36,6 +39,12 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
@@ -47,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.TOOLS = {}
app.add_middleware(
@@ -63,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])

View File

@@ -298,6 +298,15 @@ class ChatTable:
# .limit(limit).offset(skip)
]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))

View File

@@ -0,0 +1,132 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)

View File

@@ -269,73 +269,88 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
############################
# GetAdminDetails
############################
@router.get("/admin/details")
async def get_admin_details(request: Request, user=Depends(get_current_user)):
if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
print(admin_email, admin_name)
if admin_email:
admin = Users.get_user_by_email(admin_email)
if admin:
admin_name = admin.name
else:
admin = Users.get_first_user()
if admin:
admin_email = admin.email
admin_name = admin.name
return {
"name": admin_name,
"email": admin_email,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
############################
# ToggleSignUp
############################
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.ENABLE_SIGNUP
@router.get("/admin/config")
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
return request.app.state.config.ENABLE_SIGNUP
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
ENABLE_SIGNUP: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
############################
# Default User Role
############################
@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel):
role: str
@router.post("/signup/user/role")
async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
@router.post("/admin/config")
async def update_admin_config(
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
):
if form_data.role in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return request.app.state.config.DEFAULT_USER_ROLE
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
############################
# JWT Expiration
############################
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
duration: str
@router.post("/token/expires/update")
async def update_token_expires_duration(
request: Request,
form_data: UpdateJWTExpiresDurationForm,
user=Depends(get_admin_user),
):
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return request.app.state.config.JWT_EXPIRES_IN
else:
return request.app.state.config.JWT_EXPIRES_IN
if re.match(pattern, form_data.JWT_EXPIRES_IN):
request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
request.app.state.config.ENABLE_COMMUNITY_SHARING = (
form_data.ENABLE_COMMUNITY_SHARING
)
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
############################

View File

@@ -113,6 +113,19 @@ async def get_user_chats(user=Depends(get_current_user)):
]
############################
# GetArchivedChats
############################
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
############################
# GetAllChatsInDB
############################
@@ -148,7 +161,7 @@ async def get_archived_session_user_chat_list(
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
@router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
@@ -288,6 +301,32 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
return result
############################
# CloneChat
############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat_body = json.loads(chat.chat)
updated_chat = {
**chat_body,
"originalChatId": chat.id,
"branchPointMessageId": chat_body["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# ArchiveChat
############################

View File

@@ -73,7 +73,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
############################
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
@@ -105,7 +105,7 @@ class TagDocumentForm(BaseModel):
tags: List[dict]
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
@@ -128,7 +128,7 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
############################
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
@@ -152,7 +152,7 @@ async def update_doc_by_name(
############################
@router.delete("/name/{name}/delete", response_model=bool)
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result

View File

@@ -0,0 +1,183 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
# delete the toolkit file
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
os.remove(toolkit_path)
return result

View File

@@ -19,7 +19,12 @@ from apps.webui.models.users import (
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
from utils.utils import get_verified_user, get_password_hash, get_admin_user
from utils.utils import (
get_verified_user,
get_password_hash,
get_current_user,
get_admin_user,
)
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS

View File

@@ -7,6 +7,8 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
@@ -26,6 +28,21 @@ async def get_gravatar(
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel):
md: str
@@ -107,3 +124,12 @@ async def download_db(user=Depends(get_admin_user)):
media_type="application/octet-stream",
filename="webui.db",
)
@router.get("/litellm/config")
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
return FileResponse(
f"{DATA_DIR}/litellm/config.yaml",
media_type="application/octet-stream",
filename="config.yaml",
)

View File

@@ -0,0 +1,23 @@
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e