mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into feat/oauth
This commit is contained in:
@@ -37,6 +37,10 @@ from config import (
|
||||
ENABLE_IMAGE_GENERATION,
|
||||
AUTOMATIC1111_BASE_URL,
|
||||
COMFYUI_BASE_URL,
|
||||
COMFYUI_CFG_SCALE,
|
||||
COMFYUI_SAMPLER,
|
||||
COMFYUI_SCHEDULER,
|
||||
COMFYUI_SD3,
|
||||
IMAGES_OPENAI_API_BASE_URL,
|
||||
IMAGES_OPENAI_API_KEY,
|
||||
IMAGE_GENERATION_MODEL,
|
||||
@@ -78,6 +82,10 @@ app.state.config.COMFYUI_BASE_URL = COMFYUI_BASE_URL
|
||||
|
||||
app.state.config.IMAGE_SIZE = IMAGE_SIZE
|
||||
app.state.config.IMAGE_STEPS = IMAGE_STEPS
|
||||
app.state.config.COMFYUI_CFG_SCALE = COMFYUI_CFG_SCALE
|
||||
app.state.config.COMFYUI_SAMPLER = COMFYUI_SAMPLER
|
||||
app.state.config.COMFYUI_SCHEDULER = COMFYUI_SCHEDULER
|
||||
app.state.config.COMFYUI_SD3 = COMFYUI_SD3
|
||||
|
||||
|
||||
@app.get("/config")
|
||||
@@ -457,6 +465,18 @@ def generate_image(
|
||||
if form_data.negative_prompt is not None:
|
||||
data["negative_prompt"] = form_data.negative_prompt
|
||||
|
||||
if app.state.config.COMFYUI_CFG_SCALE:
|
||||
data["cfg_scale"] = app.state.config.COMFYUI_CFG_SCALE
|
||||
|
||||
if app.state.config.COMFYUI_SAMPLER is not None:
|
||||
data["sampler"] = app.state.config.COMFYUI_SAMPLER
|
||||
|
||||
if app.state.config.COMFYUI_SCHEDULER is not None:
|
||||
data["scheduler"] = app.state.config.COMFYUI_SCHEDULER
|
||||
|
||||
if app.state.config.COMFYUI_SD3 is not None:
|
||||
data["sd3"] = app.state.config.COMFYUI_SD3
|
||||
|
||||
data = ImageGenerationPayload(**data)
|
||||
|
||||
res = comfyui_generate_image(
|
||||
|
||||
@@ -190,6 +190,10 @@ class ImageGenerationPayload(BaseModel):
|
||||
width: int
|
||||
height: int
|
||||
n: int = 1
|
||||
cfg_scale: Optional[float] = None
|
||||
sampler: Optional[str] = None
|
||||
scheduler: Optional[str] = None
|
||||
sd3: Optional[bool] = None
|
||||
|
||||
|
||||
def comfyui_generate_image(
|
||||
@@ -199,6 +203,18 @@ def comfyui_generate_image(
|
||||
|
||||
comfyui_prompt = json.loads(COMFYUI_DEFAULT_PROMPT)
|
||||
|
||||
if payload.cfg_scale:
|
||||
comfyui_prompt["3"]["inputs"]["cfg"] = payload.cfg_scale
|
||||
|
||||
if payload.sampler:
|
||||
comfyui_prompt["3"]["inputs"]["sampler"] = payload.sampler
|
||||
|
||||
if payload.scheduler:
|
||||
comfyui_prompt["3"]["inputs"]["scheduler"] = payload.scheduler
|
||||
|
||||
if payload.sd3:
|
||||
comfyui_prompt["5"]["class_type"] = "EmptySD3LatentImage"
|
||||
|
||||
comfyui_prompt["4"]["inputs"]["ckpt_name"] = model
|
||||
comfyui_prompt["5"]["inputs"]["batch_size"] = payload.n
|
||||
comfyui_prompt["5"]["inputs"]["width"] = payload.width
|
||||
|
||||
@@ -40,6 +40,7 @@ from utils.utils import (
|
||||
get_verified_user,
|
||||
get_admin_user,
|
||||
)
|
||||
from utils.task import prompt_template
|
||||
|
||||
|
||||
from config import (
|
||||
@@ -52,7 +53,7 @@ from config import (
|
||||
UPLOAD_DIR,
|
||||
AppConfig,
|
||||
)
|
||||
from utils.misc import calculate_sha256
|
||||
from utils.misc import calculate_sha256, add_or_update_system_message
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
|
||||
@@ -199,9 +200,6 @@ def merge_models_lists(model_lists):
|
||||
return list(merged_models.values())
|
||||
|
||||
|
||||
# user=Depends(get_current_user)
|
||||
|
||||
|
||||
async def get_all_models():
|
||||
log.info("get_all_models()")
|
||||
|
||||
@@ -817,24 +815,28 @@ async def generate_chat_completion(
|
||||
"num_thread", None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
system = model_info.params.get("system", None)
|
||||
if system:
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
system = prompt_template(
|
||||
system,
|
||||
**(
|
||||
{
|
||||
"user_name": user.name,
|
||||
"user_location": (
|
||||
user.info.get("location") if user.info else None
|
||||
),
|
||||
}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
payload["messages"] = add_or_update_system_message(
|
||||
system, payload["messages"]
|
||||
)
|
||||
|
||||
if url_idx == None:
|
||||
if ":" not in payload["model"]:
|
||||
@@ -850,8 +852,7 @@ async def generate_chat_completion(
|
||||
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
print(payload)
|
||||
log.debug(payload)
|
||||
|
||||
return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
|
||||
|
||||
@@ -879,10 +880,11 @@ class OpenAIChatCompletionForm(BaseModel):
|
||||
@app.post("/v1/chat/completions")
|
||||
@app.post("/v1/chat/completions/{url_idx}")
|
||||
async def generate_openai_chat_completion(
|
||||
form_data: OpenAIChatCompletionForm,
|
||||
form_data: dict,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
form_data = OpenAIChatCompletionForm(**form_data)
|
||||
|
||||
payload = {
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
@@ -914,22 +916,35 @@ async def generate_openai_chat_completion(
|
||||
else None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
system = model_info.params.get("system", None)
|
||||
|
||||
if system:
|
||||
system = prompt_template(
|
||||
system,
|
||||
**(
|
||||
{
|
||||
"user_name": user.name,
|
||||
"user_location": (
|
||||
user.info.get("location") if user.info else None
|
||||
),
|
||||
}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
message["content"] = system + message["content"]
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
"content": system,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1095,17 +1110,13 @@ async def download_file_stream(
|
||||
raise "Ollama: Could not create blob, Please try again."
|
||||
|
||||
|
||||
# def number_generator():
|
||||
# for i in range(1, 101):
|
||||
# yield f"data: {i}\n"
|
||||
|
||||
|
||||
# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
|
||||
@app.post("/models/download")
|
||||
@app.post("/models/download/{url_idx}")
|
||||
async def download_model(
|
||||
form_data: UrlForm,
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
|
||||
allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
|
||||
@@ -1134,7 +1145,11 @@ async def download_model(
|
||||
|
||||
@app.post("/models/upload")
|
||||
@app.post("/models/upload/{url_idx}")
|
||||
def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
|
||||
def upload_model(
|
||||
file: UploadFile = File(...),
|
||||
url_idx: Optional[int] = None,
|
||||
user=Depends(get_admin_user),
|
||||
):
|
||||
if url_idx == None:
|
||||
url_idx = 0
|
||||
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
@@ -1197,137 +1212,3 @@ def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
|
||||
yield f"data: {json.dumps(res)}\n\n"
|
||||
|
||||
return StreamingResponse(file_process_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# async def upload_model(file: UploadFile = File(), url_idx: Optional[int] = None):
|
||||
# if url_idx == None:
|
||||
# url_idx = 0
|
||||
# url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
|
||||
# file_location = os.path.join(UPLOAD_DIR, file.filename)
|
||||
# total_size = file.size
|
||||
|
||||
# async def file_upload_generator(file):
|
||||
# print(file)
|
||||
# try:
|
||||
# async with aiofiles.open(file_location, "wb") as f:
|
||||
# completed_size = 0
|
||||
# while True:
|
||||
# chunk = await file.read(1024*1024)
|
||||
# if not chunk:
|
||||
# break
|
||||
# await f.write(chunk)
|
||||
# completed_size += len(chunk)
|
||||
# progress = (completed_size / total_size) * 100
|
||||
|
||||
# print(progress)
|
||||
# yield f'data: {json.dumps({"status": "uploading", "percentage": progress, "total": total_size, "completed": completed_size, "done": False})}\n'
|
||||
# except Exception as e:
|
||||
# print(e)
|
||||
# yield f"data: {json.dumps({'status': 'error', 'message': str(e)})}\n"
|
||||
# finally:
|
||||
# await file.close()
|
||||
# print("done")
|
||||
# yield f'data: {json.dumps({"status": "completed", "percentage": 100, "total": total_size, "completed": completed_size, "done": True})}\n'
|
||||
|
||||
# return StreamingResponse(
|
||||
# file_upload_generator(copy.deepcopy(file)), media_type="text/event-stream"
|
||||
# )
|
||||
|
||||
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def deprecated_proxy(
|
||||
path: str, request: Request, user=Depends(get_verified_user)
|
||||
):
|
||||
url = app.state.config.OLLAMA_BASE_URLS[0]
|
||||
target_url = f"{url}/{path}"
|
||||
|
||||
body = await request.body()
|
||||
headers = dict(request.headers)
|
||||
|
||||
if user.role in ["user", "admin"]:
|
||||
if path in ["pull", "delete", "push", "copy", "create"]:
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
headers.pop("host", None)
|
||||
headers.pop("authorization", None)
|
||||
headers.pop("origin", None)
|
||||
headers.pop("referer", None)
|
||||
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
try:
|
||||
REQUEST_POOL.append(request_id)
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if path == "generate":
|
||||
data = json.loads(body.decode("utf-8"))
|
||||
|
||||
if data.get("stream", True):
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
elif path == "chat":
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
if request_id in REQUEST_POOL:
|
||||
yield chunk
|
||||
else:
|
||||
log.warning("User: canceled request")
|
||||
break
|
||||
finally:
|
||||
if hasattr(r, "close"):
|
||||
r.close()
|
||||
if request_id in REQUEST_POOL:
|
||||
REQUEST_POOL.remove(request_id)
|
||||
|
||||
r = requests.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
|
||||
# r.close()
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(),
|
||||
status_code=r.status_code,
|
||||
headers=dict(r.headers),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await run_in_threadpool(get_request)
|
||||
except Exception as e:
|
||||
error_detail = "Open WebUI: Server Connection Error"
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "error" in res:
|
||||
error_detail = f"Ollama: {res['error']}"
|
||||
except:
|
||||
error_detail = f"Ollama: {e}"
|
||||
|
||||
raise HTTPException(
|
||||
status_code=r.status_code if r else 500,
|
||||
detail=error_detail,
|
||||
)
|
||||
|
||||
@@ -20,6 +20,8 @@ from utils.utils import (
|
||||
get_verified_user,
|
||||
get_admin_user,
|
||||
)
|
||||
from utils.task import prompt_template
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
ENABLE_OPENAI_API,
|
||||
@@ -392,22 +394,34 @@ async def generate_chat_completion(
|
||||
else None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
system = model_info.params.get("system", None)
|
||||
if system:
|
||||
system = prompt_template(
|
||||
system,
|
||||
**(
|
||||
{
|
||||
"user_name": user.name,
|
||||
"user_location": (
|
||||
user.info.get("location") if user.info else None
|
||||
),
|
||||
}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
message["content"] = system + message["content"]
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
"content": system,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -418,7 +432,12 @@ async def generate_chat_completion(
|
||||
idx = model["urlIdx"]
|
||||
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {"name": user.name, "id": user.id}
|
||||
payload["user"] = {
|
||||
"name": user.name,
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
@@ -430,13 +449,11 @@ async def generate_chat_completion(
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
print(payload)
|
||||
log.debug(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
print(payload)
|
||||
|
||||
headers = {}
|
||||
headers["Authorization"] = f"Bearer {key}"
|
||||
headers["Content-Type"] = "application/json"
|
||||
|
||||
@@ -55,6 +55,9 @@ from apps.webui.models.documents import (
|
||||
DocumentForm,
|
||||
DocumentResponse,
|
||||
)
|
||||
from apps.webui.models.files import (
|
||||
Files,
|
||||
)
|
||||
|
||||
from apps.rag.utils import (
|
||||
get_model_path,
|
||||
@@ -112,6 +115,7 @@ from config import (
|
||||
YOUTUBE_LOADER_LANGUAGE,
|
||||
ENABLE_RAG_WEB_SEARCH,
|
||||
RAG_WEB_SEARCH_ENGINE,
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
SEARXNG_QUERY_URL,
|
||||
GOOGLE_PSE_API_KEY,
|
||||
GOOGLE_PSE_ENGINE_ID,
|
||||
@@ -165,6 +169,7 @@ app.state.YOUTUBE_LOADER_TRANSLATION = None
|
||||
|
||||
app.state.config.ENABLE_RAG_WEB_SEARCH = ENABLE_RAG_WEB_SEARCH
|
||||
app.state.config.RAG_WEB_SEARCH_ENGINE = RAG_WEB_SEARCH_ENGINE
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = RAG_WEB_SEARCH_DOMAIN_FILTER_LIST
|
||||
|
||||
app.state.config.SEARXNG_QUERY_URL = SEARXNG_QUERY_URL
|
||||
app.state.config.GOOGLE_PSE_API_KEY = GOOGLE_PSE_API_KEY
|
||||
@@ -775,6 +780,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
app.state.config.SEARXNG_QUERY_URL,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SEARXNG_QUERY_URL found in environment variables")
|
||||
@@ -788,6 +794,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
app.state.config.GOOGLE_PSE_ENGINE_ID,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
@@ -799,6 +806,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
app.state.config.BRAVE_SEARCH_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No BRAVE_SEARCH_API_KEY found in environment variables")
|
||||
@@ -808,6 +816,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
app.state.config.SERPSTACK_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
https_enabled=app.state.config.SERPSTACK_HTTPS,
|
||||
)
|
||||
else:
|
||||
@@ -818,6 +827,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
app.state.config.SERPER_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPER_API_KEY found in environment variables")
|
||||
@@ -827,11 +837,16 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
|
||||
app.state.config.SERPLY_API_KEY,
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
else:
|
||||
raise Exception("No SERPLY_API_KEY found in environment variables")
|
||||
elif engine == "duckduckgo":
|
||||
return search_duckduckgo(query, app.state.config.RAG_WEB_SEARCH_RESULT_COUNT)
|
||||
return search_duckduckgo(
|
||||
query,
|
||||
app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
)
|
||||
elif engine == "tavily":
|
||||
if app.state.config.TAVILY_API_KEY:
|
||||
return search_tavily(
|
||||
@@ -1119,6 +1134,60 @@ def store_doc(
|
||||
)
|
||||
|
||||
|
||||
class ProcessDocForm(BaseModel):
|
||||
file_id: str
|
||||
collection_name: Optional[str] = None
|
||||
|
||||
|
||||
@app.post("/process/doc")
|
||||
def process_doc(
|
||||
form_data: ProcessDocForm,
|
||||
user=Depends(get_current_user),
|
||||
):
|
||||
try:
|
||||
file = Files.get_file_by_id(form_data.file_id)
|
||||
file_path = file.meta.get("path", f"{UPLOAD_DIR}/{file.filename}")
|
||||
|
||||
f = open(file_path, "rb")
|
||||
|
||||
collection_name = form_data.collection_name
|
||||
if collection_name == None:
|
||||
collection_name = calculate_sha256(f)[:63]
|
||||
f.close()
|
||||
|
||||
loader, known_type = get_loader(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
)
|
||||
data = loader.load()
|
||||
|
||||
try:
|
||||
result = store_data_in_vector_db(data, collection_name)
|
||||
|
||||
if result:
|
||||
return {
|
||||
"status": True,
|
||||
"collection_name": collection_name,
|
||||
"known_type": known_type,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=e,
|
||||
)
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
if "No pandoc was found" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.PANDOC_NOT_INSTALLED,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
class TextRAGForm(BaseModel):
|
||||
name: str
|
||||
content: str
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
import logging
|
||||
|
||||
from typing import List, Optional
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
def search_brave(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Brave's Search API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
@@ -29,6 +31,9 @@ def search_brave(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("web", {}).get("results", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from typing import List, Optional
|
||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||
from duckduckgo_search import DDGS
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
@@ -8,7 +8,9 @@ log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
|
||||
def search_duckduckgo(
|
||||
query: str, count: int, filter_list: Optional[List[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""
|
||||
Search using DuckDuckGo's Search API and return the results as a list of SearchResult objects.
|
||||
Args:
|
||||
@@ -41,6 +43,7 @@ def search_duckduckgo(query: str, count: int) -> list[SearchResult]:
|
||||
snippet=result.get("body"),
|
||||
)
|
||||
)
|
||||
print(results)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
# Return the list of search results
|
||||
return results
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import List, Optional
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_google_pse(
|
||||
api_key: str, search_engine_id: str, query: str, count: int
|
||||
api_key: str,
|
||||
search_engine_id: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using Google's Programmable Search Engine API and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -35,6 +39,8 @@ def search_google_pse(
|
||||
|
||||
json_response = response.json()
|
||||
results = json_response.get("items", [])
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
|
||||
@@ -1,8 +1,19 @@
|
||||
from typing import Optional
|
||||
|
||||
from urllib.parse import urlparse
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
def get_filtered_results(results, filter_list):
|
||||
if not filter_list:
|
||||
return results
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
domain = urlparse(result["url"]).netloc
|
||||
if any(domain.endswith(filtered_domain) for filtered_domain in filter_list):
|
||||
filtered_results.append(result)
|
||||
return filtered_results
|
||||
|
||||
|
||||
class SearchResult(BaseModel):
|
||||
link: str
|
||||
title: Optional[str]
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import requests
|
||||
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_searxng(
|
||||
query_url: str, query: str, count: int, **kwargs
|
||||
query_url: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> List[SearchResult]:
|
||||
"""
|
||||
Search a SearXNG instance for a given query and return the results as a list of SearchResult objects.
|
||||
@@ -78,6 +82,8 @@ def search_searxng(
|
||||
json_response = response.json()
|
||||
results = json_response.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda x: x.get("score", 0), reverse=True)
|
||||
if filter_list:
|
||||
sorted_results = get_filtered_results(sorted_results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("content")
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import List, Optional
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
def search_serper(
|
||||
api_key: str, query: str, count: int, filter_list: Optional[List[str]] = None
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
|
||||
Args:
|
||||
@@ -29,6 +31,8 @@ def search_serper(api_key: str, query: str, count: int) -> list[SearchResult]:
|
||||
results = sorted(
|
||||
json_response.get("organic", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import List, Optional
|
||||
import requests
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -19,6 +19,7 @@ def search_serply(
|
||||
limit: int = 10,
|
||||
device_type: str = "desktop",
|
||||
proxy_location: str = "US",
|
||||
filter_list: Optional[List[str]] = None,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serper.dev's API and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -57,7 +58,8 @@ def search_serply(
|
||||
results = sorted(
|
||||
json_response.get("results", []), key=lambda x: x.get("realPosition", 0)
|
||||
)
|
||||
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["link"],
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
from typing import List, Optional
|
||||
import requests
|
||||
|
||||
from apps.rag.search.main import SearchResult
|
||||
from apps.rag.search.main import SearchResult, get_filtered_results
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@@ -11,7 +11,11 @@ log.setLevel(SRC_LOG_LEVELS["RAG"])
|
||||
|
||||
|
||||
def search_serpstack(
|
||||
api_key: str, query: str, count: int, https_enabled: bool = True
|
||||
api_key: str,
|
||||
query: str,
|
||||
count: int,
|
||||
filter_list: Optional[List[str]] = None,
|
||||
https_enabled: bool = True,
|
||||
) -> list[SearchResult]:
|
||||
"""Search using serpstack.com's and return the results as a list of SearchResult objects.
|
||||
|
||||
@@ -35,6 +39,8 @@ def search_serpstack(
|
||||
results = sorted(
|
||||
json_response.get("organic_results", []), key=lambda x: x.get("position", 0)
|
||||
)
|
||||
if filter_list:
|
||||
results = get_filtered_results(results, filter_list)
|
||||
return [
|
||||
SearchResult(
|
||||
link=result["url"], title=result.get("title"), snippet=result.get("snippet")
|
||||
|
||||
@@ -237,7 +237,7 @@ def get_embedding_function(
|
||||
|
||||
|
||||
def get_rag_context(
|
||||
docs,
|
||||
files,
|
||||
messages,
|
||||
embedding_function,
|
||||
k,
|
||||
@@ -245,29 +245,29 @@ def get_rag_context(
|
||||
r,
|
||||
hybrid_search,
|
||||
):
|
||||
log.debug(f"docs: {docs} {messages} {embedding_function} {reranking_function}")
|
||||
log.debug(f"files: {files} {messages} {embedding_function} {reranking_function}")
|
||||
query = get_last_user_message(messages)
|
||||
|
||||
extracted_collections = []
|
||||
relevant_contexts = []
|
||||
|
||||
for doc in docs:
|
||||
for file in files:
|
||||
context = None
|
||||
|
||||
collection_names = (
|
||||
doc["collection_names"]
|
||||
if doc["type"] == "collection"
|
||||
else [doc["collection_name"]]
|
||||
file["collection_names"]
|
||||
if file["type"] == "collection"
|
||||
else [file["collection_name"]]
|
||||
)
|
||||
|
||||
collection_names = set(collection_names).difference(extracted_collections)
|
||||
if not collection_names:
|
||||
log.debug(f"skipping {doc} as it has already been extracted")
|
||||
log.debug(f"skipping {file} as it has already been extracted")
|
||||
continue
|
||||
|
||||
try:
|
||||
if doc["type"] == "text":
|
||||
context = doc["content"]
|
||||
if file["type"] == "text":
|
||||
context = file["content"]
|
||||
else:
|
||||
if hybrid_search:
|
||||
context = query_collection_with_hybrid_search(
|
||||
@@ -290,7 +290,7 @@ def get_rag_context(
|
||||
context = None
|
||||
|
||||
if context:
|
||||
relevant_contexts.append({**context, "source": doc})
|
||||
relevant_contexts.append({**context, "source": file})
|
||||
|
||||
extracted_collections.extend(collection_names)
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import os
|
||||
import logging
|
||||
import json
|
||||
|
||||
from peewee import *
|
||||
from peewee_migrate import Router
|
||||
from playhouse.db_url import connect
|
||||
|
||||
from apps.webui.internal.wrappers import register_connection
|
||||
from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL, BACKEND_DIR
|
||||
import os
|
||||
import logging
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
@@ -28,12 +29,26 @@ if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||
else:
|
||||
pass
|
||||
|
||||
DB = connect(DATABASE_URL)
|
||||
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
||||
|
||||
# The `register_connection` function encapsulates the logic for setting up
|
||||
# the database connection based on the connection string, while `connect`
|
||||
# is a Peewee-specific method to manage the connection state and avoid errors
|
||||
# when a connection is already open.
|
||||
try:
|
||||
DB = register_connection(DATABASE_URL)
|
||||
log.info(f"Connected to a {DB.__class__.__name__} database.")
|
||||
except Exception as e:
|
||||
log.error(f"Failed to initialize the database connection: {e}")
|
||||
raise
|
||||
|
||||
router = Router(
|
||||
DB,
|
||||
migrate_dir=BACKEND_DIR / "apps" / "webui" / "internal" / "migrations",
|
||||
logger=log,
|
||||
)
|
||||
router.run()
|
||||
DB.connect(reuse_if_open=True)
|
||||
try:
|
||||
DB.connect(reuse_if_open=True)
|
||||
except OperationalError as e:
|
||||
log.info(f"Failed to connect to database again due to: {e}")
|
||||
pass
|
||||
|
||||
48
backend/apps/webui/internal/migrations/013_add_user_info.py
Normal file
48
backend/apps/webui/internal/migrations/013_add_user_info.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Peewee migrations -- 002_add_local_sharing.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Adding fields info to the 'user' table
|
||||
migrator.add_fields("user", info=pw.TextField(null=True))
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
# Remove the settings field
|
||||
migrator.remove_fields("user", "info")
|
||||
55
backend/apps/webui/internal/migrations/014_add_files.py
Normal file
55
backend/apps/webui/internal/migrations/014_add_files.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class File(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
filename = pw.TextField()
|
||||
meta = pw.TextField()
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "file"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("file")
|
||||
61
backend/apps/webui/internal/migrations/015_add_functions.py
Normal file
61
backend/apps/webui/internal/migrations/015_add_functions.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class Function(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
|
||||
name = pw.TextField()
|
||||
type = pw.TextField()
|
||||
|
||||
content = pw.TextField()
|
||||
meta = pw.TextField()
|
||||
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "function"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("function")
|
||||
72
backend/apps/webui/internal/wrappers.py
Normal file
72
backend/apps/webui/internal/wrappers.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from contextvars import ContextVar
|
||||
from peewee import *
|
||||
from peewee import PostgresqlDatabase, InterfaceError as PeeWeeInterfaceError
|
||||
|
||||
import logging
|
||||
from playhouse.db_url import connect, parse
|
||||
from playhouse.shortcuts import ReconnectMixin
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
|
||||
db_state_default = {"closed": None, "conn": None, "ctx": None, "transactions": None}
|
||||
db_state = ContextVar("db_state", default=db_state_default.copy())
|
||||
|
||||
|
||||
class PeeweeConnectionState(object):
|
||||
def __init__(self, **kwargs):
|
||||
super().__setattr__("_state", db_state)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
self._state.get()[name] = value
|
||||
|
||||
def __getattr__(self, name):
|
||||
value = self._state.get()[name]
|
||||
return value
|
||||
|
||||
|
||||
class CustomReconnectMixin(ReconnectMixin):
|
||||
reconnect_errors = (
|
||||
# psycopg2
|
||||
(OperationalError, "termin"),
|
||||
(InterfaceError, "closed"),
|
||||
# peewee
|
||||
(PeeWeeInterfaceError, "closed"),
|
||||
)
|
||||
|
||||
|
||||
class ReconnectingPostgresqlDatabase(CustomReconnectMixin, PostgresqlDatabase):
|
||||
pass
|
||||
|
||||
|
||||
def register_connection(db_url):
|
||||
db = connect(db_url)
|
||||
if isinstance(db, PostgresqlDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to PostgreSQL database")
|
||||
|
||||
# Get the connection details
|
||||
connection = parse(db_url)
|
||||
|
||||
# Use our custom database class that supports reconnection
|
||||
db = ReconnectingPostgresqlDatabase(
|
||||
connection["database"],
|
||||
user=connection["user"],
|
||||
password=connection["password"],
|
||||
host=connection["host"],
|
||||
port=connection["port"],
|
||||
)
|
||||
db.connect(reuse_if_open=True)
|
||||
elif isinstance(db, SqliteDatabase):
|
||||
# Enable autoconnect for SQLite databases, managed by Peewee
|
||||
db.autoconnect = True
|
||||
db.reuse_if_open = True
|
||||
log.info("Connected to SQLite database")
|
||||
else:
|
||||
raise ValueError("Unsupported database connection")
|
||||
return db
|
||||
@@ -14,7 +14,12 @@ from apps.webui.routers import (
|
||||
configs,
|
||||
memories,
|
||||
utils,
|
||||
files,
|
||||
functions,
|
||||
)
|
||||
from apps.webui.models.functions import Functions
|
||||
from apps.webui.utils import load_function_module_by_id
|
||||
|
||||
from config import (
|
||||
WEBUI_BUILD_HASH,
|
||||
SHOW_ADMIN_DETAILS,
|
||||
@@ -27,6 +32,7 @@ from config import (
|
||||
USER_PERMISSIONS,
|
||||
WEBHOOK_URL,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
JWT_EXPIRES_IN,
|
||||
WEBUI_BANNERS,
|
||||
ENABLE_COMMUNITY_SHARING,
|
||||
@@ -42,6 +48,7 @@ app.state.config = AppConfig()
|
||||
app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
|
||||
app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
app.state.AUTH_TRUSTED_NAME_HEADER = WEBUI_AUTH_TRUSTED_NAME_HEADER
|
||||
|
||||
|
||||
app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
|
||||
@@ -59,7 +66,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
|
||||
|
||||
app.state.MODELS = {}
|
||||
app.state.TOOLS = {}
|
||||
|
||||
app.state.FUNCTIONS = {}
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
@@ -69,17 +76,21 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
app.include_router(auths.router, prefix="/auths", tags=["auths"])
|
||||
app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
|
||||
app.include_router(documents.router, prefix="/documents", tags=["documents"])
|
||||
app.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
|
||||
app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
app.include_router(files.router, prefix="/files", tags=["files"])
|
||||
app.include_router(tools.router, prefix="/tools", tags=["tools"])
|
||||
app.include_router(functions.router, prefix="/functions", tags=["functions"])
|
||||
|
||||
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
||||
|
||||
|
||||
@@ -91,3 +102,58 @@ async def get_status():
|
||||
"default_models": app.state.config.DEFAULT_MODELS,
|
||||
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
||||
}
|
||||
|
||||
|
||||
async def get_pipe_models():
|
||||
pipes = Functions.get_functions_by_type("pipe")
|
||||
pipe_models = []
|
||||
|
||||
for pipe in pipes:
|
||||
# Check if function is already loaded
|
||||
if pipe.id not in app.state.FUNCTIONS:
|
||||
function_module, function_type = load_function_module_by_id(pipe.id)
|
||||
app.state.FUNCTIONS[pipe.id] = function_module
|
||||
else:
|
||||
function_module = app.state.FUNCTIONS[pipe.id]
|
||||
|
||||
# Check if function is a manifold
|
||||
if hasattr(function_module, "type"):
|
||||
if function_module.type == "manifold":
|
||||
manifold_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
if callable(function_module.pipes):
|
||||
manifold_pipes = function_module.pipes()
|
||||
else:
|
||||
manifold_pipes = function_module.pipes
|
||||
|
||||
for p in manifold_pipes:
|
||||
manifold_pipe_id = f'{pipe.id}.{p["id"]}'
|
||||
manifold_pipe_name = p["name"]
|
||||
|
||||
if hasattr(function_module, "name"):
|
||||
manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
|
||||
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": manifold_pipe_id,
|
||||
"name": manifold_pipe_name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": {"type": pipe.type},
|
||||
}
|
||||
)
|
||||
else:
|
||||
pipe_models.append(
|
||||
{
|
||||
"id": pipe.id,
|
||||
"name": pipe.name,
|
||||
"object": "model",
|
||||
"created": pipe.created_at,
|
||||
"owned_by": "openai",
|
||||
"pipe": {"type": "pipe"},
|
||||
}
|
||||
)
|
||||
|
||||
return pipe_models
|
||||
|
||||
112
backend/apps/webui/models/files.py
Normal file
112
backend/apps/webui/models/files.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from pydantic import BaseModel
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
import logging
|
||||
from apps.webui.internal.db import DB, JSONField
|
||||
|
||||
import json
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Files DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class File(Model):
|
||||
id = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
filename = TextField()
|
||||
meta = JSONField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class FileModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
filename: str
|
||||
meta: dict
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FileModelResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
filename: str
|
||||
meta: dict
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FileForm(BaseModel):
|
||||
id: str
|
||||
filename: str
|
||||
meta: dict = {}
|
||||
|
||||
|
||||
class FilesTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.db.create_tables([File])
|
||||
|
||||
def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
|
||||
file = FileModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = File.create(**file.model_dump())
|
||||
if result:
|
||||
return file
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
return None
|
||||
|
||||
def get_file_by_id(self, id: str) -> Optional[FileModel]:
|
||||
try:
|
||||
file = File.get(File.id == id)
|
||||
return FileModel(**model_to_dict(file))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_files(self) -> List[FileModel]:
|
||||
return [FileModel(**model_to_dict(file)) for file in File.select()]
|
||||
|
||||
def delete_file_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = File.delete().where((File.id == id))
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
def delete_all_files(self) -> bool:
|
||||
try:
|
||||
query = File.delete()
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
Files = FilesTable(DB)
|
||||
141
backend/apps/webui/models/functions.py
Normal file
141
backend/apps/webui/models/functions.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from pydantic import BaseModel
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from typing import List, Union, Optional
|
||||
import time
|
||||
import logging
|
||||
from apps.webui.internal.db import DB, JSONField
|
||||
|
||||
import json
|
||||
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
####################
|
||||
# Functions DB Schema
|
||||
####################
|
||||
|
||||
|
||||
class Function(Model):
|
||||
id = CharField(unique=True)
|
||||
user_id = CharField()
|
||||
name = TextField()
|
||||
type = TextField()
|
||||
content = TextField()
|
||||
meta = JSONField()
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class FunctionMeta(BaseModel):
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class FunctionModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
name: str
|
||||
type: str
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
type: str
|
||||
name: str
|
||||
meta: FunctionMeta
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class FunctionForm(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
content: str
|
||||
meta: FunctionMeta
|
||||
|
||||
|
||||
class FunctionsTable:
|
||||
def __init__(self, db):
|
||||
self.db = db
|
||||
self.db.create_tables([Function])
|
||||
|
||||
def insert_new_function(
|
||||
self, user_id: str, type: str, form_data: FunctionForm
|
||||
) -> Optional[FunctionModel]:
|
||||
function = FunctionModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"type": type,
|
||||
"updated_at": int(time.time()),
|
||||
"created_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
result = Function.create(**function.model_dump())
|
||||
if result:
|
||||
return function
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f"Error creating tool: {e}")
|
||||
return None
|
||||
|
||||
def get_function_by_id(self, id: str) -> Optional[FunctionModel]:
|
||||
try:
|
||||
function = Function.get(Function.id == id)
|
||||
return FunctionModel(**model_to_dict(function))
|
||||
except:
|
||||
return None
|
||||
|
||||
def get_functions(self) -> List[FunctionModel]:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function)) for function in Function.select()
|
||||
]
|
||||
|
||||
def get_functions_by_type(self, type: str) -> List[FunctionModel]:
|
||||
return [
|
||||
FunctionModel(**model_to_dict(function))
|
||||
for function in Function.select().where(Function.type == type)
|
||||
]
|
||||
|
||||
def update_function_by_id(self, id: str, updated: dict) -> Optional[FunctionModel]:
|
||||
try:
|
||||
query = Function.update(
|
||||
**updated,
|
||||
updated_at=int(time.time()),
|
||||
).where(Function.id == id)
|
||||
query.execute()
|
||||
|
||||
function = Function.get(Function.id == id)
|
||||
return FunctionModel(**model_to_dict(function))
|
||||
except:
|
||||
return None
|
||||
|
||||
def delete_function_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Function.delete().where((Function.id == id))
|
||||
query.execute() # Remove the rows, return number of rows removed.
|
||||
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
Functions = FunctionsTable(DB)
|
||||
@@ -26,6 +26,7 @@ class User(Model):
|
||||
|
||||
api_key = CharField(null=True, unique=True)
|
||||
settings = JSONField(null=True)
|
||||
info = JSONField(null=True)
|
||||
|
||||
oauth_sub = TextField(null=True, unique=True)
|
||||
|
||||
@@ -52,6 +53,7 @@ class UserModel(BaseModel):
|
||||
|
||||
api_key: Optional[str] = None
|
||||
settings: Optional[UserSettings] = None
|
||||
info: Optional[dict] = None
|
||||
|
||||
oauth_sub: Optional[str] = None
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
|
||||
from fastapi import Request, UploadFile, File
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.responses import Response
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
@@ -35,6 +36,7 @@ from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
||||
from config import (
|
||||
WEBUI_AUTH,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
@@ -45,7 +47,21 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/", response_model=UserResponse)
|
||||
async def get_session_user(user=Depends(get_current_user)):
|
||||
async def get_session_user(
|
||||
request: Request, response: Response, user=Depends(get_current_user)
|
||||
):
|
||||
token = create_token(
|
||||
data={"id": user.id},
|
||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
)
|
||||
|
||||
return {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
@@ -106,17 +122,22 @@ async def update_password(
|
||||
|
||||
|
||||
@router.post("/signin", response_model=SigninResponse)
|
||||
async def signin(request: Request, form_data: SigninForm):
|
||||
async def signin(request: Request, response: Response, form_data: SigninForm):
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER not in request.headers:
|
||||
raise HTTPException(400, detail=ERROR_MESSAGES.INVALID_TRUSTED_HEADER)
|
||||
|
||||
trusted_email = request.headers[WEBUI_AUTH_TRUSTED_EMAIL_HEADER].lower()
|
||||
trusted_name = trusted_email
|
||||
if WEBUI_AUTH_TRUSTED_NAME_HEADER:
|
||||
trusted_name = request.headers.get(
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER, trusted_email
|
||||
)
|
||||
if not Users.get_user_by_email(trusted_email.lower()):
|
||||
await signup(
|
||||
request,
|
||||
SignupForm(
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_email
|
||||
email=trusted_email, password=str(uuid.uuid4()), name=trusted_name
|
||||
),
|
||||
)
|
||||
user = Auths.authenticate_user_by_trusted_header(trusted_email)
|
||||
@@ -145,6 +166,13 @@ async def signin(request: Request, form_data: SigninForm):
|
||||
expires_delta=parse_duration(request.app.state.config.JWT_EXPIRES_IN),
|
||||
)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
)
|
||||
|
||||
return {
|
||||
"token": token,
|
||||
"token_type": "Bearer",
|
||||
@@ -164,7 +192,7 @@ async def signin(request: Request, form_data: SigninForm):
|
||||
|
||||
|
||||
@router.post("/signup", response_model=SigninResponse)
|
||||
async def signup(request: Request, form_data: SignupForm):
|
||||
async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
if not request.app.state.config.ENABLE_SIGNUP and WEBUI_AUTH:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
@@ -200,6 +228,13 @@ async def signup(request: Request, form_data: SignupForm):
|
||||
)
|
||||
# response.set_cookie(key='token', value=token, httponly=True)
|
||||
|
||||
# Set the cookie token
|
||||
response.set_cookie(
|
||||
key="token",
|
||||
value=token,
|
||||
httponly=True, # Ensures the cookie is not accessible via JavaScript
|
||||
)
|
||||
|
||||
if request.app.state.config.WEBHOOK_URL:
|
||||
post_webhook(
|
||||
request.app.state.config.WEBHOOK_URL,
|
||||
|
||||
219
backend/apps/webui/routers/files.py
Normal file
219
backend/apps/webui/routers/files.py
Normal file
@@ -0,0 +1,219 @@
|
||||
from fastapi import (
|
||||
Depends,
|
||||
FastAPI,
|
||||
HTTPException,
|
||||
status,
|
||||
Request,
|
||||
UploadFile,
|
||||
File,
|
||||
Form,
|
||||
)
|
||||
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
|
||||
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from apps.webui.models.files import (
|
||||
Files,
|
||||
FileForm,
|
||||
FileModel,
|
||||
FileModelResponse,
|
||||
)
|
||||
from utils.utils import get_verified_user, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from importlib import util
|
||||
import os
|
||||
import uuid
|
||||
import os, shutil, logging, re
|
||||
|
||||
|
||||
from config import SRC_LOG_LEVELS, UPLOAD_DIR
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# Upload File
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/")
|
||||
def upload_file(
|
||||
file: UploadFile = File(...),
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
log.info(f"file.content_type: {file.content_type}")
|
||||
try:
|
||||
unsanitized_filename = file.filename
|
||||
filename = os.path.basename(unsanitized_filename)
|
||||
|
||||
# replace filename with uuid
|
||||
id = str(uuid.uuid4())
|
||||
filename = f"{id}_{filename}"
|
||||
file_path = f"{UPLOAD_DIR}/{filename}"
|
||||
|
||||
contents = file.file.read()
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(contents)
|
||||
f.close()
|
||||
|
||||
file = Files.insert_new_file(
|
||||
user.id,
|
||||
FileForm(
|
||||
**{
|
||||
"id": id,
|
||||
"filename": filename,
|
||||
"meta": {
|
||||
"content_type": file.content_type,
|
||||
"size": len(contents),
|
||||
"path": file_path,
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
if file:
|
||||
return file
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# List Files
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[FileModel])
|
||||
async def list_files(user=Depends(get_verified_user)):
|
||||
files = Files.get_files()
|
||||
return files
|
||||
|
||||
|
||||
############################
|
||||
# Delete All Files
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/all")
|
||||
async def delete_all_files(user=Depends(get_admin_user)):
|
||||
result = Files.delete_all_files()
|
||||
|
||||
if result:
|
||||
folder = f"{UPLOAD_DIR}"
|
||||
try:
|
||||
# Check if the directory exists
|
||||
if os.path.exists(folder):
|
||||
# Iterate over all the files and directories in the specified directory
|
||||
for filename in os.listdir(folder):
|
||||
file_path = os.path.join(folder, filename)
|
||||
try:
|
||||
if os.path.isfile(file_path) or os.path.islink(file_path):
|
||||
os.unlink(file_path) # Remove the file or link
|
||||
elif os.path.isdir(file_path):
|
||||
shutil.rmtree(file_path) # Remove the directory
|
||||
except Exception as e:
|
||||
print(f"Failed to delete {file_path}. Reason: {e}")
|
||||
else:
|
||||
print(f"The directory {folder} does not exist")
|
||||
except Exception as e:
|
||||
print(f"Failed to process the directory {folder}. Reason: {e}")
|
||||
|
||||
return {"message": "All files deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting files"),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Get File By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[FileModel])
|
||||
async def get_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file:
|
||||
return file
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Get File Content By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}/content", response_model=Optional[FileModel])
|
||||
async def get_file_content_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file:
|
||||
file_path = Path(file.meta["path"])
|
||||
|
||||
# Check if the file already exists in the cache
|
||||
if file_path.is_file():
|
||||
print(f"file_path: {file_path}")
|
||||
return FileResponse(file_path)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# Delete File By Id
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}")
|
||||
async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
|
||||
file = Files.get_file_by_id(id)
|
||||
|
||||
if file:
|
||||
result = Files.delete_file_by_id(id)
|
||||
if result:
|
||||
return {"message": "File deleted successfully"}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error deleting file"),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
180
backend/apps/webui/routers/functions.py
Normal file
180
backend/apps/webui/routers/functions.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
|
||||
from apps.webui.models.functions import (
|
||||
Functions,
|
||||
FunctionForm,
|
||||
FunctionModel,
|
||||
FunctionResponse,
|
||||
)
|
||||
from apps.webui.utils import load_function_module_by_id
|
||||
from utils.utils import get_verified_user, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
from importlib import util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetFunctions
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[FunctionResponse])
|
||||
async def get_functions(user=Depends(get_verified_user)):
|
||||
return Functions.get_functions()
|
||||
|
||||
|
||||
############################
|
||||
# ExportFunctions
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/export", response_model=List[FunctionModel])
|
||||
async def get_functions(user=Depends(get_admin_user)):
|
||||
return Functions.get_functions()
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewFunction
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[FunctionResponse])
|
||||
async def create_new_function(
|
||||
request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
|
||||
):
|
||||
if not form_data.id.isidentifier():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Only alphanumeric characters and underscores are allowed in the id",
|
||||
)
|
||||
|
||||
form_data.id = form_data.id.lower()
|
||||
|
||||
function = Functions.get_function_by_id(form_data.id)
|
||||
if function == None:
|
||||
function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
|
||||
try:
|
||||
with open(function_path, "w") as function_file:
|
||||
function_file.write(form_data.content)
|
||||
|
||||
function_module, function_type = load_function_module_by_id(form_data.id)
|
||||
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
FUNCTIONS[form_data.id] = function_module
|
||||
|
||||
function = Functions.insert_new_function(user.id, function_type, form_data)
|
||||
|
||||
function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
|
||||
function_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if function:
|
||||
return function
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.ID_TAKEN,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetFunctionById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/id/{id}", response_model=Optional[FunctionModel])
|
||||
async def get_function_by_id(id: str, user=Depends(get_admin_user)):
|
||||
function = Functions.get_function_by_id(id)
|
||||
|
||||
if function:
|
||||
return function
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateFunctionById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
|
||||
async def update_toolkit_by_id(
|
||||
request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
|
||||
):
|
||||
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
|
||||
|
||||
try:
|
||||
with open(function_path, "w") as function_file:
|
||||
function_file.write(form_data.content)
|
||||
|
||||
function_module, function_type = load_function_module_by_id(id)
|
||||
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
FUNCTIONS[id] = function_module
|
||||
|
||||
updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
|
||||
print(updated)
|
||||
|
||||
function = Functions.update_function_by_id(id, updated)
|
||||
|
||||
if function:
|
||||
return function
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.DEFAULT(e),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteFunctionById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/id/{id}/delete", response_model=bool)
|
||||
async def delete_function_by_id(
|
||||
request: Request, id: str, user=Depends(get_admin_user)
|
||||
):
|
||||
result = Functions.delete_function_by_id(id)
|
||||
|
||||
if result:
|
||||
FUNCTIONS = request.app.state.FUNCTIONS
|
||||
if id in FUNCTIONS:
|
||||
del FUNCTIONS[id]
|
||||
|
||||
# delete the function file
|
||||
function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
|
||||
os.remove(function_path)
|
||||
|
||||
return result
|
||||
@@ -15,8 +15,9 @@ from constants import ERROR_MESSAGES
|
||||
|
||||
from importlib import util
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from config import DATA_DIR
|
||||
from config import DATA_DIR, CACHE_DIR
|
||||
|
||||
|
||||
TOOLS_DIR = f"{DATA_DIR}/tools"
|
||||
@@ -79,6 +80,9 @@ async def create_new_toolkit(
|
||||
specs = get_tools_specs(TOOLS[form_data.id])
|
||||
toolkit = Tools.insert_new_tool(user.id, form_data, specs)
|
||||
|
||||
tool_cache_dir = Path(CACHE_DIR) / "tools" / form_data.id
|
||||
tool_cache_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if toolkit:
|
||||
return toolkit
|
||||
else:
|
||||
|
||||
@@ -115,6 +115,52 @@ async def update_user_settings_by_session_user(
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserInfoBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/user/info", response_model=Optional[dict])
|
||||
async def get_user_info_by_session_user(user=Depends(get_verified_user)):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateUserInfoBySessionUser
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/user/info/update", response_model=Optional[dict])
|
||||
async def update_user_settings_by_session_user(
|
||||
form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
user = Users.get_user_by_id(user.id)
|
||||
if user:
|
||||
if user.info is None:
|
||||
user.info = {}
|
||||
|
||||
user = Users.update_user_by_id(user.id, {"info": {**user.info, **form_data}})
|
||||
if user:
|
||||
return user.info
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=ERROR_MESSAGES.USER_NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetUserById
|
||||
############################
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from importlib import util
|
||||
import os
|
||||
|
||||
from config import TOOLS_DIR
|
||||
from config import TOOLS_DIR, FUNCTIONS_DIR
|
||||
|
||||
|
||||
def load_toolkit_module_by_id(toolkit_id):
|
||||
@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
|
||||
# Move the file to the error folder
|
||||
os.rename(toolkit_path, f"{toolkit_path}.error")
|
||||
raise e
|
||||
|
||||
|
||||
def load_function_module_by_id(function_id):
|
||||
function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
|
||||
|
||||
spec = util.spec_from_file_location(function_id, function_path)
|
||||
module = util.module_from_spec(spec)
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
print(f"Loaded module: {module.__name__}")
|
||||
if hasattr(module, "Pipe"):
|
||||
return module.Pipe(), "pipe"
|
||||
elif hasattr(module, "Filter"):
|
||||
return module.Filter(), "filter"
|
||||
else:
|
||||
raise Exception("No Function class found")
|
||||
except Exception as e:
|
||||
print(f"Error loading module: {function_id}")
|
||||
# Move the file to the error folder
|
||||
os.rename(function_path, f"{function_path}.error")
|
||||
raise e
|
||||
|
||||
@@ -294,6 +294,7 @@ WEBUI_AUTH = os.environ.get("WEBUI_AUTH", "True").lower() == "true"
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER = os.environ.get(
|
||||
"WEBUI_AUTH_TRUSTED_EMAIL_HEADER", None
|
||||
)
|
||||
WEBUI_AUTH_TRUSTED_NAME_HEADER = os.environ.get("WEBUI_AUTH_TRUSTED_NAME_HEADER", None)
|
||||
JWT_EXPIRES_IN = PersistentConfig(
|
||||
"JWT_EXPIRES_IN", "auth.jwt_expiry", os.environ.get("JWT_EXPIRES_IN", "-1")
|
||||
)
|
||||
@@ -505,6 +506,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
|
||||
Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# Functions DIR
|
||||
####################################
|
||||
|
||||
FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
|
||||
Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
####################################
|
||||
# LITELLM_CONFIG
|
||||
####################################
|
||||
@@ -554,7 +563,14 @@ OLLAMA_API_BASE_URL = os.environ.get(
|
||||
)
|
||||
|
||||
OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
|
||||
AIOHTTP_CLIENT_TIMEOUT = int(os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300"))
|
||||
AIOHTTP_CLIENT_TIMEOUT = os.environ.get("AIOHTTP_CLIENT_TIMEOUT", "300")
|
||||
|
||||
if AIOHTTP_CLIENT_TIMEOUT == "":
|
||||
AIOHTTP_CLIENT_TIMEOUT = None
|
||||
else:
|
||||
AIOHTTP_CLIENT_TIMEOUT = int(AIOHTTP_CLIENT_TIMEOUT)
|
||||
|
||||
|
||||
K8S_FLAG = os.environ.get("K8S_FLAG", "")
|
||||
USE_OLLAMA_DOCKER = os.environ.get("USE_OLLAMA_DOCKER", "false")
|
||||
|
||||
@@ -1034,6 +1050,18 @@ RAG_WEB_SEARCH_ENGINE = PersistentConfig(
|
||||
os.getenv("RAG_WEB_SEARCH_ENGINE", ""),
|
||||
)
|
||||
|
||||
# You can provide a list of your own websites to filter after performing a web search.
|
||||
# This ensures the highest level of safety and reliability of the information sources.
|
||||
RAG_WEB_SEARCH_DOMAIN_FILTER_LIST = PersistentConfig(
|
||||
"RAG_WEB_SEARCH_DOMAIN_FILTER_LIST",
|
||||
"rag.rag.web.search.domain.filter_list",
|
||||
[
|
||||
# "wikipedia.com",
|
||||
# "wikimedia.org",
|
||||
# "wikidata.org",
|
||||
],
|
||||
)
|
||||
|
||||
SEARXNG_QUERY_URL = PersistentConfig(
|
||||
"SEARXNG_QUERY_URL",
|
||||
"rag.web.search.searxng_query_url",
|
||||
@@ -1139,6 +1167,30 @@ COMFYUI_BASE_URL = PersistentConfig(
|
||||
os.getenv("COMFYUI_BASE_URL", ""),
|
||||
)
|
||||
|
||||
COMFYUI_CFG_SCALE = PersistentConfig(
|
||||
"COMFYUI_CFG_SCALE",
|
||||
"image_generation.comfyui.cfg_scale",
|
||||
os.getenv("COMFYUI_CFG_SCALE", ""),
|
||||
)
|
||||
|
||||
COMFYUI_SAMPLER = PersistentConfig(
|
||||
"COMFYUI_SAMPLER",
|
||||
"image_generation.comfyui.sampler",
|
||||
os.getenv("COMFYUI_SAMPLER", ""),
|
||||
)
|
||||
|
||||
COMFYUI_SCHEDULER = PersistentConfig(
|
||||
"COMFYUI_SCHEDULER",
|
||||
"image_generation.comfyui.scheduler",
|
||||
os.getenv("COMFYUI_SCHEDULER", ""),
|
||||
)
|
||||
|
||||
COMFYUI_SD3 = PersistentConfig(
|
||||
"COMFYUI_SD3",
|
||||
"image_generation.comfyui.sd3",
|
||||
os.environ.get("COMFYUI_SD3", "").lower() == "true",
|
||||
)
|
||||
|
||||
IMAGES_OPENAI_API_BASE_URL = PersistentConfig(
|
||||
"IMAGES_OPENAI_API_BASE_URL",
|
||||
"image_generation.openai.api_base_url",
|
||||
|
||||
698
backend/main.py
698
backend/main.py
@@ -15,9 +15,11 @@ import requests
|
||||
import mimetypes
|
||||
import shutil
|
||||
import os
|
||||
import uuid
|
||||
import inspect
|
||||
import asyncio
|
||||
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import JSONResponse
|
||||
@@ -46,16 +48,19 @@ from apps.openai.main import (
|
||||
from apps.audio.main import app as audio_app
|
||||
from apps.images.main import app as images_app
|
||||
from apps.rag.main import app as rag_app
|
||||
from apps.webui.main import app as webui_app
|
||||
from apps.webui.main import app as webui_app, get_pipe_models
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Iterator, Generator, Union
|
||||
|
||||
from apps.webui.models.auths import Auths
|
||||
from apps.webui.models.models import Models, ModelModel
|
||||
from apps.webui.models.tools import Tools
|
||||
from apps.webui.models.functions import Functions
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
||||
from apps.webui.utils import load_toolkit_module_by_id
|
||||
|
||||
from utils.misc import parse_duration
|
||||
@@ -72,7 +77,11 @@ from utils.task import (
|
||||
search_query_generation_template,
|
||||
tools_function_calling_generation_template,
|
||||
)
|
||||
from utils.misc import get_last_user_message, add_or_update_system_message
|
||||
from utils.misc import (
|
||||
get_last_user_message,
|
||||
add_or_update_system_message,
|
||||
stream_message_template,
|
||||
)
|
||||
|
||||
from apps.rag.utils import get_rag_context, rag_template
|
||||
|
||||
@@ -85,6 +94,7 @@ from config import (
|
||||
VERSION,
|
||||
CHANGELOG,
|
||||
FRONTEND_BUILD_DIR,
|
||||
UPLOAD_DIR,
|
||||
CACHE_DIR,
|
||||
STATIC_DIR,
|
||||
ENABLE_OPENAI_API,
|
||||
@@ -184,7 +194,16 @@ app.state.MODELS = {}
|
||||
origins = ["*"]
|
||||
|
||||
|
||||
async def get_function_call_response(messages, tool_id, template, task_model_id, user):
|
||||
##################################
|
||||
#
|
||||
# ChatCompletion Middleware
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
async def get_function_call_response(
|
||||
messages, files, tool_id, template, task_model_id, user
|
||||
):
|
||||
tool = Tools.get_tool_by_id(tool_id)
|
||||
tools_specs = json.dumps(tool.specs, indent=2)
|
||||
content = tools_function_calling_generation_template(template, tools_specs)
|
||||
@@ -222,9 +241,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
||||
response = None
|
||||
try:
|
||||
if model["owned_by"] == "ollama":
|
||||
response = await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
response = await generate_ollama_chat_completion(payload, user=user)
|
||||
else:
|
||||
response = await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
@@ -247,6 +264,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
||||
result = json.loads(content)
|
||||
print(result)
|
||||
|
||||
citation = None
|
||||
# Call the function
|
||||
if "name" in result:
|
||||
if tool_id in webui_app.state.TOOLS:
|
||||
@@ -255,76 +273,170 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
||||
toolkit_module = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = toolkit_module
|
||||
|
||||
file_handler = False
|
||||
# check if toolkit_module has file_handler self variable
|
||||
if hasattr(toolkit_module, "file_handler"):
|
||||
file_handler = True
|
||||
print("file_handler: ", file_handler)
|
||||
|
||||
function = getattr(toolkit_module, result["name"])
|
||||
function_result = None
|
||||
try:
|
||||
# Get the signature of the function
|
||||
sig = inspect.signature(function)
|
||||
# Check if '__user__' is a parameter of the function
|
||||
params = result["parameters"]
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
# Call the function with the '__user__' parameter included
|
||||
function_result = function(
|
||||
**{
|
||||
**result["parameters"],
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
}
|
||||
)
|
||||
params = {
|
||||
**params,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
}
|
||||
|
||||
if "__messages__" in sig.parameters:
|
||||
# Call the function with the '__messages__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__messages__": messages,
|
||||
}
|
||||
|
||||
if "__files__" in sig.parameters:
|
||||
# Call the function with the '__files__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__files__": files,
|
||||
}
|
||||
|
||||
if "__model__" in sig.parameters:
|
||||
# Call the function with the '__model__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__model__": model,
|
||||
}
|
||||
|
||||
if "__id__" in sig.parameters:
|
||||
# Call the function with the '__id__' parameter included
|
||||
params = {
|
||||
**params,
|
||||
"__id__": tool_id,
|
||||
}
|
||||
|
||||
if inspect.iscoroutinefunction(function):
|
||||
function_result = await function(**params)
|
||||
else:
|
||||
# Call the function without modifying the parameters
|
||||
function_result = function(**result["parameters"])
|
||||
function_result = function(**params)
|
||||
|
||||
if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
||||
citation = {
|
||||
"source": {"name": f"TOOL:{tool.name}/{result['name']}"},
|
||||
"document": [function_result],
|
||||
"metadata": [{"source": result["name"]}],
|
||||
}
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
# Add the function result to the system prompt
|
||||
if function_result:
|
||||
return function_result
|
||||
if function_result is not None:
|
||||
return function_result, citation, file_handler
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
|
||||
return None
|
||||
return None, None, False
|
||||
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
return_citations = False
|
||||
data_items = []
|
||||
|
||||
if request.method == "POST" and (
|
||||
"/ollama/api/chat" in request.url.path
|
||||
or "/chat/completions" in request.url.path
|
||||
show_citations = False
|
||||
citations = []
|
||||
|
||||
if request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
):
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
|
||||
# Read the original request body
|
||||
body = await request.body()
|
||||
# Decode body to string
|
||||
body_str = body.decode("utf-8")
|
||||
# Parse string to JSON
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
)
|
||||
|
||||
# Remove the citations from the body
|
||||
return_citations = data.get("citations", False)
|
||||
if "citations" in data:
|
||||
# Flag to skip RAG completions if file_handler is present in tools/functions
|
||||
skip_files = False
|
||||
if data.get("citations"):
|
||||
show_citations = True
|
||||
del data["citations"]
|
||||
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
if task_model_id not in app.state.MODELS:
|
||||
model_id = data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
# Check if the user has a custom task model
|
||||
# If the user has a custom task model, use that model
|
||||
# Check if the model has any filters
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
# Check if the function has a file_handler variable
|
||||
if hasattr(function_module, "file_handler"):
|
||||
skip_files = function_module.file_handler
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "inlet"):
|
||||
inlet = function_module.inlet
|
||||
|
||||
if inspect.iscoroutinefunction(inlet):
|
||||
data = await inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = inlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
# Set the task model
|
||||
task_model_id = data["model"]
|
||||
# Check if the user has a custom task model and use that model
|
||||
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
||||
if (
|
||||
app.state.config.TASK_MODEL
|
||||
@@ -347,55 +459,71 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
for tool_id in data["tool_ids"]:
|
||||
print(tool_id)
|
||||
try:
|
||||
response = await get_function_call_response(
|
||||
messages=data["messages"],
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
response, citation, file_handler = (
|
||||
await get_function_call_response(
|
||||
messages=data["messages"],
|
||||
files=data.get("files", []),
|
||||
tool_id=tool_id,
|
||||
template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
||||
task_model_id=task_model_id,
|
||||
user=user,
|
||||
)
|
||||
)
|
||||
|
||||
if response:
|
||||
print(file_handler)
|
||||
if isinstance(response, str):
|
||||
context += ("\n" if context != "" else "") + response
|
||||
|
||||
if citation:
|
||||
citations.append(citation)
|
||||
show_citations = True
|
||||
|
||||
if file_handler:
|
||||
skip_files = True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
del data["tool_ids"]
|
||||
|
||||
print(f"tool_context: {context}")
|
||||
|
||||
# If docs field is present, generate RAG completions
|
||||
if "docs" in data:
|
||||
data = {**data}
|
||||
rag_context, citations = get_rag_context(
|
||||
docs=data["docs"],
|
||||
messages=data["messages"],
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
# If files field is present, generate RAG completions
|
||||
# If skip_files is True, skip the RAG completions
|
||||
if "files" in data:
|
||||
if not skip_files:
|
||||
data = {**data}
|
||||
rag_context, rag_citations = get_rag_context(
|
||||
files=data["files"],
|
||||
messages=data["messages"],
|
||||
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
||||
k=rag_app.state.config.TOP_K,
|
||||
reranking_function=rag_app.state.sentence_transformer_rf,
|
||||
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
||||
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
||||
)
|
||||
if rag_context:
|
||||
context += ("\n" if context != "" else "") + rag_context
|
||||
|
||||
if rag_context:
|
||||
context += ("\n" if context != "" else "") + rag_context
|
||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||
|
||||
del data["docs"]
|
||||
if rag_citations:
|
||||
citations.extend(rag_citations)
|
||||
|
||||
log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
||||
del data["files"]
|
||||
|
||||
if show_citations and len(citations) > 0:
|
||||
data_items.append({"citations": citations})
|
||||
|
||||
if context != "":
|
||||
system_prompt = rag_template(
|
||||
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
||||
)
|
||||
|
||||
print(system_prompt)
|
||||
|
||||
data["messages"] = add_or_update_system_message(
|
||||
f"\n{system_prompt}", data["messages"]
|
||||
system_prompt, data["messages"]
|
||||
)
|
||||
|
||||
modified_body_bytes = json.dumps(data).encode("utf-8")
|
||||
|
||||
# Replace the request body with the modified one
|
||||
request._body = modified_body_bytes
|
||||
# Set custom header to ensure content-length matches new body length
|
||||
@@ -408,43 +536,54 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
],
|
||||
]
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if return_citations:
|
||||
# Inject the citations into the response
|
||||
response = await call_next(request)
|
||||
if isinstance(response, StreamingResponse):
|
||||
# If it's a streaming response, inject it as SSE event or NDJSON line
|
||||
content_type = response.headers.get("Content-Type")
|
||||
if "text/event-stream" in content_type:
|
||||
return StreamingResponse(
|
||||
self.openai_stream_wrapper(response.body_iterator, citations),
|
||||
self.openai_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
if "application/x-ndjson" in content_type:
|
||||
return StreamingResponse(
|
||||
self.ollama_stream_wrapper(response.body_iterator, citations),
|
||||
self.ollama_stream_wrapper(response.body_iterator, data_items),
|
||||
)
|
||||
else:
|
||||
return response
|
||||
|
||||
# If it's not a chat completion request, just pass it through
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
async def _receive(self, body: bytes):
|
||||
return {"type": "http.request", "body": body, "more_body": False}
|
||||
|
||||
async def openai_stream_wrapper(self, original_generator, citations):
|
||||
yield f"data: {json.dumps({'citations': citations})}\n\n"
|
||||
async def openai_stream_wrapper(self, original_generator, data_items):
|
||||
for item in data_items:
|
||||
yield f"data: {json.dumps(item)}\n\n"
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
async def ollama_stream_wrapper(self, original_generator, citations):
|
||||
yield f"{json.dumps({'citations': citations})}\n"
|
||||
async def ollama_stream_wrapper(self, original_generator, data_items):
|
||||
for item in data_items:
|
||||
yield f"{json.dumps(item)}\n"
|
||||
|
||||
async for data in original_generator:
|
||||
yield data
|
||||
|
||||
|
||||
app.add_middleware(ChatCompletionMiddleware)
|
||||
|
||||
##################################
|
||||
#
|
||||
# Pipeline Middleware
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
def filter_pipeline(payload, user):
|
||||
user = {"id": user.id, "name": user.name, "role": user.role}
|
||||
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
||||
model_id = payload["model"]
|
||||
filters = [
|
||||
model
|
||||
@@ -532,7 +671,8 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
data = json.loads(body_str) if body_str else {}
|
||||
|
||||
user = get_current_user(
|
||||
get_http_authorization_cred(request.headers.get("Authorization"))
|
||||
request,
|
||||
get_http_authorization_cred(request.headers.get("Authorization")),
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -600,7 +740,6 @@ async def update_embedding_function(request: Request, call_next):
|
||||
|
||||
app.mount("/ws", socket_app)
|
||||
|
||||
|
||||
app.mount("/ollama", ollama_app)
|
||||
app.mount("/openai", openai_app)
|
||||
|
||||
@@ -614,17 +753,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
||||
|
||||
|
||||
async def get_all_models():
|
||||
pipe_models = []
|
||||
openai_models = []
|
||||
ollama_models = []
|
||||
|
||||
pipe_models = await get_pipe_models()
|
||||
|
||||
if app.state.config.ENABLE_OPENAI_API:
|
||||
openai_models = await get_openai_models()
|
||||
|
||||
openai_models = openai_models["data"]
|
||||
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
ollama_models = await get_ollama_models()
|
||||
|
||||
ollama_models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
@@ -637,9 +777,9 @@ async def get_all_models():
|
||||
for model in ollama_models["models"]
|
||||
]
|
||||
|
||||
models = openai_models + ollama_models
|
||||
custom_models = Models.get_all_models()
|
||||
models = pipe_models + openai_models + ollama_models
|
||||
|
||||
custom_models = Models.get_all_models()
|
||||
for custom_model in custom_models:
|
||||
if custom_model.base_model_id == None:
|
||||
for model in models:
|
||||
@@ -702,6 +842,253 @@ async def get_models(user=Depends(get_verified_user)):
|
||||
return {"data": models}
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
print(model)
|
||||
|
||||
pipe = model.get("pipe")
|
||||
if pipe:
|
||||
form_data["user"] = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
|
||||
async def job():
|
||||
pipe_id = form_data["model"]
|
||||
if "." in pipe_id:
|
||||
pipe_id, sub_pipe_id = pipe_id.split(".", 1)
|
||||
print(pipe_id)
|
||||
|
||||
pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
|
||||
if form_data["stream"]:
|
||||
|
||||
async def stream_content():
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(body=form_data)
|
||||
else:
|
||||
res = pipe(body=form_data)
|
||||
|
||||
if isinstance(res, str):
|
||||
message = stream_message_template(form_data["model"], res)
|
||||
yield f"data: {json.dumps(message)}\n\n"
|
||||
|
||||
if isinstance(res, Iterator):
|
||||
for line in res:
|
||||
if isinstance(line, BaseModel):
|
||||
line = line.model_dump_json()
|
||||
line = f"data: {line}"
|
||||
try:
|
||||
line = line.decode("utf-8")
|
||||
except:
|
||||
pass
|
||||
|
||||
if line.startswith("data:"):
|
||||
yield f"{line}\n\n"
|
||||
else:
|
||||
line = stream_message_template(form_data["model"], line)
|
||||
yield f"data: {json.dumps(line)}\n\n"
|
||||
|
||||
if isinstance(res, str) or isinstance(res, Generator):
|
||||
finish_message = {
|
||||
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": form_data["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
yield f"data: {json.dumps(finish_message)}\n\n"
|
||||
yield f"data: [DONE]"
|
||||
|
||||
return StreamingResponse(
|
||||
stream_content(), media_type="text/event-stream"
|
||||
)
|
||||
else:
|
||||
if inspect.iscoroutinefunction(pipe):
|
||||
res = await pipe(body=form_data)
|
||||
else:
|
||||
res = pipe(body=form_data)
|
||||
|
||||
if isinstance(res, dict):
|
||||
return res
|
||||
elif isinstance(res, BaseModel):
|
||||
return res.model_dump()
|
||||
else:
|
||||
message = ""
|
||||
if isinstance(res, str):
|
||||
message = res
|
||||
if isinstance(res, Generator):
|
||||
for stream in res:
|
||||
message = f"{message}{stream}"
|
||||
|
||||
return {
|
||||
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion",
|
||||
"created": int(time.time()),
|
||||
"model": form_data["model"],
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": message,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
return await job()
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(form_data, user=user)
|
||||
else:
|
||||
return await generate_openai_chat_completion(form_data, user=user)
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
model = app.state.MODELS[model_id]
|
||||
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": {"id": user.id, "name": user.name, "role": user.role},
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
# Check if the model has any filters
|
||||
if "info" in model and "meta" in model["info"]:
|
||||
for filter_id in model["info"]["meta"].get("filterIds", []):
|
||||
filter = Functions.get_function_by_id(filter_id)
|
||||
if filter:
|
||||
if filter_id in webui_app.state.FUNCTIONS:
|
||||
function_module = webui_app.state.FUNCTIONS[filter_id]
|
||||
else:
|
||||
function_module, function_type = load_function_module_by_id(
|
||||
filter_id
|
||||
)
|
||||
webui_app.state.FUNCTIONS[filter_id] = function_module
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "outlet"):
|
||||
outlet = function_module.outlet
|
||||
if inspect.iscoroutinefunction(outlet):
|
||||
data = await outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
else:
|
||||
data = outlet(
|
||||
data,
|
||||
{
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={"detail": str(e)},
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Task Endpoints
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
# TODO: Refactor task API endpoints below into a separate file
|
||||
|
||||
|
||||
@app.get("/api/task/config")
|
||||
async def get_task_config(user=Depends(get_verified_user)):
|
||||
return {
|
||||
@@ -780,7 +1167,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = title_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
template,
|
||||
form_data["prompt"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
@@ -792,7 +1184,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
"title": True,
|
||||
}
|
||||
|
||||
print(payload)
|
||||
log.debug(payload)
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
@@ -803,9 +1195,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
||||
)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
return await generate_ollama_chat_completion(payload, user=user)
|
||||
else:
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
@@ -846,7 +1236,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
||||
|
||||
content = search_query_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
template, form_data["prompt"], {"name": user.name}
|
||||
)
|
||||
|
||||
payload = {
|
||||
@@ -868,9 +1258,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
||||
)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
return await generate_ollama_chat_completion(payload, user=user)
|
||||
else:
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
@@ -909,7 +1297,12 @@ Message: """{{prompt}}"""
|
||||
'''
|
||||
|
||||
content = title_generation_template(
|
||||
template, form_data["prompt"], user.model_dump()
|
||||
template,
|
||||
form_data["prompt"],
|
||||
{
|
||||
"name": user.name,
|
||||
"location": user.info.get("location") if user.info else None,
|
||||
},
|
||||
)
|
||||
|
||||
payload = {
|
||||
@@ -921,7 +1314,7 @@ Message: """{{prompt}}"""
|
||||
"task": True,
|
||||
}
|
||||
|
||||
print(payload)
|
||||
log.debug(payload)
|
||||
|
||||
try:
|
||||
payload = filter_pipeline(payload, user)
|
||||
@@ -932,9 +1325,7 @@ Message: """{{prompt}}"""
|
||||
)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**payload), user=user
|
||||
)
|
||||
return await generate_ollama_chat_completion(payload, user=user)
|
||||
else:
|
||||
return await generate_openai_chat_completion(payload, user=user)
|
||||
|
||||
@@ -967,8 +1358,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
||||
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
||||
|
||||
try:
|
||||
context = await get_function_call_response(
|
||||
form_data["messages"], form_data["tool_id"], template, model_id, user
|
||||
context, citation, file_handler = await get_function_call_response(
|
||||
form_data["messages"],
|
||||
form_data.get("files", []),
|
||||
form_data["tool_id"],
|
||||
template,
|
||||
model_id,
|
||||
user,
|
||||
)
|
||||
return context
|
||||
except Exception as e:
|
||||
@@ -978,94 +1374,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/chat/completions")
|
||||
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
||||
model_id = form_data["model"]
|
||||
if model_id not in app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Model not found",
|
||||
)
|
||||
|
||||
model = app.state.MODELS[model_id]
|
||||
print(model)
|
||||
|
||||
if model["owned_by"] == "ollama":
|
||||
return await generate_ollama_chat_completion(
|
||||
OpenAIChatCompletionForm(**form_data), user=user
|
||||
)
|
||||
else:
|
||||
return await generate_openai_chat_completion(form_data, user=user)
|
||||
##################################
|
||||
#
|
||||
# Pipelines Endpoints
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
@app.post("/api/chat/completed")
|
||||
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
||||
data = form_data
|
||||
model_id = data["model"]
|
||||
|
||||
filters = [
|
||||
model
|
||||
for model in app.state.MODELS.values()
|
||||
if "pipeline" in model
|
||||
and "type" in model["pipeline"]
|
||||
and model["pipeline"]["type"] == "filter"
|
||||
and (
|
||||
model["pipeline"]["pipelines"] == ["*"]
|
||||
or any(
|
||||
model_id == target_model_id
|
||||
for target_model_id in model["pipeline"]["pipelines"]
|
||||
)
|
||||
)
|
||||
]
|
||||
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
||||
|
||||
print(model_id)
|
||||
|
||||
if model_id in app.state.MODELS:
|
||||
model = app.state.MODELS[model_id]
|
||||
if "pipeline" in model:
|
||||
sorted_filters = [model] + sorted_filters
|
||||
|
||||
for filter in sorted_filters:
|
||||
r = None
|
||||
try:
|
||||
urlIdx = filter["urlIdx"]
|
||||
|
||||
url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
||||
key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
||||
|
||||
if key != "":
|
||||
headers = {"Authorization": f"Bearer {key}"}
|
||||
r = requests.post(
|
||||
f"{url}/{filter['id']}/filter/outlet",
|
||||
headers=headers,
|
||||
json={
|
||||
"user": {"id": user.id, "name": user.name, "role": user.role},
|
||||
"body": data,
|
||||
},
|
||||
)
|
||||
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
except Exception as e:
|
||||
# Handle connection error here
|
||||
print(f"Connection error: {e}")
|
||||
|
||||
if r is not None:
|
||||
try:
|
||||
res = r.json()
|
||||
if "detail" in res:
|
||||
return JSONResponse(
|
||||
status_code=r.status_code,
|
||||
content=res,
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
else:
|
||||
pass
|
||||
|
||||
return data
|
||||
# TODO: Refactor pipelines API endpoints below into a separate file
|
||||
|
||||
|
||||
@app.get("/api/pipelines/list")
|
||||
@@ -1388,6 +1704,13 @@ async def update_pipeline_valves(
|
||||
)
|
||||
|
||||
|
||||
##################################
|
||||
#
|
||||
# Config Endpoints
|
||||
#
|
||||
##################################
|
||||
|
||||
|
||||
@app.get("/api/config")
|
||||
async def get_app_config():
|
||||
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
|
||||
@@ -1457,6 +1780,9 @@ async def update_model_filter_config(
|
||||
}
|
||||
|
||||
|
||||
# TODO: webhook endpoint should be under config endpoints
|
||||
|
||||
|
||||
@app.get("/api/webhook")
|
||||
async def get_webhook_url(user=Depends(get_admin_user)):
|
||||
return {
|
||||
|
||||
@@ -3,7 +3,9 @@ import hashlib
|
||||
import json
|
||||
import re
|
||||
from datetime import timedelta
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Tuple
|
||||
import uuid
|
||||
import time
|
||||
|
||||
|
||||
def get_last_user_message(messages: List[dict]) -> str:
|
||||
@@ -28,6 +30,21 @@ def get_last_assistant_message(messages: List[dict]) -> str:
|
||||
return None
|
||||
|
||||
|
||||
def get_system_message(messages: List[dict]) -> dict:
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def remove_system_message(messages: List[dict]) -> List[dict]:
|
||||
return [message for message in messages if message["role"] != "system"]
|
||||
|
||||
|
||||
def pop_system_message(messages: List[dict]) -> Tuple[dict, List[dict]]:
|
||||
return get_system_message(messages), remove_system_message(messages)
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: List[dict]):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
@@ -47,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
|
||||
return messages
|
||||
|
||||
|
||||
def stream_message_template(model: str, message: str):
|
||||
return {
|
||||
"id": f"{model}-{str(uuid.uuid4())}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {"content": message},
|
||||
"logprobs": None,
|
||||
"finish_reason": None,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
# Trim leading and trailing whitespace from
|
||||
# an email address and force all characters
|
||||
|
||||
@@ -6,24 +6,34 @@ from typing import Optional
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: str = None, current_location: str = None
|
||||
template: str, user_name: str = None, user_location: str = None
|
||||
) -> str:
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
|
||||
# Format the date to YYYY-MM-DD
|
||||
formatted_date = current_date.strftime("%Y-%m-%d")
|
||||
formatted_time = current_date.strftime("%I:%M:%S %p")
|
||||
|
||||
# Replace {{CURRENT_DATE}} in the template with the formatted date
|
||||
template = template.replace("{{CURRENT_DATE}}", formatted_date)
|
||||
template = template.replace("{{CURRENT_TIME}}", formatted_time)
|
||||
template = template.replace(
|
||||
"{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}"
|
||||
)
|
||||
|
||||
if user_name:
|
||||
# Replace {{USER_NAME}} in the template with the user's name
|
||||
template = template.replace("{{USER_NAME}}", user_name)
|
||||
else:
|
||||
# Replace {{USER_NAME}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_NAME}}", "Unknown")
|
||||
|
||||
if current_location:
|
||||
# Replace {{CURRENT_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{CURRENT_LOCATION}}", current_location)
|
||||
if user_location:
|
||||
# Replace {{USER_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{USER_LOCATION}}", user_location)
|
||||
else:
|
||||
# Replace {{USER_LOCATION}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_LOCATION}}", "Unknown")
|
||||
|
||||
return template
|
||||
|
||||
@@ -61,7 +71,7 @@ def title_generation_template(
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
@@ -104,7 +114,7 @@ def search_query_generation_template(
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "current_location": user.get("location")}
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from fastapi import HTTPException, status, Depends
|
||||
from fastapi import HTTPException, status, Depends, Request
|
||||
|
||||
from apps.webui.models.users import Users
|
||||
|
||||
@@ -24,7 +24,7 @@ ALGORITHM = "HS256"
|
||||
# Auth Utils
|
||||
##############
|
||||
|
||||
bearer_security = HTTPBearer()
|
||||
bearer_security = HTTPBearer(auto_error=False)
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
@@ -75,13 +75,26 @@ def get_http_authorization_cred(auth_header: str):
|
||||
|
||||
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
token = None
|
||||
|
||||
if auth_token is not None:
|
||||
token = auth_token.credentials
|
||||
|
||||
if token is None and "token" in request.cookies:
|
||||
token = request.cookies.get("token")
|
||||
|
||||
if token is None:
|
||||
raise HTTPException(status_code=403, detail="Not authenticated")
|
||||
|
||||
# auth by api key
|
||||
if auth_token.credentials.startswith("sk-"):
|
||||
return get_current_user_by_api_key(auth_token.credentials)
|
||||
if token.startswith("sk-"):
|
||||
return get_current_user_by_api_key(token)
|
||||
|
||||
# auth by jwt token
|
||||
data = decode_token(auth_token.credentials)
|
||||
data = decode_token(token)
|
||||
if data != None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
|
||||
Reference in New Issue
Block a user