Merge remote-tracking branch 'upstream/main' into feature-external-db-reconnect

This commit is contained in:
perf3ct
2024-06-16 09:03:57 -07:00
206 changed files with 21988 additions and 7009 deletions

View File

@@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware
from faster_whisper import WhisperModel
from pydantic import BaseModel
import uuid
import requests
import hashlib
from pathlib import Path
import json
from constants import ERROR_MESSAGES
from utils.utils import (
decode_token,
@@ -41,10 +40,15 @@ from config import (
WHISPER_MODEL_DIR,
WHISPER_MODEL_AUTO_UPDATE,
DEVICE_TYPE,
AUDIO_OPENAI_API_BASE_URL,
AUDIO_OPENAI_API_KEY,
AUDIO_OPENAI_API_MODEL,
AUDIO_OPENAI_API_VOICE,
AUDIO_STT_OPENAI_API_BASE_URL,
AUDIO_STT_OPENAI_API_KEY,
AUDIO_TTS_OPENAI_API_BASE_URL,
AUDIO_TTS_OPENAI_API_KEY,
AUDIO_STT_ENGINE,
AUDIO_STT_MODEL,
AUDIO_TTS_ENGINE,
AUDIO_TTS_MODEL,
AUDIO_TTS_VOICE,
AppConfig,
)
@@ -61,10 +65,17 @@ app.add_middleware(
)
app.state.config = AppConfig()
app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
app.state.config.STT_MODEL = AUDIO_STT_MODEL
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
# setting device type for whisper model
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
@@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
class OpenAIConfigUpdateForm(BaseModel):
url: str
key: str
model: str
speaker: str
class TTSConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
VOICE: str
class STTConfigForm(BaseModel):
OPENAI_API_BASE_URL: str
OPENAI_API_KEY: str
ENGINE: str
MODEL: str
class AudioConfigUpdateForm(BaseModel):
tts: TTSConfigForm
stt: STTConfigForm
from pydub import AudioSegment
from pydub.utils import mediainfo
def is_mp4_audio(file_path):
"""Check if the given file is an MP4 audio file."""
if not os.path.isfile(file_path):
print(f"File not found: {file_path}")
return False
info = mediainfo(file_path)
if (
info.get("codec_name") == "aac"
and info.get("codec_type") == "audio"
and info.get("codec_tag_string") == "mp4a"
):
return True
return False
def convert_mp4_to_wav(file_path, output_path):
"""Convert MP4 audio file to WAV format."""
audio = AudioSegment.from_file(file_path, format="mp4")
audio.export(output_path, format="wav")
print(f"Converted {file_path} to {output_path}")
@app.get("/config")
async def get_openai_config(user=Depends(get_admin_user)):
async def get_audio_config(user=Depends(get_admin_user)):
return {
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
}
@app.post("/config/update")
async def update_openai_config(
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
async def update_audio_config(
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
):
if form_data.key == "":
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
app.state.config.TTS_MODEL = form_data.tts.MODEL
app.state.config.TTS_VOICE = form_data.tts.VOICE
app.state.config.OPENAI_API_BASE_URL = form_data.url
app.state.config.OPENAI_API_KEY = form_data.key
app.state.config.OPENAI_API_MODEL = form_data.model
app.state.config.OPENAI_API_VOICE = form_data.speaker
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
app.state.config.STT_ENGINE = form_data.stt.ENGINE
app.state.config.STT_MODEL = form_data.stt.MODEL
return {
"status": True,
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
"tts": {
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
"ENGINE": app.state.config.TTS_ENGINE,
"MODEL": app.state.config.TTS_MODEL,
"VOICE": app.state.config.TTS_VOICE,
},
"stt": {
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
"ENGINE": app.state.config.STT_ENGINE,
"MODEL": app.state.config.STT_MODEL,
},
}
@@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
return FileResponse(file_path)
headers = {}
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
headers["Content-Type"] = "application/json"
try:
body = body.decode("utf-8")
body = json.loads(body)
body["model"] = app.state.config.TTS_MODEL
body = json.dumps(body).encode("utf-8")
except Exception as e:
pass
r = None
try:
r = requests.post(
url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
data=body,
headers=headers,
stream=True,
@@ -181,41 +260,110 @@ def transcribe(
)
try:
filename = file.filename
file_path = f"{UPLOAD_DIR}/{filename}"
ext = file.filename.split(".")[-1]
id = uuid.uuid4()
filename = f"{id}.{ext}"
file_dir = f"{CACHE_DIR}/audio/transcriptions"
os.makedirs(file_dir, exist_ok=True)
file_path = f"{file_dir}/{filename}"
print(filename)
contents = file.file.read()
with open(file_path, "wb") as f:
f.write(contents)
f.close()
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
if app.state.config.STT_ENGINE == "":
whisper_kwargs = {
"model_size_or_path": WHISPER_MODEL,
"device": whisper_device_type,
"compute_type": "int8",
"download_root": WHISPER_MODEL_DIR,
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
}
log.debug(f"whisper_kwargs: {whisper_kwargs}")
log.debug(f"whisper_kwargs: {whisper_kwargs}")
try:
model = WhisperModel(**whisper_kwargs)
except:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
try:
model = WhisperModel(**whisper_kwargs)
except:
log.warning(
"WhisperModel initialization failed, attempting download with local_files_only=False"
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
whisper_kwargs["local_files_only"] = False
model = WhisperModel(**whisper_kwargs)
segments, info = model.transcribe(file_path, beam_size=5)
log.info(
"Detected language '%s' with probability %f"
% (info.language, info.language_probability)
)
transcript = "".join([segment.text for segment in list(segments)])
transcript = "".join([segment.text for segment in list(segments)])
data = {"text": transcript.strip()}
return {"text": transcript.strip()}
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
elif app.state.config.STT_ENGINE == "openai":
if is_mp4_audio(file_path):
print("is_mp4_audio")
os.rename(file_path, file_path.replace(".wav", ".mp4"))
# Convert MP4 audio file to WAV format
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
files = {"file": (filename, open(file_path, "rb"))}
data = {"model": "whisper-1"}
print(files, data)
r = None
try:
r = requests.post(
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
headers=headers,
files=files,
data=data,
)
r.raise_for_status()
data = r.json()
# save the transcript to a json file
transcript_file = f"{file_dir}/{id}.json"
with open(transcript_file, "w") as f:
json.dump(data, f)
print(data)
return data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"External: {res['error']['message']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r != None else 500,
detail=error_detail,
)
except Exception as e:
log.exception(e)

View File

@@ -29,6 +29,8 @@ import time
from urllib.parse import urlparse
from typing import Optional, List, Union
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES
@@ -39,8 +41,6 @@ from utils.utils import (
get_admin_user,
)
from utils.models import get_model_id_from_custom_model_id
from config import (
SRC_LOG_LEVELS,
@@ -75,9 +75,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
app.state.MODELS = {}
REQUEST_POOL = []
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
# least connections, or least response time for better resource utilization and performance optimization.
@@ -132,20 +129,10 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
@app.get("/cancel/{request_id}")
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
if user:
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
return True
else:
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
async def fetch_url(url):
timeout = aiohttp.ClientTimeout(total=5)
try:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url) as response:
return await response.json()
except Exception as e:
@@ -154,6 +141,45 @@ async def fetch_url(url):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
async def post_streaming_url(url: str, payload: str):
r = None
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.post(url, data=payload)
r.raise_for_status()
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(cleanup_response, response=r, session=session),
)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status if r else 500,
detail=error_detail,
)
def merge_models_lists(model_lists):
merged_models = {}
@@ -246,54 +272,57 @@ async def get_ollama_tags(
@app.get("/api/version")
@app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API:
if url_idx == None:
if url_idx == None:
# returns lowest version
tasks = [
fetch_url(f"{url}/api/version")
for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
# returns lowest version
tasks = [
fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
]
responses = await asyncio.gather(*tasks)
responses = list(filter(lambda x: x is not None, responses))
if len(responses) > 0:
lowest_version = min(
responses,
key=lambda x: tuple(
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
),
)
if len(responses) > 0:
lowest_version = min(
responses,
key=lambda x: tuple(
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
),
)
return {"version": lowest_version["version"]}
return {"version": lowest_version["version"]}
else:
raise HTTPException(
status_code=500,
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
else:
raise HTTPException(
status_code=500,
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
else:
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
r = None
try:
r = requests.request(method="GET", url=f"{url}/api/version")
r.raise_for_status()
return r.json()
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return {"version": False}
class ModelNameForm(BaseModel):
@@ -313,65 +342,7 @@ async def pull_model(
# Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
def get_request():
nonlocal url
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/pull",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
class PushModelForm(BaseModel):
@@ -399,50 +370,9 @@ async def push_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.debug(f"url: {url}")
r = None
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/push",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
)
class CreateModelForm(BaseModel):
@@ -461,53 +391,9 @@ async def create_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal url
nonlocal r
try:
def stream_content():
for chunk in r.iter_content(chunk_size=8192):
yield chunk
r = requests.request(
method="POST",
url=f"{url}/api/create",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
log.debug(f"r: {r}")
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
)
class CopyModelForm(BaseModel):
@@ -797,66 +683,9 @@ async def generate_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal form_data
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if form_data.stream:
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/generate",
data=form_data.model_dump_json(exclude_none=True).encode(),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
)
class ChatMessage(BaseModel):
@@ -897,7 +726,6 @@ async def generate_chat_completion(
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@@ -906,44 +734,77 @@ async def generate_chat_completion(
if model_info.params:
payload["options"] = {}
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if model_info.params.get("mirostat", None):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if model_info.params.get("mirostat_eta", None):
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("mirostat_tau", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("num_ctx", None):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("repeat_last_n", None):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("frequency_penalty", None):
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if model_info.params.get("temperature", None) is not None:
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
if model_info.params.get("seed", None):
payload["options"]["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("tfs_z", None):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if model_info.params.get("max_tokens", None):
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
if model_info.params.get("top_k", None):
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if model_info.params.get("top_p", None):
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if model_info.params.get("use_mmap", None):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if model_info.params.get("use_mlock", None):
payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None
)
if model_info.params.get("num_thread", None):
payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
@@ -981,73 +842,18 @@ async def generate_chat_completion(
print(payload)
r = None
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream", None):
yield json.dumps({"id": request_id, "done": False}) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/api/chat",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
log.exception(e)
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
# TODO: we should update this part once Ollama supports other types
class OpenAIChatMessageContent(BaseModel):
type: str
model_config = ConfigDict(extra="allow")
class OpenAIChatMessage(BaseModel):
role: str
content: str
content: Union[str, OpenAIChatMessageContent]
model_config = ConfigDict(extra="allow")
@@ -1075,7 +881,6 @@ async def generate_openai_chat_completion(
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
@@ -1132,68 +937,7 @@ async def generate_openai_chat_completion(
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}")
r = None
def get_request():
nonlocal payload
nonlocal r
request_id = str(uuid.uuid4())
try:
REQUEST_POOL.append(request_id)
def stream_content():
try:
if payload.get("stream"):
yield json.dumps(
{"request_id": request_id, "done": False}
) + "\n"
for chunk in r.iter_content(chunk_size=8192):
if request_id in REQUEST_POOL:
yield chunk
else:
log.warning("User: canceled request")
break
finally:
if hasattr(r, "close"):
r.close()
if request_id in REQUEST_POOL:
REQUEST_POOL.remove(request_id)
r = requests.request(
method="POST",
url=f"{url}/v1/chat/completions",
data=json.dumps(payload),
stream=True,
)
r.raise_for_status()
return StreamingResponse(
stream_content(),
status_code=r.status_code,
headers=dict(r.headers),
)
except Exception as e:
raise e
try:
return await run_in_threadpool(get_request)
except Exception as e:
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
if "error" in res:
error_detail = f"Ollama: {res['error']}"
except:
error_detail = f"Ollama: {e}"
raise HTTPException(
status_code=r.status_code if r else 500,
detail=error_detail,
)
return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
@app.get("/v1/models")
@@ -1305,7 +1049,7 @@ async def download_file_stream(
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(file_url, headers=headers) as response:
total_size = int(response.headers.get("content-length", 0)) + current_size
@@ -1522,7 +1266,7 @@ async def deprecated_proxy(
if path == "generate":
data = json.loads(body.decode("utf-8"))
if not ("stream" in data and data["stream"] == False):
if data.get("stream", True):
yield json.dumps({"id": request_id, "done": False}) + "\n"
elif path == "chat":

View File

@@ -9,6 +9,7 @@ import json
import logging
from pydantic import BaseModel
from starlette.background import BackgroundTask
from apps.webui.models.models import Models
from apps.webui.models.users import Users
@@ -185,7 +186,7 @@ async def fetch_url(url, key):
timeout = aiohttp.ClientTimeout(total=5)
try:
headers = {"Authorization": f"Bearer {key}"}
async with aiohttp.ClientSession(timeout=timeout) as session:
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
async with session.get(url, headers=headers) as response:
return await response.json()
except Exception as e:
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
return None
async def cleanup_response(
response: Optional[aiohttp.ClientResponse],
session: Optional[aiohttp.ClientSession],
):
if response:
response.close()
if session:
await session.close()
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
@@ -228,6 +239,27 @@ async def get_all_models(raw: bool = False):
) or not app.state.config.ENABLE_OPENAI_API:
models = {"data": []}
else:
# Check if API KEYS length is same than API URLS length
if len(app.state.config.OPENAI_API_KEYS) != len(
app.state.config.OPENAI_API_BASE_URLS
):
# if there are more keys than urls, remove the extra keys
if len(app.state.config.OPENAI_API_KEYS) > len(
app.state.config.OPENAI_API_BASE_URLS
):
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
: len(app.state.config.OPENAI_API_BASE_URLS)
]
# if there are more urls than keys, add empty keys
else:
app.state.config.OPENAI_API_KEYS += [
""
for _ in range(
len(app.state.config.OPENAI_API_BASE_URLS)
- len(app.state.config.OPENAI_API_KEYS)
)
]
tasks = [
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
@@ -313,113 +345,155 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
)
@app.post("/chat/completions")
@app.post("/chat/completions/{url_idx}")
async def generate_chat_completion(
form_data: dict,
url_idx: Optional[int] = None,
user=Depends(get_verified_user),
):
idx = 0
payload = {**form_data}
model_id = form_data.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None) is not None:
payload["temperature"] = float(model_info.params.get("temperature"))
if model_info.params.get("top_p", None):
payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None) + message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON
payload = json.dumps(payload)
print(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
print(payload)
headers = {}
headers["Authorization"] = f"Bearer {key}"
headers["Content-Type"] = "application/json"
r = None
session = None
streaming = False
try:
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method="POST",
url=f"{url}/chat/completions",
data=payload,
headers=headers,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
idx = 0
body = await request.body()
# TODO: Remove below after gpt-4-vision fix from Open AI
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
payload = None
try:
if "chat/completions" in path:
body = body.decode("utf-8")
body = json.loads(body)
payload = {**body}
model_id = body.get("model")
model_info = Models.get_model_by_id(model_id)
if model_info:
print(model_info)
if model_info.base_model_id:
payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump()
if model_info.params:
if model_info.params.get("temperature", None):
payload["temperature"] = int(
model_info.params.get("temperature")
)
if model_info.params.get("top_p", None):
payload["top_p"] = int(model_info.params.get("top_p", None))
if model_info.params.get("max_tokens", None):
payload["max_tokens"] = int(
model_info.params.get("max_tokens", None)
)
if model_info.params.get("frequency_penalty", None):
payload["frequency_penalty"] = int(
model_info.params.get("frequency_penalty", None)
)
if model_info.params.get("seed", None):
payload["seed"] = model_info.params.get("seed", None)
if model_info.params.get("stop", None):
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if model_info.params.get("system", None):
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = (
model_info.params.get("system", None)
+ message["content"]
)
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": model_info.params.get("system", None),
},
)
else:
pass
model = app.state.MODELS[payload.get("model")]
idx = model["urlIdx"]
if "pipeline" in model and model.get("pipeline"):
payload["user"] = {"name": user.name, "id": user.id}
payload["title"] = (
True
if payload["stream"] == False and payload["max_tokens"] == 50
else False
)
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
# This is a workaround until OpenAI fixes the issue with this model
if payload.get("model") == "gpt-4-vision-preview":
if "max_tokens" not in payload:
payload["max_tokens"] = 4000
log.debug("Modified payload:", payload)
# Convert the modified body back to JSON
payload = json.dumps(payload)
except json.JSONDecodeError as e:
log.error("Error loading request body into a dictionary:", e)
print(payload)
url = app.state.config.OPENAI_API_BASE_URLS[idx]
key = app.state.config.OPENAI_API_KEYS[idx]
@@ -431,40 +505,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
headers["Content-Type"] = "application/json"
r = None
session = None
streaming = False
try:
r = requests.request(
session = aiohttp.ClientSession(trust_env=True)
r = await session.request(
method=request.method,
url=target_url,
data=payload if payload else body,
data=body,
headers=headers,
stream=True,
)
r.raise_for_status()
# Check if response is SSE
if "text/event-stream" in r.headers.get("Content-Type", ""):
streaming = True
return StreamingResponse(
r.iter_content(chunk_size=8192),
status_code=r.status_code,
r.content,
status_code=r.status,
headers=dict(r.headers),
background=BackgroundTask(
cleanup_response, response=r, session=session
),
)
else:
response_data = r.json()
response_data = await r.json()
return response_data
except Exception as e:
log.exception(e)
error_detail = "Open WebUI: Server Connection Error"
if r is not None:
try:
res = r.json()
res = await r.json()
print(res)
if "error" in res:
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
except:
error_detail = f"External: {e}"
raise HTTPException(
status_code=r.status_code if r else 500, detail=error_detail
)
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
finally:
if not streaming and session:
if r:
r.close()
await session.close()

View File

@@ -8,12 +8,15 @@ from fastapi import (
Form,
)
from fastapi.middleware.cors import CORSMiddleware
import requests
import os, shutil, logging, re
from datetime import datetime
from pathlib import Path
from typing import List, Union, Sequence
from typing import List, Union, Sequence, Iterator, Any
from chromadb.utils.batch_utils import create_batches
from langchain_core.documents import Document
from langchain_community.document_loaders import (
WebBaseLoader,
@@ -30,6 +33,7 @@ from langchain_community.document_loaders import (
UnstructuredExcelLoader,
UnstructuredPowerPointLoader,
YoutubeLoader,
OutlookMessageLoader,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -59,9 +63,17 @@ from apps.rag.utils import (
query_doc_with_hybrid_search,
query_collection,
query_collection_with_hybrid_search,
search_web,
)
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from apps.rag.search.serply import search_serply
from apps.rag.search.duckduckgo import search_duckduckgo
from utils.misc import (
calculate_sha256,
calculate_sha256_string,
@@ -71,6 +83,7 @@ from utils.misc import (
from utils.utils import get_current_user, get_admin_user
from config import (
AppConfig,
ENV,
SRC_LOG_LEVELS,
UPLOAD_DIR,
@@ -96,8 +109,19 @@ from config import (
RAG_TEMPLATE,
ENABLE_RAG_LOCAL_WEB_FETCH,
YOUTUBE_LOADER_LANGUAGE,
ENABLE_RAG_WEB_SEARCH,
RAG_WEB_SEARCH_ENGINE,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
BRAVE_SEARCH_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
SERPER_API_KEY,
SERPLY_API_KEY,
RAG_WEB_SEARCH_RESULT_COUNT,
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
AppConfig,
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
from constants import ERROR_MESSAGES
@@ -122,6 +146,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
@@ -136,6 +161,21 @@ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
app.state.YOUTUBE_LOADER_TRANSLATION = None
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
app.state.config.SERPER_API_KEY = SERPER_API_KEY
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
def update_embedding_model(
embedding_model: str,
update_model: bool = False,
@@ -181,6 +221,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
origins = ["*"]
@@ -217,6 +258,7 @@ async def get_status():
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
}
@@ -229,6 +271,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
@@ -244,6 +287,7 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
class OpenAIConfigForm(BaseModel):
url: str
key: str
batch_size: Optional[int] = None
class EmbeddingModelUpdateForm(BaseModel):
@@ -264,9 +308,14 @@ async def update_embedding_config(
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
if form_data.openai_config != None:
if form_data.openai_config is not None:
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
form_data.openai_config.batch_size
if form_data.openai_config.batch_size
else 1
)
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -276,6 +325,7 @@ async def update_embedding_config(
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
return {
@@ -285,6 +335,7 @@ async def update_embedding_config(
"openai_config": {
"url": app.state.config.OPENAI_API_BASE_URL,
"key": app.state.config.OPENAI_API_KEY,
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
},
}
except Exception as e:
@@ -332,11 +383,27 @@ async def get_rag_config(user=Depends(get_admin_user)):
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
},
}
@@ -350,11 +417,31 @@ class YoutubeLoaderConfig(BaseModel):
translation: Optional[str] = None
class WebSearchConfig(BaseModel):
enabled: bool
engine: Optional[str] = None
searxng_query_url: Optional[str] = None
google_pse_api_key: Optional[str] = None
google_pse_engine_id: Optional[str] = None
brave_search_api_key: Optional[str] = None
serpstack_api_key: Optional[str] = None
serpstack_https: Optional[bool] = None
serper_api_key: Optional[str] = None
serply_api_key: Optional[str] = None
result_count: Optional[int] = None
concurrent_requests: Optional[int] = None
class WebConfig(BaseModel):
search: WebSearchConfig
web_loader_ssl_verification: Optional[bool] = None
class ConfigUpdateForm(BaseModel):
pdf_extract_images: Optional[bool] = None
chunk: Optional[ChunkParamUpdateForm] = None
web_loader_ssl_verification: Optional[bool] = None
youtube: Optional[YoutubeLoaderConfig] = None
web: Optional[WebConfig] = None
@app.post("/config/update")
@@ -365,35 +452,37 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
else app.state.config.PDF_EXTRACT_IMAGES
)
app.state.config.CHUNK_SIZE = (
form_data.chunk.chunk_size
if form_data.chunk is not None
else app.state.config.CHUNK_SIZE
)
if form_data.chunk is not None:
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
app.state.config.CHUNK_OVERLAP = (
form_data.chunk.chunk_overlap
if form_data.chunk is not None
else app.state.config.CHUNK_OVERLAP
)
if form_data.youtube is not None:
app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web_loader_ssl_verification
if form_data.web_loader_ssl_verification != None
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
)
if form_data.web is not None:
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
form_data.web.web_loader_ssl_verification
)
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
form_data.youtube.language
if form_data.youtube is not None
else app.state.config.YOUTUBE_LOADER_LANGUAGE
)
app.state.YOUTUBE_LOADER_TRANSLATION = (
form_data.youtube.translation
if form_data.youtube is not None
else app.state.YOUTUBE_LOADER_TRANSLATION
)
app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
app.state.config.GOOGLE_PSE_ENGINE_ID = (
form_data.web.search.google_pse_engine_id
)
app.state.config.BRAVE_SEARCH_API_KEY = (
form_data.web.search.brave_search_api_key
)
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
form_data.web.search.concurrent_requests
)
return {
"status": True,
@@ -402,11 +491,27 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
"chunk_size": app.state.config.CHUNK_SIZE,
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
},
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"youtube": {
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
},
"web": {
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
"search": {
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
"serper_api_key": app.state.config.SERPER_API_KEY,
"serply_api_key": app.state.config.SERPLY_API_KEY,
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
},
},
}
@@ -599,7 +704,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
# Check if the URL is valid
if not validate_url(url):
raise ValueError(ERROR_MESSAGES.INVALID_URL)
return WebBaseLoader(
return SafeWebBaseLoader(
url,
verify_ssl=verify_ssl,
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@@ -642,17 +747,107 @@ def resolve_hostname(hostname):
return ipv4_addresses, ipv6_addresses
def search_web(engine: str, query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
- SERPLY_API_KEY
Args:
query (str): The query to search for
"""
# TODO: add playwright to search the web
if engine == "searxng":
if app.state.config.SEARXNG_QUERY_URL:
return search_searxng(
app.state.config.SEARXNG_QUERY_URL,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
elif engine == "google_pse":
if (
app.state.config.GOOGLE_PSE_API_KEY
and app.state.config.GOOGLE_PSE_ENGINE_ID
):
return search_google_pse(
app.state.config.GOOGLE_PSE_API_KEY,
app.state.config.GOOGLE_PSE_ENGINE_ID,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception(
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
)
elif engine == "brave":
if app.state.config.BRAVE_SEARCH_API_KEY:
return search_brave(
app.state.config.BRAVE_SEARCH_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
elif engine == "serpstack":
if app.state.config.SERPSTACK_API_KEY:
return search_serpstack(
app.state.config.SERPSTACK_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
https_enabled=app.state.config.SERPSTACK_HTTPS,
)
else:
raise Exception("No SERPSTACK_API_KEY found in environment variables")
elif engine == "serper":
if app.state.config.SERPER_API_KEY:
return search_serper(
app.state.config.SERPER_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SERPER_API_KEY found in environment variables")
elif engine == "serply":
if app.state.config.SERPLY_API_KEY:
return search_serply(
app.state.config.SERPLY_API_KEY,
query,
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
)
else:
raise Exception("No SERPLY_API_KEY found in environment variables")
elif engine == "duckduckgo":
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
else:
raise Exception("No search engine API key found in environment variables")
@app.post("/web/search")
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
try:
try:
web_results = search_web(form_data.query)
except Exception as e:
log.exception(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
)
logging.info(
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
)
web_results = search_web(
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
)
except Exception as e:
log.exception(e)
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
)
try:
urls = [result.link for result in web_results]
loader = get_web_loader(urls)
data = loader.load()
@@ -710,6 +905,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
texts = [doc.page_content for doc in docs]
metadatas = [doc.metadata for doc in docs]
# ChromaDB does not like datetime formats
# for meta-data so convert them to string.
for metadata in metadatas:
for key, value in metadata.items():
if isinstance(value, datetime):
metadata[key] = str(value)
try:
if overwrite:
for collection in CHROMA_CLIENT.list_collections():
@@ -725,6 +927,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
app.state.sentence_transformer_ef,
app.state.config.OPENAI_API_KEY,
app.state.config.OPENAI_API_BASE_URL,
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
)
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
@@ -795,6 +998,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"swift",
"vue",
"svelte",
"msg",
]
if file_ext == "pdf":
@@ -829,6 +1033,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
] or file_ext in ["ppt", "pptx"]:
loader = UnstructuredPowerPointLoader(file_path)
elif file_ext == "msg":
loader = OutlookMessageLoader(file_path)
elif file_ext in known_source_ext or (
file_content_type and file_content_type.find("text/") >= 0
):
@@ -994,6 +1200,30 @@ def reset_vector_db(user=Depends(get_admin_user)):
CHROMA_CLIENT.reset()
@app.get("/reset/uploads")
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
try:
# Check if the directory exists
if os.path.exists(folder):
# Iterate over all the files and directories in the specified directory
for filename in os.listdir(folder):
file_path = os.path.join(folder, filename)
try:
if os.path.isfile(file_path) or os.path.islink(file_path):
os.unlink(file_path) # Remove the file or link
elif os.path.isdir(file_path):
shutil.rmtree(file_path) # Remove the directory
except Exception as e:
print(f"Failed to delete {file_path}. Reason: {e}")
else:
print(f"The directory {folder} does not exist")
except Exception as e:
print(f"Failed to process the directory {folder}. Reason: {e}")
return True
@app.get("/reset")
def reset(user=Depends(get_admin_user)) -> bool:
folder = f"{UPLOAD_DIR}"
@@ -1015,6 +1245,33 @@ def reset(user=Depends(get_admin_user)) -> bool:
return True
class SafeWebBaseLoader(WebBaseLoader):
"""WebBaseLoader with enhanced error handling for URLs."""
def lazy_load(self) -> Iterator[Document]:
"""Lazy load text from the url(s) in web_path with error handling."""
for path in self.web_paths:
try:
soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
text = soup.get_text(**self.bs_get_text_kwargs)
# Build metadata
metadata = {"source": path}
if title := soup.find("title"):
metadata["title"] = title.get_text()
if description := soup.find("meta", attrs={"name": "description"}):
metadata["description"] = description.get(
"content", "No description found."
)
if html := soup.find("html"):
metadata["language"] = html.get("lang", "No language found.")
yield Document(page_content=text, metadata=metadata)
except Exception as e:
# Log the error and continue with the next URL
log.error(f"Error loading {path}: {e}")
if ENV == "dev":
@app.get("/ef")

View File

@@ -3,13 +3,13 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_brave(api_key: str, query: str) -> list[SearchResult]:
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
Args:
@@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
"Accept-Encoding": "gzip",
"X-Subscription-Token": api_key,
}
params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}
params = {"q": query, "count": count}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
@@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in results[:count]
]

View File

@@ -0,0 +1,46 @@
import logging
from apps.rag.search.main import SearchResult
from duckduckgo_search import DDGS
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
"""
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
Args:
query (str): The query to search for
count (int): The number of results to return
Returns:
List[SearchResult]: A list of search results
"""
# Use the DDGS context manager to create a DDGS object
with DDGS() as ddgs:
# Use the ddgs.text() method to perform the search
ddgs_gen = ddgs.text(
query, safesearch="moderate", max_results=count, backend="api"
)
# Check if there are search results
if ddgs_gen:
# Convert the search results into a list
search_results = [r for r in ddgs_gen]
# Create an empty list to store the SearchResult objects
results = []
# Iterate over each search result
for result in search_results:
# Create a SearchResult object and append it to the results list
results.append(
SearchResult(
link=result["href"],
title=result.get("title"),
snippet=result.get("body"),
)
)
print(results)
# Return the list of search results
return results

View File

@@ -4,14 +4,14 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_google_pse(
api_key: str, search_engine_id: str, query: str
api_key: str, search_engine_id: str, query: str, count: int
) -> list[SearchResult]:
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
@@ -27,7 +27,7 @@ def search_google_pse(
"cx": search_engine_id,
"q": query,
"key": api_key,
"num": RAG_WEB_SEARCH_RESULT_COUNT,
"num": count,
}
response = requests.request("GET", url, headers=headers, params=params)

View File

@@ -1,28 +1,68 @@
import logging
import requests
from typing import List
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_searxng(query_url: str, query: str) -> list[SearchResult]:
"""Search a SearXNG instance for a query and return the results as a list of SearchResult objects.
def search_searxng(
query_url: str, query: str, count: int, **kwargs
) -> List[SearchResult]:
"""
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
The function allows passing additional parameters such as language or time_range to tailor the search result.
Args:
query_url (str): The URL of the SearXNG instance to search. Must contain "<query>" as a placeholder
query (str): The query to search for
"""
url = query_url.replace("<query>", query)
if "&format=json" not in url:
url += "&format=json"
log.debug(f"searching {url}")
query_url (str): The base URL of the SearXNG server.
query (str): The search term or question to find in the SearXNG database.
count (int): The maximum number of results to retrieve from the search.
r = requests.get(
url,
Keyword Args:
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
Returns:
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
Raise:
requests.exceptions.RequestException: If a request error occurs during the search process.
"""
# Default values for optional parameters are provided as empty strings or None when not specified.
language = kwargs.get("language", "en-US")
safesearch = kwargs.get("safesearch", "1")
time_range = kwargs.get("time_range", "")
categories = "".join(kwargs.get("categories", []))
params = {
"q": query,
"format": "json",
"pageno": 1,
"safesearch": safesearch,
"language": language,
"time_range": time_range,
"categories": categories,
"theme": "simple",
"image_proxy": 0,
}
# Legacy query format
if "<query>" in query_url:
# Strip all query parameters from the URL
query_url = query_url.split("?")[0]
log.debug(f"searching {query_url}")
response = requests.get(
query_url,
headers={
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
"Accept": "text/html",
@@ -30,15 +70,17 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]:
"Accept-Language": "en-US,en;q=0.5",
"Connection": "keep-alive",
},
params=params,
)
r.raise_for_status()
json_response = r.json()
response.raise_for_status() # Raise an exception for HTTP errors.
json_response = response.json()
results = json_response.get("results", [])
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
return [
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("content")
)
for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in sorted_results[:count]
]

View File

@@ -4,13 +4,13 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serper(api_key: str, query: str) -> list[SearchResult]:
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
@@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in results[:count]
]

View File

@@ -0,0 +1,68 @@
import json
import logging
import requests
from urllib.parse import urlencode
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serply(
api_key: str,
query: str,
count: int,
hl: str = "us",
limit: int = 10,
device_type: str = "desktop",
proxy_location: str = "US",
) -> list[SearchResult]:
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
Args:
api_key (str): A serply.io API key
query (str): The query to search for
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
limit (int): The maximum number of results to return [10-100, defaults to 10]
"""
log.info("Searching with Serply")
url = "https://api.serply.io/v1/search/"
query_payload = {
"q": query,
"language": "en",
"num": limit,
"gl": proxy_location.upper(),
"hl": hl.lower(),
}
url = f"{url}{urlencode(query_payload)}"
headers = {
"X-API-KEY": api_key,
"X-User-Agent": device_type,
"User-Agent": "open-webui",
"X-Proxy-Location": proxy_location,
}
response = requests.request("GET", url, headers=headers)
response.raise_for_status()
json_response = response.json()
log.info(f"results from serply search: {json_response}")
results = sorted(
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
)
return [
SearchResult(
link=result["link"],
title=result.get("title"),
snippet=result.get("description"),
)
for result in results[:count]
]

View File

@@ -4,14 +4,14 @@ import logging
import requests
from apps.rag.search.main import SearchResult
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
def search_serpstack(
api_key: str, query: str, https_enabled: bool = True
api_key: str, query: str, count: int, https_enabled: bool = True
) -> list[SearchResult]:
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
@@ -39,5 +39,5 @@ def search_serpstack(
SearchResult(
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
)
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
for result in results[:count]
]

View File

@@ -0,0 +1,206 @@
{
"ads": [],
"ads_count": 0,
"answers": [],
"results": [
{
"title": "Apple",
"link": "https://www.apple.com/",
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
"additional_links": [
{
"text": "AppleApplehttps://www.apple.com",
"href": "https://www.apple.com/"
}
],
"cite": {},
"subdomains": [
{
"title": "Support",
"link": "https://support.apple.com/",
"description": "SupportContact - iPhone Support - Billing and Subscriptions - Apple Repair"
},
{
"title": "Store",
"link": "https://www.apple.com/store",
"description": "StoreShop iPhone - Shop iPad - App Store - Shop Mac - ..."
},
{
"title": "Mac",
"link": "https://www.apple.com/mac/",
"description": "MacMacBook Air - MacBook Pro - iMac - Compare Mac models - Mac mini"
},
{
"title": "iPad",
"link": "https://www.apple.com/ipad/",
"description": "iPadShop iPad - iPad Pro - iPad Air - Compare iPad models - ..."
},
{
"title": "Watch",
"link": "https://www.apple.com/watch/",
"description": "WatchShop Apple Watch - Series 9 - SE - Ultra 2 - Nike - Hermès - ..."
}
],
"realPosition": 1
},
{
"title": "Apple",
"link": "https://www.apple.com/",
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
"additional_links": [
{
"text": "AppleApplehttps://www.apple.com",
"href": "https://www.apple.com/"
}
],
"cite": {},
"realPosition": 2
},
{
"title": "Apple Inc.",
"link": "https://en.wikipedia.org/wiki/Apple_Inc.",
"description": "Apple Inc. (formerly Apple Computer, Inc.) is an American multinational corporation and technology company headquartered in Cupertino, California, ...",
"additional_links": [
{
"text": "Apple Inc.Wikipediahttps://en.wikipedia.org wiki Apple_Inc",
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
},
{
"text": "",
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
},
{
"text": "History",
"href": "https://en.wikipedia.org/wiki/History_of_Apple_Inc."
},
{
"text": "List of Apple products",
"href": "https://en.wikipedia.org/wiki/List_of_Apple_products"
},
{
"text": "Litigation involving Apple Inc.",
"href": "https://en.wikipedia.org/wiki/Litigation_involving_Apple_Inc."
},
{
"text": "Apple Park",
"href": "https://en.wikipedia.org/wiki/Apple_Park"
}
],
"cite": {
"domain": "https://en.wikipedia.org wiki Apple_Inc",
"span": " wiki Apple_Inc"
},
"realPosition": 3
},
{
"title": "Apple Inc. (AAPL) Company Profile & Facts",
"link": "https://finance.yahoo.com/quote/AAPL/profile/",
"description": "Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables, and accessories worldwide. The company offers iPhone, a line ...",
"additional_links": [
{
"text": "Apple Inc. (AAPL) Company Profile & FactsYahoo Financehttps://finance.yahoo.com quote AAPL profile",
"href": "https://finance.yahoo.com/quote/AAPL/profile/"
}
],
"cite": {
"domain": "https://finance.yahoo.com quote AAPL profile",
"span": " quote AAPL profile"
},
"realPosition": 4
},
{
"title": "Apple Inc - Company Profile and News",
"link": "https://www.bloomberg.com/profile/company/AAPL:US",
"description": "Apple Inc. Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables and accessories, and sells a variety of related ...",
"additional_links": [
{
"text": "Apple Inc - Company Profile and NewsBloomberghttps://www.bloomberg.com company AAPL:US",
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
},
{
"text": "",
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
}
],
"cite": {
"domain": "https://www.bloomberg.com company AAPL:US",
"span": " company AAPL:US"
},
"realPosition": 5
},
{
"title": "Apple Inc. | History, Products, Headquarters, & Facts",
"link": "https://www.britannica.com/money/Apple-Inc",
"description": "May 22, 2024 — Apple Inc. is an American multinational technology company that revolutionized the technology sector through its innovation of computer ...",
"additional_links": [
{
"text": "Apple Inc. | History, Products, Headquarters, & FactsBritannicahttps://www.britannica.com money Apple-Inc",
"href": "https://www.britannica.com/money/Apple-Inc"
},
{
"text": "",
"href": "https://www.britannica.com/money/Apple-Inc"
}
],
"cite": {
"domain": "https://www.britannica.com money Apple-Inc",
"span": " money Apple-Inc"
},
"realPosition": 6
}
],
"shopping_ads": [],
"places": [
{
"title": "Apple Inc."
},
{
"title": "Apple Inc"
},
{
"title": "Apple Inc"
}
],
"related_searches": {
"images": [],
"text": [
{
"title": "apple inc full form",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+full+form&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhPEAE"
},
{
"title": "apple company history",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+company+history&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhOEAE"
},
{
"title": "apple store",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Store&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhQEAE"
},
{
"title": "apple id",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+id&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhSEAE"
},
{
"title": "apple inc industry",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+industry&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhREAE"
},
{
"title": "apple login",
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+login&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhTEAE"
}
]
},
"image_results": [],
"carousel": [],
"total": 2450000000,
"knowledge_graph": "",
"related_questions": [
"What does the Apple Inc do?",
"Why did Apple change to Apple Inc?",
"Who owns Apple Inc.?",
"What is Apple Inc best known for?"
],
"carousel_count": 0,
"ts": 2.491065263748169,
"device_type": null
}

View File

@@ -2,7 +2,7 @@ import os
import logging
import requests
from typing import List
from typing import List, Union
from apps.ollama.main import (
generate_ollama_embeddings,
@@ -20,23 +20,8 @@ from langchain.retrievers import (
from typing import Optional
from apps.rag.search.brave import search_brave
from apps.rag.search.google_pse import search_google_pse
from apps.rag.search.main import SearchResult
from apps.rag.search.searxng import search_searxng
from apps.rag.search.serper import search_serper
from apps.rag.search.serpstack import search_serpstack
from config import (
SRC_LOG_LEVELS,
CHROMA_CLIENT,
SEARXNG_QUERY_URL,
GOOGLE_PSE_API_KEY,
GOOGLE_PSE_ENGINE_ID,
BRAVE_SEARCH_API_KEY,
SERPSTACK_API_KEY,
SERPSTACK_HTTPS,
SERPER_API_KEY,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -214,6 +199,7 @@ def get_embedding_function(
embedding_function,
openai_key,
openai_url,
batch_size,
):
if embedding_engine == "":
return lambda query: embedding_function.encode(query).tolist()
@@ -237,17 +223,22 @@ def get_embedding_function(
def generate_multiple(query, f):
if isinstance(query, list):
return [f(q) for q in query]
if embedding_engine == "openai":
embeddings = []
for i in range(0, len(query), batch_size):
embeddings.extend(f(query[i : i + batch_size]))
return embeddings
else:
return [f(q) for q in query]
else:
return f(query)
return lambda query: generate_multiple(query, func)
def rag_messages(
def get_rag_context(
docs,
messages,
template,
embedding_function,
k,
reranking_function,
@@ -255,31 +246,7 @@ def rag_messages(
hybrid_search,
):
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
last_user_message_idx = None
for i in range(len(messages) - 1, -1, -1):
if messages[i]["role"] == "user":
last_user_message_idx = i
break
user_message = messages[last_user_message_idx]
if isinstance(user_message["content"], list):
# Handle list content input
content_type = "list"
query = ""
for content_item in user_message["content"]:
if content_item["type"] == "text":
query = content_item["text"]
break
elif isinstance(user_message["content"], str):
# Handle text content input
content_type = "text"
query = user_message["content"]
else:
# Fallback in case the input does not match expected types
content_type = None
query = ""
query = get_last_user_message(messages)
extracted_collections = []
relevant_contexts = []
@@ -350,33 +317,7 @@ def rag_messages(
context_string = context_string.strip()
ra_content = rag_template(
template=template,
context=context_string,
query=query,
)
log.debug(f"ra_content: {ra_content}")
if content_type == "list":
new_content = []
for content_item in user_message["content"]:
if content_item["type"] == "text":
# Update the text item's content with ra_content
new_content.append({"type": "text", "text": ra_content})
else:
# Keep other types of content as they are
new_content.append(content_item)
new_user_message = {**user_message, "content": new_content}
else:
new_user_message = {
**user_message,
"content": ra_content,
}
messages[last_user_message_idx] = new_user_message
return messages, citations
return context_string, citations
def get_model_path(model: str, update_model: bool = False):
@@ -418,8 +359,22 @@ def get_model_path(model: str, update_model: bool = False):
def generate_openai_embeddings(
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
model: str,
text: Union[str, list[str]],
key: str,
url: str = "https://api.openai.com/v1",
):
if isinstance(text, list):
embeddings = generate_openai_batch_embeddings(model, text, key, url)
else:
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
return embeddings[0] if isinstance(text, str) else embeddings
def generate_openai_batch_embeddings(
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
) -> Optional[list[list[float]]]:
try:
r = requests.post(
f"{url}/embeddings",
@@ -427,12 +382,12 @@ def generate_openai_embeddings(
"Content-Type": "application/json",
"Authorization": f"Bearer {key}",
},
json={"input": text, "model": model},
json={"input": texts, "model": model},
)
r.raise_for_status()
data = r.json()
if "data" in data:
return data["data"][0]["embedding"]
return [elem["embedding"] for elem in data["data"]]
else:
raise "Something went wrong :/"
except Exception as e:
@@ -536,31 +491,3 @@ class RerankCompressor(BaseDocumentCompressor):
)
final_results.append(doc)
return final_results
def search_web(query: str) -> list[SearchResult]:
"""Search the web using a search engine and return the results as a list of SearchResult objects.
Will look for a search engine API key in environment variables in the following order:
- SEARXNG_QUERY_URL
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
- BRAVE_SEARCH_API_KEY
- SERPSTACK_API_KEY
- SERPER_API_KEY
Args:
query (str): The query to search for
"""
# TODO: add playwright to search the web
if SEARXNG_QUERY_URL:
return search_searxng(SEARXNG_QUERY_URL, query)
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
elif BRAVE_SEARCH_API_KEY:
return search_brave(BRAVE_SEARCH_API_KEY, query)
elif SERPSTACK_API_KEY:
return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
elif SERPER_API_KEY:
return search_serper(SERPER_API_KEY, query)
else:
raise Exception("No search engine API key found in environment variables")

139
backend/apps/socket/main.py Normal file
View File

@@ -0,0 +1,139 @@
import socketio
import asyncio
from apps.webui.models.users import Users
from utils.utils import decode_token
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
# Dictionary to maintain the user pool
SESSION_POOL = {}
USER_POOL = {}
USAGE_POOL = {}
# Timeout duration in seconds
TIMEOUT_DURATION = 3
@sio.event
async def connect(sid, environ, auth):
user = None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
await sio.emit("usage", {"models": get_models_in_use()})
@sio.on("user-join")
async def user_join(sid, data):
print("user-join", sid, data)
auth = data["auth"] if "auth" in data else None
if auth and "token" in auth:
data = decode_token(auth["token"])
if data is not None and "id" in data:
user = Users.get_user_by_id(data["id"])
if user:
SESSION_POOL[sid] = user.id
if user.id in USER_POOL:
USER_POOL[user.id].append(sid)
else:
USER_POOL[user.id] = [sid]
print(f"user {user.name}({user.id}) connected with session ID {sid}")
await sio.emit("user-count", {"count": len(set(USER_POOL))})
@sio.on("user-count")
async def user_count(sid):
await sio.emit("user-count", {"count": len(set(USER_POOL))})
def get_models_in_use():
# Aggregate all models in use
models_in_use = []
for model_id, data in USAGE_POOL.items():
models_in_use.append(model_id)
return models_in_use
@sio.on("usage")
async def usage(sid, data):
model_id = data["model"]
# Cancel previous callback if there is one
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["callback"].cancel()
# Store the new usage data and task
if model_id in USAGE_POOL:
USAGE_POOL[model_id]["sids"].append(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
else:
USAGE_POOL[model_id] = {"sids": [sid]}
# Schedule a task to remove the usage data after TIMEOUT_DURATION
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
remove_after_timeout(sid, model_id)
)
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
async def remove_after_timeout(sid, model_id):
try:
await asyncio.sleep(TIMEOUT_DURATION)
if model_id in USAGE_POOL:
print(USAGE_POOL[model_id]["sids"])
USAGE_POOL[model_id]["sids"].remove(sid)
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
if len(USAGE_POOL[model_id]["sids"]) == 0:
del USAGE_POOL[model_id]
# Broadcast the usage data to all clients
await sio.emit("usage", {"models": get_models_in_use()})
except asyncio.CancelledError:
# Task was cancelled due to new 'usage' event
pass
@sio.event
async def disconnect(sid):
if sid in SESSION_POOL:
user_id = SESSION_POOL[sid]
del SESSION_POOL[sid]
USER_POOL[user_id].remove(sid)
if len(USER_POOL[user_id]) == 0:
del USER_POOL[user_id]
await sio.emit("user-count", {"count": len(USER_POOL)})
else:
print(f"Unknown session ID {sid} disconnected")

View File

@@ -0,0 +1,61 @@
"""Peewee migrations -- 009_add_models.py.
Some examples (model - class or model name)::
> Model = migrator.orm['table_name'] # Return model in current state by name
> Model = migrator.ModelClass # Return model in current state by name
> migrator.sql(sql) # Run custom SQL
> migrator.run(func, *args, **kwargs) # Run python function with the given args
> migrator.create_model(Model) # Create a model (could be used as decorator)
> migrator.remove_model(model, cascade=True) # Remove a model
> migrator.add_fields(model, **fields) # Add fields to a model
> migrator.change_fields(model, **fields) # Change fields
> migrator.remove_fields(model, *field_names, cascade=True)
> migrator.rename_field(model, old_field_name, new_field_name)
> migrator.rename_table(model, new_table_name)
> migrator.add_index(model, *col_names, unique=False)
> migrator.add_not_null(model, *field_names)
> migrator.add_default(model, field_name, default)
> migrator.add_constraint(model, name, sql)
> migrator.drop_index(model, *col_names)
> migrator.drop_not_null(model, *field_names)
> migrator.drop_constraints(model, *constraints)
"""
from contextlib import suppress
import peewee as pw
from peewee_migrate import Migrator
with suppress(ImportError):
import playhouse.postgres_ext as pw_pext
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your migrations here."""
@migrator.create_model
class Tool(pw.Model):
id = pw.TextField(unique=True)
user_id = pw.TextField()
name = pw.TextField()
content = pw.TextField()
specs = pw.TextField()
meta = pw.TextField()
created_at = pw.BigIntegerField(null=False)
updated_at = pw.BigIntegerField(null=False)
class Meta:
table_name = "tool"
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
"""Write your rollback migrations here."""
migrator.remove_model("tool")

View File

@@ -6,6 +6,7 @@ from apps.webui.routers import (
users,
chats,
documents,
tools,
models,
prompts,
configs,
@@ -14,6 +15,8 @@ from apps.webui.routers import (
)
from config import (
WEBUI_BUILD_HASH,
SHOW_ADMIN_DETAILS,
ADMIN_EMAIL,
WEBUI_AUTH,
DEFAULT_MODELS,
DEFAULT_PROMPT_SUGGESTIONS,
@@ -24,8 +27,8 @@ from config import (
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
JWT_EXPIRES_IN,
WEBUI_BANNERS,
AppConfig,
ENABLE_COMMUNITY_SHARING,
AppConfig,
)
app = FastAPI()
@@ -36,6 +39,12 @@ app.state.config = AppConfig()
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
@@ -47,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
app.state.MODELS = {}
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
app.state.TOOLS = {}
app.add_middleware(
@@ -63,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
app.include_router(chats.router, prefix="/chats", tags=["chats"])
app.include_router(documents.router, prefix="/documents", tags=["documents"])
app.include_router(tools.router, prefix="/tools", tags=["tools"])
app.include_router(models.router, prefix="/models", tags=["models"])
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
app.include_router(memories.router, prefix="/memories", tags=["memories"])

View File

@@ -298,6 +298,15 @@ class ChatTable:
# .limit(limit).offset(skip)
]
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
return [
ChatModel(**model_to_dict(chat))
for chat in Chat.select()
.where(Chat.archived == True)
.where(Chat.user_id == user_id)
.order_by(Chat.updated_at.desc())
]
def delete_chat_by_id(self, id: str) -> bool:
try:
query = Chat.delete().where((Chat.id == id))

View File

@@ -0,0 +1,132 @@
from pydantic import BaseModel
from peewee import *
from playhouse.shortcuts import model_to_dict
from typing import List, Union, Optional
import time
import logging
from apps.webui.internal.db import DB, JSONField
import json
from config import SRC_LOG_LEVELS
log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["MODELS"])
####################
# Tools DB Schema
####################
class Tool(Model):
id = CharField(unique=True)
user_id = CharField()
name = TextField()
content = TextField()
specs = JSONField()
meta = JSONField()
updated_at = BigIntegerField()
created_at = BigIntegerField()
class Meta:
database = DB
class ToolMeta(BaseModel):
description: Optional[str] = None
class ToolModel(BaseModel):
id: str
user_id: str
name: str
content: str
specs: List[dict]
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
####################
# Forms
####################
class ToolResponse(BaseModel):
id: str
user_id: str
name: str
meta: ToolMeta
updated_at: int # timestamp in epoch
created_at: int # timestamp in epoch
class ToolForm(BaseModel):
id: str
name: str
content: str
meta: ToolMeta
class ToolsTable:
def __init__(self, db):
self.db = db
self.db.create_tables([Tool])
def insert_new_tool(
self, user_id: str, form_data: ToolForm, specs: List[dict]
) -> Optional[ToolModel]:
tool = ToolModel(
**{
**form_data.model_dump(),
"specs": specs,
"user_id": user_id,
"updated_at": int(time.time()),
"created_at": int(time.time()),
}
)
try:
result = Tool.create(**tool.model_dump())
if result:
return tool
else:
return None
except Exception as e:
print(f"Error creating tool: {e}")
return None
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
try:
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def get_tools(self) -> List[ToolModel]:
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
try:
query = Tool.update(
**updated,
updated_at=int(time.time()),
).where(Tool.id == id)
query.execute()
tool = Tool.get(Tool.id == id)
return ToolModel(**model_to_dict(tool))
except:
return None
def delete_tool_by_id(self, id: str) -> bool:
try:
query = Tool.delete().where((Tool.id == id))
query.execute() # Remove the rows, return number of rows removed.
return True
except:
return False
Tools = ToolsTable(DB)

View File

@@ -269,73 +269,88 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
############################
# GetAdminDetails
############################
@router.get("/admin/details")
async def get_admin_details(request: Request, user=Depends(get_current_user)):
if request.app.state.config.SHOW_ADMIN_DETAILS:
admin_email = request.app.state.config.ADMIN_EMAIL
admin_name = None
print(admin_email, admin_name)
if admin_email:
admin = Users.get_user_by_email(admin_email)
if admin:
admin_name = admin.name
else:
admin = Users.get_first_user()
if admin:
admin_email = admin.email
admin_name = admin.name
return {
"name": admin_name,
"email": admin_email,
}
else:
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
############################
# ToggleSignUp
############################
@router.get("/signup/enabled", response_model=bool)
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.ENABLE_SIGNUP
@router.get("/admin/config")
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
@router.get("/signup/enabled/toggle", response_model=bool)
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
return request.app.state.config.ENABLE_SIGNUP
class AdminConfig(BaseModel):
SHOW_ADMIN_DETAILS: bool
ENABLE_SIGNUP: bool
DEFAULT_USER_ROLE: str
JWT_EXPIRES_IN: str
ENABLE_COMMUNITY_SHARING: bool
############################
# Default User Role
############################
@router.get("/signup/user/role")
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.DEFAULT_USER_ROLE
class UpdateRoleForm(BaseModel):
role: str
@router.post("/signup/user/role")
async def update_default_user_role(
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
@router.post("/admin/config")
async def update_admin_config(
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
):
if form_data.role in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.role
return request.app.state.config.DEFAULT_USER_ROLE
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
############################
# JWT Expiration
############################
@router.get("/token/expires")
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
return request.app.state.config.JWT_EXPIRES_IN
class UpdateJWTExpiresDurationForm(BaseModel):
duration: str
@router.post("/token/expires/update")
async def update_token_expires_duration(
request: Request,
form_data: UpdateJWTExpiresDurationForm,
user=Depends(get_admin_user),
):
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
# Check if the input string matches the pattern
if re.match(pattern, form_data.duration):
request.app.state.config.JWT_EXPIRES_IN = form_data.duration
return request.app.state.config.JWT_EXPIRES_IN
else:
return request.app.state.config.JWT_EXPIRES_IN
if re.match(pattern, form_data.JWT_EXPIRES_IN):
request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
request.app.state.config.ENABLE_COMMUNITY_SHARING = (
form_data.ENABLE_COMMUNITY_SHARING
)
return {
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
}
############################

View File

@@ -113,6 +113,19 @@ async def get_user_chats(user=Depends(get_current_user)):
]
############################
# GetArchivedChats
############################
@router.get("/all/archived", response_model=List[ChatResponse])
async def get_user_chats(user=Depends(get_current_user)):
return [
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
for chat in Chats.get_archived_chats_by_user_id(user.id)
]
############################
# GetAllChatsInDB
############################
@@ -148,7 +161,7 @@ async def get_archived_session_user_chat_list(
############################
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
@router.post("/archive/all", response_model=bool)
async def archive_all_chats(user=Depends(get_current_user)):
return Chats.archive_all_chats_by_user_id(user.id)
@@ -288,6 +301,32 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
return result
############################
# CloneChat
############################
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
if chat:
chat_body = json.loads(chat.chat)
updated_chat = {
**chat_body,
"originalChatId": chat.id,
"branchPointMessageId": chat_body["history"]["currentId"],
"title": f"Clone of {chat.title}",
}
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
)
############################
# ArchiveChat
############################

View File

@@ -73,7 +73,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
############################
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
@router.get("/doc", response_model=Optional[DocumentResponse])
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
doc = Documents.get_doc_by_name(name)
@@ -105,7 +105,7 @@ class TagDocumentForm(BaseModel):
tags: List[dict]
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
@@ -128,7 +128,7 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
############################
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
@router.post("/doc/update", response_model=Optional[DocumentResponse])
async def update_doc_by_name(
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
):
@@ -152,7 +152,7 @@ async def update_doc_by_name(
############################
@router.delete("/name/{name}/delete", response_model=bool)
@router.delete("/doc/delete", response_model=bool)
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
result = Documents.delete_doc_by_name(name)
return result

View File

@@ -0,0 +1,183 @@
from fastapi import Depends, FastAPI, HTTPException, status, Request
from datetime import datetime, timedelta
from typing import List, Union, Optional
from fastapi import APIRouter
from pydantic import BaseModel
import json
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import get_current_user, get_admin_user
from utils.tools import get_tools_specs
from constants import ERROR_MESSAGES
from importlib import util
import os
from config import DATA_DIR
TOOLS_DIR = f"{DATA_DIR}/tools"
os.makedirs(TOOLS_DIR, exist_ok=True)
router = APIRouter()
############################
# GetToolkits
############################
@router.get("/", response_model=List[ToolResponse])
async def get_toolkits(user=Depends(get_current_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# ExportToolKits
############################
@router.get("/export", response_model=List[ToolModel])
async def get_toolkits(user=Depends(get_admin_user)):
toolkits = [toolkit for toolkit in Tools.get_tools()]
return toolkits
############################
# CreateNewToolKit
############################
@router.post("/create", response_model=Optional[ToolResponse])
async def create_new_toolkit(
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
):
if not form_data.id.isidentifier():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only alphanumeric characters and underscores are allowed in the id",
)
form_data.id = form_data.id.lower()
toolkit = Tools.get_tool_by_id(form_data.id)
if toolkit == None:
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(form_data.id)
TOOLS = request.app.state.TOOLS
TOOLS[form_data.id] = toolkit_module
specs = get_tools_specs(TOOLS[form_data.id])
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
)
except Exception as e:
print(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.ID_TAKEN,
)
############################
# GetToolkitById
############################
@router.get("/id/{id}", response_model=Optional[ToolModel])
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
toolkit = Tools.get_tool_by_id(id)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=ERROR_MESSAGES.NOT_FOUND,
)
############################
# UpdateToolkitById
############################
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
async def update_toolkit_by_id(
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
):
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
try:
with open(toolkit_path, "w") as tool_file:
tool_file.write(form_data.content)
toolkit_module = load_toolkit_module_by_id(id)
TOOLS = request.app.state.TOOLS
TOOLS[id] = toolkit_module
specs = get_tools_specs(TOOLS[id])
updated = {
**form_data.model_dump(exclude={"id"}),
"specs": specs,
}
print(updated)
toolkit = Tools.update_tool_by_id(id, updated)
if toolkit:
return toolkit
else:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)
############################
# DeleteToolkitById
############################
@router.delete("/id/{id}/delete", response_model=bool)
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
result = Tools.delete_tool_by_id(id)
if result:
TOOLS = request.app.state.TOOLS
if id in TOOLS:
del TOOLS[id]
# delete the toolkit file
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
os.remove(toolkit_path)
return result

View File

@@ -19,7 +19,12 @@ from apps.webui.models.users import (
from apps.webui.models.auths import Auths
from apps.webui.models.chats import Chats
from utils.utils import get_verified_user, get_password_hash, get_admin_user
from utils.utils import (
get_verified_user,
get_password_hash,
get_current_user,
get_admin_user,
)
from constants import ERROR_MESSAGES
from config import SRC_LOG_LEVELS

View File

@@ -7,6 +7,8 @@ from pydantic import BaseModel
from fpdf import FPDF
import markdown
import black
from apps.webui.internal.db import DB
from utils.utils import get_admin_user
@@ -26,6 +28,21 @@ async def get_gravatar(
return get_gravatar_url(email)
class CodeFormatRequest(BaseModel):
code: str
@router.post("/code/format")
async def format_code(request: CodeFormatRequest):
try:
formatted_code = black.format_str(request.code, mode=black.Mode())
return {"code": formatted_code}
except black.NothingChanged:
return {"code": request.code}
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
class MarkdownForm(BaseModel):
md: str
@@ -107,3 +124,12 @@ async def download_db(user=Depends(get_admin_user)):
media_type="application/octet-stream",
filename="webui.db",
)
@router.get("/litellm/config")
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
return FileResponse(
f"{DATA_DIR}/litellm/config.yaml",
media_type="application/octet-stream",
filename="config.yaml",
)

View File

@@ -0,0 +1,23 @@
from importlib import util
import os
from config import TOOLS_DIR
def load_toolkit_module_by_id(toolkit_id):
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
module = util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
print(f"Loaded module: {module.__name__}")
if hasattr(module, "Tools"):
return module.Tools()
else:
raise Exception("No Tools class found")
except Exception as e:
print(f"Error loading module: {toolkit_id}")
# Move the file to the error folder
os.rename(toolkit_path, f"{toolkit_path}.error")
raise e

View File

@@ -180,6 +180,17 @@ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
RESET_CONFIG_ON_START = (
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
)
if RESET_CONFIG_ON_START:
try:
os.remove(f"{DATA_DIR}/config.json")
with open(f"{DATA_DIR}/config.json", "w") as f:
f.write("{}")
except:
pass
try:
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
except:
@@ -295,7 +306,11 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
if frontend_favicon.exists():
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
try:
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
except Exception as e:
logging.error(f"An error occurred: {e}")
else:
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
@@ -353,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# Tools DIR
####################################
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
####################################
# LITELLM_CONFIG
####################################
@@ -590,6 +613,92 @@ WEBUI_BANNERS = PersistentConfig(
[BannerModel(**banner) for banner in json.loads("[]")],
)
SHOW_ADMIN_DETAILS = PersistentConfig(
"SHOW_ADMIN_DETAILS",
"auth.admin.show",
os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true",
)
ADMIN_EMAIL = PersistentConfig(
"ADMIN_EMAIL",
"auth.admin.email",
os.environ.get("ADMIN_EMAIL", None),
)
####################################
# TASKS
####################################
TASK_MODEL = PersistentConfig(
"TASK_MODEL",
"task.model.default",
os.environ.get("TASK_MODEL", ""),
)
TASK_MODEL_EXTERNAL = PersistentConfig(
"TASK_MODEL_EXTERNAL",
"task.model.external",
os.environ.get("TASK_MODEL_EXTERNAL", ""),
)
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"task.title.prompt_template",
os.environ.get(
"TITLE_GENERATION_PROMPT_TEMPLATE",
"""Here is the query:
{{prompt:middletruncate:8000}}
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
Examples of titles:
📉 Stock Market Trends
🍪 Perfect Chocolate Chip Recipe
Evolution of Music Streaming
Remote Work Productivity Tips
Artificial Intelligence in Healthcare
🎮 Video Game Development Insights""",
),
)
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"task.search.prompt_template",
os.environ.get(
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
Question:
{{prompt:end:4000}}""",
),
)
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
"task.search.prompt_length_threshold",
int(
os.environ.get(
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
100,
)
),
)
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"task.tools.prompt_template",
os.environ.get(
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
"""Tools: {{TOOLS}}
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
),
)
####################################
# WEBUI_SECRET_KEY
####################################
@@ -672,6 +781,12 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
)
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
"rag.embedding_openai_batch_size",
os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1),
)
RAG_RERANKING_MODEL = PersistentConfig(
"RAG_RERANKING_MODEL",
"rag.reranking_model",
@@ -766,28 +881,81 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
)
SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "")
GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "")
GOOGLE_PSE_ENGINE_ID = os.getenv("GOOGLE_PSE_ENGINE_ID", "")
BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
RAG_WEB_SEARCH_ENABLED = (
SEARXNG_QUERY_URL != ""
or (GOOGLE_PSE_API_KEY != "" and GOOGLE_PSE_ENGINE_ID != "")
or BRAVE_SEARCH_API_KEY != ""
or SERPSTACK_API_KEY != ""
or SERPER_API_KEY != ""
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
"ENABLE_RAG_WEB_SEARCH",
"rag.web.search.enable",
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
)
RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3"))
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
"RAG_WEB_SEARCH_ENGINE",
"rag.web.search.engine",
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
)
SEARXNG_QUERY_URL = PersistentConfig(
"SEARXNG_QUERY_URL",
"rag.web.search.searxng_query_url",
os.getenv("SEARXNG_QUERY_URL", ""),
)
GOOGLE_PSE_API_KEY = PersistentConfig(
"GOOGLE_PSE_API_KEY",
"rag.web.search.google_pse_api_key",
os.getenv("GOOGLE_PSE_API_KEY", ""),
)
GOOGLE_PSE_ENGINE_ID = PersistentConfig(
"GOOGLE_PSE_ENGINE_ID",
"rag.web.search.google_pse_engine_id",
os.getenv("GOOGLE_PSE_ENGINE_ID", ""),
)
BRAVE_SEARCH_API_KEY = PersistentConfig(
"BRAVE_SEARCH_API_KEY",
"rag.web.search.brave_search_api_key",
os.getenv("BRAVE_SEARCH_API_KEY", ""),
)
SERPSTACK_API_KEY = PersistentConfig(
"SERPSTACK_API_KEY",
"rag.web.search.serpstack_api_key",
os.getenv("SERPSTACK_API_KEY", ""),
)
SERPSTACK_HTTPS = PersistentConfig(
"SERPSTACK_HTTPS",
"rag.web.search.serpstack_https",
os.getenv("SERPSTACK_HTTPS", "True").lower() == "true",
)
SERPER_API_KEY = PersistentConfig(
"SERPER_API_KEY",
"rag.web.search.serper_api_key",
os.getenv("SERPER_API_KEY", ""),
)
SERPLY_API_KEY = PersistentConfig(
"SERPLY_API_KEY",
"rag.web.search.serply_api_key",
os.getenv("SERPLY_API_KEY", ""),
)
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
"RAG_WEB_SEARCH_RESULT_COUNT",
"rag.web.search.result_count",
int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
)
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
"RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
"rag.web.search.concurrent_requests",
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
)
####################################
# Transcribe
####################################
@@ -855,25 +1023,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
# Audio
####################################
AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_OPENAI_API_BASE_URL",
"audio.openai.api_base_url",
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_STT_OPENAI_API_BASE_URL",
"audio.stt.openai.api_base_url",
os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_OPENAI_API_KEY = PersistentConfig(
"AUDIO_OPENAI_API_KEY",
"audio.openai.api_key",
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
"AUDIO_STT_OPENAI_API_KEY",
"audio.stt.openai.api_key",
os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_OPENAI_API_MODEL = PersistentConfig(
"AUDIO_OPENAI_API_MODEL",
"audio.openai.api_model",
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
AUDIO_STT_ENGINE = PersistentConfig(
"AUDIO_STT_ENGINE",
"audio.stt.engine",
os.getenv("AUDIO_STT_ENGINE", ""),
)
AUDIO_OPENAI_API_VOICE = PersistentConfig(
"AUDIO_OPENAI_API_VOICE",
"audio.openai.api_voice",
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
AUDIO_STT_MODEL = PersistentConfig(
"AUDIO_STT_MODEL",
"audio.stt.model",
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
)
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
"AUDIO_TTS_OPENAI_API_BASE_URL",
"audio.tts.openai.api_base_url",
os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
)
AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
"AUDIO_TTS_OPENAI_API_KEY",
"audio.tts.openai.api_key",
os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
)
AUDIO_TTS_ENGINE = PersistentConfig(
"AUDIO_TTS_ENGINE",
"audio.tts.engine",
os.getenv("AUDIO_TTS_ENGINE", ""),
)
AUDIO_TTS_MODEL = PersistentConfig(
"AUDIO_TTS_MODEL",
"audio.tts.model",
os.getenv("AUDIO_TTS_MODEL", "tts-1"),
)
AUDIO_TTS_VOICE = PersistentConfig(
"AUDIO_TTS_VOICE",
"audio.tts.voice",
os.getenv("AUDIO_TTS_VOICE", "alloy"),
)

View File

@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
@@ -82,5 +83,9 @@ class ERROR_MESSAGES(str, Enum):
)
WEB_SEARCH_ERROR = (
"Oops! Something went wrong while searching the web. Please try again later."
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
)
OLLAMA_API_DISABLED = (
"The Ollama API is disabled. Please enable it to use this feature."
)

View File

@@ -9,8 +9,12 @@ import logging
import aiohttp
import requests
import mimetypes
import shutil
import os
import inspect
import asyncio
from fastapi import FastAPI, Request, Depends, status
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
from fastapi.staticfiles import StaticFiles
from fastapi.responses import JSONResponse
from fastapi import HTTPException
@@ -20,26 +24,48 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import StreamingResponse, Response
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
from apps.socket.main import app as socket_app
from apps.ollama.main import (
app as ollama_app,
OpenAIChatCompletionForm,
get_all_models as get_ollama_models,
generate_openai_chat_completion as generate_ollama_chat_completion,
)
from apps.openai.main import (
app as openai_app,
get_all_models as get_openai_models,
generate_chat_completion as generate_openai_chat_completion,
)
from apps.audio.main import app as audio_app
from apps.images.main import app as images_app
from apps.rag.main import app as rag_app
from apps.webui.main import app as webui_app
import asyncio
from pydantic import BaseModel
from typing import List, Optional
from apps.webui.models.models import Models, ModelModel
from apps.webui.models.tools import Tools
from apps.webui.utils import load_toolkit_module_by_id
from utils.utils import (
get_admin_user,
get_verified_user,
get_current_user,
get_http_authorization_cred,
)
from apps.rag.utils import rag_messages
from utils.task import (
title_generation_template,
search_query_generation_template,
tools_function_calling_generation_template,
)
from utils.misc import get_last_user_message, add_or_update_system_message
from apps.rag.utils import get_rag_context, rag_template
from config import (
CONFIG_DATA,
@@ -60,9 +86,14 @@ from config import (
SRC_LOG_LEVELS,
WEBHOOK_URL,
ENABLE_ADMIN_EXPORT,
RAG_WEB_SEARCH_ENABLED,
AppConfig,
WEBUI_BUILD_HASH,
TASK_MODEL,
TASK_MODEL_EXTERNAL,
TITLE_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
AppConfig,
)
from constants import ERROR_MESSAGES
@@ -116,27 +147,133 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
app.state.config.WEBHOOK_URL = WEBHOOK_URL
app.state.config.TASK_MODEL = TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
app.state.MODELS = {}
origins = ["*"]
# Custom middleware to add security headers
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
# async def dispatch(self, request: Request, call_next):
# response: Response = await call_next(request)
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
# return response
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
tool = Tools.get_tool_by_id(tool_id)
tools_specs = json.dumps(tool.specs, indent=2)
content = tools_function_calling_generation_template(template, tools_specs)
user_message = get_last_user_message(messages)
prompt = (
"History:\n"
+ "\n".join(
[
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
for message in messages[::-1][:4]
]
)
+ f"\nQuery: {user_message}"
)
print(prompt)
payload = {
"model": task_model_id,
"messages": [
{"role": "system", "content": content},
{"role": "user", "content": f"Query: {prompt}"},
],
"stream": False,
}
try:
payload = filter_pipeline(payload, user)
except Exception as e:
raise e
model = app.state.MODELS[task_model_id]
response = None
try:
if model["owned_by"] == "ollama":
response = await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
response = await generate_openai_chat_completion(payload, user=user)
content = None
if hasattr(response, "body_iterator"):
async for chunk in response.body_iterator:
data = json.loads(chunk.decode("utf-8"))
content = data["choices"][0]["message"]["content"]
# Cleanup any remaining background tasks if necessary
if response.background is not None:
await response.background()
else:
content = response["choices"][0]["message"]["content"]
# Parse the function response
if content is not None:
print(f"content: {content}")
result = json.loads(content)
print(result)
# Call the function
if "name" in result:
if tool_id in webui_app.state.TOOLS:
toolkit_module = webui_app.state.TOOLS[tool_id]
else:
toolkit_module = load_toolkit_module_by_id(tool_id)
webui_app.state.TOOLS[tool_id] = toolkit_module
function = getattr(toolkit_module, result["name"])
function_result = None
try:
# Get the signature of the function
sig = inspect.signature(function)
# Check if '__user__' is a parameter of the function
if "__user__" in sig.parameters:
# Call the function with the '__user__' parameter included
function_result = function(
**{
**result["parameters"],
"__user__": {
"id": user.id,
"email": user.email,
"name": user.name,
"role": user.role,
},
}
)
else:
# Call the function without modifying the parameters
function_result = function(**result["parameters"])
except Exception as e:
print(e)
# Add the function result to the system prompt
if function_result:
return function_result
except Exception as e:
print(f"Error: {e}")
return None
# app.add_middleware(SecurityHeadersMiddleware)
class RAGMiddleware(BaseHTTPMiddleware):
class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
return_citations = False
@@ -153,35 +290,98 @@ class RAGMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
# Remove the citations from the body
return_citations = data.get("citations", False)
if "citations" in data:
del data["citations"]
# Example: Add a new key-value pair or modify existing ones
# data["modified"] = True # Example modification
# Set the task model
task_model_id = data["model"]
if task_model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
if (
app.state.config.TASK_MODEL
and app.state.config.TASK_MODEL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL
else:
if (
app.state.config.TASK_MODEL_EXTERNAL
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
):
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
prompt = get_last_user_message(data["messages"])
context = ""
# If tool_ids field is present, call the functions
if "tool_ids" in data:
print(data["tool_ids"])
for tool_id in data["tool_ids"]:
print(tool_id)
try:
response = await get_function_call_response(
messages=data["messages"],
tool_id=tool_id,
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
task_model_id=task_model_id,
user=user,
)
if response:
context += ("\n" if context != "" else "") + response
except Exception as e:
print(f"Error: {e}")
del data["tool_ids"]
print(f"tool_context: {context}")
# If docs field is present, generate RAG completions
if "docs" in data:
data = {**data}
data["messages"], citations = rag_messages(
rag_context, citations = get_rag_context(
docs=data["docs"],
messages=data["messages"],
template=rag_app.state.config.RAG_TEMPLATE,
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
k=rag_app.state.config.TOP_K,
reranking_function=rag_app.state.sentence_transformer_rf,
r=rag_app.state.config.RELEVANCE_THRESHOLD,
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
)
if rag_context:
context += ("\n" if context != "" else "") + rag_context
del data["docs"]
log.debug(
f"data['messages']: {data['messages']}, citations: {citations}"
log.debug(f"rag_context: {rag_context}, citations: {citations}")
if context != "":
system_prompt = rag_template(
rag_app.state.config.RAG_TEMPLATE, context, prompt
)
print(system_prompt)
data["messages"] = add_or_update_system_message(
f"\n{system_prompt}", data["messages"]
)
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
request._body = modified_body_bytes
# Set custom header to ensure content-length matches new body length
request.headers.__dict__["_list"] = [
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
@@ -224,7 +424,77 @@ class RAGMiddleware(BaseHTTPMiddleware):
yield data
app.add_middleware(RAGMiddleware)
app.add_middleware(ChatCompletionMiddleware)
def filter_pipeline(payload, user):
user = {"id": user.id, "name": user.name, "role": user.role}
model_id = payload["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters.append(model)
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": payload,
},
)
r.raise_for_status()
payload = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
except:
pass
if "detail" in res:
raise Exception(r.status_code, res["detail"])
else:
pass
if "pipeline" not in app.state.MODELS[model_id]:
if "chat_id" in payload:
del payload["chat_id"]
if "title" in payload:
del payload["title"]
return payload
class PipelineMiddleware(BaseHTTPMiddleware):
@@ -242,76 +512,17 @@ class PipelineMiddleware(BaseHTTPMiddleware):
# Parse string to JSON
data = json.loads(body_str) if body_str else {}
model_id = data["model"]
filters = [
model
for model in app.state.MODELS.values()
if "pipeline" in model
and "type" in model["pipeline"]
and model["pipeline"]["type"] == "filter"
and (
model["pipeline"]["pipelines"] == ["*"]
or any(
model_id == target_model_id
for target_model_id in model["pipeline"]["pipelines"]
)
user = get_current_user(
get_http_authorization_cred(request.headers.get("Authorization"))
)
try:
data = filter_pipeline(data, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
user = None
if len(sorted_filters) > 0:
try:
user = get_current_user(
get_http_authorization_cred(
request.headers.get("Authorization")
)
)
user = {"id": user.id, "name": user.name, "role": user.role}
except:
pass
for filter in sorted_filters:
r = None
try:
urlIdx = filter["urlIdx"]
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
if key != "":
headers = {"Authorization": f"Bearer {key}"}
r = requests.post(
f"{url}/{filter['id']}/filter/inlet",
headers=headers,
json={
"user": user,
"body": data,
},
)
r.raise_for_status()
data = r.json()
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
if r is not None:
try:
res = r.json()
if "detail" in res:
return JSONResponse(
status_code=r.status_code,
content=res,
)
except:
pass
else:
pass
if "chat_id" in data:
del data["chat_id"]
modified_body_bytes = json.dumps(data).encode("utf-8")
# Replace the request body with the modified one
@@ -368,6 +579,9 @@ async def update_embedding_function(request: Request, call_next):
return response
app.mount("/ws", socket_app)
app.mount("/ollama", ollama_app)
app.mount("/openai", openai_app)
@@ -469,6 +683,237 @@ async def get_models(user=Depends(get_verified_user)):
return {"data": models}
@app.get("/api/task/config")
async def get_task_config(user=Depends(get_verified_user)):
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
class TaskConfigForm(BaseModel):
TASK_MODEL: Optional[str]
TASK_MODEL_EXTERNAL: Optional[str]
TITLE_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
@app.post("/api/task/config/update")
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
app.state.config.TASK_MODEL = form_data.TASK_MODEL
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
)
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
)
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
)
return {
"TASK_MODEL": app.state.config.TASK_MODEL,
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
}
@app.post("/api/task/title/completions")
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
print("generate_title")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
content = title_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 50,
"chat_id": form_data.get("chat_id", None),
"title": True,
}
print(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/query/completions")
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
print("generate_search_query")
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
)
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
model = app.state.MODELS[model_id]
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
content = search_query_generation_template(
template, form_data["prompt"], user.model_dump()
)
payload = {
"model": model_id,
"messages": [{"role": "user", "content": content}],
"stream": False,
"max_tokens": 30,
}
print(payload)
try:
payload = filter_pipeline(payload, user)
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**payload), user=user
)
else:
return await generate_openai_chat_completion(payload, user=user)
@app.post("/api/task/tools/completions")
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
print("get_tools_function_calling")
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
# Check if the user has a custom task model
# If the user has a custom task model, use that model
if app.state.MODELS[model_id]["owned_by"] == "ollama":
if app.state.config.TASK_MODEL:
task_model_id = app.state.config.TASK_MODEL
if task_model_id in app.state.MODELS:
model_id = task_model_id
else:
if app.state.config.TASK_MODEL_EXTERNAL:
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
if task_model_id in app.state.MODELS:
model_id = task_model_id
print(model_id)
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
try:
context = await get_function_call_response(
form_data["messages"], form_data["tool_id"], template, model_id, user
)
return context
except Exception as e:
return JSONResponse(
status_code=e.args[0],
content={"detail": e.args[1]},
)
@app.post("/api/chat/completions")
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
model_id = form_data["model"]
if model_id not in app.state.MODELS:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Model not found",
)
model = app.state.MODELS[model_id]
print(model)
if model["owned_by"] == "ollama":
return await generate_ollama_chat_completion(
OpenAIChatCompletionForm(**form_data), user=user
)
else:
return await generate_openai_chat_completion(form_data, user=user)
@app.post("/api/chat/completed")
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
data = form_data
@@ -490,6 +935,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
]
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
print(model_id)
if model_id in app.state.MODELS:
model = app.state.MODELS[model_id]
if "pipeline" in model:
sorted_filters = [model] + sorted_filters
for filter in sorted_filters:
r = None
try:
@@ -537,7 +989,11 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
responses = await get_openai_models(raw=True)
print(responses)
urlIdxs = [idx for idx, response in enumerate(responses) if "pipelines" in response]
urlIdxs = [
idx
for idx, response in enumerate(responses)
if response != None and "pipelines" in response
]
return {
"data": [
@@ -550,6 +1006,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
}
@app.post("/api/pipelines/upload")
async def upload_pipeline(
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
):
print("upload_pipeline", urlIdx, file.filename)
# Check if the uploaded file is a python file
if not file.filename.endswith(".py"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Only Python (.py) files are allowed.",
)
upload_folder = f"{CACHE_DIR}/pipelines"
os.makedirs(upload_folder, exist_ok=True)
file_path = os.path.join(upload_folder, file.filename)
try:
# Save the uploaded file
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
headers = {"Authorization": f"Bearer {key}"}
with open(file_path, "rb") as f:
files = {"file": f}
r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
r.raise_for_status()
data = r.json()
return {**data}
except Exception as e:
# Handle connection error here
print(f"Connection error: {e}")
detail = "Pipeline not found"
if r is not None:
try:
res = r.json()
if "detail" in res:
detail = res["detail"]
except:
pass
raise HTTPException(
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
detail=detail,
)
finally:
# Ensure the file is deleted after the upload is completed or on failure
if os.path.exists(file_path):
os.remove(file_path)
class AddPipelineForm(BaseModel):
url: str
urlIdx: int
@@ -811,11 +1324,20 @@ async def get_app_config():
"auth": WEBUI_AUTH,
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
"enable_web_search": RAG_WEB_SEARCH_ENABLED,
"enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
"enable_image_generation": images_app.state.config.ENABLED,
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
"enable_admin_export": ENABLE_ADMIN_EXPORT,
},
"audio": {
"tts": {
"engine": audio_app.state.config.TTS_ENGINE,
"voice": audio_app.state.config.TTS_VOICE,
},
"stt": {
"engine": audio_app.state.config.STT_ENGINE,
},
},
}
@@ -860,23 +1382,7 @@ class UrlForm(BaseModel):
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
app.state.config.WEBHOOK_URL = form_data.url
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
return {
"url": app.state.config.WEBHOOK_URL,
}
@app.get("/api/community_sharing", response_model=bool)
async def get_community_sharing_status(request: Request, user=Depends(get_admin_user)):
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
@app.get("/api/community_sharing/toggle", response_model=bool)
async def toggle_community_sharing(request: Request, user=Depends(get_admin_user)):
webui_app.state.config.ENABLE_COMMUNITY_SHARING = (
not webui_app.state.config.ENABLE_COMMUNITY_SHARING
)
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
return {"url": app.state.config.WEBHOOK_URL}
@app.get("/api/version")
@@ -894,7 +1400,7 @@ async def get_app_changelog():
@app.get("/api/version/updates")
async def get_app_latest_release_version():
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
) as response:

View File

@@ -56,4 +56,8 @@ PyJWT[crypto]==2.8.0
black==24.4.2
langfuse==2.33.0
youtube-transcript-api==0.6.2
pytube==15.0.0
pytube==15.0.0
extract_msg
pydub
duckduckgo-search~=6.1.5

View File

@@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
WEBUI_SECRET_KEY=$(cat "$KEY_FILE")
fi
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then
if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then
echo "USE_OLLAMA is set to true, starting ollama serve."
ollama serve &
fi
if [ "$USE_CUDA_DOCKER" = "true" ]; then
if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
fi

View File

@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b
SET "KEY_FILE=.webui_secret_key"
IF "%PORT%"=="" SET PORT=8080
IF "%HOST%"=="" SET HOST=0.0.0.0
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
:: Execute uvicorn
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*'
uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'

View File

@@ -3,7 +3,48 @@ import hashlib
import json
import re
from datetime import timedelta
from typing import Optional
from typing import Optional, List
def get_last_user_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "user":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def get_last_assistant_message(messages: List[dict]) -> str:
for message in reversed(messages):
if message["role"] == "assistant":
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "text":
return item["text"]
return message["content"]
return None
def add_or_update_system_message(content: str, messages: List[dict]):
"""
Adds a new system message at the beginning of the messages list
or updates the existing system message at the beginning.
:param msg: The message to be added or appended.
:param messages: The list of message dictionaries.
:return: The updated list of message dictionaries.
"""
if messages and messages[0].get("role") == "system":
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
else:
# Insert at the beginning
messages.insert(0, {"role": "system", "content": content})
return messages
def get_gravatar_url(email):
@@ -123,11 +164,25 @@ def parse_ollama_modelfile(model_text):
"repeat_penalty": float,
"temperature": float,
"seed": int,
"stop": str,
"tfs_z": float,
"num_predict": int,
"top_k": int,
"top_p": float,
"num_keep": int,
"typical_p": float,
"presence_penalty": float,
"frequency_penalty": float,
"penalize_newline": bool,
"numa": bool,
"num_batch": int,
"num_gpu": int,
"main_gpu": int,
"low_vram": bool,
"f16_kv": bool,
"vocab_only": bool,
"use_mmap": bool,
"use_mlock": bool,
"num_thread": int,
}
data = {"base_model_id": None, "params": {}}
@@ -156,10 +211,18 @@ def parse_ollama_modelfile(model_text):
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
if param_match:
value = param_match.group(1)
if param_type == int:
value = int(value)
elif param_type == float:
value = float(value)
try:
if param_type == int:
value = int(value)
elif param_type == float:
value = float(value)
elif param_type == bool:
value = value.lower() == "true"
except Exception as e:
print(e)
continue
data["params"][param] = value
# Parse adapter
@@ -171,8 +234,14 @@ def parse_ollama_modelfile(model_text):
system_desc_match = re.search(
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
)
system_desc_match_single = re.search(
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
)
if system_desc_match:
data["params"]["system"] = system_desc_match.group(1).strip()
elif system_desc_match_single:
data["params"]["system"] = system_desc_match_single.group(1).strip()
# Parse messages
messages = []

View File

@@ -1,10 +0,0 @@
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
def get_model_id_from_custom_model_id(id: str):
model = Models.get_model_by_id(id)
if model:
return model.id
else:
return id

117
backend/utils/task.py Normal file
View File

@@ -0,0 +1,117 @@
import re
import math
from datetime import datetime
from typing import Optional
def prompt_template(
template: str, user_name: str = None, current_location: str = None
) -> str:
# Get the current date
current_date = datetime.now()
# Format the date to YYYY-MM-DD
formatted_date = current_date.strftime("%Y-%m-%d")
# Replace {{CURRENT_DATE}} in the template with the formatted date
template = template.replace("{{CURRENT_DATE}}", formatted_date)
if user_name:
# Replace {{USER_NAME}} in the template with the user's name
template = template.replace("{{USER_NAME}}", user_name)
if current_location:
# Replace {{CURRENT_LOCATION}} in the template with the current location
template = template.replace("{{CURRENT_LOCATION}}", current_location)
return template
def title_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
if user
else {}
),
)
return template
def search_query_generation_template(
template: str, prompt: str, user: Optional[dict] = None
) -> str:
def replacement_function(match):
full_match = match.group(0)
start_length = match.group(1)
end_length = match.group(2)
middle_length = match.group(3)
if full_match == "{{prompt}}":
return prompt
elif start_length is not None:
return prompt[: int(start_length)]
elif end_length is not None:
return prompt[-int(end_length) :]
elif middle_length is not None:
middle_length = int(middle_length)
if len(prompt) <= middle_length:
return prompt
start = prompt[: math.ceil(middle_length / 2)]
end = prompt[-math.floor(middle_length / 2) :]
return f"{start}...{end}"
return ""
template = re.sub(
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
replacement_function,
template,
)
template = prompt_template(
template,
**(
{"user_name": user.get("name"), "current_location": user.get("location")}
if user
else {}
),
)
return template
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
template = template.replace("{{TOOLS}}", tools_specs)
return template

73
backend/utils/tools.py Normal file
View File

@@ -0,0 +1,73 @@
import inspect
from typing import get_type_hints, List, Dict, Any
def doc_to_dict(docstring):
lines = docstring.split("\n")
description = lines[1].strip()
param_dict = {}
for line in lines:
if ":param" in line:
line = line.replace(":param", "").strip()
param, desc = line.split(":", 1)
param_dict[param.strip()] = desc.strip()
ret_dict = {"description": description, "params": param_dict}
return ret_dict
def get_tools_specs(tools) -> List[dict]:
function_list = [
{"name": func, "function": getattr(tools, func)}
for func in dir(tools)
if callable(getattr(tools, func)) and not func.startswith("__")
]
specs = []
for function_item in function_list:
function_name = function_item["name"]
function = function_item["function"]
function_doc = doc_to_dict(function.__doc__ or function_name)
specs.append(
{
"name": function_name,
# TODO: multi-line desc?
"description": function_doc.get("description", function_name),
"parameters": {
"type": "object",
"properties": {
param_name: {
"type": param_annotation.__name__.lower(),
**(
{
"enum": (
str(param_annotation.__args__)
if hasattr(param_annotation, "__args__")
else None
)
}
if hasattr(param_annotation, "__args__")
else {}
),
"description": function_doc.get("params", {}).get(
param_name, param_name
),
}
for param_name, param_annotation in get_type_hints(
function
).items()
if param_name != "return" and param_name != "__user__"
},
"required": [
name
for name, param in inspect.signature(
function
).parameters.items()
if param.default is param.empty
],
},
}
)
return specs