mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge branch 'open-webui:main' into next
This commit is contained in:
@@ -23,6 +23,7 @@ from open_webui.env import (
|
||||
TRUSTED_SIGNATURE_KEY,
|
||||
STATIC_DIR,
|
||||
SRC_LOG_LEVELS,
|
||||
WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
|
||||
)
|
||||
|
||||
from fastapi import BackgroundTasks, Depends, HTTPException, Request, Response, status
|
||||
@@ -157,6 +158,7 @@ def get_http_authorization_cred(auth_header: Optional[str]):
|
||||
|
||||
def get_current_user(
|
||||
request: Request,
|
||||
response: Response,
|
||||
background_tasks: BackgroundTasks,
|
||||
auth_token: HTTPAuthorizationCredentials = Depends(bearer_security),
|
||||
):
|
||||
@@ -225,6 +227,19 @@ def get_current_user(
|
||||
detail=ERROR_MESSAGES.INVALID_TOKEN,
|
||||
)
|
||||
else:
|
||||
if WEBUI_AUTH_TRUSTED_EMAIL_HEADER:
|
||||
trusted_email = request.headers.get(WEBUI_AUTH_TRUSTED_EMAIL_HEADER)
|
||||
if trusted_email and user.email != trusted_email:
|
||||
# Delete the token cookie
|
||||
response.delete_cookie("token")
|
||||
# Delete OAuth token if present
|
||||
if request.cookies.get("oauth_id_token"):
|
||||
response.delete_cookie("oauth_id_token")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User mismatch. Please sign in again.",
|
||||
)
|
||||
|
||||
# Add user info to current span
|
||||
current_span = trace.get_current_span()
|
||||
if current_span:
|
||||
|
||||
@@ -320,12 +320,7 @@ async def chat_completed(request: Request, form_data: dict, user: Any):
|
||||
extra_params = {
|
||||
"__event_emitter__": get_event_emitter(metadata),
|
||||
"__event_call__": get_event_call(metadata),
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -424,12 +419,7 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
||||
params[key] = value
|
||||
|
||||
if "__user__" in sig.parameters:
|
||||
__user__ = {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
}
|
||||
__user__ = (user.model_dump() if isinstance(user, UserModel) else {},)
|
||||
|
||||
try:
|
||||
if hasattr(function_module, "UserValves"):
|
||||
|
||||
90
backend/open_webui/utils/embeddings.py
Normal file
90
backend/open_webui/utils/embeddings.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import random
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from fastapi import Request
|
||||
from open_webui.models.users import UserModel
|
||||
from open_webui.models.models import Models
|
||||
from open_webui.utils.models import check_model_access
|
||||
from open_webui.env import SRC_LOG_LEVELS, GLOBAL_LOG_LEVEL, BYPASS_MODEL_ACCESS_CONTROL
|
||||
|
||||
from open_webui.routers.openai import embeddings as openai_embeddings
|
||||
from open_webui.routers.ollama import (
|
||||
embeddings as ollama_embeddings,
|
||||
GenerateEmbeddingsForm,
|
||||
)
|
||||
|
||||
|
||||
from open_webui.utils.payload import convert_embedding_payload_openai_to_ollama
|
||||
from open_webui.utils.response import convert_embedding_response_ollama_to_openai
|
||||
|
||||
logging.basicConfig(stream=sys.stdout, level=GLOBAL_LOG_LEVEL)
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
||||
|
||||
|
||||
async def generate_embeddings(
|
||||
request: Request,
|
||||
form_data: dict,
|
||||
user: UserModel,
|
||||
bypass_filter: bool = False,
|
||||
):
|
||||
"""
|
||||
Dispatch and handle embeddings generation based on the model type (OpenAI, Ollama).
|
||||
|
||||
Args:
|
||||
request (Request): The FastAPI request context.
|
||||
form_data (dict): The input data sent to the endpoint.
|
||||
user (UserModel): The authenticated user.
|
||||
bypass_filter (bool): If True, disables access filtering (default False).
|
||||
|
||||
Returns:
|
||||
dict: The embeddings response, following OpenAI API compatibility.
|
||||
"""
|
||||
if BYPASS_MODEL_ACCESS_CONTROL:
|
||||
bypass_filter = True
|
||||
|
||||
# Attach extra metadata from request.state if present
|
||||
if hasattr(request.state, "metadata"):
|
||||
if "metadata" not in form_data:
|
||||
form_data["metadata"] = request.state.metadata
|
||||
else:
|
||||
form_data["metadata"] = {
|
||||
**form_data["metadata"],
|
||||
**request.state.metadata,
|
||||
}
|
||||
|
||||
# If "direct" flag present, use only that model
|
||||
if getattr(request.state, "direct", False) and hasattr(request.state, "model"):
|
||||
models = {
|
||||
request.state.model["id"]: request.state.model,
|
||||
}
|
||||
else:
|
||||
models = request.app.state.MODELS
|
||||
|
||||
model_id = form_data.get("model")
|
||||
if model_id not in models:
|
||||
raise Exception("Model not found")
|
||||
model = models[model_id]
|
||||
|
||||
# Access filtering
|
||||
if not getattr(request.state, "direct", False):
|
||||
if not bypass_filter and user.role == "user":
|
||||
check_model_access(user, model)
|
||||
|
||||
# Ollama backend
|
||||
if model.get("owned_by") == "ollama":
|
||||
ollama_payload = convert_embedding_payload_openai_to_ollama(form_data)
|
||||
response = await ollama_embeddings(
|
||||
request=request,
|
||||
form_data=GenerateEmbeddingsForm(**ollama_payload),
|
||||
user=user,
|
||||
)
|
||||
return convert_embedding_response_ollama_to_openai(response)
|
||||
|
||||
# Default: OpenAI or compatible backend
|
||||
return await openai_embeddings(
|
||||
request=request,
|
||||
form_data=form_data,
|
||||
user=user,
|
||||
)
|
||||
@@ -32,11 +32,17 @@ from open_webui.socket.main import (
|
||||
from open_webui.routers.tasks import (
|
||||
generate_queries,
|
||||
generate_title,
|
||||
generate_follow_ups,
|
||||
generate_image_prompt,
|
||||
generate_chat_tags,
|
||||
)
|
||||
from open_webui.routers.retrieval import process_web_search, SearchForm
|
||||
from open_webui.routers.images import image_generations, GenerateImageForm
|
||||
from open_webui.routers.images import (
|
||||
load_b64_image_data,
|
||||
image_generations,
|
||||
GenerateImageForm,
|
||||
upload_image,
|
||||
)
|
||||
from open_webui.routers.pipelines import (
|
||||
process_pipeline_inlet_filter,
|
||||
process_pipeline_outlet_filter,
|
||||
@@ -692,13 +698,8 @@ def apply_params_to_form_data(form_data, model):
|
||||
params = deep_update(params, custom_params)
|
||||
|
||||
if model.get("ollama"):
|
||||
# Ollama specific parameters
|
||||
form_data["options"] = params
|
||||
|
||||
if "format" in params:
|
||||
form_data["format"] = params["format"]
|
||||
|
||||
if "keep_alive" in params:
|
||||
form_data["keep_alive"] = params["keep_alive"]
|
||||
else:
|
||||
if isinstance(params, dict):
|
||||
for key, value in params.items():
|
||||
@@ -726,12 +727,7 @@ async def process_chat_payload(request, form_data, user, metadata, model):
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_call,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -1061,6 +1057,59 @@ async def process_chat_response(
|
||||
)
|
||||
|
||||
if tasks and messages:
|
||||
if (
|
||||
TASKS.FOLLOW_UP_GENERATION in tasks
|
||||
and tasks[TASKS.FOLLOW_UP_GENERATION]
|
||||
):
|
||||
res = await generate_follow_ups(
|
||||
request,
|
||||
{
|
||||
"model": message["model"],
|
||||
"messages": messages,
|
||||
"message_id": metadata["message_id"],
|
||||
"chat_id": metadata["chat_id"],
|
||||
},
|
||||
user,
|
||||
)
|
||||
|
||||
if res and isinstance(res, dict):
|
||||
if len(res.get("choices", [])) == 1:
|
||||
follow_ups_string = (
|
||||
res.get("choices", [])[0]
|
||||
.get("message", {})
|
||||
.get("content", "")
|
||||
)
|
||||
else:
|
||||
follow_ups_string = ""
|
||||
|
||||
follow_ups_string = follow_ups_string[
|
||||
follow_ups_string.find("{") : follow_ups_string.rfind("}")
|
||||
+ 1
|
||||
]
|
||||
|
||||
try:
|
||||
follow_ups = json.loads(follow_ups_string).get(
|
||||
"follow_ups", []
|
||||
)
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
metadata["chat_id"],
|
||||
metadata["message_id"],
|
||||
{
|
||||
"followUps": follow_ups,
|
||||
},
|
||||
)
|
||||
|
||||
await event_emitter(
|
||||
{
|
||||
"type": "chat:message:follow_ups",
|
||||
"data": {
|
||||
"follow_ups": follow_ups,
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
if TASKS.TITLE_GENERATION in tasks:
|
||||
if tasks[TASKS.TITLE_GENERATION]:
|
||||
res = await generate_title(
|
||||
@@ -1286,12 +1335,7 @@ async def process_chat_response(
|
||||
extra_params = {
|
||||
"__event_emitter__": event_emitter,
|
||||
"__event_call__": event_caller,
|
||||
"__user__": {
|
||||
"id": user.id,
|
||||
"email": user.email,
|
||||
"name": user.name,
|
||||
"role": user.role,
|
||||
},
|
||||
"__user__": user.model_dump() if isinstance(user, UserModel) else {},
|
||||
"__metadata__": metadata,
|
||||
"__request__": request,
|
||||
"__model__": model,
|
||||
@@ -1835,9 +1879,11 @@ async def process_chat_response(
|
||||
|
||||
value = delta.get("content")
|
||||
|
||||
reasoning_content = delta.get(
|
||||
"reasoning_content"
|
||||
) or delta.get("reasoning")
|
||||
reasoning_content = (
|
||||
delta.get("reasoning_content")
|
||||
or delta.get("reasoning")
|
||||
or delta.get("thinking")
|
||||
)
|
||||
if reasoning_content:
|
||||
if (
|
||||
not content_blocks
|
||||
@@ -2230,28 +2276,21 @@ async def process_chat_response(
|
||||
stdoutLines = stdout.split("\n")
|
||||
for idx, line in enumerate(stdoutLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
image_url = ""
|
||||
# Extract base64 image data from the line
|
||||
image_data, content_type = (
|
||||
load_b64_image_data(line)
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
if image_data is not None:
|
||||
image_url = upload_image(
|
||||
request,
|
||||
image_data,
|
||||
content_type,
|
||||
metadata,
|
||||
user,
|
||||
)
|
||||
|
||||
stdoutLines[idx] = (
|
||||
f""
|
||||
f""
|
||||
)
|
||||
|
||||
output["stdout"] = "\n".join(stdoutLines)
|
||||
@@ -2262,30 +2301,22 @@ async def process_chat_response(
|
||||
resultLines = result.split("\n")
|
||||
for idx, line in enumerate(resultLines):
|
||||
if "data:image/png;base64" in line:
|
||||
id = str(uuid4())
|
||||
|
||||
# ensure the path exists
|
||||
os.makedirs(
|
||||
os.path.join(CACHE_DIR, "images"),
|
||||
exist_ok=True,
|
||||
image_url = ""
|
||||
# Extract base64 image data from the line
|
||||
image_data, content_type = (
|
||||
load_b64_image_data(line)
|
||||
)
|
||||
|
||||
image_path = os.path.join(
|
||||
CACHE_DIR,
|
||||
f"images/{id}.png",
|
||||
)
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(
|
||||
base64.b64decode(
|
||||
line.split(",")[1]
|
||||
)
|
||||
if image_data is not None:
|
||||
image_url = upload_image(
|
||||
request,
|
||||
image_data,
|
||||
content_type,
|
||||
metadata,
|
||||
user,
|
||||
)
|
||||
|
||||
resultLines[idx] = (
|
||||
f""
|
||||
f""
|
||||
)
|
||||
|
||||
output["result"] = "\n".join(resultLines)
|
||||
except Exception as e:
|
||||
output = str(e)
|
||||
@@ -2394,8 +2425,8 @@ async def process_chat_response(
|
||||
await response.background()
|
||||
|
||||
# background_tasks.add_task(post_response_handler, response, events)
|
||||
task_id, _ = create_task(
|
||||
post_response_handler(response, events), id=metadata["chat_id"]
|
||||
task_id, _ = await create_task(
|
||||
request, post_response_handler(response, events), id=metadata["chat_id"]
|
||||
)
|
||||
return {"status": True, "task_id": task_id}
|
||||
|
||||
|
||||
@@ -208,6 +208,7 @@ def openai_chat_message_template(model: str):
|
||||
def openai_chat_chunk_message_template(
|
||||
model: str,
|
||||
content: Optional[str] = None,
|
||||
reasoning_content: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
@@ -220,6 +221,9 @@ def openai_chat_chunk_message_template(
|
||||
if content:
|
||||
template["choices"][0]["delta"]["content"] = content
|
||||
|
||||
if reasoning_content:
|
||||
template["choices"][0]["delta"]["reasoning_content"] = reasoning_content
|
||||
|
||||
if tool_calls:
|
||||
template["choices"][0]["delta"]["tool_calls"] = tool_calls
|
||||
|
||||
@@ -234,6 +238,7 @@ def openai_chat_chunk_message_template(
|
||||
def openai_chat_completion_message_template(
|
||||
model: str,
|
||||
message: Optional[str] = None,
|
||||
reasoning_content: Optional[str] = None,
|
||||
tool_calls: Optional[list[dict]] = None,
|
||||
usage: Optional[dict] = None,
|
||||
) -> dict:
|
||||
@@ -241,8 +246,9 @@ def openai_chat_completion_message_template(
|
||||
template["object"] = "chat.completion"
|
||||
if message is not None:
|
||||
template["choices"][0]["message"] = {
|
||||
"content": message,
|
||||
"role": "assistant",
|
||||
"content": message,
|
||||
**({"reasoning_content": reasoning_content} if reasoning_content else {}),
|
||||
**({"tool_calls": tool_calls} if tool_calls else {}),
|
||||
}
|
||||
|
||||
|
||||
@@ -538,7 +538,7 @@ class OAuthManager:
|
||||
# Redirect back to the frontend with the JWT token
|
||||
|
||||
redirect_base_url = request.app.state.config.WEBUI_URL or request.base_url
|
||||
if redirect_base_url.endswith("/"):
|
||||
if isinstance(redirect_base_url, str) and redirect_base_url.endswith("/"):
|
||||
redirect_base_url = redirect_base_url[:-1]
|
||||
redirect_url = f"{redirect_base_url}/auth#token={jwt_token}"
|
||||
|
||||
|
||||
@@ -175,16 +175,32 @@ def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
|
||||
"num_thread": int,
|
||||
}
|
||||
|
||||
# Extract keep_alive from options if it exists
|
||||
if "options" in form_data and "keep_alive" in form_data["options"]:
|
||||
form_data["keep_alive"] = form_data["options"]["keep_alive"]
|
||||
del form_data["options"]["keep_alive"]
|
||||
def parse_json(value: str) -> dict:
|
||||
"""
|
||||
Parses a JSON string into a dictionary, handling potential JSONDecodeError.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception as e:
|
||||
return value
|
||||
|
||||
if "options" in form_data and "format" in form_data["options"]:
|
||||
form_data["format"] = form_data["options"]["format"]
|
||||
del form_data["options"]["format"]
|
||||
ollama_root_params = {
|
||||
"format": lambda x: parse_json(x),
|
||||
"keep_alive": lambda x: parse_json(x),
|
||||
"think": bool,
|
||||
}
|
||||
|
||||
return apply_model_params_to_body(params, form_data, mappings)
|
||||
for key, value in ollama_root_params.items():
|
||||
if (param := params.get(key, None)) is not None:
|
||||
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
|
||||
form_data[key] = value(param)
|
||||
del params[key]
|
||||
|
||||
# Unlike OpenAI, Ollama does not support params directly in the body
|
||||
form_data["options"] = apply_model_params_to_body(
|
||||
params, (form_data.get("options", {}) or {}), mappings
|
||||
)
|
||||
return form_data
|
||||
|
||||
|
||||
def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||
@@ -279,36 +295,48 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
openai_payload.get("messages")
|
||||
)
|
||||
ollama_payload["stream"] = openai_payload.get("stream", False)
|
||||
|
||||
if "tools" in openai_payload:
|
||||
ollama_payload["tools"] = openai_payload["tools"]
|
||||
|
||||
if "format" in openai_payload:
|
||||
ollama_payload["format"] = openai_payload["format"]
|
||||
|
||||
# If there are advanced parameters in the payload, format them in Ollama's options field
|
||||
if openai_payload.get("options"):
|
||||
ollama_payload["options"] = openai_payload["options"]
|
||||
ollama_options = openai_payload["options"]
|
||||
|
||||
def parse_json(value: str) -> dict:
|
||||
"""
|
||||
Parses a JSON string into a dictionary, handling potential JSONDecodeError.
|
||||
"""
|
||||
try:
|
||||
return json.loads(value)
|
||||
except Exception as e:
|
||||
return value
|
||||
|
||||
ollama_root_params = {
|
||||
"format": lambda x: parse_json(x),
|
||||
"keep_alive": lambda x: parse_json(x),
|
||||
"think": bool,
|
||||
}
|
||||
|
||||
# Ollama's options field can contain parameters that should be at the root level.
|
||||
for key, value in ollama_root_params.items():
|
||||
if (param := ollama_options.get(key, None)) is not None:
|
||||
# Copy the parameter to new name then delete it, to prevent Ollama warning of invalid option provided
|
||||
ollama_payload[key] = value(param)
|
||||
del ollama_options[key]
|
||||
|
||||
# Re-Mapping OpenAI's `max_tokens` -> Ollama's `num_predict`
|
||||
if "max_tokens" in ollama_options:
|
||||
ollama_options["num_predict"] = ollama_options["max_tokens"]
|
||||
del ollama_options[
|
||||
"max_tokens"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
del ollama_options["max_tokens"]
|
||||
|
||||
# Ollama lacks a "system" prompt option. It has to be provided as a direct parameter, so we copy it down.
|
||||
# Comment: Not sure why this is needed, but we'll keep it for compatibility.
|
||||
if "system" in ollama_options:
|
||||
ollama_payload["system"] = ollama_options["system"]
|
||||
del ollama_options[
|
||||
"system"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
del ollama_options["system"]
|
||||
|
||||
# Extract keep_alive from options if it exists
|
||||
if "keep_alive" in ollama_options:
|
||||
ollama_payload["keep_alive"] = ollama_options["keep_alive"]
|
||||
del ollama_options["keep_alive"]
|
||||
ollama_payload["options"] = ollama_options
|
||||
|
||||
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
|
||||
if "stop" in openai_payload:
|
||||
@@ -329,3 +357,32 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
ollama_payload["format"] = format
|
||||
|
||||
return ollama_payload
|
||||
|
||||
|
||||
def convert_embedding_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
"""
|
||||
Convert an embeddings request payload from OpenAI format to Ollama format.
|
||||
|
||||
Args:
|
||||
openai_payload (dict): The original payload designed for OpenAI API usage.
|
||||
|
||||
Returns:
|
||||
dict: A payload compatible with the Ollama API embeddings endpoint.
|
||||
"""
|
||||
ollama_payload = {"model": openai_payload.get("model")}
|
||||
input_value = openai_payload.get("input")
|
||||
|
||||
# Ollama expects 'input' as a list, and 'prompt' as a single string.
|
||||
if isinstance(input_value, list):
|
||||
ollama_payload["input"] = input_value
|
||||
ollama_payload["prompt"] = "\n".join(str(x) for x in input_value)
|
||||
else:
|
||||
ollama_payload["input"] = [input_value]
|
||||
ollama_payload["prompt"] = str(input_value)
|
||||
|
||||
# Optionally forward other fields if present
|
||||
for optional_key in ("options", "truncate", "keep_alive"):
|
||||
if optional_key in openai_payload:
|
||||
ollama_payload[optional_key] = openai_payload[optional_key]
|
||||
|
||||
return ollama_payload
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import socketio
|
||||
import redis
|
||||
from redis import asyncio as aioredis
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def parse_redis_service_url(redis_url):
|
||||
@@ -18,23 +17,46 @@ def parse_redis_service_url(redis_url):
|
||||
}
|
||||
|
||||
|
||||
def get_redis_connection(redis_url, redis_sentinels, decode_responses=True):
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
def get_redis_connection(
|
||||
redis_url, redis_sentinels, async_mode=False, decode_responses=True
|
||||
):
|
||||
if async_mode:
|
||||
import redis.asyncio as redis
|
||||
|
||||
# Get a master connection from Sentinel
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
# If using sentinel in async mode
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
elif redis_url:
|
||||
return redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
# Standard Redis connection
|
||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
import redis
|
||||
|
||||
if redis_sentinels:
|
||||
redis_config = parse_redis_service_url(redis_url)
|
||||
sentinel = redis.sentinel.Sentinel(
|
||||
redis_sentinels,
|
||||
port=redis_config["port"],
|
||||
db=redis_config["db"],
|
||||
username=redis_config["username"],
|
||||
password=redis_config["password"],
|
||||
decode_responses=decode_responses,
|
||||
)
|
||||
return sentinel.master_for(redis_config["service"])
|
||||
elif redis_url:
|
||||
return redis.Redis.from_url(redis_url, decode_responses=decode_responses)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def get_sentinels_from_env(sentinel_hosts_env, sentinel_port_env):
|
||||
|
||||
@@ -83,6 +83,7 @@ def convert_ollama_usage_to_openai(data: dict) -> dict:
|
||||
def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
model = ollama_response.get("model", "ollama")
|
||||
message_content = ollama_response.get("message", {}).get("content", "")
|
||||
reasoning_content = ollama_response.get("message", {}).get("thinking", None)
|
||||
tool_calls = ollama_response.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
@@ -94,7 +95,7 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
response = openai_chat_completion_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
model, message_content, reasoning_content, openai_tool_calls, usage
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -105,6 +106,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
message_content = data.get("message", {}).get("content", None)
|
||||
reasoning_content = data.get("message", {}).get("thinking", None)
|
||||
tool_calls = data.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
@@ -118,10 +120,71 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
data = openai_chat_chunk_message_template(
|
||||
model, message_content, openai_tool_calls, usage
|
||||
model, message_content, reasoning_content, openai_tool_calls, usage
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
yield line
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
def convert_embedding_response_ollama_to_openai(response) -> dict:
|
||||
"""
|
||||
Convert the response from Ollama embeddings endpoint to the OpenAI-compatible format.
|
||||
|
||||
Args:
|
||||
response (dict): The response from the Ollama API,
|
||||
e.g. {"embedding": [...], "model": "..."}
|
||||
or {"embeddings": [{"embedding": [...], "index": 0}, ...], "model": "..."}
|
||||
|
||||
Returns:
|
||||
dict: Response adapted to OpenAI's embeddings API format.
|
||||
e.g. {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{"object": "embedding", "embedding": [...], "index": 0},
|
||||
...
|
||||
],
|
||||
"model": "...",
|
||||
}
|
||||
"""
|
||||
# Ollama batch-style output
|
||||
if isinstance(response, dict) and "embeddings" in response:
|
||||
openai_data = []
|
||||
for i, emb in enumerate(response["embeddings"]):
|
||||
openai_data.append(
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": emb.get("embedding"),
|
||||
"index": emb.get("index", i),
|
||||
}
|
||||
)
|
||||
return {
|
||||
"object": "list",
|
||||
"data": openai_data,
|
||||
"model": response.get("model"),
|
||||
}
|
||||
# Ollama single output
|
||||
elif isinstance(response, dict) and "embedding" in response:
|
||||
return {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": response["embedding"],
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
"model": response.get("model"),
|
||||
}
|
||||
# Already OpenAI-compatible?
|
||||
elif (
|
||||
isinstance(response, dict)
|
||||
and "data" in response
|
||||
and isinstance(response["data"], list)
|
||||
):
|
||||
return response
|
||||
|
||||
# Fallback: return as is if unrecognized
|
||||
return response
|
||||
|
||||
@@ -207,6 +207,24 @@ def title_generation_template(
|
||||
return template
|
||||
|
||||
|
||||
def follow_up_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
prompt = get_last_user_message(messages)
|
||||
template = replace_prompt_variable(template, prompt)
|
||||
template = replace_messages_variable(template, messages)
|
||||
|
||||
template = prompt_template(
|
||||
template,
|
||||
**(
|
||||
{"user_name": user.get("name"), "user_location": user.get("location")}
|
||||
if user
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return template
|
||||
|
||||
|
||||
def tags_generation_template(
|
||||
template: str, messages: list[dict], user: Optional[dict] = None
|
||||
) -> str:
|
||||
|
||||
Reference in New Issue
Block a user