mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
refac: mv backend files to /open_webui dir
This commit is contained in:
BIN
backend/open_webui/utils/logo.png
Normal file
BIN
backend/open_webui/utils/logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 6.0 KiB |
404
backend/open_webui/utils/misc.py
Normal file
404
backend/open_webui/utils/misc.py
Normal file
@@ -0,0 +1,404 @@
|
||||
import hashlib
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
from open_webui.utils.task import prompt_template
|
||||
|
||||
|
||||
def get_last_user_message_item(messages: list[dict]) -> Optional[dict]:
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "user":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def get_content_from_message(message: dict) -> Optional[str]:
|
||||
if isinstance(message["content"], list):
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
return item["text"]
|
||||
else:
|
||||
return message["content"]
|
||||
return None
|
||||
|
||||
|
||||
def get_last_user_message(messages: list[dict]) -> Optional[str]:
|
||||
message = get_last_user_message_item(messages)
|
||||
if message is None:
|
||||
return None
|
||||
|
||||
return get_content_from_message(message)
|
||||
|
||||
|
||||
def get_last_assistant_message(messages: list[dict]) -> Optional[str]:
|
||||
for message in reversed(messages):
|
||||
if message["role"] == "assistant":
|
||||
return get_content_from_message(message)
|
||||
return None
|
||||
|
||||
|
||||
def get_system_message(messages: list[dict]) -> Optional[dict]:
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
return message
|
||||
return None
|
||||
|
||||
|
||||
def remove_system_message(messages: list[dict]) -> list[dict]:
|
||||
return [message for message in messages if message["role"] != "system"]
|
||||
|
||||
|
||||
def pop_system_message(messages: list[dict]) -> tuple[Optional[dict], list[dict]]:
|
||||
return get_system_message(messages), remove_system_message(messages)
|
||||
|
||||
|
||||
def prepend_to_first_user_message_content(
|
||||
content: str, messages: list[dict]
|
||||
) -> list[dict]:
|
||||
for message in messages:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], list):
|
||||
for item in message["content"]:
|
||||
if item["type"] == "text":
|
||||
item["text"] = f"{content}\n{item['text']}"
|
||||
else:
|
||||
message["content"] = f"{content}\n{message['content']}"
|
||||
break
|
||||
return messages
|
||||
|
||||
|
||||
def add_or_update_system_message(content: str, messages: list[dict]):
|
||||
"""
|
||||
Adds a new system message at the beginning of the messages list
|
||||
or updates the existing system message at the beginning.
|
||||
|
||||
:param msg: The message to be added or appended.
|
||||
:param messages: The list of message dictionaries.
|
||||
:return: The updated list of message dictionaries.
|
||||
"""
|
||||
|
||||
if messages and messages[0].get("role") == "system":
|
||||
messages[0]["content"] = f"{content}\n{messages[0]['content']}"
|
||||
else:
|
||||
# Insert at the beginning
|
||||
messages.insert(0, {"role": "system", "content": content})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def openai_chat_message_template(model: str):
|
||||
return {
|
||||
"id": f"{model}-{str(uuid.uuid4())}",
|
||||
"created": int(time.time()),
|
||||
"model": model,
|
||||
"choices": [{"index": 0, "logprobs": None, "finish_reason": None}],
|
||||
}
|
||||
|
||||
|
||||
def openai_chat_chunk_message_template(model: str, message: str) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion.chunk"
|
||||
template["choices"][0]["delta"] = {"content": message}
|
||||
return template
|
||||
|
||||
|
||||
def openai_chat_completion_message_template(model: str, message: str) -> dict:
|
||||
template = openai_chat_message_template(model)
|
||||
template["object"] = "chat.completion"
|
||||
template["choices"][0]["message"] = {"content": message, "role": "assistant"}
|
||||
template["choices"][0]["finish_reason"] = "stop"
|
||||
return template
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> dict:
|
||||
system = params.get("system", None)
|
||||
if not system:
|
||||
return form_data
|
||||
|
||||
if user:
|
||||
template_params = {
|
||||
"user_name": user.name,
|
||||
"user_location": user.info.get("location") if user.info else None,
|
||||
}
|
||||
else:
|
||||
template_params = {}
|
||||
system = prompt_template(system, **template_params)
|
||||
form_data["messages"] = add_or_update_system_message(
|
||||
system, form_data.get("messages", [])
|
||||
)
|
||||
return form_data
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body(
|
||||
params: dict, form_data: dict, mappings: dict[str, Callable]
|
||||
) -> dict:
|
||||
if not params:
|
||||
return form_data
|
||||
|
||||
for key, cast_func in mappings.items():
|
||||
if (value := params.get(key)) is not None:
|
||||
form_data[key] = cast_func(value)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
# inplace function: form_data is modified
|
||||
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
|
||||
mappings = {
|
||||
"temperature": float,
|
||||
"top_p": int,
|
||||
"max_tokens": int,
|
||||
"frequency_penalty": int,
|
||||
"seed": lambda x: x,
|
||||
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
|
||||
}
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
|
||||
def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
opts = [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"seed",
|
||||
"mirostat",
|
||||
"mirostat_eta",
|
||||
"mirostat_tau",
|
||||
"num_ctx",
|
||||
"num_batch",
|
||||
"num_keep",
|
||||
"repeat_last_n",
|
||||
"tfs_z",
|
||||
"top_k",
|
||||
"min_p",
|
||||
"use_mmap",
|
||||
"use_mlock",
|
||||
"num_thread",
|
||||
"num_gpu",
|
||||
]
|
||||
mappings = {i: lambda x: x for i in opts}
|
||||
form_data = apply_model_params_to_body(params, form_data, mappings)
|
||||
|
||||
name_differences = {
|
||||
"max_tokens": "num_predict",
|
||||
"frequency_penalty": "repeat_penalty",
|
||||
}
|
||||
|
||||
for key, value in name_differences.items():
|
||||
if (param := params.get(key, None)) is not None:
|
||||
form_data[value] = param
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def get_gravatar_url(email):
|
||||
# Trim leading and trailing whitespace from
|
||||
# an email address and force all characters
|
||||
# to lower case
|
||||
address = str(email).strip().lower()
|
||||
|
||||
# Create a SHA256 hash of the final string
|
||||
hash_object = hashlib.sha256(address.encode())
|
||||
hash_hex = hash_object.hexdigest()
|
||||
|
||||
# Grab the actual image URL
|
||||
return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
|
||||
|
||||
|
||||
def calculate_sha256(file):
|
||||
sha256 = hashlib.sha256()
|
||||
# Read the file in chunks to efficiently handle large files
|
||||
for chunk in iter(lambda: file.read(8192), b""):
|
||||
sha256.update(chunk)
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
def calculate_sha256_string(string):
|
||||
# Create a new SHA-256 hash object
|
||||
sha256_hash = hashlib.sha256()
|
||||
# Update the hash object with the bytes of the input string
|
||||
sha256_hash.update(string.encode("utf-8"))
|
||||
# Get the hexadecimal representation of the hash
|
||||
hashed_string = sha256_hash.hexdigest()
|
||||
return hashed_string
|
||||
|
||||
|
||||
def validate_email_format(email: str) -> bool:
|
||||
if email.endswith("@localhost"):
|
||||
return True
|
||||
|
||||
return bool(re.match(r"[^@]+@[^@]+\.[^@]+", email))
|
||||
|
||||
|
||||
def sanitize_filename(file_name):
|
||||
# Convert to lowercase
|
||||
lower_case_file_name = file_name.lower()
|
||||
|
||||
# Remove special characters using regular expression
|
||||
sanitized_file_name = re.sub(r"[^\w\s]", "", lower_case_file_name)
|
||||
|
||||
# Replace spaces with dashes
|
||||
final_file_name = re.sub(r"\s+", "-", sanitized_file_name)
|
||||
|
||||
return final_file_name
|
||||
|
||||
|
||||
def extract_folders_after_data_docs(path):
|
||||
# Convert the path to a Path object if it's not already
|
||||
path = Path(path)
|
||||
|
||||
# Extract parts of the path
|
||||
parts = path.parts
|
||||
|
||||
# Find the index of '/data/docs' in the path
|
||||
try:
|
||||
index_data_docs = parts.index("data") + 1
|
||||
index_docs = parts.index("docs", index_data_docs) + 1
|
||||
except ValueError:
|
||||
return []
|
||||
|
||||
# Exclude the filename and accumulate folder names
|
||||
tags = []
|
||||
|
||||
folders = parts[index_docs:-1]
|
||||
for idx, _ in enumerate(folders):
|
||||
tags.append("/".join(folders[: idx + 1]))
|
||||
|
||||
return tags
|
||||
|
||||
|
||||
def parse_duration(duration: str) -> Optional[timedelta]:
|
||||
if duration == "-1" or duration == "0":
|
||||
return None
|
||||
|
||||
# Regular expression to find number and unit pairs
|
||||
pattern = r"(-?\d+(\.\d+)?)(ms|s|m|h|d|w)"
|
||||
matches = re.findall(pattern, duration)
|
||||
|
||||
if not matches:
|
||||
raise ValueError("Invalid duration string")
|
||||
|
||||
total_duration = timedelta()
|
||||
|
||||
for number, _, unit in matches:
|
||||
number = float(number)
|
||||
if unit == "ms":
|
||||
total_duration += timedelta(milliseconds=number)
|
||||
elif unit == "s":
|
||||
total_duration += timedelta(seconds=number)
|
||||
elif unit == "m":
|
||||
total_duration += timedelta(minutes=number)
|
||||
elif unit == "h":
|
||||
total_duration += timedelta(hours=number)
|
||||
elif unit == "d":
|
||||
total_duration += timedelta(days=number)
|
||||
elif unit == "w":
|
||||
total_duration += timedelta(weeks=number)
|
||||
|
||||
return total_duration
|
||||
|
||||
|
||||
def parse_ollama_modelfile(model_text):
|
||||
parameters_meta = {
|
||||
"mirostat": int,
|
||||
"mirostat_eta": float,
|
||||
"mirostat_tau": float,
|
||||
"num_ctx": int,
|
||||
"repeat_last_n": int,
|
||||
"repeat_penalty": float,
|
||||
"temperature": float,
|
||||
"seed": int,
|
||||
"tfs_z": float,
|
||||
"num_predict": int,
|
||||
"top_k": int,
|
||||
"top_p": float,
|
||||
"num_keep": int,
|
||||
"typical_p": float,
|
||||
"presence_penalty": float,
|
||||
"frequency_penalty": float,
|
||||
"penalize_newline": bool,
|
||||
"numa": bool,
|
||||
"num_batch": int,
|
||||
"num_gpu": int,
|
||||
"main_gpu": int,
|
||||
"low_vram": bool,
|
||||
"f16_kv": bool,
|
||||
"vocab_only": bool,
|
||||
"use_mmap": bool,
|
||||
"use_mlock": bool,
|
||||
"num_thread": int,
|
||||
}
|
||||
|
||||
data = {"base_model_id": None, "params": {}}
|
||||
|
||||
# Parse base model
|
||||
base_model_match = re.search(
|
||||
r"^FROM\s+(\w+)", model_text, re.MULTILINE | re.IGNORECASE
|
||||
)
|
||||
if base_model_match:
|
||||
data["base_model_id"] = base_model_match.group(1)
|
||||
|
||||
# Parse template
|
||||
template_match = re.search(
|
||||
r'TEMPLATE\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
if template_match:
|
||||
data["params"] = {"template": template_match.group(1).strip()}
|
||||
|
||||
# Parse stops
|
||||
stops = re.findall(r'PARAMETER stop "(.*?)"', model_text, re.IGNORECASE)
|
||||
if stops:
|
||||
data["params"]["stop"] = stops
|
||||
|
||||
# Parse other parameters from the provided list
|
||||
for param, param_type in parameters_meta.items():
|
||||
param_match = re.search(rf"PARAMETER {param} (.+)", model_text, re.IGNORECASE)
|
||||
if param_match:
|
||||
value = param_match.group(1)
|
||||
|
||||
try:
|
||||
if param_type is int:
|
||||
value = int(value)
|
||||
elif param_type is float:
|
||||
value = float(value)
|
||||
elif param_type is bool:
|
||||
value = value.lower() == "true"
|
||||
except Exception as e:
|
||||
print(e)
|
||||
continue
|
||||
|
||||
data["params"][param] = value
|
||||
|
||||
# Parse adapter
|
||||
adapter_match = re.search(r"ADAPTER (.+)", model_text, re.IGNORECASE)
|
||||
if adapter_match:
|
||||
data["params"]["adapter"] = adapter_match.group(1)
|
||||
|
||||
# Parse system description
|
||||
system_desc_match = re.search(
|
||||
r'SYSTEM\s+"""(.+?)"""', model_text, re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
system_desc_match_single = re.search(
|
||||
r"SYSTEM\s+([^\n]+)", model_text, re.IGNORECASE
|
||||
)
|
||||
|
||||
if system_desc_match:
|
||||
data["params"]["system"] = system_desc_match.group(1).strip()
|
||||
elif system_desc_match_single:
|
||||
data["params"]["system"] = system_desc_match_single.group(1).strip()
|
||||
|
||||
# Parse messages
|
||||
messages = []
|
||||
message_matches = re.findall(r"MESSAGE (\w+) (.+)", model_text, re.IGNORECASE)
|
||||
for role, content in message_matches:
|
||||
messages.append({"role": role, "content": content})
|
||||
|
||||
if messages:
|
||||
data["params"]["messages"] = messages
|
||||
|
||||
return data
|
||||
108
backend/open_webui/utils/schemas.py
Normal file
108
backend/open_webui/utils/schemas.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from ast import literal_eval
|
||||
from typing import Any, Literal, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
|
||||
def json_schema_to_model(tool_dict: dict[str, Any]) -> Type[BaseModel]:
|
||||
"""
|
||||
Converts a JSON schema to a Pydantic BaseModel class.
|
||||
|
||||
Args:
|
||||
json_schema: The JSON schema to convert.
|
||||
|
||||
Returns:
|
||||
A Pydantic BaseModel class.
|
||||
"""
|
||||
|
||||
# Extract the model name from the schema title.
|
||||
model_name = tool_dict["name"]
|
||||
schema = tool_dict["parameters"]
|
||||
|
||||
# Extract the field definitions from the schema properties.
|
||||
field_definitions = {
|
||||
name: json_schema_to_pydantic_field(name, prop, schema.get("required", []))
|
||||
for name, prop in schema.get("properties", {}).items()
|
||||
}
|
||||
|
||||
# Create the BaseModel class using create_model().
|
||||
return create_model(model_name, **field_definitions)
|
||||
|
||||
|
||||
def json_schema_to_pydantic_field(
|
||||
name: str, json_schema: dict[str, Any], required: list[str]
|
||||
) -> Any:
|
||||
"""
|
||||
Converts a JSON schema property to a Pydantic field definition.
|
||||
|
||||
Args:
|
||||
name: The field name.
|
||||
json_schema: The JSON schema property.
|
||||
|
||||
Returns:
|
||||
A Pydantic field definition.
|
||||
"""
|
||||
|
||||
# Get the field type.
|
||||
type_ = json_schema_to_pydantic_type(json_schema)
|
||||
|
||||
# Get the field description.
|
||||
description = json_schema.get("description")
|
||||
|
||||
# Get the field examples.
|
||||
examples = json_schema.get("examples")
|
||||
|
||||
# Create a Field object with the type, description, and examples.
|
||||
# The 'required' flag will be set later when creating the model.
|
||||
return (
|
||||
type_,
|
||||
Field(
|
||||
description=description,
|
||||
examples=examples,
|
||||
default=... if name in required else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def json_schema_to_pydantic_type(json_schema: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Converts a JSON schema type to a Pydantic type.
|
||||
|
||||
Args:
|
||||
json_schema: The JSON schema to convert.
|
||||
|
||||
Returns:
|
||||
A Pydantic type.
|
||||
"""
|
||||
|
||||
type_ = json_schema.get("type")
|
||||
|
||||
if type_ == "string" or type_ == "str":
|
||||
return str
|
||||
elif type_ == "integer" or type_ == "int":
|
||||
return int
|
||||
elif type_ == "number" or type_ == "float":
|
||||
return float
|
||||
elif type_ == "boolean" or type_ == "bool":
|
||||
return bool
|
||||
elif type_ == "array" or type_ == "list":
|
||||
items_schema = json_schema.get("items")
|
||||
if items_schema:
|
||||
item_type = json_schema_to_pydantic_type(items_schema)
|
||||
return list[item_type]
|
||||
else:
|
||||
return list
|
||||
elif type_ == "object":
|
||||
# Handle nested models.
|
||||
properties = json_schema.get("properties")
|
||||
if properties:
|
||||
nested_model = json_schema_to_model(json_schema)
|
||||
return nested_model
|
||||
else:
|
||||
return dict
|
||||
elif type_ == "null":
|
||||
return Optional[Any] # Use Optional[Any] for nullable fields
|
||||
elif type_ == "literal":
|
||||
return Literal[literal_eval(json_schema.get("enum"))]
|
||||
else:
|
||||
raise ValueError(f"Unsupported JSON schema type: {type_}")
|
||||
162
backend/open_webui/utils/task.py
Normal file
162
backend/open_webui/utils/task.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import math
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def prompt_template(
|
||||
template: str, user_name: Optional[str] = None, user_location: Optional[str] = None
|
||||
) -> str:
|
||||
# Get the current date
|
||||
current_date = datetime.now()
|
||||
|
||||
# Format the date to YYYY-MM-DD
|
||||
formatted_date = current_date.strftime("%Y-%m-%d")
|
||||
formatted_time = current_date.strftime("%I:%M:%S %p")
|
||||
|
||||
template = template.replace("{{CURRENT_DATE}}", formatted_date)
|
||||
template = template.replace("{{CURRENT_TIME}}", formatted_time)
|
||||
template = template.replace(
|
||||
"{{CURRENT_DATETIME}}", f"{formatted_date} {formatted_time}"
|
||||
)
|
||||
|
||||
if user_name:
|
||||
# Replace {{USER_NAME}} in the template with the user's name
|
||||
template = template.replace("{{USER_NAME}}", user_name)
|
||||
else:
|
||||
# Replace {{USER_NAME}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_NAME}}", "Unknown")
|
||||
|
||||
if user_location:
|
||||
# Replace {{USER_LOCATION}} in the template with the current location
|
||||
template = template.replace("{{USER_LOCATION}}", user_location)
|
||||
else:
|
||||
# Replace {{USER_LOCATION}} in the template with "Unknown"
|
||||
template = template.replace("{{USER_LOCATION}}", "Unknown")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def title_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
end_length = match.group(2)
|
||||
middle_length = match.group(3)
|
||||
|
||||
if full_match == "{{prompt}}":
|
||||
return prompt
|
||||
elif start_length is not None:
|
||||
return prompt[: int(start_length)]
|
||||
elif end_length is not None:
|
||||
return prompt[-int(end_length) :]
|
||||
elif middle_length is not None:
|
||||
middle_length = int(middle_length)
|
||||
if len(prompt) <= middle_length:
|
||||
return prompt
|
||||
start = prompt[: math.ceil(middle_length / 2)]
|
||||
end = prompt[-math.floor(middle_length / 2) :]
|
||||
return f"{start}...{end}"
|
||||
return ""
|
||||
|
||||
template = re.sub(
|
||||
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||
replacement_function,
|
||||
template,
|
||||
)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def search_query_generation_template(
|
||||
template: str, prompt: str, user: Optional[dict] = None
|
||||
) -> str:
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
end_length = match.group(2)
|
||||
middle_length = match.group(3)
|
||||
|
||||
if full_match == "{{prompt}}":
|
||||
return prompt
|
||||
elif start_length is not None:
|
||||
return prompt[: int(start_length)]
|
||||
elif end_length is not None:
|
||||
return prompt[-int(end_length) :]
|
||||
elif middle_length is not None:
|
||||
middle_length = int(middle_length)
|
||||
if len(prompt) <= middle_length:
|
||||
return prompt
|
||||
start = prompt[: math.ceil(middle_length / 2)]
|
||||
end = prompt[-math.floor(middle_length / 2) :]
|
||||
return f"{start}...{end}"
|
||||
return ""
|
||||
|
||||
template = re.sub(
|
||||
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||
replacement_function,
|
||||
template,
|
||||
)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def moa_response_generation_template(
|
||||
template: str, prompt: str, responses: list[str]
|
||||
) -> str:
|
||||
def replacement_function(match):
|
||||
full_match = match.group(0)
|
||||
start_length = match.group(1)
|
||||
end_length = match.group(2)
|
||||
middle_length = match.group(3)
|
||||
|
||||
if full_match == "{{prompt}}":
|
||||
return prompt
|
||||
elif start_length is not None:
|
||||
return prompt[: int(start_length)]
|
||||
elif end_length is not None:
|
||||
return prompt[-int(end_length) :]
|
||||
elif middle_length is not None:
|
||||
middle_length = int(middle_length)
|
||||
if len(prompt) <= middle_length:
|
||||
return prompt
|
||||
start = prompt[: math.ceil(middle_length / 2)]
|
||||
end = prompt[-math.floor(middle_length / 2) :]
|
||||
return f"{start}...{end}"
|
||||
return ""
|
||||
|
||||
template = re.sub(
|
||||
r"{{prompt}}|{{prompt:start:(\d+)}}|{{prompt:end:(\d+)}}|{{prompt:middletruncate:(\d+)}}",
|
||||
replacement_function,
|
||||
template,
|
||||
)
|
||||
|
||||
responses = [f'"""{response}"""' for response in responses]
|
||||
responses = "\n\n".join(responses)
|
||||
|
||||
template = template.replace("{{responses}}", responses)
|
||||
return template
|
||||
|
||||
|
||||
def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
|
||||
template = template.replace("{{TOOLS}}", tools_specs)
|
||||
return template
|
||||
163
backend/open_webui/utils/tools.py
Normal file
163
backend/open_webui/utils/tools.py
Normal file
@@ -0,0 +1,163 @@
|
||||
import inspect
|
||||
import logging
|
||||
from typing import Awaitable, Callable, get_type_hints
|
||||
|
||||
from open_webui.apps.webui.models.tools import Tools
|
||||
from open_webui.apps.webui.models.users import UserModel
|
||||
from open_webui.apps.webui.utils import load_toolkit_module_by_id
|
||||
from open_webui.utils.schemas import json_schema_to_model
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def apply_extra_params_to_tool_function(
|
||||
function: Callable, extra_params: dict
|
||||
) -> Callable[..., Awaitable]:
|
||||
sig = inspect.signature(function)
|
||||
extra_params = {
|
||||
key: value for key, value in extra_params.items() if key in sig.parameters
|
||||
}
|
||||
is_coroutine = inspect.iscoroutinefunction(function)
|
||||
|
||||
async def new_function(**kwargs):
|
||||
extra_kwargs = kwargs | extra_params
|
||||
if is_coroutine:
|
||||
return await function(**extra_kwargs)
|
||||
return function(**extra_kwargs)
|
||||
|
||||
return new_function
|
||||
|
||||
|
||||
# Mutation on extra_params
|
||||
def get_tools(
|
||||
webui_app, tool_ids: list[str], user: UserModel, extra_params: dict
|
||||
) -> dict[str, dict]:
|
||||
tools = {}
|
||||
for tool_id in tool_ids:
|
||||
toolkit = Tools.get_tool_by_id(tool_id)
|
||||
if toolkit is None:
|
||||
continue
|
||||
|
||||
module = webui_app.state.TOOLS.get(tool_id, None)
|
||||
if module is None:
|
||||
module, _ = load_toolkit_module_by_id(tool_id)
|
||||
webui_app.state.TOOLS[tool_id] = module
|
||||
|
||||
extra_params["__id__"] = tool_id
|
||||
if hasattr(module, "valves") and hasattr(module, "Valves"):
|
||||
valves = Tools.get_tool_valves_by_id(tool_id) or {}
|
||||
module.valves = module.Valves(**valves)
|
||||
|
||||
if hasattr(module, "UserValves"):
|
||||
extra_params["__user__"]["valves"] = module.UserValves( # type: ignore
|
||||
**Tools.get_user_valves_by_id_and_user_id(tool_id, user.id)
|
||||
)
|
||||
|
||||
for spec in toolkit.specs:
|
||||
# TODO: Fix hack for OpenAI API
|
||||
for val in spec.get("parameters", {}).get("properties", {}).values():
|
||||
if val["type"] == "str":
|
||||
val["type"] = "string"
|
||||
function_name = spec["name"]
|
||||
|
||||
# convert to function that takes only model params and inserts custom params
|
||||
original_func = getattr(module, function_name)
|
||||
callable = apply_extra_params_to_tool_function(original_func, extra_params)
|
||||
if hasattr(original_func, "__doc__"):
|
||||
callable.__doc__ = original_func.__doc__
|
||||
|
||||
# TODO: This needs to be a pydantic model
|
||||
tool_dict = {
|
||||
"toolkit_id": tool_id,
|
||||
"callable": callable,
|
||||
"spec": spec,
|
||||
"pydantic_model": json_schema_to_model(spec),
|
||||
"file_handler": hasattr(module, "file_handler") and module.file_handler,
|
||||
"citation": hasattr(module, "citation") and module.citation,
|
||||
}
|
||||
|
||||
# TODO: if collision, prepend toolkit name
|
||||
if function_name in tools:
|
||||
log.warning(f"Tool {function_name} already exists in another toolkit!")
|
||||
log.warning(f"Collision between {toolkit} and {tool_id}.")
|
||||
log.warning(f"Discarding {toolkit}.{function_name}")
|
||||
else:
|
||||
tools[function_name] = tool_dict
|
||||
return tools
|
||||
|
||||
|
||||
def doc_to_dict(docstring):
|
||||
lines = docstring.split("\n")
|
||||
description = lines[1].strip()
|
||||
param_dict = {}
|
||||
|
||||
for line in lines:
|
||||
if ":param" in line:
|
||||
line = line.replace(":param", "").strip()
|
||||
param, desc = line.split(":", 1)
|
||||
param_dict[param.strip()] = desc.strip()
|
||||
ret_dict = {"description": description, "params": param_dict}
|
||||
return ret_dict
|
||||
|
||||
|
||||
def get_tools_specs(tools) -> list[dict]:
|
||||
function_list = [
|
||||
{"name": func, "function": getattr(tools, func)}
|
||||
for func in dir(tools)
|
||||
if callable(getattr(tools, func))
|
||||
and not func.startswith("__")
|
||||
and not inspect.isclass(getattr(tools, func))
|
||||
]
|
||||
|
||||
specs = []
|
||||
for function_item in function_list:
|
||||
function_name = function_item["name"]
|
||||
function = function_item["function"]
|
||||
|
||||
function_doc = doc_to_dict(function.__doc__ or function_name)
|
||||
specs.append(
|
||||
{
|
||||
"name": function_name,
|
||||
# TODO: multi-line desc?
|
||||
"description": function_doc.get("description", function_name),
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
param_name: {
|
||||
"type": param_annotation.__name__.lower(),
|
||||
**(
|
||||
{
|
||||
"enum": (
|
||||
str(param_annotation.__args__)
|
||||
if hasattr(param_annotation, "__args__")
|
||||
else None
|
||||
)
|
||||
}
|
||||
if hasattr(param_annotation, "__args__")
|
||||
else {}
|
||||
),
|
||||
"description": function_doc.get("params", {}).get(
|
||||
param_name, param_name
|
||||
),
|
||||
}
|
||||
for param_name, param_annotation in get_type_hints(
|
||||
function
|
||||
).items()
|
||||
if param_name != "return"
|
||||
and not (
|
||||
param_name.startswith("__") and param_name.endswith("__")
|
||||
)
|
||||
},
|
||||
"required": [
|
||||
name
|
||||
for name, param in inspect.signature(
|
||||
function
|
||||
).parameters.items()
|
||||
if param.default is param.empty
|
||||
and not (name.startswith("__") and name.endswith("__"))
|
||||
],
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
return specs
|
||||
141
backend/open_webui/utils/utils.py
Normal file
141
backend/open_webui/utils/utils.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Optional, Union
|
||||
|
||||
import jwt
|
||||
from open_webui.apps.webui.models.users import Users
|
||||
from open_webui.constants import ERROR_MESSAGES
|
||||
from open_webui.env import WEBUI_SECRET_KEY
|
||||
from fastapi import Depends, HTTPException, Request, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from passlib.context import CryptContext
|
||||
|
||||
logging.getLogger("passlib").setLevel(logging.ERROR)
|
||||
|
||||
|
||||
SESSION_SECRET = WEBUI_SECRET_KEY
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
##############
|
||||
# Auth Utils
|
||||
##############
|
||||
|
||||
bearer_security = HTTPBearer(auto_error=False)
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return (
|
||||
pwd_context.verify(plain_password, hashed_password) if hashed_password else None
|
||||
)
|
||||
|
||||
|
||||
def get_password_hash(password):
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_token(data: dict, expires_delta: Union[timedelta, None] = None) -> str:
|
||||
payload = data.copy()
|
||||
|
||||
if expires_delta:
|
||||
expire = datetime.now(UTC) + expires_delta
|
||||
payload.update({"exp": expire})
|
||||
|
||||
encoded_jwt = jwt.encode(payload, SESSION_SECRET, algorithm=ALGORITHM)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
def decode_token(token: str) -> Optional[dict]:
|
||||
try:
|
||||
decoded = jwt.decode(token, SESSION_SECRET, algorithms=[ALGORITHM])
|
||||
return decoded
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def extract_token_from_auth_header(auth_header: str):
|
||||
return auth_header[len("Bearer ") :]
|
||||
|
||||
|
||||
def create_api_key():
|
||||
key = str(uuid.uuid4()).replace("-", "")
|
||||
return f"sk-{key}"
|
||||
|
||||
|
||||
def get_http_authorization_cred(auth_header: str):
|
||||
try:
|
||||
scheme, credentials = auth_header.split(" ")
|
||||
return HTTPAuthorizationCredentials(scheme=scheme, credentials=credentials)
|
||||
except Exception:
|
||||
raise ValueError(ERROR_MESSAGES.INVALID_TOKEN)
|
||||
|
||||
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
token = None
|
||||
|
||||
if auth_token is not None:
|
||||
token = auth_token.credentials
|
||||
|
||||
if token is None and "token" in request.cookies:
|
||||
token = request.cookies.get("token")
|
||||
|
||||
if token is None:
|
||||
raise HTTPException(status_code=403, detail="Not authenticated")
|
||||
|
||||
# auth by api key
|
||||
if token.startswith("sk-"):
|
||||
return get_current_user_by_api_key(token)
|
||||
|
||||
# auth by jwt token
|
||||
data = decode_token(token)
|
||||
if data is not None and "id" in data:
|
||||
user = Users.get_user_by_id(data["id"])
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
return user
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.UNAUTHORIZED,
|
||||
)
|
||||
|
||||
|
||||
def get_current_user_by_api_key(api_key: str):
|
||||
user = Users.get_user_by_api_key(api_key)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
Users.update_user_last_active_by_id(user.id)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_verified_user(user=Depends(get_current_user)):
|
||||
if user.role not in {"user", "admin"}:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
return user
|
||||
|
||||
|
||||
def get_admin_user(user=Depends(get_current_user)):
|
||||
if user.role != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
return user
|
||||
55
backend/open_webui/utils/webhook.py
Normal file
55
backend/open_webui/utils/webhook.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import requests
|
||||
from open_webui.config import WEBUI_FAVICON_URL, WEBUI_NAME
|
||||
from open_webui.env import SRC_LOG_LEVELS, VERSION
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["WEBHOOK"])
|
||||
|
||||
|
||||
def post_webhook(url: str, message: str, event_data: dict) -> bool:
|
||||
try:
|
||||
payload = {}
|
||||
|
||||
# Slack and Google Chat Webhooks
|
||||
if "https://hooks.slack.com" in url or "https://chat.googleapis.com" in url:
|
||||
payload["text"] = message
|
||||
# Discord Webhooks
|
||||
elif "https://discord.com/api/webhooks" in url:
|
||||
payload["content"] = message
|
||||
# Microsoft Teams Webhooks
|
||||
elif "webhook.office.com" in url:
|
||||
action = event_data.get("action", "undefined")
|
||||
facts = [
|
||||
{"name": name, "value": value}
|
||||
for name, value in json.loads(event_data.get("user", {})).items()
|
||||
]
|
||||
payload = {
|
||||
"@type": "MessageCard",
|
||||
"@context": "http://schema.org/extensions",
|
||||
"themeColor": "0076D7",
|
||||
"summary": message,
|
||||
"sections": [
|
||||
{
|
||||
"activityTitle": message,
|
||||
"activitySubtitle": f"{WEBUI_NAME} ({VERSION}) - {action}",
|
||||
"activityImage": WEBUI_FAVICON_URL,
|
||||
"facts": facts,
|
||||
"markdown": True,
|
||||
}
|
||||
],
|
||||
}
|
||||
# Default Payload
|
||||
else:
|
||||
payload = {**event_data}
|
||||
|
||||
log.debug(f"payload: {payload}")
|
||||
r = requests.post(url, json=payload)
|
||||
r.raise_for_status()
|
||||
log.debug(f"r.text: {r.text}")
|
||||
return True
|
||||
except Exception as e:
|
||||
log.exception(e)
|
||||
return False
|
||||
Reference in New Issue
Block a user