mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/main' into feature-external-db-reconnect
This commit is contained in:
@@ -17,13 +17,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from faster_whisper import WhisperModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
import uuid
|
||||
import requests
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.utils import (
|
||||
decode_token,
|
||||
@@ -41,10 +40,15 @@ from config import (
|
||||
WHISPER_MODEL_DIR,
|
||||
WHISPER_MODEL_AUTO_UPDATE,
|
||||
DEVICE_TYPE,
|
||||
AUDIO_OPENAI_API_BASE_URL,
|
||||
AUDIO_OPENAI_API_KEY,
|
||||
AUDIO_OPENAI_API_MODEL,
|
||||
AUDIO_OPENAI_API_VOICE,
|
||||
AUDIO_STT_OPENAI_API_BASE_URL,
|
||||
AUDIO_STT_OPENAI_API_KEY,
|
||||
AUDIO_TTS_OPENAI_API_BASE_URL,
|
||||
AUDIO_TTS_OPENAI_API_KEY,
|
||||
AUDIO_STT_ENGINE,
|
||||
AUDIO_STT_MODEL,
|
||||
AUDIO_TTS_ENGINE,
|
||||
AUDIO_TTS_MODEL,
|
||||
AUDIO_TTS_VOICE,
|
||||
AppConfig,
|
||||
)
|
||||
|
||||
@@ -61,10 +65,17 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
app.state.config = AppConfig()
|
||||
app.state.config.OPENAI_API_BASE_URL = AUDIO_OPENAI_API_BASE_URL
|
||||
app.state.config.OPENAI_API_KEY = AUDIO_OPENAI_API_KEY
|
||||
app.state.config.OPENAI_API_MODEL = AUDIO_OPENAI_API_MODEL
|
||||
app.state.config.OPENAI_API_VOICE = AUDIO_OPENAI_API_VOICE
|
||||
|
||||
app.state.config.STT_OPENAI_API_BASE_URL = AUDIO_STT_OPENAI_API_BASE_URL
|
||||
app.state.config.STT_OPENAI_API_KEY = AUDIO_STT_OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = AUDIO_STT_ENGINE
|
||||
app.state.config.STT_MODEL = AUDIO_STT_MODEL
|
||||
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = AUDIO_TTS_OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = AUDIO_TTS_OPENAI_API_KEY
|
||||
app.state.config.TTS_ENGINE = AUDIO_TTS_ENGINE
|
||||
app.state.config.TTS_MODEL = AUDIO_TTS_MODEL
|
||||
app.state.config.TTS_VOICE = AUDIO_TTS_VOICE
|
||||
|
||||
# setting device type for whisper model
|
||||
whisper_device_type = DEVICE_TYPE if DEVICE_TYPE and DEVICE_TYPE == "cuda" else "cpu"
|
||||
@@ -74,41 +85,101 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
|
||||
SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
class OpenAIConfigUpdateForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
model: str
|
||||
speaker: str
|
||||
class TTSConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
VOICE: str
|
||||
|
||||
|
||||
class STTConfigForm(BaseModel):
|
||||
OPENAI_API_BASE_URL: str
|
||||
OPENAI_API_KEY: str
|
||||
ENGINE: str
|
||||
MODEL: str
|
||||
|
||||
|
||||
class AudioConfigUpdateForm(BaseModel):
|
||||
tts: TTSConfigForm
|
||||
stt: STTConfigForm
|
||||
|
||||
|
||||
from pydub import AudioSegment
|
||||
from pydub.utils import mediainfo
|
||||
|
||||
|
||||
def is_mp4_audio(file_path):
|
||||
"""Check if the given file is an MP4 audio file."""
|
||||
if not os.path.isfile(file_path):
|
||||
print(f"File not found: {file_path}")
|
||||
return False
|
||||
|
||||
info = mediainfo(file_path)
|
||||
if (
|
||||
info.get("codec_name") == "aac"
|
||||
and info.get("codec_type") == "audio"
|
||||
and info.get("codec_tag_string") == "mp4a"
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def convert_mp4_to_wav(file_path, output_path):
|
||||
"""Convert MP4 audio file to WAV format."""
|
||||
audio = AudioSegment.from_file(file_path, format="mp4")
|
||||
audio.export(output_path, format="wav")
|
||||
print(f"Converted {file_path} to {output_path}")
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
async def get_openai_config(user=Depends(get_admin_user)):
|
||||
async def get_audio_config(user=Depends(get_admin_user)):
|
||||
return {
|
||||
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
||||
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
|
||||
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
|
||||
"tts": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.TTS_ENGINE,
|
||||
"MODEL": app.state.config.TTS_MODEL,
|
||||
"VOICE": app.state.config.TTS_VOICE,
|
||||
},
|
||||
"stt": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
async def update_openai_config(
|
||||
form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
|
||||
async def update_audio_config(
|
||||
form_data: AudioConfigUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
if form_data.key == "":
|
||||
raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
|
||||
app.state.config.TTS_OPENAI_API_BASE_URL = form_data.tts.OPENAI_API_BASE_URL
|
||||
app.state.config.TTS_OPENAI_API_KEY = form_data.tts.OPENAI_API_KEY
|
||||
app.state.config.TTS_ENGINE = form_data.tts.ENGINE
|
||||
app.state.config.TTS_MODEL = form_data.tts.MODEL
|
||||
app.state.config.TTS_VOICE = form_data.tts.VOICE
|
||||
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.key
|
||||
app.state.config.OPENAI_API_MODEL = form_data.model
|
||||
app.state.config.OPENAI_API_VOICE = form_data.speaker
|
||||
app.state.config.STT_OPENAI_API_BASE_URL = form_data.stt.OPENAI_API_BASE_URL
|
||||
app.state.config.STT_OPENAI_API_KEY = form_data.stt.OPENAI_API_KEY
|
||||
app.state.config.STT_ENGINE = form_data.stt.ENGINE
|
||||
app.state.config.STT_MODEL = form_data.stt.MODEL
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
"OPENAI_API_BASE_URL": app.state.config.OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.OPENAI_API_KEY,
|
||||
"OPENAI_API_MODEL": app.state.config.OPENAI_API_MODEL,
|
||||
"OPENAI_API_VOICE": app.state.config.OPENAI_API_VOICE,
|
||||
"tts": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.TTS_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.TTS_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.TTS_ENGINE,
|
||||
"MODEL": app.state.config.TTS_MODEL,
|
||||
"VOICE": app.state.config.TTS_VOICE,
|
||||
},
|
||||
"stt": {
|
||||
"OPENAI_API_BASE_URL": app.state.config.STT_OPENAI_API_BASE_URL,
|
||||
"OPENAI_API_KEY": app.state.config.STT_OPENAI_API_KEY,
|
||||
"ENGINE": app.state.config.STT_ENGINE,
|
||||
"MODEL": app.state.config.STT_MODEL,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -125,13 +196,21 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
return FileResponse(file_path)
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {app.state.config.OPENAI_API_KEY}"
|
||||
headers["Authorization"] = f"Bearer {app.state.config.TTS_OPENAI_API_KEY}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
try:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
body["model"] = app.state.config.TTS_MODEL
|
||||
body = json.dumps(body).encode("utf-8")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
url=f"{app.state.config.OPENAI_API_BASE_URL}/audio/speech",
|
||||
url=f"{app.state.config.TTS_OPENAI_API_BASE_URL}/audio/speech",
|
||||
data=body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
@@ -181,41 +260,110 @@ def transcribe(
|
||||
)
|
||||
|
||||
try:
|
||||
filename = file.filename
|
||||
file_path = f"{UPLOAD_DIR}/{filename}"
|
||||
ext = file.filename.split(".")[-1]
|
||||
|
||||
id = uuid.uuid4()
|
||||
filename = f"{id}.{ext}"
|
||||
|
||||
file_dir = f"{CACHE_DIR}/audio/transcriptions"
|
||||
os.makedirs(file_dir, exist_ok=True)
|
||||
file_path = f"{file_dir}/{filename}"
|
||||
|
||||
print(filename)
|
||||
|
||||
contents = file.file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
f.close()
|
||||
|
||||
whisper_kwargs = {
|
||||
"model_size_or_path": WHISPER_MODEL,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
|
||||
}
|
||||
if app.state.config.STT_ENGINE == "":
|
||||
whisper_kwargs = {
|
||||
"model_size_or_path": WHISPER_MODEL,
|
||||
"device": whisper_device_type,
|
||||
"compute_type": "int8",
|
||||
"download_root": WHISPER_MODEL_DIR,
|
||||
"local_files_only": not WHISPER_MODEL_AUTO_UPDATE,
|
||||
}
|
||||
|
||||
log.debug(f"whisper_kwargs: {whisper_kwargs}")
|
||||
log.debug(f"whisper_kwargs: {whisper_kwargs}")
|
||||
|
||||
try:
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
except:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
try:
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
except:
|
||||
log.warning(
|
||||
"WhisperModel initialization failed, attempting download with local_files_only=False"
|
||||
)
|
||||
whisper_kwargs["local_files_only"] = False
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
whisper_kwargs["local_files_only"] = False
|
||||
model = WhisperModel(**whisper_kwargs)
|
||||
|
||||
segments, info = model.transcribe(file_path, beam_size=5)
|
||||
log.info(
|
||||
"Detected language '%s' with probability %f"
|
||||
% (info.language, info.language_probability)
|
||||
)
|
||||
transcript = "".join([segment.text for segment in list(segments)])
|
||||
|
||||
transcript = "".join([segment.text for segment in list(segments)])
|
||||
data = {"text": transcript.strip()}
|
||||
|
||||
return {"text": transcript.strip()}
|
||||
# save the transcript to a json file
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
print(data)
|
||||
|
||||
return data
|
||||
|
||||
elif app.state.config.STT_ENGINE == "openai":
|
||||
if is_mp4_audio(file_path):
|
||||
print("is_mp4_audio")
|
||||
os.rename(file_path, file_path.replace(".wav", ".mp4"))
|
||||
# Convert MP4 audio file to WAV format
|
||||
convert_mp4_to_wav(file_path.replace(".wav", ".mp4"), file_path)
|
||||
|
||||
headers = {"Authorization": f"Bearer {app.state.config.STT_OPENAI_API_KEY}"}
|
||||
|
||||
files = {"file": (filename, open(file_path, "rb"))}
|
||||
data = {"model": "whisper-1"}
|
||||
|
||||
print(files, data)
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.post(
|
||||
url=f"{app.state.config.STT_OPENAI_API_BASE_URL}/audio/transcriptions",
|
||||
headers=headers,
|
||||
files=files,
|
||||
data=data,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
data = r.json()
|
||||
|
||||
# save the transcript to a json file
|
||||
transcript_file = f"{file_dir}/{id}.json"
|
||||
with open(transcript_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
print(data)
|
||||
return data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r != None else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
@@ -29,6 +29,8 @@ import time
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, List, Union
|
||||
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.users import Users
|
||||
from constants import ERROR_MESSAGES
|
||||
@@ -39,8 +41,6 @@ from utils.utils import (
|
||||
get_admin_user,
|
||||
)
|
||||
|
||||
from utils.models import get_model_id_from_custom_model_id
|
||||
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
@@ -75,9 +75,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
app.state.MODELS = {}
|
||||
|
||||
|
||||
REQUEST_POOL = []
|
||||
|
||||
|
||||
# TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
|
||||
# Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
|
||||
# least connections, or least response time for better resource utilization and performance optimization.
|
||||
@@ -132,20 +129,10 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
|
||||
return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
|
||||
|
||||
|
||||
@app.get("/cancel/{request_id}")
|
||||
async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
|
||||
if user:
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
return True
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
|
||||
|
||||
|
||||
async def fetch_url(url):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -154,6 +141,45 @@ async def fetch_url(url):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
async def post_streaming_url(url: str, payload: str):
|
||||
r = None
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.post(url, data=payload)
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(cleanup_response, response=r, session=session),
|
||||
)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
merged_models = {}
|
||||
|
||||
@@ -246,54 +272,57 @@ async def get_ollama_tags(
|
||||
@app.get("/api/version")
|
||||
@app.get("/api/version/{url_idx}")
|
||||
async def get_ollama_versions(url_idx: Optional[int] = None):
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
if url_idx == None:
|
||||
|
||||
if url_idx == None:
|
||||
# returns lowest version
|
||||
tasks = [
|
||||
fetch_url(f"{url}/api/version")
|
||||
for url in app.state.config.OLLAMA_BASE_URLS
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
responses = list(filter(lambda x: x is not None, responses))
|
||||
|
||||
# returns lowest version
|
||||
tasks = [
|
||||
fetch_url(f"{url}/api/version") for url in app.state.config.OLLAMA_BASE_URLS
|
||||
]
|
||||
responses = await asyncio.gather(*tasks)
|
||||
responses = list(filter(lambda x: x is not None, responses))
|
||||
if len(responses) > 0:
|
||||
lowest_version = min(
|
||||
responses,
|
||||
key=lambda x: tuple(
|
||||
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
|
||||
),
|
||||
)
|
||||
|
||||
if len(responses) > 0:
|
||||
lowest_version = min(
|
||||
responses,
|
||||
key=lambda x: tuple(
|
||||
map(int, re.sub(r"^v|-.*", "", x["version"]).split("."))
|
||||
),
|
||||
)
|
||||
|
||||
return {"version": lowest_version["version"]}
|
||||
return {"version": lowest_version["version"]}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=ERROR_MESSAGES.OLLAMA_NOT_FOUND,
|
||||
)
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.request(method="GET", url=f"{url}/api/version")
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
else:
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
r = None
|
||||
try:
|
||||
r = requests.request(method="GET", url=f"{url}/api/version")
|
||||
r.raise_for_status()
|
||||
|
||||
return r.json()
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return {"version": False}
|
||||
|
||||
|
||||
class ModelNameForm(BaseModel):
|
||||
@@ -313,65 +342,7 @@ async def pull_model(
|
||||
# Admin should be able to pull models from any source
|
||||
payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
|
||||
|
||||
def get_request():
|
||||
nonlocal url
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/pull",
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
|
||||
|
||||
|
||||
class PushModelForm(BaseModel):
|
||||
@@ -399,50 +370,9 @@ async def push_model(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.debug(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal url
|
||||
nonlocal r
|
||||
try:
|
||||
|
||||
def stream_content():
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/push",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
|
||||
|
||||
class CreateModelForm(BaseModel):
|
||||
@@ -461,53 +391,9 @@ async def create_model(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal url
|
||||
nonlocal r
|
||||
try:
|
||||
|
||||
def stream_content():
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
yield chunk
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/create",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
log.debug(f"r: {r}")
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
|
||||
|
||||
class CopyModelForm(BaseModel):
|
||||
@@ -797,66 +683,9 @@ async def generate_completion(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal form_data
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if form_data.stream:
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/generate",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(
|
||||
f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
@@ -897,7 +726,6 @@ async def generate_chat_completion(
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
@@ -906,44 +734,77 @@ async def generate_chat_completion(
|
||||
if model_info.params:
|
||||
payload["options"] = {}
|
||||
|
||||
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
||||
payload["options"]["mirostat_eta"] = model_info.params.get(
|
||||
"mirostat_eta", None
|
||||
)
|
||||
payload["options"]["mirostat_tau"] = model_info.params.get(
|
||||
"mirostat_tau", None
|
||||
)
|
||||
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||
if model_info.params.get("mirostat", None):
|
||||
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
||||
|
||||
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||
"repeat_last_n", None
|
||||
)
|
||||
payload["options"]["repeat_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
if model_info.params.get("mirostat_eta", None):
|
||||
payload["options"]["mirostat_eta"] = model_info.params.get(
|
||||
"mirostat_eta", None
|
||||
)
|
||||
|
||||
payload["options"]["temperature"] = model_info.params.get(
|
||||
"temperature", None
|
||||
)
|
||||
payload["options"]["seed"] = model_info.params.get("seed", None)
|
||||
if model_info.params.get("mirostat_tau", None):
|
||||
|
||||
payload["options"]["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
payload["options"]["mirostat_tau"] = model_info.params.get(
|
||||
"mirostat_tau", None
|
||||
)
|
||||
|
||||
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
||||
if model_info.params.get("num_ctx", None):
|
||||
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||
|
||||
payload["options"]["num_predict"] = model_info.params.get(
|
||||
"max_tokens", None
|
||||
)
|
||||
payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
||||
if model_info.params.get("repeat_last_n", None):
|
||||
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||
"repeat_last_n", None
|
||||
)
|
||||
|
||||
payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
||||
if model_info.params.get("frequency_penalty", None):
|
||||
payload["options"]["repeat_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
|
||||
if model_info.params.get("temperature", None) is not None:
|
||||
payload["options"]["temperature"] = model_info.params.get(
|
||||
"temperature", None
|
||||
)
|
||||
|
||||
if model_info.params.get("seed", None):
|
||||
payload["options"]["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
if model_info.params.get("stop", None):
|
||||
payload["options"]["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
if model_info.params.get("tfs_z", None):
|
||||
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
||||
|
||||
if model_info.params.get("max_tokens", None):
|
||||
payload["options"]["num_predict"] = model_info.params.get(
|
||||
"max_tokens", None
|
||||
)
|
||||
|
||||
if model_info.params.get("top_k", None):
|
||||
payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
||||
|
||||
if model_info.params.get("top_p", None):
|
||||
payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
||||
|
||||
if model_info.params.get("use_mmap", None):
|
||||
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
|
||||
|
||||
if model_info.params.get("use_mlock", None):
|
||||
payload["options"]["use_mlock"] = model_info.params.get(
|
||||
"use_mlock", None
|
||||
)
|
||||
|
||||
if model_info.params.get("num_thread", None):
|
||||
payload["options"]["num_thread"] = model_info.params.get(
|
||||
"num_thread", None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
@@ -981,73 +842,18 @@ async def generate_chat_completion(
|
||||
|
||||
print(payload)
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if payload.get("stream", None):
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/chat",
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
|
||||
|
||||
|
||||
# TODO: we should update this part once Ollama supports other types
|
||||
class OpenAIChatMessageContent(BaseModel):
|
||||
type: str
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class OpenAIChatMessage(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
content: Union[str, OpenAIChatMessageContent]
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -1075,7 +881,6 @@ async def generate_openai_chat_completion(
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
@@ -1132,68 +937,7 @@ async def generate_openai_chat_completion(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if payload.get("stream"):
|
||||
yield json.dumps(
|
||||
{"request_id": request_id, "done": False}
|
||||
) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/v1/chat/completions",
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
|
||||
|
||||
|
||||
@app.get("/v1/models")
|
||||
@@ -1305,7 +1049,7 @@ async def download_file_stream(
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=600) # Set the timeout
|
||||
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(file_url, headers=headers) as response:
|
||||
total_size = int(response.headers.get("content-length", 0)) + current_size
|
||||
|
||||
@@ -1522,7 +1266,7 @@ async def deprecated_proxy(
|
||||
if path == "generate":
|
||||
data = json.loads(body.decode("utf-8"))
|
||||
|
||||
if not ("stream" in data and data["stream"] == False):
|
||||
if data.get("stream", True):
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
elif path == "chat":
|
||||
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
from starlette.background import BackgroundTask
|
||||
|
||||
from apps.webui.models.models import Models
|
||||
from apps.webui.models.users import Users
|
||||
@@ -185,7 +186,7 @@ async def fetch_url(url, key):
|
||||
timeout = aiohttp.ClientTimeout(total=5)
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return await response.json()
|
||||
except Exception as e:
|
||||
@@ -194,6 +195,16 @@ async def fetch_url(url, key):
|
||||
return None
|
||||
|
||||
|
||||
async def cleanup_response(
|
||||
response: Optional[aiohttp.ClientResponse],
|
||||
session: Optional[aiohttp.ClientSession],
|
||||
):
|
||||
if response:
|
||||
response.close()
|
||||
if session:
|
||||
await session.close()
|
||||
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
log.debug(f"merge_models_lists {model_lists}")
|
||||
merged_list = []
|
||||
@@ -228,6 +239,27 @@ async def get_all_models(raw: bool = False):
|
||||
) or not app.state.config.ENABLE_OPENAI_API:
|
||||
models = {"data": []}
|
||||
else:
|
||||
# Check if API KEYS length is same than API URLS length
|
||||
if len(app.state.config.OPENAI_API_KEYS) != len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
# if there are more keys than urls, remove the extra keys
|
||||
if len(app.state.config.OPENAI_API_KEYS) > len(
|
||||
app.state.config.OPENAI_API_BASE_URLS
|
||||
):
|
||||
app.state.config.OPENAI_API_KEYS = app.state.config.OPENAI_API_KEYS[
|
||||
: len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
]
|
||||
# if there are more urls than keys, add empty keys
|
||||
else:
|
||||
app.state.config.OPENAI_API_KEYS += [
|
||||
""
|
||||
for _ in range(
|
||||
len(app.state.config.OPENAI_API_BASE_URLS)
|
||||
- len(app.state.config.OPENAI_API_KEYS)
|
||||
)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
fetch_url(f"{url}/models", app.state.config.OPENAI_API_KEYS[idx])
|
||||
for idx, url in enumerate(app.state.config.OPENAI_API_BASE_URLS)
|
||||
@@ -313,113 +345,155 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
|
||||
)
|
||||
|
||||
|
||||
@app.post("/chat/completions")
|
||||
@app.post("/chat/completions/{url_idx}")
|
||||
async def generate_chat_completion(
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
idx = 0
|
||||
payload = {**form_data}
|
||||
|
||||
model_id = form_data.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
if model_info.params.get("temperature", None) is not None:
|
||||
payload["temperature"] = float(model_info.params.get("temperature"))
|
||||
|
||||
if model_info.params.get("top_p", None):
|
||||
payload["top_p"] = int(model_info.params.get("top_p", None))
|
||||
|
||||
if model_info.params.get("max_tokens", None):
|
||||
payload["max_tokens"] = int(model_info.params.get("max_tokens", None))
|
||||
|
||||
if model_info.params.get("frequency_penalty", None):
|
||||
payload["frequency_penalty"] = int(
|
||||
model_info.params.get("frequency_penalty", None)
|
||||
)
|
||||
|
||||
if model_info.params.get("seed", None):
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
if model_info.params.get("stop", None):
|
||||
payload["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
idx = model["urlIdx"]
|
||||
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {"name": user.name, "id": user.id}
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
if payload.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in payload:
|
||||
payload["max_tokens"] = 4000
|
||||
log.debug("Modified payload:", payload)
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
print(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
print(payload)
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method="POST",
|
||||
url=f"{url}/chat/completions",
|
||||
data=payload,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = await r.json()
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
idx = 0
|
||||
|
||||
body = await request.body()
|
||||
# TODO: Remove below after gpt-4-vision fix from Open AI
|
||||
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
|
||||
|
||||
payload = None
|
||||
|
||||
try:
|
||||
if "chat/completions" in path:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
|
||||
payload = {**body}
|
||||
|
||||
model_id = body.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
if model_info.params.get("temperature", None):
|
||||
payload["temperature"] = int(
|
||||
model_info.params.get("temperature")
|
||||
)
|
||||
|
||||
if model_info.params.get("top_p", None):
|
||||
payload["top_p"] = int(model_info.params.get("top_p", None))
|
||||
|
||||
if model_info.params.get("max_tokens", None):
|
||||
payload["max_tokens"] = int(
|
||||
model_info.params.get("max_tokens", None)
|
||||
)
|
||||
|
||||
if model_info.params.get("frequency_penalty", None):
|
||||
payload["frequency_penalty"] = int(
|
||||
model_info.params.get("frequency_penalty", None)
|
||||
)
|
||||
|
||||
if model_info.params.get("seed", None):
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
if model_info.params.get("stop", None):
|
||||
payload["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None)
|
||||
+ message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
|
||||
idx = model["urlIdx"]
|
||||
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {"name": user.name, "id": user.id}
|
||||
payload["title"] = (
|
||||
True
|
||||
if payload["stream"] == False and payload["max_tokens"] == 50
|
||||
else False
|
||||
)
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
if payload.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in payload:
|
||||
payload["max_tokens"] = 4000
|
||||
log.debug("Modified payload:", payload)
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
log.error("Error loading request body into a dictionary:", e)
|
||||
|
||||
print(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
@@ -431,40 +505,48 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
r = None
|
||||
session = None
|
||||
streaming = False
|
||||
|
||||
try:
|
||||
r = requests.request(
|
||||
session = aiohttp.ClientSession(trust_env=True)
|
||||
r = await session.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=payload if payload else body,
|
||||
data=body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# Check if response is SSE
|
||||
if "text/event-stream" in r.headers.get("Content-Type", ""):
|
||||
streaming = True
|
||||
return StreamingResponse(
|
||||
r.iter_content(chunk_size=8192),
|
||||
status_code=r.status_code,
|
||||
r.content,
|
||||
status_code=r.status,
|
||||
headers=dict(r.headers),
|
||||
background=BackgroundTask(
|
||||
cleanup_response, response=r, session=session
|
||||
),
|
||||
)
|
||||
else:
|
||||
response_data = r.json()
|
||||
response_data = await r.json()
|
||||
return response_data
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
res = await r.json()
|
||||
print(res)
|
||||
if "error" in res:
|
||||
error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
|
||||
except:
|
||||
error_detail = f"External: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500, detail=error_detail
|
||||
)
|
||||
raise HTTPException(status_code=r.status if r else 500, detail=error_detail)
|
||||
finally:
|
||||
if not streaming and session:
|
||||
if r:
|
||||
r.close()
|
||||
await session.close()
|
||||
|
||||
@@ -8,12 +8,15 @@ from fastapi import (
|
||||
Form,
|
||||
)
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import requests
|
||||
import os, shutil, logging, re
|
||||
from datetime import datetime
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Union, Sequence
|
||||
from typing import List, Union, Sequence, Iterator, Any
|
||||
|
||||
from chromadb.utils.batch_utils import create_batches
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
WebBaseLoader,
|
||||
@@ -30,6 +33,7 @@ from langchain_community.document_loaders import (
|
||||
UnstructuredExcelLoader,
|
||||
UnstructuredPowerPointLoader,
|
||||
YoutubeLoader,
|
||||
OutlookMessageLoader,
|
||||
)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
@@ -59,9 +63,17 @@ from apps.rag.utils import (
|
||||
query_doc_with_hybrid_search,
|
||||
query_collection,
|
||||
query_collection_with_hybrid_search,
|
||||
search_web,
|
||||
)
|
||||
|
||||
from apps.rag.search.brave import search_brave
|
||||
from apps.rag.search.google_pse import search_google_pse
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.searxng import search_searxng
|
||||
from apps.rag.search.serper import search_serper
|
||||
from apps.rag.search.serpstack import search_serpstack
|
||||
from apps.rag.search.serply import search_serply
|
||||
from apps.rag.search.duckduckgo import search_duckduckgo
|
||||
|
||||
from utils.misc import (
|
||||
calculate_sha256,
|
||||
calculate_sha256_string,
|
||||
@@ -71,6 +83,7 @@ from utils.misc import (
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
|
||||
from config import (
|
||||
AppConfig,
|
||||
ENV,
|
||||
SRC_LOG_LEVELS,
|
||||
UPLOAD_DIR,
|
||||
@@ -96,8 +109,19 @@ from config import (
|
||||
RAG_TEMPLATE,
|
||||
ENABLE_RAG_LOCAL_WEB_FETCH,
|
||||
YOUTUBE_LOADER_LANGUAGE,
|
||||
ENABLE_RAG_WEB_SEARCH,
|
||||
RAG_WEB_SEARCH_ENGINE,
|
||||
SEARXNG_QUERY_URL,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
SERPSTACK_API_KEY,
|
||||
SERPSTACK_HTTPS,
|
||||
SERPER_API_KEY,
|
||||
SERPLY_API_KEY,
|
||||
RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
AppConfig,
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
from constants import ERROR_MESSAGES
|
||||
@@ -122,6 +146,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
|
||||
|
||||
app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
|
||||
app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
|
||||
app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
|
||||
app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
|
||||
|
||||
@@ -136,6 +161,21 @@ app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
|
||||
app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||
|
||||
|
||||
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||
|
||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID = GOOGLE_PSE_ENGINE_ID
|
||||
app.state.config.BRAVE_SEARCH_API_KEY = BRAVE_SEARCH_API_KEY
|
||||
app.state.config.SERPSTACK_API_KEY = SERPSTACK_API_KEY
|
||||
app.state.config.SERPSTACK_HTTPS = SERPSTACK_HTTPS
|
||||
app.state.config.SERPER_API_KEY = SERPER_API_KEY
|
||||
app.state.config.SERPLY_API_KEY = SERPLY_API_KEY
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = RAG_WEB_SEARCH_RESULT_COUNT
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = RAG_WEB_SEARCH_CONCURRENT_REQUESTS
|
||||
|
||||
|
||||
def update_embedding_model(
|
||||
embedding_model: str,
|
||||
update_model: bool = False,
|
||||
@@ -181,6 +221,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
origins = ["*"]
|
||||
@@ -217,6 +258,7 @@ async def get_status():
|
||||
"embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
|
||||
"embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
|
||||
"reranking_model": app.state.config.RAG_RERANKING_MODEL,
|
||||
"openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
}
|
||||
|
||||
|
||||
@@ -229,6 +271,7 @@ async def get_embedding_config(user=Depends(get_admin_user)):
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -244,6 +287,7 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
|
||||
class OpenAIConfigForm(BaseModel):
|
||||
url: str
|
||||
key: str
|
||||
batch_size: Optional[int] = None
|
||||
|
||||
|
||||
class EmbeddingModelUpdateForm(BaseModel):
|
||||
@@ -264,9 +308,14 @@ async def update_embedding_config(
|
||||
app.state.config.RAG_EMBEDDING_MODEL = form_data.embedding_model
|
||||
|
||||
if app.state.config.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
|
||||
if form_data.openai_config != None:
|
||||
if form_data.openai_config is not None:
|
||||
app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
|
||||
app.state.config.OPENAI_API_KEY = form_data.openai_config.key
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
|
||||
form_data.openai_config.batch_size
|
||||
if form_data.openai_config.batch_size
|
||||
else 1
|
||||
)
|
||||
|
||||
update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
|
||||
|
||||
@@ -276,6 +325,7 @@ async def update_embedding_config(
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -285,6 +335,7 @@ async def update_embedding_config(
|
||||
"openai_config": {
|
||||
"url": app.state.config.OPENAI_API_BASE_URL,
|
||||
"key": app.state.config.OPENAI_API_KEY,
|
||||
"batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
},
|
||||
}
|
||||
except Exception as e:
|
||||
@@ -332,11 +383,27 @@ async def get_rag_config(user=Depends(get_admin_user)):
|
||||
"chunk_size": app.state.config.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
|
||||
},
|
||||
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"youtube": {
|
||||
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"web": {
|
||||
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"search": {
|
||||
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
|
||||
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
|
||||
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": app.state.config.SERPER_API_KEY,
|
||||
"serply_api_key": app.state.config.SERPLY_API_KEY,
|
||||
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -350,11 +417,31 @@ class YoutubeLoaderConfig(BaseModel):
|
||||
translation: Optional[str] = None
|
||||
|
||||
|
||||
class WebSearchConfig(BaseModel):
|
||||
enabled: bool
|
||||
engine: Optional[str] = None
|
||||
searxng_query_url: Optional[str] = None
|
||||
google_pse_api_key: Optional[str] = None
|
||||
google_pse_engine_id: Optional[str] = None
|
||||
brave_search_api_key: Optional[str] = None
|
||||
serpstack_api_key: Optional[str] = None
|
||||
serpstack_https: Optional[bool] = None
|
||||
serper_api_key: Optional[str] = None
|
||||
serply_api_key: Optional[str] = None
|
||||
result_count: Optional[int] = None
|
||||
concurrent_requests: Optional[int] = None
|
||||
|
||||
|
||||
class WebConfig(BaseModel):
|
||||
search: WebSearchConfig
|
||||
web_loader_ssl_verification: Optional[bool] = None
|
||||
|
||||
|
||||
class ConfigUpdateForm(BaseModel):
|
||||
pdf_extract_images: Optional[bool] = None
|
||||
chunk: Optional[ChunkParamUpdateForm] = None
|
||||
web_loader_ssl_verification: Optional[bool] = None
|
||||
youtube: Optional[YoutubeLoaderConfig] = None
|
||||
web: Optional[WebConfig] = None
|
||||
|
||||
|
||||
@app.post("/config/update")
|
||||
@@ -365,35 +452,37 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
||||
else app.state.config.PDF_EXTRACT_IMAGES
|
||||
)
|
||||
|
||||
app.state.config.CHUNK_SIZE = (
|
||||
form_data.chunk.chunk_size
|
||||
if form_data.chunk is not None
|
||||
else app.state.config.CHUNK_SIZE
|
||||
)
|
||||
if form_data.chunk is not None:
|
||||
app.state.config.CHUNK_SIZE = form_data.chunk.chunk_size
|
||||
app.state.config.CHUNK_OVERLAP = form_data.chunk.chunk_overlap
|
||||
|
||||
app.state.config.CHUNK_OVERLAP = (
|
||||
form_data.chunk.chunk_overlap
|
||||
if form_data.chunk is not None
|
||||
else app.state.config.CHUNK_OVERLAP
|
||||
)
|
||||
if form_data.youtube is not None:
|
||||
app.state.config.YOUTUBE_LOADER_LANGUAGE = form_data.youtube.language
|
||||
app.state.YOUTUBE_LOADER_TRANSLATION = form_data.youtube.translation
|
||||
|
||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
form_data.web_loader_ssl_verification
|
||||
if form_data.web_loader_ssl_verification != None
|
||||
else app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION
|
||||
)
|
||||
if form_data.web is not None:
|
||||
app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
form_data.web.web_loader_ssl_verification
|
||||
)
|
||||
|
||||
app.state.config.YOUTUBE_LOADER_LANGUAGE = (
|
||||
form_data.youtube.language
|
||||
if form_data.youtube is not None
|
||||
else app.state.config.YOUTUBE_LOADER_LANGUAGE
|
||||
)
|
||||
|
||||
app.state.YOUTUBE_LOADER_TRANSLATION = (
|
||||
form_data.youtube.translation
|
||||
if form_data.youtube is not None
|
||||
else app.state.YOUTUBE_LOADER_TRANSLATION
|
||||
)
|
||||
app.state.config.ENABLE_RAG_WEB_SEARCH = form_data.web.search.enabled
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = form_data.web.search.engine
|
||||
app.state.config.SEARXNG_QUERY_URL = form_data.web.search.searxng_query_url
|
||||
app.state.config.GOOGLE_PSE_API_KEY = form_data.web.search.google_pse_api_key
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID = (
|
||||
form_data.web.search.google_pse_engine_id
|
||||
)
|
||||
app.state.config.BRAVE_SEARCH_API_KEY = (
|
||||
form_data.web.search.brave_search_api_key
|
||||
)
|
||||
app.state.config.SERPSTACK_API_KEY = form_data.web.search.serpstack_api_key
|
||||
app.state.config.SERPSTACK_HTTPS = form_data.web.search.serpstack_https
|
||||
app.state.config.SERPER_API_KEY = form_data.web.search.serper_api_key
|
||||
app.state.config.SERPLY_API_KEY = form_data.web.search.serply_api_key
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT = form_data.web.search.result_count
|
||||
app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS = (
|
||||
form_data.web.search.concurrent_requests
|
||||
)
|
||||
|
||||
return {
|
||||
"status": True,
|
||||
@@ -402,11 +491,27 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
|
||||
"chunk_size": app.state.config.CHUNK_SIZE,
|
||||
"chunk_overlap": app.state.config.CHUNK_OVERLAP,
|
||||
},
|
||||
"web_loader_ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"youtube": {
|
||||
"language": app.state.config.YOUTUBE_LOADER_LANGUAGE,
|
||||
"translation": app.state.YOUTUBE_LOADER_TRANSLATION,
|
||||
},
|
||||
"web": {
|
||||
"ssl_verification": app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
||||
"search": {
|
||||
"enabled": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"engine": app.state.config.RAG_WEB_SEARCH_ENGINE,
|
||||
"searxng_query_url": app.state.config.SEARXNG_QUERY_URL,
|
||||
"google_pse_api_key": app.state.config.GOOGLE_PSE_API_KEY,
|
||||
"google_pse_engine_id": app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
"brave_search_api_key": app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
"serpstack_api_key": app.state.config.SERPSTACK_API_KEY,
|
||||
"serpstack_https": app.state.config.SERPSTACK_HTTPS,
|
||||
"serper_api_key": app.state.config.SERPER_API_KEY,
|
||||
"serply_api_key": app.state.config.SERPLY_API_KEY,
|
||||
"result_count": app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"concurrent_requests": app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -599,7 +704,7 @@ def get_web_loader(url: Union[str, Sequence[str]], verify_ssl: bool = True):
|
||||
# Check if the URL is valid
|
||||
if not validate_url(url):
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_URL)
|
||||
return WebBaseLoader(
|
||||
return SafeWebBaseLoader(
|
||||
url,
|
||||
verify_ssl=verify_ssl,
|
||||
requests_per_second=RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
@@ -642,17 +747,107 @@ def resolve_hostname(hostname):
|
||||
return ipv4_addresses, ipv6_addresses
|
||||
|
||||
|
||||
def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
||||
Will look for a search engine API key in environment variables in the following order:
|
||||
- SEARXNG_QUERY_URL
|
||||
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
|
||||
- BRAVE_SEARCH_API_KEY
|
||||
- SERPSTACK_API_KEY
|
||||
- SERPER_API_KEY
|
||||
- SERPLY_API_KEY
|
||||
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
"""
|
||||
|
||||
# TODO: add playwright to search the web
|
||||
if engine == "searxng":
|
||||
if app.state.config.SEARXNG_QUERY_URL:
|
||||
return search_searxng(
|
||||
app.state.config.SEARXNG_QUERY_URL,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
|
||||
elif engine == "google_pse":
|
||||
if (
|
||||
app.state.config.GOOGLE_PSE_API_KEY
|
||||
and app.state.config.GOOGLE_PSE_ENGINE_ID
|
||||
):
|
||||
return search_google_pse(
|
||||
app.state.config.GOOGLE_PSE_API_KEY,
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
"No GOOGLE_PSE_API_KEY or GOOGLE_PSE_ENGINE_ID found in environment variables"
|
||||
)
|
||||
elif engine == "brave":
|
||||
if app.state.config.BRAVE_SEARCH_API_KEY:
|
||||
return search_brave(
|
||||
app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
)
|
||||
else:
|
||||
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
|
||||
elif engine == "serpstack":
|
||||
if app.state.config.SERPSTACK_API_KEY:
|
||||
return search_serpstack(
|
||||
app.state.config.SERPSTACK_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
https_enabled=app.state.config.SERPSTACK_HTTPS,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPSTACK_API_KEY found in environment variables")
|
||||
elif engine == "serper":
|
||||
if app.state.config.SERPER_API_KEY:
|
||||
return search_serper(
|
||||
app.state.config.SERPER_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPER_API_KEY found in environment variables")
|
||||
elif engine == "serply":
|
||||
if app.state.config.SERPLY_API_KEY:
|
||||
return search_serply(
|
||||
app.state.config.SERPLY_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPLY_API_KEY found in environment variables")
|
||||
elif engine == "duckduckgo":
|
||||
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
|
||||
@app.post("/web/search")
|
||||
def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
|
||||
try:
|
||||
try:
|
||||
web_results = search_web(form_data.query)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR,
|
||||
)
|
||||
logging.info(
|
||||
f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
|
||||
)
|
||||
web_results = search_web(
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.WEB_SEARCH_ERROR(e),
|
||||
)
|
||||
|
||||
try:
|
||||
urls = [result.link for result in web_results]
|
||||
loader = get_web_loader(urls)
|
||||
data = loader.load()
|
||||
@@ -710,6 +905,13 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||
texts = [doc.page_content for doc in docs]
|
||||
metadatas = [doc.metadata for doc in docs]
|
||||
|
||||
# ChromaDB does not like datetime formats
|
||||
# for meta-data so convert them to string.
|
||||
for metadata in metadatas:
|
||||
for key, value in metadata.items():
|
||||
if isinstance(value, datetime):
|
||||
metadata[key] = str(value)
|
||||
|
||||
try:
|
||||
if overwrite:
|
||||
for collection in CHROMA_CLIENT.list_collections():
|
||||
@@ -725,6 +927,7 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
|
||||
app.state.sentence_transformer_ef,
|
||||
app.state.config.OPENAI_API_KEY,
|
||||
app.state.config.OPENAI_API_BASE_URL,
|
||||
app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
|
||||
)
|
||||
|
||||
embedding_texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||||
@@ -795,6 +998,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
|
||||
"swift",
|
||||
"vue",
|
||||
"svelte",
|
||||
"msg",
|
||||
]
|
||||
|
||||
if file_ext == "pdf":
|
||||
@@ -829,6 +1033,8 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
] or file_ext in ["ppt", "pptx"]:
|
||||
loader = UnstructuredPowerPointLoader(file_path)
|
||||
elif file_ext == "msg":
|
||||
loader = OutlookMessageLoader(file_path)
|
||||
elif file_ext in known_source_ext or (
|
||||
file_content_type and file_content_type.find("text/") >= 0
|
||||
):
|
||||
@@ -994,6 +1200,30 @@ def reset_vector_db(user=Depends(get_admin_user)):
|
||||
CHROMA_CLIENT.reset()
|
||||
|
||||
|
||||
@app.get("/reset/uploads")
|
||||
def reset_upload_dir(user=Depends(get_admin_user)) -> bool:
|
||||
folder = f"{UPLOAD_DIR}"
|
||||
try:
|
||||
# Check if the directory exists
|
||||
if os.path.exists(folder):
|
||||
# Iterate over all the files and directories in the specified directory
|
||||
for filename in os.listdir(folder):
|
||||
file_path = os.path.join(folder, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path) # Remove the file or link
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path) # Remove the directory
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
print(f"The directory {folder} does not exist")
|
||||
except Exception as e:
|
||||
print(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@app.get("/reset")
|
||||
def reset(user=Depends(get_admin_user)) -> bool:
|
||||
folder = f"{UPLOAD_DIR}"
|
||||
@@ -1015,6 +1245,33 @@ def reset(user=Depends(get_admin_user)) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class SafeWebBaseLoader(WebBaseLoader):
|
||||
"""WebBaseLoader with enhanced error handling for URLs."""
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""Lazy load text from the url(s) in web_path with error handling."""
|
||||
for path in self.web_paths:
|
||||
try:
|
||||
soup = self._scrape(path, bs_kwargs=self.bs_kwargs)
|
||||
text = soup.get_text(**self.bs_get_text_kwargs)
|
||||
|
||||
# Build metadata
|
||||
metadata = {"source": path}
|
||||
if title := soup.find("title"):
|
||||
metadata["title"] = title.get_text()
|
||||
if description := soup.find("meta", attrs={"name": "description"}):
|
||||
metadata["description"] = description.get(
|
||||
"content", "No description found."
|
||||
)
|
||||
if html := soup.find("html"):
|
||||
metadata["language"] = html.get("lang", "No language found.")
|
||||
|
||||
yield Document(page_content=text, metadata=metadata)
|
||||
except Exception as e:
|
||||
# Log the error and continue with the next URL
|
||||
log.error(f"Error loading {path}: {e}")
|
||||
|
||||
|
||||
if ENV == "dev":
|
||||
|
||||
@app.get("/ef")
|
||||
|
||||
@@ -3,13 +3,13 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_brave(api_key: str, query: str) -> list[SearchResult]:
|
||||
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
@@ -22,7 +22,7 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
|
||||
"Accept-Encoding": "gzip",
|
||||
"X-Subscription-Token": api_key,
|
||||
}
|
||||
params = {"q": query, "count": RAG_WEB_SEARCH_RESULT_COUNT}
|
||||
params = {"q": query, "count": count}
|
||||
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
@@ -33,5 +33,5 @@ def search_brave(api_key: str, query: str) -> list[SearchResult]:
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
)
|
||||
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
46
backend/apps/rag/search/duckduckgo.py
Normal file
46
backend/apps/rag/search/duckduckgo.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import logging
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from duckduckgo_search import DDGS
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
|
||||
"""
|
||||
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
count (int): The number of results to return
|
||||
|
||||
Returns:
|
||||
List[SearchResult]: A list of search results
|
||||
"""
|
||||
# Use the DDGS context manager to create a DDGS object
|
||||
with DDGS() as ddgs:
|
||||
# Use the ddgs.text() method to perform the search
|
||||
ddgs_gen = ddgs.text(
|
||||
query, safesearch="moderate", max_results=count, backend="api"
|
||||
)
|
||||
# Check if there are search results
|
||||
if ddgs_gen:
|
||||
# Convert the search results into a list
|
||||
search_results = [r for r in ddgs_gen]
|
||||
|
||||
# Create an empty list to store the SearchResult objects
|
||||
results = []
|
||||
# Iterate over each search result
|
||||
for result in search_results:
|
||||
# Create a SearchResult object and append it to the results list
|
||||
results.append(
|
||||
SearchResult(
|
||||
link=result["href"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("body"),
|
||||
)
|
||||
)
|
||||
print(results)
|
||||
# Return the list of search results
|
||||
return results
|
||||
@@ -4,14 +4,14 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_google_pse(
|
||||
api_key: str, search_engine_id: str, query: str
|
||||
api_key: str, search_engine_id: str, query: str, count: int
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -27,7 +27,7 @@ def search_google_pse(
|
||||
"cx": search_engine_id,
|
||||
"q": query,
|
||||
"key": api_key,
|
||||
"num": RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"num": count,
|
||||
}
|
||||
|
||||
response = requests.request("GET", url, headers=headers, params=params)
|
||||
|
||||
@@ -1,28 +1,68 @@
|
||||
import logging
|
||||
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_searxng(query_url: str, query: str) -> list[SearchResult]:
|
||||
"""Search a SearXNG instance for a query and return the results as a list of SearchResult objects.
|
||||
def search_searxng(
|
||||
query_url: str, query: str, count: int, **kwargs
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
|
||||
|
||||
The function allows passing additional parameters such as language or time_range to tailor the search result.
|
||||
|
||||
Args:
|
||||
query_url (str): The URL of the SearXNG instance to search. Must contain "<query>" as a placeholder
|
||||
query (str): The query to search for
|
||||
"""
|
||||
url = query_url.replace("<query>", query)
|
||||
if "&format=json" not in url:
|
||||
url += "&format=json"
|
||||
log.debug(f"searching {url}")
|
||||
query_url (str): The base URL of the SearXNG server.
|
||||
query (str): The search term or question to find in the SearXNG database.
|
||||
count (int): The maximum number of results to retrieve from the search.
|
||||
|
||||
r = requests.get(
|
||||
url,
|
||||
Keyword Args:
|
||||
language (str): Language filter for the search results; e.g., "en-US". Defaults to an empty string.
|
||||
safesearch (int): Safe search filter for safer web results; 0 = off, 1 = moderate, 2 = strict. Defaults to 1 (moderate).
|
||||
time_range (str): Time range for filtering results by date; e.g., "2023-04-05..today" or "all-time". Defaults to ''.
|
||||
categories: (Optional[List[str]]): Specific categories within which the search should be performed, defaulting to an empty string if not provided.
|
||||
|
||||
Returns:
|
||||
List[SearchResult]: A list of SearchResults sorted by relevance score in descending order.
|
||||
|
||||
Raise:
|
||||
requests.exceptions.RequestException: If a request error occurs during the search process.
|
||||
"""
|
||||
|
||||
# Default values for optional parameters are provided as empty strings or None when not specified.
|
||||
language = kwargs.get("language", "en-US")
|
||||
safesearch = kwargs.get("safesearch", "1")
|
||||
time_range = kwargs.get("time_range", "")
|
||||
categories = "".join(kwargs.get("categories", []))
|
||||
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"pageno": 1,
|
||||
"safesearch": safesearch,
|
||||
"language": language,
|
||||
"time_range": time_range,
|
||||
"categories": categories,
|
||||
"theme": "simple",
|
||||
"image_proxy": 0,
|
||||
}
|
||||
|
||||
# Legacy query format
|
||||
if "<query>" in query_url:
|
||||
# Strip all query parameters from the URL
|
||||
query_url = query_url.split("?")[0]
|
||||
|
||||
log.debug(f"searching {query_url}")
|
||||
|
||||
response = requests.get(
|
||||
query_url,
|
||||
headers={
|
||||
"User-Agent": "Open WebUI (https://github.com/open-webui/open-webui) RAG Bot",
|
||||
"Accept": "text/html",
|
||||
@@ -30,15 +70,17 @@ def search_searxng(query_url: str, query: str) -> list[SearchResult]:
|
||||
"Accept-Language": "en-US,en;q=0.5",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
params=params,
|
||||
)
|
||||
r.raise_for_status()
|
||||
|
||||
json_response = r.json()
|
||||
response.raise_for_status() # Raise an exception for HTTP errors.
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
||||
)
|
||||
for result in sorted_results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
for result in sorted_results[:count]
|
||||
]
|
||||
|
||||
@@ -4,13 +4,13 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serper(api_key: str, query: str) -> list[SearchResult]:
|
||||
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
@@ -35,5 +35,5 @@ def search_serper(api_key: str, query: str) -> list[SearchResult]:
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
68
backend/apps/rag/search/serply.py
Normal file
68
backend/apps/rag/search/serply.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serply(
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
hl: str = "us",
|
||||
limit: int = 10,
|
||||
device_type: str = "desktop",
|
||||
proxy_location: str = "US",
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
api_key (str): A serply.io API key
|
||||
query (str): The query to search for
|
||||
hl (str): Host Language code to display results in (reference https://developers.google.com/custom-search/docs/xml_results?hl=en#wsInterfaceLanguages)
|
||||
limit (int): The maximum number of results to return [10-100, defaults to 10]
|
||||
"""
|
||||
log.info("Searching with Serply")
|
||||
|
||||
url = "https://api.serply.io/v1/search/"
|
||||
|
||||
query_payload = {
|
||||
"q": query,
|
||||
"language": "en",
|
||||
"num": limit,
|
||||
"gl": proxy_location.upper(),
|
||||
"hl": hl.lower(),
|
||||
}
|
||||
|
||||
url = f"{url}{urlencode(query_payload)}"
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"X-User-Agent": device_type,
|
||||
"User-Agent": "open-webui",
|
||||
"X-Proxy-Location": proxy_location,
|
||||
}
|
||||
|
||||
response = requests.request("GET", url, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
json_response = response.json()
|
||||
log.info(f"results from serply search: {json_response}")
|
||||
|
||||
results = sorted(
|
||||
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
|
||||
)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
title=result.get("title"),
|
||||
snippet=result.get("description"),
|
||||
)
|
||||
for result in results[:count]
|
||||
]
|
||||
@@ -4,14 +4,14 @@ import logging
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from config import SRC_LOG_LEVELS, RAG_WEB_SEARCH_RESULT_COUNT
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serpstack(
|
||||
api_key: str, query: str, https_enabled: bool = True
|
||||
api_key: str, query: str, count: int, https_enabled: bool = True
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -39,5 +39,5 @@ def search_serpstack(
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
)
|
||||
for result in results[:RAG_WEB_SEARCH_RESULT_COUNT]
|
||||
for result in results[:count]
|
||||
]
|
||||
|
||||
206
backend/apps/rag/search/testdata/serply.json
vendored
Normal file
206
backend/apps/rag/search/testdata/serply.json
vendored
Normal file
@@ -0,0 +1,206 @@
|
||||
{
|
||||
"ads": [],
|
||||
"ads_count": 0,
|
||||
"answers": [],
|
||||
"results": [
|
||||
{
|
||||
"title": "Apple",
|
||||
"link": "https://www.apple.com/",
|
||||
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
|
||||
"additional_links": [
|
||||
{
|
||||
"text": "AppleApplehttps://www.apple.com",
|
||||
"href": "https://www.apple.com/"
|
||||
}
|
||||
],
|
||||
"cite": {},
|
||||
"subdomains": [
|
||||
{
|
||||
"title": "Support",
|
||||
"link": "https://support.apple.com/",
|
||||
"description": "SupportContact - iPhone Support - Billing and Subscriptions - Apple Repair"
|
||||
},
|
||||
{
|
||||
"title": "Store",
|
||||
"link": "https://www.apple.com/store",
|
||||
"description": "StoreShop iPhone - Shop iPad - App Store - Shop Mac - ..."
|
||||
},
|
||||
{
|
||||
"title": "Mac",
|
||||
"link": "https://www.apple.com/mac/",
|
||||
"description": "MacMacBook Air - MacBook Pro - iMac - Compare Mac models - Mac mini"
|
||||
},
|
||||
{
|
||||
"title": "iPad",
|
||||
"link": "https://www.apple.com/ipad/",
|
||||
"description": "iPadShop iPad - iPad Pro - iPad Air - Compare iPad models - ..."
|
||||
},
|
||||
{
|
||||
"title": "Watch",
|
||||
"link": "https://www.apple.com/watch/",
|
||||
"description": "WatchShop Apple Watch - Series 9 - SE - Ultra 2 - Nike - Hermès - ..."
|
||||
}
|
||||
],
|
||||
"realPosition": 1
|
||||
},
|
||||
{
|
||||
"title": "Apple",
|
||||
"link": "https://www.apple.com/",
|
||||
"description": "Discover the innovative world of Apple and shop everything iPhone, iPad, Apple Watch, Mac, and Apple TV, plus explore accessories, entertainment, ...",
|
||||
"additional_links": [
|
||||
{
|
||||
"text": "AppleApplehttps://www.apple.com",
|
||||
"href": "https://www.apple.com/"
|
||||
}
|
||||
],
|
||||
"cite": {},
|
||||
"realPosition": 2
|
||||
},
|
||||
{
|
||||
"title": "Apple Inc.",
|
||||
"link": "https://en.wikipedia.org/wiki/Apple_Inc.",
|
||||
"description": "Apple Inc. (formerly Apple Computer, Inc.) is an American multinational corporation and technology company headquartered in Cupertino, California, ...",
|
||||
"additional_links": [
|
||||
{
|
||||
"text": "Apple Inc.Wikipediahttps://en.wikipedia.org › wiki › Apple_Inc",
|
||||
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
|
||||
},
|
||||
{
|
||||
"text": "",
|
||||
"href": "https://en.wikipedia.org/wiki/Apple_Inc."
|
||||
},
|
||||
{
|
||||
"text": "History",
|
||||
"href": "https://en.wikipedia.org/wiki/History_of_Apple_Inc."
|
||||
},
|
||||
{
|
||||
"text": "List of Apple products",
|
||||
"href": "https://en.wikipedia.org/wiki/List_of_Apple_products"
|
||||
},
|
||||
{
|
||||
"text": "Litigation involving Apple Inc.",
|
||||
"href": "https://en.wikipedia.org/wiki/Litigation_involving_Apple_Inc."
|
||||
},
|
||||
{
|
||||
"text": "Apple Park",
|
||||
"href": "https://en.wikipedia.org/wiki/Apple_Park"
|
||||
}
|
||||
],
|
||||
"cite": {
|
||||
"domain": "https://en.wikipedia.org › wiki › Apple_Inc",
|
||||
"span": " › wiki › Apple_Inc"
|
||||
},
|
||||
"realPosition": 3
|
||||
},
|
||||
{
|
||||
"title": "Apple Inc. (AAPL) Company Profile & Facts",
|
||||
"link": "https://finance.yahoo.com/quote/AAPL/profile/",
|
||||
"description": "Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables, and accessories worldwide. The company offers iPhone, a line ...",
|
||||
"additional_links": [
|
||||
{
|
||||
"text": "Apple Inc. (AAPL) Company Profile & FactsYahoo Financehttps://finance.yahoo.com › quote › AAPL › profile",
|
||||
"href": "https://finance.yahoo.com/quote/AAPL/profile/"
|
||||
}
|
||||
],
|
||||
"cite": {
|
||||
"domain": "https://finance.yahoo.com › quote › AAPL › profile",
|
||||
"span": " › quote › AAPL › profile"
|
||||
},
|
||||
"realPosition": 4
|
||||
},
|
||||
{
|
||||
"title": "Apple Inc - Company Profile and News",
|
||||
"link": "https://www.bloomberg.com/profile/company/AAPL:US",
|
||||
"description": "Apple Inc. Apple Inc. designs, manufactures, and markets smartphones, personal computers, tablets, wearables and accessories, and sells a variety of related ...",
|
||||
"additional_links": [
|
||||
{
|
||||
"text": "Apple Inc - Company Profile and NewsBloomberghttps://www.bloomberg.com › company › AAPL:US",
|
||||
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
|
||||
},
|
||||
{
|
||||
"text": "",
|
||||
"href": "https://www.bloomberg.com/profile/company/AAPL:US"
|
||||
}
|
||||
],
|
||||
"cite": {
|
||||
"domain": "https://www.bloomberg.com › company › AAPL:US",
|
||||
"span": " › company › AAPL:US"
|
||||
},
|
||||
"realPosition": 5
|
||||
},
|
||||
{
|
||||
"title": "Apple Inc. | History, Products, Headquarters, & Facts",
|
||||
"link": "https://www.britannica.com/money/Apple-Inc",
|
||||
"description": "May 22, 2024 — Apple Inc. is an American multinational technology company that revolutionized the technology sector through its innovation of computer ...",
|
||||
"additional_links": [
|
||||
{
|
||||
"text": "Apple Inc. | History, Products, Headquarters, & FactsBritannicahttps://www.britannica.com › money › Apple-Inc",
|
||||
"href": "https://www.britannica.com/money/Apple-Inc"
|
||||
},
|
||||
{
|
||||
"text": "",
|
||||
"href": "https://www.britannica.com/money/Apple-Inc"
|
||||
}
|
||||
],
|
||||
"cite": {
|
||||
"domain": "https://www.britannica.com › money › Apple-Inc",
|
||||
"span": " › money › Apple-Inc"
|
||||
},
|
||||
"realPosition": 6
|
||||
}
|
||||
],
|
||||
"shopping_ads": [],
|
||||
"places": [
|
||||
{
|
||||
"title": "Apple Inc."
|
||||
},
|
||||
{
|
||||
"title": "Apple Inc"
|
||||
},
|
||||
{
|
||||
"title": "Apple Inc"
|
||||
}
|
||||
],
|
||||
"related_searches": {
|
||||
"images": [],
|
||||
"text": [
|
||||
{
|
||||
"title": "apple inc full form",
|
||||
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+full+form&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhPEAE"
|
||||
},
|
||||
{
|
||||
"title": "apple company history",
|
||||
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+company+history&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhOEAE"
|
||||
},
|
||||
{
|
||||
"title": "apple store",
|
||||
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Store&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhQEAE"
|
||||
},
|
||||
{
|
||||
"title": "apple id",
|
||||
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+id&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhSEAE"
|
||||
},
|
||||
{
|
||||
"title": "apple inc industry",
|
||||
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+Inc+industry&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhREAE"
|
||||
},
|
||||
{
|
||||
"title": "apple login",
|
||||
"link": "https://www.google.com/search?sca_esv=6b6df170a5c9891b&sca_upv=1&q=Apple+login&sa=X&ved=2ahUKEwjLxuSJwM-GAxUHODQIHYuJBhgQ1QJ6BAhTEAE"
|
||||
}
|
||||
]
|
||||
},
|
||||
"image_results": [],
|
||||
"carousel": [],
|
||||
"total": 2450000000,
|
||||
"knowledge_graph": "",
|
||||
"related_questions": [
|
||||
"What does the Apple Inc do?",
|
||||
"Why did Apple change to Apple Inc?",
|
||||
"Who owns Apple Inc.?",
|
||||
"What is Apple Inc best known for?"
|
||||
],
|
||||
"carousel_count": 0,
|
||||
"ts": 2.491065263748169,
|
||||
"device_type": null
|
||||
}
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
from typing import List, Union
|
||||
|
||||
from apps.ollama.main import (
|
||||
generate_ollama_embeddings,
|
||||
@@ -20,23 +20,8 @@ from langchain.retrievers import (
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from apps.rag.search.brave import search_brave
|
||||
from apps.rag.search.google_pse import search_google_pse
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.searxng import search_searxng
|
||||
from apps.rag.search.serper import search_serper
|
||||
from apps.rag.search.serpstack import search_serpstack
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
CHROMA_CLIENT,
|
||||
SEARXNG_QUERY_URL,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
BRAVE_SEARCH_API_KEY,
|
||||
SERPSTACK_API_KEY,
|
||||
SERPSTACK_HTTPS,
|
||||
SERPER_API_KEY,
|
||||
)
|
||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||
from config import SRC_LOG_LEVELS, CHROMA_CLIENT
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
@@ -214,6 +199,7 @@ def get_embedding_function(
|
||||
embedding_function,
|
||||
openai_key,
|
||||
openai_url,
|
||||
batch_size,
|
||||
):
|
||||
if embedding_engine == "":
|
||||
return lambda query: embedding_function.encode(query).tolist()
|
||||
@@ -237,17 +223,22 @@ def get_embedding_function(
|
||||
|
||||
def generate_multiple(query, f):
|
||||
if isinstance(query, list):
|
||||
return [f(q) for q in query]
|
||||
if embedding_engine == "openai":
|
||||
embeddings = []
|
||||
for i in range(0, len(query), batch_size):
|
||||
embeddings.extend(f(query[i : i + batch_size]))
|
||||
return embeddings
|
||||
else:
|
||||
return [f(q) for q in query]
|
||||
else:
|
||||
return f(query)
|
||||
|
||||
return lambda query: generate_multiple(query, func)
|
||||
|
||||
|
||||
def rag_messages(
|
||||
def get_rag_context(
|
||||
docs,
|
||||
messages,
|
||||
template,
|
||||
embedding_function,
|
||||
k,
|
||||
reranking_function,
|
||||
@@ -255,31 +246,7 @@ def rag_messages(
|
||||
hybrid_search,
|
||||
):
|
||||
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
|
||||
|
||||
last_user_message_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if messages[i]["role"] == "user":
|
||||
last_user_message_idx = i
|
||||
break
|
||||
|
||||
user_message = messages[last_user_message_idx]
|
||||
|
||||
if isinstance(user_message["content"], list):
|
||||
# Handle list content input
|
||||
content_type = "list"
|
||||
query = ""
|
||||
for content_item in user_message["content"]:
|
||||
if content_item["type"] == "text":
|
||||
query = content_item["text"]
|
||||
break
|
||||
elif isinstance(user_message["content"], str):
|
||||
# Handle text content input
|
||||
content_type = "text"
|
||||
query = user_message["content"]
|
||||
else:
|
||||
# Fallback in case the input does not match expected types
|
||||
content_type = None
|
||||
query = ""
|
||||
query = get_last_user_message(messages)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
@@ -350,33 +317,7 @@ def rag_messages(
|
||||
|
||||
context_string = context_string.strip()
|
||||
|
||||
ra_content = rag_template(
|
||||
template=template,
|
||||
context=context_string,
|
||||
query=query,
|
||||
)
|
||||
|
||||
log.debug(f"ra_content: {ra_content}")
|
||||
|
||||
if content_type == "list":
|
||||
new_content = []
|
||||
for content_item in user_message["content"]:
|
||||
if content_item["type"] == "text":
|
||||
# Update the text item's content with ra_content
|
||||
new_content.append({"type": "text", "text": ra_content})
|
||||
else:
|
||||
# Keep other types of content as they are
|
||||
new_content.append(content_item)
|
||||
new_user_message = {**user_message, "content": new_content}
|
||||
else:
|
||||
new_user_message = {
|
||||
**user_message,
|
||||
"content": ra_content,
|
||||
}
|
||||
|
||||
messages[last_user_message_idx] = new_user_message
|
||||
|
||||
return messages, citations
|
||||
return context_string, citations
|
||||
|
||||
|
||||
def get_model_path(model: str, update_model: bool = False):
|
||||
@@ -418,8 +359,22 @@ def get_model_path(model: str, update_model: bool = False):
|
||||
|
||||
|
||||
def generate_openai_embeddings(
|
||||
model: str, text: str, key: str, url: str = "https://api.openai.com/v1"
|
||||
model: str,
|
||||
text: Union[str, list[str]],
|
||||
key: str,
|
||||
url: str = "https://api.openai.com/v1",
|
||||
):
|
||||
if isinstance(text, list):
|
||||
embeddings = generate_openai_batch_embeddings(model, text, key, url)
|
||||
else:
|
||||
embeddings = generate_openai_batch_embeddings(model, [text], key, url)
|
||||
|
||||
return embeddings[0] if isinstance(text, str) else embeddings
|
||||
|
||||
|
||||
def generate_openai_batch_embeddings(
|
||||
model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
|
||||
) -> Optional[list[list[float]]]:
|
||||
try:
|
||||
r = requests.post(
|
||||
f"{url}/embeddings",
|
||||
@@ -427,12 +382,12 @@ def generate_openai_embeddings(
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {key}",
|
||||
},
|
||||
json={"input": text, "model": model},
|
||||
json={"input": texts, "model": model},
|
||||
)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "data" in data:
|
||||
return data["data"][0]["embedding"]
|
||||
return [elem["embedding"] for elem in data["data"]]
|
||||
else:
|
||||
raise "Something went wrong :/"
|
||||
except Exception as e:
|
||||
@@ -536,31 +491,3 @@ class RerankCompressor(BaseDocumentCompressor):
|
||||
)
|
||||
final_results.append(doc)
|
||||
return final_results
|
||||
|
||||
|
||||
def search_web(query: str) -> list[SearchResult]:
|
||||
"""Search the web using a search engine and return the results as a list of SearchResult objects.
|
||||
Will look for a search engine API key in environment variables in the following order:
|
||||
- SEARXNG_QUERY_URL
|
||||
- GOOGLE_PSE_API_KEY + GOOGLE_PSE_ENGINE_ID
|
||||
- BRAVE_SEARCH_API_KEY
|
||||
- SERPSTACK_API_KEY
|
||||
- SERPER_API_KEY
|
||||
|
||||
Args:
|
||||
query (str): The query to search for
|
||||
"""
|
||||
|
||||
# TODO: add playwright to search the web
|
||||
if SEARXNG_QUERY_URL:
|
||||
return search_searxng(SEARXNG_QUERY_URL, query)
|
||||
elif GOOGLE_PSE_API_KEY and GOOGLE_PSE_ENGINE_ID:
|
||||
return search_google_pse(GOOGLE_PSE_API_KEY, GOOGLE_PSE_ENGINE_ID, query)
|
||||
elif BRAVE_SEARCH_API_KEY:
|
||||
return search_brave(BRAVE_SEARCH_API_KEY, query)
|
||||
elif SERPSTACK_API_KEY:
|
||||
return search_serpstack(SERPSTACK_API_KEY, query, https_enabled=SERPSTACK_HTTPS)
|
||||
elif SERPER_API_KEY:
|
||||
return search_serper(SERPER_API_KEY, query)
|
||||
else:
|
||||
raise Exception("No search engine API key found in environment variables")
|
||||
|
||||
139
backend/apps/socket/main.py
Normal file
139
backend/apps/socket/main.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import socketio
|
||||
import asyncio
|
||||
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
from utils.utils import decode_token
|
||||
|
||||
sio = socketio.AsyncServer(cors_allowed_origins=[], async_mode="asgi")
|
||||
app = socketio.ASGIApp(sio, socketio_path="/ws/socket.io")
|
||||
|
||||
# Dictionary to maintain the user pool
|
||||
|
||||
SESSION_POOL = {}
|
||||
USER_POOL = {}
|
||||
USAGE_POOL = {}
|
||||
# Timeout duration in seconds
|
||||
TIMEOUT_DURATION = 3
|
||||
|
||||
|
||||
@sio.event
|
||||
async def connect(sid, environ, auth):
|
||||
user = None
|
||||
if auth and "token" in auth:
|
||||
data = decode_token(auth["token"])
|
||||
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
|
||||
if user:
|
||||
SESSION_POOL[sid] = user.id
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
|
||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
|
||||
@sio.on("user-join")
|
||||
async def user_join(sid, data):
|
||||
print("user-join", sid, data)
|
||||
|
||||
auth = data["auth"] if "auth" in data else None
|
||||
|
||||
if auth and "token" in auth:
|
||||
data = decode_token(auth["token"])
|
||||
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
|
||||
if user:
|
||||
|
||||
SESSION_POOL[sid] = user.id
|
||||
if user.id in USER_POOL:
|
||||
USER_POOL[user.id].append(sid)
|
||||
else:
|
||||
USER_POOL[user.id] = [sid]
|
||||
|
||||
print(f"user {user.name}({user.id}) connected with session ID {sid}")
|
||||
|
||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
||||
|
||||
|
||||
@sio.on("user-count")
|
||||
async def user_count(sid):
|
||||
await sio.emit("user-count", {"count": len(set(USER_POOL))})
|
||||
|
||||
|
||||
def get_models_in_use():
|
||||
# Aggregate all models in use
|
||||
models_in_use = []
|
||||
for model_id, data in USAGE_POOL.items():
|
||||
models_in_use.append(model_id)
|
||||
|
||||
return models_in_use
|
||||
|
||||
|
||||
@sio.on("usage")
|
||||
async def usage(sid, data):
|
||||
|
||||
model_id = data["model"]
|
||||
|
||||
# Cancel previous callback if there is one
|
||||
if model_id in USAGE_POOL:
|
||||
USAGE_POOL[model_id]["callback"].cancel()
|
||||
|
||||
# Store the new usage data and task
|
||||
|
||||
if model_id in USAGE_POOL:
|
||||
USAGE_POOL[model_id]["sids"].append(sid)
|
||||
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
|
||||
|
||||
else:
|
||||
USAGE_POOL[model_id] = {"sids": [sid]}
|
||||
|
||||
# Schedule a task to remove the usage data after TIMEOUT_DURATION
|
||||
USAGE_POOL[model_id]["callback"] = asyncio.create_task(
|
||||
remove_after_timeout(sid, model_id)
|
||||
)
|
||||
|
||||
# Broadcast the usage data to all clients
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
|
||||
|
||||
async def remove_after_timeout(sid, model_id):
|
||||
try:
|
||||
await asyncio.sleep(TIMEOUT_DURATION)
|
||||
if model_id in USAGE_POOL:
|
||||
print(USAGE_POOL[model_id]["sids"])
|
||||
USAGE_POOL[model_id]["sids"].remove(sid)
|
||||
USAGE_POOL[model_id]["sids"] = list(set(USAGE_POOL[model_id]["sids"]))
|
||||
|
||||
if len(USAGE_POOL[model_id]["sids"]) == 0:
|
||||
del USAGE_POOL[model_id]
|
||||
|
||||
# Broadcast the usage data to all clients
|
||||
await sio.emit("usage", {"models": get_models_in_use()})
|
||||
except asyncio.CancelledError:
|
||||
# Task was cancelled due to new 'usage' event
|
||||
pass
|
||||
|
||||
|
||||
@sio.event
|
||||
async def disconnect(sid):
|
||||
if sid in SESSION_POOL:
|
||||
user_id = SESSION_POOL[sid]
|
||||
del SESSION_POOL[sid]
|
||||
|
||||
USER_POOL[user_id].remove(sid)
|
||||
|
||||
if len(USER_POOL[user_id]) == 0:
|
||||
del USER_POOL[user_id]
|
||||
|
||||
await sio.emit("user-count", {"count": len(USER_POOL)})
|
||||
else:
|
||||
print(f"Unknown session ID {sid} disconnected")
|
||||
61
backend/apps/webui/internal/migrations/012_add_tools.py
Normal file
61
backend/apps/webui/internal/migrations/012_add_tools.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class Tool(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
|
||||
name = pw.TextField()
|
||||
content = pw.TextField()
|
||||
specs = pw.TextField()
|
||||
|
||||
meta = pw.TextField()
|
||||
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "tool"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("tool")
|
||||
@@ -6,6 +6,7 @@ from apps.webui.routers import (
|
||||
users,
|
||||
chats,
|
||||
documents,
|
||||
tools,
|
||||
models,
|
||||
prompts,
|
||||
configs,
|
||||
@@ -14,6 +15,8 @@ from apps.webui.routers import (
|
||||
)
|
||||
from config import (
|
||||
WEBUI_BUILD_HASH,
|
||||
SHOW_ADMIN_DETAILS,
|
||||
ADMIN_EMAIL,
|
||||
WEBUI_AUTH,
|
||||
DEFAULT_MODELS,
|
||||
DEFAULT_PROMPT_SUGGESTIONS,
|
||||
@@ -24,8 +27,8 @@ from config import (
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
JWT_EXPIRES_IN,
|
||||
WEBUI_BANNERS,
|
||||
AppConfig,
|
||||
ENABLE_COMMUNITY_SHARING,
|
||||
AppConfig,
|
||||
)
|
||||
|
||||
app = FastAPI()
|
||||
@@ -36,6 +39,12 @@ app.state.config = AppConfig()
|
||||
|
||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
|
||||
|
||||
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
|
||||
app.state.config.ADMIN_EMAIL = ADMIN_EMAIL
|
||||
|
||||
|
||||
app.state.config.DEFAULT_MODELS = DEFAULT_MODELS
|
||||
app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||
@@ -47,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
|
||||
app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||
|
||||
app.state.MODELS = {}
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
app.state.TOOLS = {}
|
||||
|
||||
|
||||
app.add_middleware(
|
||||
@@ -63,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
|
||||
app.include_router(documents.router, prefix="/documents", tags=["documents"])
|
||||
app.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
|
||||
@@ -298,6 +298,15 @@ class ChatTable:
|
||||
# .limit(limit).offset(skip)
|
||||
]
|
||||
|
||||
def get_archived_chats_by_user_id(self, user_id: str) -> List[ChatModel]:
|
||||
return [
|
||||
ChatModel(**model_to_dict(chat))
|
||||
for chat in Chat.select()
|
||||
.where(Chat.archived == True)
|
||||
.where(Chat.user_id == user_id)
|
||||
.order_by(Chat.updated_at.desc())
|
||||
]
|
||||
|
||||
def delete_chat_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Chat.delete().where((Chat.id == id))
|
||||
|
||||
132
backend/apps/webui/models/tools.py
Normal file
132
backend/apps/webui/models/tools.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from pydantic import BaseModel
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
import logging
|
||||
from apps.webui.internal.db import DB, JSONField
|
||||
|
||||
import json
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Tools DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Tool(Model):
|
||||
id = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
name = TextField()
|
||||
content = TextField()
|
||||
specs = JSONField()
|
||||
meta = JSONField()
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class ToolMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class ToolModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
content: str
|
||||
specs: List[dict]
|
||||
meta: ToolMeta
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
meta: ToolMeta
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class ToolForm(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
content: str
|
||||
meta: ToolMeta
|
||||
|
||||
|
||||
class ToolsTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.db.create_tables([Tool])
|
||||
|
||||
def insert_new_tool(
|
||||
self, user_id: str, form_data: ToolForm, specs: List[dict]
|
||||
) -> Optional[ToolModel]:
|
||||
tool = ToolModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"specs": specs,
|
||||
"user_id": user_id,
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Tool.create(**tool.model_dump())
|
||||
if result:
|
||||
return tool
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
return None
|
||||
|
||||
def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
|
||||
try:
|
||||
tool = Tool.get(Tool.id == id)
|
||||
return ToolModel(**model_to_dict(tool))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_tools(self) -> List[ToolModel]:
|
||||
return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
|
||||
|
||||
def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
|
||||
try:
|
||||
query = Tool.update(
|
||||
**updated,
|
||||
updated_at=int(time.time()),
|
||||
).where(Tool.id == id)
|
||||
query.execute()
|
||||
|
||||
tool = Tool.get(Tool.id == id)
|
||||
return ToolModel(**model_to_dict(tool))
|
||||
except:
|
||||
return None
|
||||
|
||||
def delete_tool_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Tool.delete().where((Tool.id == id))
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
Tools = ToolsTable(DB)
|
||||
@@ -269,73 +269,88 @@ async def add_user(form_data: AddUserForm, user=Depends(get_admin_user)):
|
||||
raise HTTPException(500, detail=ERROR_MESSAGES.DEFAULT(err))
|
||||
|
||||
|
||||
############################
|
||||
# GetAdminDetails
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/admin/details")
|
||||
async def get_admin_details(request: Request, user=Depends(get_current_user)):
|
||||
if request.app.state.config.SHOW_ADMIN_DETAILS:
|
||||
admin_email = request.app.state.config.ADMIN_EMAIL
|
||||
admin_name = None
|
||||
|
||||
print(admin_email, admin_name)
|
||||
|
||||
if admin_email:
|
||||
admin = Users.get_user_by_email(admin_email)
|
||||
if admin:
|
||||
admin_name = admin.name
|
||||
else:
|
||||
admin = Users.get_first_user()
|
||||
if admin:
|
||||
admin_email = admin.email
|
||||
admin_name = admin.name
|
||||
|
||||
return {
|
||||
"name": admin_name,
|
||||
"email": admin_email,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.ACTION_PROHIBITED)
|
||||
|
||||
|
||||
############################
|
||||
# ToggleSignUp
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/signup/enabled", response_model=bool)
|
||||
async def get_sign_up_status(request: Request, user=Depends(get_admin_user)):
|
||||
return request.app.state.config.ENABLE_SIGNUP
|
||||
@router.get("/admin/config")
|
||||
async def get_admin_config(request: Request, user=Depends(get_admin_user)):
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/signup/enabled/toggle", response_model=bool)
|
||||
async def toggle_sign_up(request: Request, user=Depends(get_admin_user)):
|
||||
request.app.state.config.ENABLE_SIGNUP = not request.app.state.config.ENABLE_SIGNUP
|
||||
return request.app.state.config.ENABLE_SIGNUP
|
||||
class AdminConfig(BaseModel):
|
||||
SHOW_ADMIN_DETAILS: bool
|
||||
ENABLE_SIGNUP: bool
|
||||
DEFAULT_USER_ROLE: str
|
||||
JWT_EXPIRES_IN: str
|
||||
ENABLE_COMMUNITY_SHARING: bool
|
||||
|
||||
|
||||
############################
|
||||
# Default User Role
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/signup/user/role")
|
||||
async def get_default_user_role(request: Request, user=Depends(get_admin_user)):
|
||||
return request.app.state.config.DEFAULT_USER_ROLE
|
||||
|
||||
|
||||
class UpdateRoleForm(BaseModel):
|
||||
role: str
|
||||
|
||||
|
||||
@router.post("/signup/user/role")
|
||||
async def update_default_user_role(
|
||||
request: Request, form_data: UpdateRoleForm, user=Depends(get_admin_user)
|
||||
@router.post("/admin/config")
|
||||
async def update_admin_config(
|
||||
request: Request, form_data: AdminConfig, user=Depends(get_admin_user)
|
||||
):
|
||||
if form_data.role in ["pending", "user", "admin"]:
|
||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.role
|
||||
return request.app.state.config.DEFAULT_USER_ROLE
|
||||
request.app.state.config.SHOW_ADMIN_DETAILS = form_data.SHOW_ADMIN_DETAILS
|
||||
request.app.state.config.ENABLE_SIGNUP = form_data.ENABLE_SIGNUP
|
||||
|
||||
if form_data.DEFAULT_USER_ROLE in ["pending", "user", "admin"]:
|
||||
request.app.state.config.DEFAULT_USER_ROLE = form_data.DEFAULT_USER_ROLE
|
||||
|
||||
############################
|
||||
# JWT Expiration
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/token/expires")
|
||||
async def get_token_expires_duration(request: Request, user=Depends(get_admin_user)):
|
||||
return request.app.state.config.JWT_EXPIRES_IN
|
||||
|
||||
|
||||
class UpdateJWTExpiresDurationForm(BaseModel):
|
||||
duration: str
|
||||
|
||||
|
||||
@router.post("/token/expires/update")
|
||||
async def update_token_expires_duration(
|
||||
request: Request,
|
||||
form_data: UpdateJWTExpiresDurationForm,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
pattern = r"^(-1|0|(-?\d+(\.\d+)?)(ms|s|m|h|d|w))$"
|
||||
|
||||
# Check if the input string matches the pattern
|
||||
if re.match(pattern, form_data.duration):
|
||||
request.app.state.config.JWT_EXPIRES_IN = form_data.duration
|
||||
return request.app.state.config.JWT_EXPIRES_IN
|
||||
else:
|
||||
return request.app.state.config.JWT_EXPIRES_IN
|
||||
if re.match(pattern, form_data.JWT_EXPIRES_IN):
|
||||
request.app.state.config.JWT_EXPIRES_IN = form_data.JWT_EXPIRES_IN
|
||||
|
||||
request.app.state.config.ENABLE_COMMUNITY_SHARING = (
|
||||
form_data.ENABLE_COMMUNITY_SHARING
|
||||
)
|
||||
|
||||
return {
|
||||
"SHOW_ADMIN_DETAILS": request.app.state.config.SHOW_ADMIN_DETAILS,
|
||||
"ENABLE_SIGNUP": request.app.state.config.ENABLE_SIGNUP,
|
||||
"DEFAULT_USER_ROLE": request.app.state.config.DEFAULT_USER_ROLE,
|
||||
"JWT_EXPIRES_IN": request.app.state.config.JWT_EXPIRES_IN,
|
||||
"ENABLE_COMMUNITY_SHARING": request.app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
}
|
||||
|
||||
|
||||
############################
|
||||
|
||||
@@ -113,6 +113,19 @@ async def get_user_chats(user=Depends(get_current_user)):
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetArchivedChats
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/all/archived", response_model=List[ChatResponse])
|
||||
async def get_user_chats(user=Depends(get_current_user)):
|
||||
return [
|
||||
ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
for chat in Chats.get_archived_chats_by_user_id(user.id)
|
||||
]
|
||||
|
||||
|
||||
############################
|
||||
# GetAllChatsInDB
|
||||
############################
|
||||
@@ -148,7 +161,7 @@ async def get_archived_session_user_chat_list(
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/archive/all", response_model=List[ChatTitleIdResponse])
|
||||
@router.post("/archive/all", response_model=bool)
|
||||
async def archive_all_chats(user=Depends(get_current_user)):
|
||||
return Chats.archive_all_chats_by_user_id(user.id)
|
||||
|
||||
@@ -288,6 +301,32 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
|
||||
return result
|
||||
|
||||
|
||||
############################
|
||||
# CloneChat
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/clone", response_model=Optional[ChatResponse])
|
||||
async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
|
||||
chat = Chats.get_chat_by_id_and_user_id(id, user.id)
|
||||
if chat:
|
||||
|
||||
chat_body = json.loads(chat.chat)
|
||||
updated_chat = {
|
||||
**chat_body,
|
||||
"originalChatId": chat.id,
|
||||
"branchPointMessageId": chat_body["history"]["currentId"],
|
||||
"title": f"Clone of {chat.title}",
|
||||
}
|
||||
|
||||
chat = Chats.insert_new_chat(user.id, ChatForm(**{"chat": updated_chat}))
|
||||
return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.DEFAULT()
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# ArchiveChat
|
||||
############################
|
||||
|
||||
@@ -73,7 +73,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/name/{name}", response_model=Optional[DocumentResponse])
|
||||
@router.get("/doc", response_model=Optional[DocumentResponse])
|
||||
async def get_doc_by_name(name: str, user=Depends(get_current_user)):
|
||||
doc = Documents.get_doc_by_name(name)
|
||||
|
||||
@@ -105,7 +105,7 @@ class TagDocumentForm(BaseModel):
|
||||
tags: List[dict]
|
||||
|
||||
|
||||
@router.post("/name/{name}/tags", response_model=Optional[DocumentResponse])
|
||||
@router.post("/doc/tags", response_model=Optional[DocumentResponse])
|
||||
async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
|
||||
doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
|
||||
|
||||
@@ -128,7 +128,7 @@ async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_u
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/name/{name}/update", response_model=Optional[DocumentResponse])
|
||||
@router.post("/doc/update", response_model=Optional[DocumentResponse])
|
||||
async def update_doc_by_name(
|
||||
name: str, form_data: DocumentUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
@@ -152,7 +152,7 @@ async def update_doc_by_name(
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/name/{name}/delete", response_model=bool)
|
||||
@router.delete("/doc/delete", response_model=bool)
|
||||
async def delete_doc_by_name(name: str, user=Depends(get_admin_user)):
|
||||
result = Documents.delete_doc_by_name(name)
|
||||
return result
|
||||
|
||||
183
backend/apps/webui/routers/tools.py
Normal file
183
backend/apps/webui/routers/tools.py
Normal file
@@ -0,0 +1,183 @@
|
||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from utils.tools import get_tools_specs
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from importlib import util
|
||||
import os
|
||||
|
||||
from config import DATA_DIR
|
||||
|
||||
|
||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
||||
os.makedirs(TOOLS_DIR, exist_ok=True)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetToolkits
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[ToolResponse])
|
||||
async def get_toolkits(user=Depends(get_current_user)):
|
||||
toolkits = [toolkit for toolkit in Tools.get_tools()]
|
||||
return toolkits
|
||||
|
||||
|
||||
############################
|
||||
# ExportToolKits
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/export", response_model=List[ToolModel])
|
||||
async def get_toolkits(user=Depends(get_admin_user)):
|
||||
toolkits = [toolkit for toolkit in Tools.get_tools()]
|
||||
return toolkits
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewToolKit
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[ToolResponse])
|
||||
async def create_new_toolkit(
|
||||
request: Request, form_data: ToolForm, user=Depends(get_admin_user)
|
||||
):
|
||||
if not form_data.id.isidentifier():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only alphanumeric characters and underscores are allowed in the id",
|
||||
)
|
||||
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
toolkit = Tools.get_tool_by_id(form_data.id)
|
||||
if toolkit == None:
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
|
||||
try:
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
tool_file.write(form_data.content)
|
||||
|
||||
toolkit_module = load_toolkit_module_by_id(form_data.id)
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[form_data.id] = toolkit_module
|
||||
|
||||
specs = get_tools_specs(TOOLS[form_data.id])
|
||||
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
|
||||
if toolkit:
|
||||
return toolkit
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating toolkit"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ID_TAKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetToolkitById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}", response_model=Optional[ToolModel])
|
||||
async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
|
||||
toolkit = Tools.get_tool_by_id(id)
|
||||
|
||||
if toolkit:
|
||||
return toolkit
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateToolkitById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/update", response_model=Optional[ToolModel])
|
||||
async def update_toolkit_by_id(
|
||||
request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
|
||||
):
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
|
||||
|
||||
try:
|
||||
with open(toolkit_path, "w") as tool_file:
|
||||
tool_file.write(form_data.content)
|
||||
|
||||
toolkit_module = load_toolkit_module_by_id(id)
|
||||
|
||||
TOOLS = request.app.state.TOOLS
|
||||
TOOLS[id] = toolkit_module
|
||||
|
||||
specs = get_tools_specs(TOOLS[id])
|
||||
|
||||
updated = {
|
||||
**form_data.model_dump(exclude={"id"}),
|
||||
"specs": specs,
|
||||
}
|
||||
|
||||
print(updated)
|
||||
toolkit = Tools.update_tool_by_id(id, updated)
|
||||
|
||||
if toolkit:
|
||||
return toolkit
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteToolkitById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
|
||||
result = Tools.delete_tool_by_id(id)
|
||||
|
||||
if result:
|
||||
TOOLS = request.app.state.TOOLS
|
||||
if id in TOOLS:
|
||||
del TOOLS[id]
|
||||
|
||||
# delete the toolkit file
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
|
||||
os.remove(toolkit_path)
|
||||
|
||||
return result
|
||||
@@ -19,7 +19,12 @@ from apps.webui.models.users import (
|
||||
from apps.webui.models.auths import Auths
|
||||
from apps.webui.models.chats import Chats
|
||||
|
||||
from utils.utils import get_verified_user, get_password_hash, get_admin_user
|
||||
from utils.utils import (
|
||||
get_verified_user,
|
||||
get_password_hash,
|
||||
get_current_user,
|
||||
get_admin_user,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
@@ -7,6 +7,8 @@ from pydantic import BaseModel
|
||||
|
||||
from fpdf import FPDF
|
||||
import markdown
|
||||
import black
|
||||
|
||||
|
||||
from apps.webui.internal.db import DB
|
||||
from utils.utils import get_admin_user
|
||||
@@ -26,6 +28,21 @@ async def get_gravatar(
|
||||
return get_gravatar_url(email)
|
||||
|
||||
|
||||
class CodeFormatRequest(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
@router.post("/code/format")
|
||||
async def format_code(request: CodeFormatRequest):
|
||||
try:
|
||||
formatted_code = black.format_str(request.code, mode=black.Mode())
|
||||
return {"code": formatted_code}
|
||||
except black.NothingChanged:
|
||||
return {"code": request.code}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
class MarkdownForm(BaseModel):
|
||||
md: str
|
||||
|
||||
@@ -107,3 +124,12 @@ async def download_db(user=Depends(get_admin_user)):
|
||||
media_type="application/octet-stream",
|
||||
filename="webui.db",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/litellm/config")
|
||||
async def download_litellm_config_yaml(user=Depends(get_admin_user)):
|
||||
return FileResponse(
|
||||
f"{DATA_DIR}/litellm/config.yaml",
|
||||
media_type="application/octet-stream",
|
||||
filename="config.yaml",
|
||||
)
|
||||
|
||||
23
backend/apps/webui/utils.py
Normal file
23
backend/apps/webui/utils.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from importlib import util
|
||||
import os
|
||||
|
||||
from config import TOOLS_DIR
|
||||
|
||||
|
||||
def load_toolkit_module_by_id(toolkit_id):
|
||||
toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
|
||||
spec = util.spec_from_file_location(toolkit_id, toolkit_path)
|
||||
module = util.module_from_spec(spec)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
print(f"Loaded module: {module.__name__}")
|
||||
if hasattr(module, "Tools"):
|
||||
return module.Tools()
|
||||
else:
|
||||
raise Exception("No Tools class found")
|
||||
except Exception as e:
|
||||
print(f"Error loading module: {toolkit_id}")
|
||||
# Move the file to the error folder
|
||||
os.rename(toolkit_path, f"{toolkit_path}.error")
|
||||
raise e
|
||||
@@ -180,6 +180,17 @@ WEBUI_BUILD_HASH = os.environ.get("WEBUI_BUILD_HASH", "dev-build")
|
||||
DATA_DIR = Path(os.getenv("DATA_DIR", BACKEND_DIR / "data")).resolve()
|
||||
FRONTEND_BUILD_DIR = Path(os.getenv("FRONTEND_BUILD_DIR", BASE_DIR / "build")).resolve()
|
||||
|
||||
RESET_CONFIG_ON_START = (
|
||||
os.environ.get("RESET_CONFIG_ON_START", "False").lower() == "true"
|
||||
)
|
||||
if RESET_CONFIG_ON_START:
|
||||
try:
|
||||
os.remove(f"{DATA_DIR}/config.json")
|
||||
with open(f"{DATA_DIR}/config.json", "w") as f:
|
||||
f.write("{}")
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
CONFIG_DATA = json.loads((DATA_DIR / "config.json").read_text())
|
||||
except:
|
||||
@@ -295,7 +306,11 @@ STATIC_DIR = Path(os.getenv("STATIC_DIR", BACKEND_DIR / "static")).resolve()
|
||||
|
||||
frontend_favicon = FRONTEND_BUILD_DIR / "favicon.png"
|
||||
if frontend_favicon.exists():
|
||||
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
|
||||
try:
|
||||
shutil.copyfile(frontend_favicon, STATIC_DIR / "favicon.png")
|
||||
except Exception as e:
|
||||
logging.error(f"An error occurred: {e}")
|
||||
|
||||
else:
|
||||
logging.warning(f"Frontend favicon not found at {frontend_favicon}")
|
||||
|
||||
@@ -353,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
|
||||
Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# Tools DIR
|
||||
####################################
|
||||
|
||||
TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
||||
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# LITELLM_CONFIG
|
||||
####################################
|
||||
@@ -590,6 +613,92 @@ WEBUI_BANNERS = PersistentConfig(
|
||||
[BannerModel(**banner) for banner in json.loads("[]")],
|
||||
)
|
||||
|
||||
|
||||
SHOW_ADMIN_DETAILS = PersistentConfig(
|
||||
"SHOW_ADMIN_DETAILS",
|
||||
"auth.admin.show",
|
||||
os.environ.get("SHOW_ADMIN_DETAILS", "true").lower() == "true",
|
||||
)
|
||||
|
||||
ADMIN_EMAIL = PersistentConfig(
|
||||
"ADMIN_EMAIL",
|
||||
"auth.admin.email",
|
||||
os.environ.get("ADMIN_EMAIL", None),
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# TASKS
|
||||
####################################
|
||||
|
||||
|
||||
TASK_MODEL = PersistentConfig(
|
||||
"TASK_MODEL",
|
||||
"task.model.default",
|
||||
os.environ.get("TASK_MODEL", ""),
|
||||
)
|
||||
|
||||
TASK_MODEL_EXTERNAL = PersistentConfig(
|
||||
"TASK_MODEL_EXTERNAL",
|
||||
"task.model.external",
|
||||
os.environ.get("TASK_MODEL_EXTERNAL", ""),
|
||||
)
|
||||
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.title.prompt_template",
|
||||
os.environ.get(
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE",
|
||||
"""Here is the query:
|
||||
{{prompt:middletruncate:8000}}
|
||||
|
||||
Create a concise, 3-5 word phrase with an emoji as a title for the previous query. Suitable Emojis for the summary can be used to enhance understanding but avoid quotation marks or special formatting. RESPOND ONLY WITH THE TITLE TEXT.
|
||||
|
||||
Examples of titles:
|
||||
📉 Stock Market Trends
|
||||
🍪 Perfect Chocolate Chip Recipe
|
||||
Evolution of Music Streaming
|
||||
Remote Work Productivity Tips
|
||||
Artificial Intelligence in Healthcare
|
||||
🎮 Video Game Development Insights""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
||||
"task.search.prompt_template",
|
||||
os.environ.get(
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE",
|
||||
"""You are tasked with generating web search queries. Give me an appropriate query to answer my question for google search. Answer with only the query. Today is {{CURRENT_DATE}}.
|
||||
|
||||
Question:
|
||||
{{prompt:end:4000}}""",
|
||||
),
|
||||
)
|
||||
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
|
||||
"task.search.prompt_length_threshold",
|
||||
int(
|
||||
os.environ.get(
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
|
||||
100,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||
"task.tools.prompt_template",
|
||||
os.environ.get(
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
|
||||
"""Tools: {{TOOLS}}
|
||||
If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks. Only return the object. Do not return any other text.""",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# WEBUI_SECRET_KEY
|
||||
####################################
|
||||
@@ -672,6 +781,12 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
|
||||
os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
|
||||
)
|
||||
|
||||
RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
|
||||
"RAG_EMBEDDING_OPENAI_BATCH_SIZE",
|
||||
"rag.embedding_openai_batch_size",
|
||||
os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", 1),
|
||||
)
|
||||
|
||||
RAG_RERANKING_MODEL = PersistentConfig(
|
||||
"RAG_RERANKING_MODEL",
|
||||
"rag.reranking_model",
|
||||
@@ -766,28 +881,81 @@ YOUTUBE_LOADER_LANGUAGE = PersistentConfig(
|
||||
os.getenv("YOUTUBE_LOADER_LANGUAGE", "en").split(","),
|
||||
)
|
||||
|
||||
SEARXNG_QUERY_URL = os.getenv("SEARXNG_QUERY_URL", "")
|
||||
GOOGLE_PSE_API_KEY = os.getenv("GOOGLE_PSE_API_KEY", "")
|
||||
GOOGLE_PSE_ENGINE_ID = os.getenv("GOOGLE_PSE_ENGINE_ID", "")
|
||||
BRAVE_SEARCH_API_KEY = os.getenv("BRAVE_SEARCH_API_KEY", "")
|
||||
SERPSTACK_API_KEY = os.getenv("SERPSTACK_API_KEY", "")
|
||||
SERPSTACK_HTTPS = os.getenv("SERPSTACK_HTTPS", "True").lower() == "true"
|
||||
SERPER_API_KEY = os.getenv("SERPER_API_KEY", "")
|
||||
|
||||
|
||||
RAG_WEB_SEARCH_ENABLED = (
|
||||
SEARXNG_QUERY_URL != ""
|
||||
or (GOOGLE_PSE_API_KEY != "" and GOOGLE_PSE_ENGINE_ID != "")
|
||||
or BRAVE_SEARCH_API_KEY != ""
|
||||
or SERPSTACK_API_KEY != ""
|
||||
or SERPER_API_KEY != ""
|
||||
ENABLE_RAG_WEB_SEARCH = PersistentConfig(
|
||||
"ENABLE_RAG_WEB_SEARCH",
|
||||
"rag.web.search.enable",
|
||||
os.getenv("ENABLE_RAG_WEB_SEARCH", "False").lower() == "true",
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_RESULT_COUNT = int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3"))
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = int(
|
||||
os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")
|
||||
RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_ENGINE",
|
||||
"rag.web.search.engine",
|
||||
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
||||
)
|
||||
|
||||
SEARXNG_QUERY_URL = PersistentConfig(
|
||||
"SEARXNG_QUERY_URL",
|
||||
"rag.web.search.searxng_query_url",
|
||||
os.getenv("SEARXNG_QUERY_URL", ""),
|
||||
)
|
||||
|
||||
GOOGLE_PSE_API_KEY = PersistentConfig(
|
||||
"GOOGLE_PSE_API_KEY",
|
||||
"rag.web.search.google_pse_api_key",
|
||||
os.getenv("GOOGLE_PSE_API_KEY", ""),
|
||||
)
|
||||
|
||||
GOOGLE_PSE_ENGINE_ID = PersistentConfig(
|
||||
"GOOGLE_PSE_ENGINE_ID",
|
||||
"rag.web.search.google_pse_engine_id",
|
||||
os.getenv("GOOGLE_PSE_ENGINE_ID", ""),
|
||||
)
|
||||
|
||||
BRAVE_SEARCH_API_KEY = PersistentConfig(
|
||||
"BRAVE_SEARCH_API_KEY",
|
||||
"rag.web.search.brave_search_api_key",
|
||||
os.getenv("BRAVE_SEARCH_API_KEY", ""),
|
||||
)
|
||||
|
||||
SERPSTACK_API_KEY = PersistentConfig(
|
||||
"SERPSTACK_API_KEY",
|
||||
"rag.web.search.serpstack_api_key",
|
||||
os.getenv("SERPSTACK_API_KEY", ""),
|
||||
)
|
||||
|
||||
SERPSTACK_HTTPS = PersistentConfig(
|
||||
"SERPSTACK_HTTPS",
|
||||
"rag.web.search.serpstack_https",
|
||||
os.getenv("SERPSTACK_HTTPS", "True").lower() == "true",
|
||||
)
|
||||
|
||||
SERPER_API_KEY = PersistentConfig(
|
||||
"SERPER_API_KEY",
|
||||
"rag.web.search.serper_api_key",
|
||||
os.getenv("SERPER_API_KEY", ""),
|
||||
)
|
||||
|
||||
SERPLY_API_KEY = PersistentConfig(
|
||||
"SERPLY_API_KEY",
|
||||
"rag.web.search.serply_api_key",
|
||||
os.getenv("SERPLY_API_KEY", ""),
|
||||
)
|
||||
|
||||
|
||||
RAG_WEB_SEARCH_RESULT_COUNT = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_RESULT_COUNT",
|
||||
"rag.web.search.result_count",
|
||||
int(os.getenv("RAG_WEB_SEARCH_RESULT_COUNT", "3")),
|
||||
)
|
||||
|
||||
RAG_WEB_SEARCH_CONCURRENT_REQUESTS = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_CONCURRENT_REQUESTS",
|
||||
"rag.web.search.concurrent_requests",
|
||||
int(os.getenv("RAG_WEB_SEARCH_CONCURRENT_REQUESTS", "10")),
|
||||
)
|
||||
|
||||
|
||||
####################################
|
||||
# Transcribe
|
||||
####################################
|
||||
@@ -855,25 +1023,59 @@ IMAGE_GENERATION_MODEL = PersistentConfig(
|
||||
# Audio
|
||||
####################################
|
||||
|
||||
AUDIO_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_OPENAI_API_BASE_URL",
|
||||
"audio.openai.api_base_url",
|
||||
os.getenv("AUDIO_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
||||
AUDIO_STT_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_STT_OPENAI_API_BASE_URL",
|
||||
"audio.stt.openai.api_base_url",
|
||||
os.getenv("AUDIO_STT_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
||||
)
|
||||
AUDIO_OPENAI_API_KEY = PersistentConfig(
|
||||
"AUDIO_OPENAI_API_KEY",
|
||||
"audio.openai.api_key",
|
||||
os.getenv("AUDIO_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||
|
||||
AUDIO_STT_OPENAI_API_KEY = PersistentConfig(
|
||||
"AUDIO_STT_OPENAI_API_KEY",
|
||||
"audio.stt.openai.api_key",
|
||||
os.getenv("AUDIO_STT_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||
)
|
||||
AUDIO_OPENAI_API_MODEL = PersistentConfig(
|
||||
"AUDIO_OPENAI_API_MODEL",
|
||||
"audio.openai.api_model",
|
||||
os.getenv("AUDIO_OPENAI_API_MODEL", "tts-1"),
|
||||
|
||||
AUDIO_STT_ENGINE = PersistentConfig(
|
||||
"AUDIO_STT_ENGINE",
|
||||
"audio.stt.engine",
|
||||
os.getenv("AUDIO_STT_ENGINE", ""),
|
||||
)
|
||||
AUDIO_OPENAI_API_VOICE = PersistentConfig(
|
||||
"AUDIO_OPENAI_API_VOICE",
|
||||
"audio.openai.api_voice",
|
||||
os.getenv("AUDIO_OPENAI_API_VOICE", "alloy"),
|
||||
|
||||
AUDIO_STT_MODEL = PersistentConfig(
|
||||
"AUDIO_STT_MODEL",
|
||||
"audio.stt.model",
|
||||
os.getenv("AUDIO_STT_MODEL", "whisper-1"),
|
||||
)
|
||||
|
||||
AUDIO_TTS_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"AUDIO_TTS_OPENAI_API_BASE_URL",
|
||||
"audio.tts.openai.api_base_url",
|
||||
os.getenv("AUDIO_TTS_OPENAI_API_BASE_URL", OPENAI_API_BASE_URL),
|
||||
)
|
||||
AUDIO_TTS_OPENAI_API_KEY = PersistentConfig(
|
||||
"AUDIO_TTS_OPENAI_API_KEY",
|
||||
"audio.tts.openai.api_key",
|
||||
os.getenv("AUDIO_TTS_OPENAI_API_KEY", OPENAI_API_KEY),
|
||||
)
|
||||
|
||||
|
||||
AUDIO_TTS_ENGINE = PersistentConfig(
|
||||
"AUDIO_TTS_ENGINE",
|
||||
"audio.tts.engine",
|
||||
os.getenv("AUDIO_TTS_ENGINE", ""),
|
||||
)
|
||||
|
||||
|
||||
AUDIO_TTS_MODEL = PersistentConfig(
|
||||
"AUDIO_TTS_MODEL",
|
||||
"audio.tts.model",
|
||||
os.getenv("AUDIO_TTS_MODEL", "tts-1"),
|
||||
)
|
||||
|
||||
AUDIO_TTS_VOICE = PersistentConfig(
|
||||
"AUDIO_TTS_VOICE",
|
||||
"audio.tts.voice",
|
||||
os.getenv("AUDIO_TTS_VOICE", "alloy"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
|
||||
COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
|
||||
FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
|
||||
|
||||
ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
|
||||
MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
|
||||
|
||||
NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."
|
||||
@@ -82,5 +83,9 @@ class ERROR_MESSAGES(str, Enum):
|
||||
)
|
||||
|
||||
WEB_SEARCH_ERROR = (
|
||||
"Oops! Something went wrong while searching the web. Please try again later."
|
||||
lambda err="": f"{err if err else 'Oops! Something went wrong while searching the web.'}"
|
||||
)
|
||||
|
||||
OLLAMA_API_DISABLED = (
|
||||
"The Ollama API is disabled. Please enable it to use this feature."
|
||||
)
|
||||
|
||||
738
backend/main.py
738
backend/main.py
@@ -9,8 +9,12 @@ import logging
|
||||
import aiohttp
|
||||
import requests
|
||||
import mimetypes
|
||||
import shutil
|
||||
import os
|
||||
import inspect
|
||||
import asyncio
|
||||
|
||||
from fastapi import FastAPI, Request, Depends, status
|
||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import HTTPException
|
||||
@@ -20,26 +24,48 @@ from starlette.exceptions import HTTPException as StarletteHTTPException
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import StreamingResponse, Response
|
||||
|
||||
from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
|
||||
from apps.openai.main import app as openai_app, get_all_models as get_openai_models
|
||||
|
||||
from apps.socket.main import app as socket_app
|
||||
from apps.ollama.main import (
|
||||
app as ollama_app,
|
||||
OpenAIChatCompletionForm,
|
||||
get_all_models as get_ollama_models,
|
||||
generate_openai_chat_completion as generate_ollama_chat_completion,
|
||||
)
|
||||
from apps.openai.main import (
|
||||
app as openai_app,
|
||||
get_all_models as get_openai_models,
|
||||
generate_chat_completion as generate_openai_chat_completion,
|
||||
)
|
||||
|
||||
from apps.audio.main import app as audio_app
|
||||
from apps.images.main import app as images_app
|
||||
from apps.rag.main import app as rag_app
|
||||
from apps.webui.main import app as webui_app
|
||||
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
|
||||
from apps.webui.models.models import Models, ModelModel
|
||||
from apps.webui.models.tools import Tools
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
|
||||
from utils.utils import (
|
||||
get_admin_user,
|
||||
get_verified_user,
|
||||
get_current_user,
|
||||
get_http_authorization_cred,
|
||||
)
|
||||
from apps.rag.utils import rag_messages
|
||||
from utils.task import (
|
||||
title_generation_template,
|
||||
search_query_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||
|
||||
from apps.rag.utils import get_rag_context, rag_template
|
||||
|
||||
from config import (
|
||||
CONFIG_DATA,
|
||||
@@ -60,9 +86,14 @@ from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
WEBHOOK_URL,
|
||||
ENABLE_ADMIN_EXPORT,
|
||||
RAG_WEB_SEARCH_ENABLED,
|
||||
AppConfig,
|
||||
WEBUI_BUILD_HASH,
|
||||
TASK_MODEL,
|
||||
TASK_MODEL_EXTERNAL,
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
AppConfig,
|
||||
)
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
@@ -116,27 +147,133 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
||||
|
||||
app.state.config.TASK_MODEL = TASK_MODEL
|
||||
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
||||
)
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
app.state.MODELS = {}
|
||||
|
||||
origins = ["*"]
|
||||
|
||||
# Custom middleware to add security headers
|
||||
# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
# async def dispatch(self, request: Request, call_next):
|
||||
# response: Response = await call_next(request)
|
||||
# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
||||
# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
||||
# return response
|
||||
|
||||
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tools_specs = json.dumps(tool.specs, indent=2)
|
||||
content = tools_function_calling_generation_template(template, tools_specs)
|
||||
|
||||
user_message = get_last_user_message(messages)
|
||||
prompt = (
|
||||
"History:\n"
|
||||
+ "\n".join(
|
||||
[
|
||||
f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
||||
for message in messages[::-1][:4]
|
||||
]
|
||||
)
|
||||
+ f"\nQuery: {user_message}"
|
||||
)
|
||||
|
||||
print(prompt)
|
||||
|
||||
payload = {
|
||||
"model": task_model_id,
|
||||
"messages": [
|
||||
{"role": "system", "content": content},
|
||||
{"role": "user", "content": f"Query: {prompt}"},
|
||||
],
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
model = app.state.MODELS[task_model_id]
|
||||
|
||||
response = None
|
||||
try:
|
||||
if model["owned_by"] == "ollama":
|
||||
response = await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
else:
|
||||
response = await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
content = None
|
||||
|
||||
if hasattr(response, "body_iterator"):
|
||||
async for chunk in response.body_iterator:
|
||||
data = json.loads(chunk.decode("utf-8"))
|
||||
content = data["choices"][0]["message"]["content"]
|
||||
|
||||
# Cleanup any remaining background tasks if necessary
|
||||
if response.background is not None:
|
||||
await response.background()
|
||||
else:
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
# Parse the function response
|
||||
if content is not None:
|
||||
print(f"content: {content}")
|
||||
result = json.loads(content)
|
||||
print(result)
|
||||
|
||||
# Call the function
|
||||
if "name" in result:
|
||||
if tool_id in webui_app.state.TOOLS:
|
||||
toolkit_module = webui_app.state.TOOLS[tool_id]
|
||||
else:
|
||||
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||
|
||||
function = getattr(toolkit_module, result["name"])
|
||||
function_result = None
|
||||
try:
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
# Check if '__user__' is a parameter of the function
|
||||
if "__user__" in sig.parameters:
|
||||
# Call the function with the '__user__' parameter included
|
||||
function_result = function(
|
||||
**{
|
||||
**result["parameters"],
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Call the function without modifying the parameters
|
||||
function_result = function(**result["parameters"])
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Add the function result to the system prompt
|
||||
if function_result:
|
||||
return function_result
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
|
||||
class RAGMiddleware(BaseHTTPMiddleware):
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
return_citations = False
|
||||
|
||||
@@ -153,35 +290,98 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
)
|
||||
|
||||
# Remove the citations from the body
|
||||
return_citations = data.get("citations", False)
|
||||
if "citations" in data:
|
||||
del data["citations"]
|
||||
|
||||
# Example: Add a new key-value pair or modify existing ones
|
||||
# data["modified"] = True # Example modification
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
if task_model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||
if (
|
||||
app.state.config.TASK_MODEL
|
||||
and app.state.config.TASK_MODEL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
else:
|
||||
if (
|
||||
app.state.config.TASK_MODEL_EXTERNAL
|
||||
and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
||||
):
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
|
||||
prompt = get_last_user_message(data["messages"])
|
||||
context = ""
|
||||
|
||||
# If tool_ids field is present, call the functions
|
||||
if "tool_ids" in data:
|
||||
print(data["tool_ids"])
|
||||
for tool_id in data["tool_ids"]:
|
||||
print(tool_id)
|
||||
try:
|
||||
response = await get_function_call_response(
|
||||
messages=data["messages"],
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
)
|
||||
|
||||
if response:
|
||||
context += ("\n" if context != "" else "") + response
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
del data["tool_ids"]
|
||||
|
||||
print(f"tool_context: {context}")
|
||||
|
||||
# If docs field is present, generate RAG completions
|
||||
if "docs" in data:
|
||||
data = {**data}
|
||||
data["messages"], citations = rag_messages(
|
||||
rag_context, citations = get_rag_context(
|
||||
docs=data["docs"],
|
||||
messages=data["messages"],
|
||||
template=rag_app.state.config.RAG_TEMPLATE,
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
|
||||
if rag_context:
|
||||
context += ("\n" if context != "" else "") + rag_context
|
||||
|
||||
del data["docs"]
|
||||
|
||||
log.debug(
|
||||
f"data['messages']: {data['messages']}, citations: {citations}"
|
||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||
|
||||
if context != "":
|
||||
system_prompt = rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||
)
|
||||
|
||||
print(system_prompt)
|
||||
|
||||
data["messages"] = add_or_update_system_message(
|
||||
f"\n{system_prompt}", data["messages"]
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
request.headers.__dict__["_list"] = [
|
||||
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
||||
@@ -224,7 +424,77 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
||||
yield data
|
||||
|
||||
|
||||
app.add_middleware(RAGMiddleware)
|
||||
app.add_middleware(ChatCompletionMiddleware)
|
||||
|
||||
|
||||
def filter_pipeline(payload, user):
|
||||
user = {"id": user.id, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
if "pipeline" in model:
|
||||
sorted_filters.append(model)
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": payload,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
payload = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
except:
|
||||
pass
|
||||
if "detail" in res:
|
||||
raise Exception(r.status_code, res["detail"])
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if "pipeline" not in app.state.MODELS[model_id]:
|
||||
if "chat_id" in payload:
|
||||
del payload["chat_id"]
|
||||
|
||||
if "title" in payload:
|
||||
del payload["title"]
|
||||
|
||||
return payload
|
||||
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
@@ -242,76 +512,17 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
model_id = data["model"]
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
)
|
||||
|
||||
try:
|
||||
data = filter_pipeline(data, user)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
user = None
|
||||
if len(sorted_filters) > 0:
|
||||
try:
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(
|
||||
request.headers.get("Authorization")
|
||||
)
|
||||
)
|
||||
user = {"id": user.id, "name": user.name, "role": user.role}
|
||||
except:
|
||||
pass
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/inlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": user,
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
if "chat_id" in data:
|
||||
del data["chat_id"]
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
# Replace the request body with the modified one
|
||||
@@ -368,6 +579,9 @@ async def update_embedding_function(request: Request, call_next):
|
||||
return response
|
||||
|
||||
|
||||
app.mount("/ws", socket_app)
|
||||
|
||||
|
||||
app.mount("/ollama", ollama_app)
|
||||
app.mount("/openai", openai_app)
|
||||
|
||||
@@ -469,6 +683,237 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
@app.get("/api/task/config")
|
||||
async def get_task_config(user=Depends(get_verified_user)):
|
||||
return {
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
class TaskConfigForm(BaseModel):
|
||||
TASK_MODEL: Optional[str]
|
||||
TASK_MODEL_EXTERNAL: Optional[str]
|
||||
TITLE_GENERATION_PROMPT_TEMPLATE: str
|
||||
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
||||
SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
|
||||
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
||||
|
||||
|
||||
@app.post("/api/task/config/update")
|
||||
async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
||||
app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
||||
app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
||||
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
||||
form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
)
|
||||
app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
||||
form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
||||
)
|
||||
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
||||
form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
)
|
||||
|
||||
return {
|
||||
"TASK_MODEL": app.state.config.TASK_MODEL,
|
||||
"TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
||||
"TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
||||
"SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
||||
"TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/task/title/completions")
|
||||
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_title")
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||
if app.state.config.TASK_MODEL:
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
else:
|
||||
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
|
||||
print(model_id)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = title_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": 50,
|
||||
"chat_id": form_data.get("chat_id", None),
|
||||
"title": True,
|
||||
}
|
||||
|
||||
print(payload)
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
else:
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/query/completions")
|
||||
async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("generate_search_query")
|
||||
|
||||
if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
|
||||
)
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||
if app.state.config.TASK_MODEL:
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
else:
|
||||
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
|
||||
print(model_id)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = search_query_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
)
|
||||
|
||||
payload = {
|
||||
"model": model_id,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"stream": False,
|
||||
"max_tokens": 30,
|
||||
}
|
||||
|
||||
print(payload)
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
else:
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
|
||||
@app.post("/api/task/tools/completions")
|
||||
async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
||||
print("get_tools_function_calling")
|
||||
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
||||
if app.state.config.TASK_MODEL:
|
||||
task_model_id = app.state.config.TASK_MODEL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
else:
|
||||
if app.state.config.TASK_MODEL_EXTERNAL:
|
||||
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
||||
if task_model_id in app.state.MODELS:
|
||||
model_id = task_model_id
|
||||
|
||||
print(model_id)
|
||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
|
||||
try:
|
||||
context = await get_function_call_response(
|
||||
form_data["messages"], form_data["tool_id"], template, model_id, user
|
||||
)
|
||||
return context
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=e.args[0],
|
||||
content={"detail": e.args[1]},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
print(model)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**form_data), user=user
|
||||
)
|
||||
else:
|
||||
return await generate_openai_chat_completion(form_data, user=user)
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
data = form_data
|
||||
@@ -490,6 +935,13 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
print(model_id)
|
||||
|
||||
if model_id in app.state.MODELS:
|
||||
model = app.state.MODELS[model_id]
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
@@ -537,7 +989,11 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
|
||||
responses = await get_openai_models(raw=True)
|
||||
|
||||
print(responses)
|
||||
urlIdxs = [idx for idx, response in enumerate(responses) if "pipelines" in response]
|
||||
urlIdxs = [
|
||||
idx
|
||||
for idx, response in enumerate(responses)
|
||||
if response != None and "pipelines" in response
|
||||
]
|
||||
|
||||
return {
|
||||
"data": [
|
||||
@@ -550,6 +1006,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
|
||||
}
|
||||
|
||||
|
||||
@app.post("/api/pipelines/upload")
|
||||
async def upload_pipeline(
|
||||
urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
||||
):
|
||||
print("upload_pipeline", urlIdx, file.filename)
|
||||
# Check if the uploaded file is a python file
|
||||
if not file.filename.endswith(".py"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only Python (.py) files are allowed.",
|
||||
)
|
||||
|
||||
upload_folder = f"{CACHE_DIR}/pipelines"
|
||||
os.makedirs(upload_folder, exist_ok=True)
|
||||
file_path = os.path.join(upload_folder, file.filename)
|
||||
|
||||
try:
|
||||
# Save the uploaded file
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
files = {"file": f}
|
||||
r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
|
||||
return {**data}
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
detail = "Pipeline not found"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
detail = res["detail"]
|
||||
except:
|
||||
pass
|
||||
|
||||
raise HTTPException(
|
||||
status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
||||
detail=detail,
|
||||
)
|
||||
finally:
|
||||
# Ensure the file is deleted after the upload is completed or on failure
|
||||
if os.path.exists(file_path):
|
||||
os.remove(file_path)
|
||||
|
||||
|
||||
class AddPipelineForm(BaseModel):
|
||||
url: str
|
||||
urlIdx: int
|
||||
@@ -811,11 +1324,20 @@ async def get_app_config():
|
||||
"auth": WEBUI_AUTH,
|
||||
"auth_trusted_header": bool(webui_app.state.AUTH_TRUSTED_EMAIL_HEADER),
|
||||
"enable_signup": webui_app.state.config.ENABLE_SIGNUP,
|
||||
"enable_web_search": RAG_WEB_SEARCH_ENABLED,
|
||||
"enable_web_search": rag_app.state.config.ENABLE_RAG_WEB_SEARCH,
|
||||
"enable_image_generation": images_app.state.config.ENABLED,
|
||||
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
|
||||
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
||||
},
|
||||
"audio": {
|
||||
"tts": {
|
||||
"engine": audio_app.state.config.TTS_ENGINE,
|
||||
"voice": audio_app.state.config.TTS_VOICE,
|
||||
},
|
||||
"stt": {
|
||||
"engine": audio_app.state.config.STT_ENGINE,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -860,23 +1382,7 @@ class UrlForm(BaseModel):
|
||||
async def update_webhook_url(form_data: UrlForm, user=Depends(get_admin_user)):
|
||||
app.state.config.WEBHOOK_URL = form_data.url
|
||||
webui_app.state.WEBHOOK_URL = app.state.config.WEBHOOK_URL
|
||||
|
||||
return {
|
||||
"url": app.state.config.WEBHOOK_URL,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/api/community_sharing", response_model=bool)
|
||||
async def get_community_sharing_status(request: Request, user=Depends(get_admin_user)):
|
||||
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
|
||||
|
||||
|
||||
@app.get("/api/community_sharing/toggle", response_model=bool)
|
||||
async def toggle_community_sharing(request: Request, user=Depends(get_admin_user)):
|
||||
webui_app.state.config.ENABLE_COMMUNITY_SHARING = (
|
||||
not webui_app.state.config.ENABLE_COMMUNITY_SHARING
|
||||
)
|
||||
return webui_app.state.config.ENABLE_COMMUNITY_SHARING
|
||||
return {"url": app.state.config.WEBHOOK_URL}
|
||||
|
||||
|
||||
@app.get("/api/version")
|
||||
@@ -894,7 +1400,7 @@ async def get_app_changelog():
|
||||
@app.get("/api/version/updates")
|
||||
async def get_app_latest_release_version():
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(
|
||||
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
||||
) as response:
|
||||
|
||||
@@ -56,4 +56,8 @@ PyJWT[crypto]==2.8.0
|
||||
black==24.4.2
|
||||
langfuse==2.33.0
|
||||
youtube-transcript-api==0.6.2
|
||||
pytube==15.0.0
|
||||
pytube==15.0.0
|
||||
|
||||
extract_msg
|
||||
pydub
|
||||
duckduckgo-search~=6.1.5
|
||||
@@ -20,12 +20,12 @@ if test "$WEBUI_SECRET_KEY $WEBUI_JWT_SECRET_KEY" = " "; then
|
||||
WEBUI_SECRET_KEY=$(cat "$KEY_FILE")
|
||||
fi
|
||||
|
||||
if [ "$USE_OLLAMA_DOCKER" = "true" ]; then
|
||||
if [[ "${USE_OLLAMA_DOCKER,,}" == "true" ]]; then
|
||||
echo "USE_OLLAMA is set to true, starting ollama serve."
|
||||
ollama serve &
|
||||
fi
|
||||
|
||||
if [ "$USE_CUDA_DOCKER" = "true" ]; then
|
||||
if [[ "${USE_CUDA_DOCKER,,}" == "true" ]]; then
|
||||
echo "CUDA is enabled, appending LD_LIBRARY_PATH to include torch/cudnn & cublas libraries."
|
||||
export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/torch/lib:/usr/local/lib/python3.11/site-packages/nvidia/cudnn/lib"
|
||||
fi
|
||||
|
||||
@@ -8,6 +8,7 @@ cd /d "%SCRIPT_DIR%" || exit /b
|
||||
|
||||
SET "KEY_FILE=.webui_secret_key"
|
||||
IF "%PORT%"=="" SET PORT=8080
|
||||
IF "%HOST%"=="" SET HOST=0.0.0.0
|
||||
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
|
||||
SET "WEBUI_JWT_SECRET_KEY=%WEBUI_JWT_SECRET_KEY%"
|
||||
|
||||
@@ -29,4 +30,4 @@ IF "%WEBUI_SECRET_KEY%%WEBUI_JWT_SECRET_KEY%" == " " (
|
||||
|
||||
:: Execute uvicorn
|
||||
SET "WEBUI_SECRET_KEY=%WEBUI_SECRET_KEY%"
|
||||
uvicorn main:app --host 0.0.0.0 --port "%PORT%" --forwarded-allow-ips '*'
|
||||
uvicorn main:app --host "%HOST%" --port "%PORT%" --forwarded-allow-ips '*'
|
||||
|
||||
@@ -3,7 +3,48 @@ import hashlib
|
||||
import json
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional
|
||||
from typing import Optional, List
|
||||
|
||||
|
||||
def get_last_user_message(messages: List[dict]) -> str:
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], list):
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
return item["text"]
|
||||
return message["content"]
|
||||
return None
|
||||
|
||||
|
||||
def get_last_assistant_message(messages: List[dict]) -> str:
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "assistant":
|
||||
if isinstance(message["content"], list):
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
return item["text"]
|
||||
return message["content"]
|
||||
return None
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: List[dict]):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
or updates the existing system message at the beginning.
|
||||
|
||||
:param msg: The message to be added or appended.
|
||||
:param messages: The list of message dictionaries.
|
||||
:return: The updated list of message dictionaries.
|
||||
"""
|
||||
|
||||
if messages and messages[0].get("role") == "system":
|
||||
messages[0]["content"] += f"{content}\n{messages[0]['content']}"
|
||||
else:
|
||||
# Insert at the beginning
|
||||
messages.insert(0, {"role": "system", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
@@ -123,11 +164,25 @@ def parse_ollama_modelfile(model_text):
|
||||
"repeat_penalty": float,
|
||||
"temperature": float,
|
||||
"seed": int,
|
||||
"stop": str,
|
||||
"tfs_z": float,
|
||||
"num_predict": int,
|
||||
"top_k": int,
|
||||
"top_p": float,
|
||||
"num_keep": int,
|
||||
"typical_p": float,
|
||||
"presence_penalty": float,
|
||||
"frequency_penalty": float,
|
||||
"penalize_newline": bool,
|
||||
"numa": bool,
|
||||
"num_batch": int,
|
||||
"num_gpu": int,
|
||||
"main_gpu": int,
|
||||
"low_vram": bool,
|
||||
"f16_kv": bool,
|
||||
"vocab_only": bool,
|
||||
"use_mmap": bool,
|
||||
"use_mlock": bool,
|
||||
"num_thread": int,
|
||||
}
|
||||
|
||||
data = {"base_model_id": None, "params": {}}
|
||||
@@ -156,10 +211,18 @@ def parse_ollama_modelfile(model_text):
|
||||
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
|
||||
if param_match:
|
||||
value = param_match.group(1)
|
||||
if param_type == int:
|
||||
value = int(value)
|
||||
elif param_type == float:
|
||||
value = float(value)
|
||||
|
||||
try:
|
||||
if param_type == int:
|
||||
value = int(value)
|
||||
elif param_type == float:
|
||||
value = float(value)
|
||||
elif param_type == bool:
|
||||
value = value.lower() == "true"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
data["params"][param] = value
|
||||
|
||||
# Parse adapter
|
||||
@@ -171,8 +234,14 @@ def parse_ollama_modelfile(model_text):
|
||||
system_desc_match = re.search(
|
||||
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
system_desc_match_single = re.search(
|
||||
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
|
||||
)
|
||||
|
||||
if system_desc_match:
|
||||
data["params"]["system"] = system_desc_match.group(1).strip()
|
||||
elif system_desc_match_single:
|
||||
data["params"]["system"] = system_desc_match_single.group(1).strip()
|
||||
|
||||
# Parse messages
|
||||
messages = []
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
from apps.webui.models.models import Models, ModelModel, ModelForm, ModelResponse
|
||||
|
||||
|
||||
def get_model_id_from_custom_model_id(id: str):
|
||||
model = Models.get_model_by_id(id)
|
||||
|
||||
if model:
|
||||
return model.id
|
||||
else:
|
||||
return id
|
||||
117
backend/utils/task.py
Normal file
117
backend/utils/task.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import re
|
||||
import math
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: str = None, current_location: str = None
|
||||
) -> str:
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
|
||||
# Format the date to YYYY-MM-DD
|
||||
formatted_date = current_date.strftime("%Y-%m-%d")
|
||||
|
||||
# Replace {{CURRENT_DATE}} in the template with the formatted date
|
||||
template = template.replace("{{CURRENT_DATE}}", formatted_date)
|
||||
|
||||
if user_name:
|
||||
# Replace {{USER_NAME}} in the template with the user's name
|
||||
template = template.replace("{{USER_NAME}}", user_name)
|
||||
|
||||
if current_location:
|
||||
# Replace {{CURRENT_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{CURRENT_LOCATION}}", current_location)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def title_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
end_length = match.group(2)
|
||||
middle_length = match.group(3)
|
||||
|
||||
if full_match == "{{prompt}}":
|
||||
return prompt
|
||||
elif start_length is not None:
|
||||
return prompt[: int(start_length)]
|
||||
elif end_length is not None:
|
||||
return prompt[-int(end_length) :]
|
||||
elif middle_length is not None:
|
||||
middle_length = int(middle_length)
|
||||
if len(prompt) <= middle_length:
|
||||
return prompt
|
||||
start = prompt[: math.ceil(middle_length / 2)]
|
||||
end = prompt[-math.floor(middle_length / 2) :]
|
||||
return f"{start}...{end}"
|
||||
return ""
|
||||
|
||||
template = re.sub(
|
||||
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||
replacement_function,
|
||||
template,
|
||||
)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def search_query_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
end_length = match.group(2)
|
||||
middle_length = match.group(3)
|
||||
|
||||
if full_match == "{{prompt}}":
|
||||
return prompt
|
||||
elif start_length is not None:
|
||||
return prompt[: int(start_length)]
|
||||
elif end_length is not None:
|
||||
return prompt[-int(end_length) :]
|
||||
elif middle_length is not None:
|
||||
middle_length = int(middle_length)
|
||||
if len(prompt) <= middle_length:
|
||||
return prompt
|
||||
start = prompt[: math.ceil(middle_length / 2)]
|
||||
end = prompt[-math.floor(middle_length / 2) :]
|
||||
return f"{start}...{end}"
|
||||
return ""
|
||||
|
||||
template = re.sub(
|
||||
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||
replacement_function,
|
||||
template,
|
||||
)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||
template = template.replace("{{TOOLS}}", tools_specs)
|
||||
return template
|
||||
73
backend/utils/tools.py
Normal file
73
backend/utils/tools.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import inspect
|
||||
from typing import get_type_hints, List, Dict, Any
|
||||
|
||||
|
||||
def doc_to_dict(docstring):
|
||||
lines = docstring.split("\n")
|
||||
description = lines[1].strip()
|
||||
param_dict = {}
|
||||
|
||||
for line in lines:
|
||||
if ":param" in line:
|
||||
line = line.replace(":param", "").strip()
|
||||
param, desc = line.split(":", 1)
|
||||
param_dict[param.strip()] = desc.strip()
|
||||
ret_dict = {"description": description, "params": param_dict}
|
||||
return ret_dict
|
||||
|
||||
|
||||
def get_tools_specs(tools) -> List[dict]:
|
||||
function_list = [
|
||||
{"name": func, "function": getattr(tools, func)}
|
||||
for func in dir(tools)
|
||||
if callable(getattr(tools, func)) and not func.startswith("__")
|
||||
]
|
||||
|
||||
specs = []
|
||||
for function_item in function_list:
|
||||
function_name = function_item["name"]
|
||||
function = function_item["function"]
|
||||
|
||||
function_doc = doc_to_dict(function.__doc__ or function_name)
|
||||
specs.append(
|
||||
{
|
||||
"name": function_name,
|
||||
# TODO: multi-line desc?
|
||||
"description": function_doc.get("description", function_name),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
param_name: {
|
||||
"type": param_annotation.__name__.lower(),
|
||||
**(
|
||||
{
|
||||
"enum": (
|
||||
str(param_annotation.__args__)
|
||||
if hasattr(param_annotation, "__args__")
|
||||
else None
|
||||
)
|
||||
}
|
||||
if hasattr(param_annotation, "__args__")
|
||||
else {}
|
||||
),
|
||||
"description": function_doc.get("params", {}).get(
|
||||
param_name, param_name
|
||||
),
|
||||
}
|
||||
for param_name, param_annotation in get_type_hints(
|
||||
function
|
||||
).items()
|
||||
if param_name != "return" and param_name != "__user__"
|
||||
},
|
||||
"required": [
|
||||
name
|
||||
for name, param in inspect.signature(
|
||||
function
|
||||
).parameters.items()
|
||||
if param.default is param.empty
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return specs
|
||||
Reference in New Issue
Block a user