mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'oui/dev' into feat_s3_virtual_path
This commit is contained in:
@@ -1583,6 +1583,18 @@ TIKA_SERVER_URL = PersistentConfig(
|
||||
os.getenv("TIKA_SERVER_URL", "http://tika:9998"), # Default for sidecar deployment
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_ENDPOINT",
|
||||
"rag.document_intelligence_endpoint",
|
||||
os.getenv("DOCUMENT_INTELLIGENCE_ENDPOINT", ""),
|
||||
)
|
||||
|
||||
DOCUMENT_INTELLIGENCE_KEY = PersistentConfig(
|
||||
"DOCUMENT_INTELLIGENCE_KEY",
|
||||
"rag.document_intelligence_key",
|
||||
os.getenv("DOCUMENT_INTELLIGENCE_KEY", ""),
|
||||
)
|
||||
|
||||
RAG_TOP_K = PersistentConfig(
|
||||
"RAG_TOP_K", "rag.top_k", int(os.environ.get("RAG_TOP_K", "3"))
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import sys
|
||||
import inspect
|
||||
import json
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import AsyncGenerator, Generator, Iterator
|
||||
@@ -76,11 +77,13 @@ async def get_function_models(request):
|
||||
if hasattr(function_module, "pipes"):
|
||||
sub_pipes = []
|
||||
|
||||
# Check if pipes is a function or a list
|
||||
|
||||
# Handle pipes being a list, sync function, or async function
|
||||
try:
|
||||
if callable(function_module.pipes):
|
||||
sub_pipes = function_module.pipes()
|
||||
if asyncio.iscoroutinefunction(function_module.pipes):
|
||||
sub_pipes = await function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes()
|
||||
else:
|
||||
sub_pipes = function_module.pipes
|
||||
except Exception as e:
|
||||
|
||||
@@ -180,6 +180,8 @@ from open_webui.config import (
|
||||
CHUNK_SIZE,
|
||||
CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY,
|
||||
RAG_TOP_K,
|
||||
RAG_TEXT_SPLITTER,
|
||||
TIKTOKEN_ENCODING_NAME,
|
||||
@@ -533,6 +535,8 @@ app.state.config.ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION = (
|
||||
|
||||
app.state.config.CONTENT_EXTRACTION_ENGINE = CONTENT_EXTRACTION_ENGINE
|
||||
app.state.config.TIKA_SERVER_URL = TIKA_SERVER_URL
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = DOCUMENT_INTELLIGENCE_ENDPOINT
|
||||
app.state.config.DOCUMENT_INTELLIGENCE_KEY = DOCUMENT_INTELLIGENCE_KEY
|
||||
|
||||
app.state.config.TEXT_SPLITTER = RAG_TEXT_SPLITTER
|
||||
app.state.config.TIKTOKEN_ENCODING_NAME = TIKTOKEN_ENCODING_NAME
|
||||
|
||||
@@ -4,6 +4,7 @@ import ftfy
|
||||
import sys
|
||||
|
||||
from langchain_community.document_loaders import (
|
||||
AzureAIDocumentIntelligenceLoader,
|
||||
BSHTMLLoader,
|
||||
CSVLoader,
|
||||
Docx2txtLoader,
|
||||
@@ -147,6 +148,27 @@ class Loader:
|
||||
file_path=file_path,
|
||||
mime_type=file_content_type,
|
||||
)
|
||||
elif (
|
||||
self.engine == "document_intelligence"
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT") != ""
|
||||
and self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY") != ""
|
||||
and (
|
||||
file_ext in ["pdf", "xls", "xlsx", "docx", "ppt", "pptx"]
|
||||
or file_content_type
|
||||
in [
|
||||
"application/vnd.ms-excel",
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"application/vnd.ms-powerpoint",
|
||||
"application/vnd.openxmlformats-officedocument.presentationml.presentation",
|
||||
]
|
||||
)
|
||||
):
|
||||
loader = AzureAIDocumentIntelligenceLoader(
|
||||
file_path=file_path,
|
||||
api_endpoint=self.kwargs.get("DOCUMENT_INTELLIGENCE_ENDPOINT"),
|
||||
api_key=self.kwargs.get("DOCUMENT_INTELLIGENCE_KEY"),
|
||||
)
|
||||
else:
|
||||
if file_ext == "pdf":
|
||||
loader = PyPDFLoader(
|
||||
|
||||
@@ -680,7 +680,22 @@ def transcription(
|
||||
def get_available_models(request: Request) -> list[dict]:
|
||||
available_models = []
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/models"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
available_models = data.get("models", [])
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching models from custom endpoint: {str(e)}")
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
else:
|
||||
available_models = [{"id": "tts-1"}, {"id": "tts-1-hd"}]
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
response = requests.get(
|
||||
@@ -711,14 +726,37 @@ def get_available_voices(request) -> dict:
|
||||
"""Returns {voice_id: voice_name} dict"""
|
||||
available_voices = {}
|
||||
if request.app.state.config.TTS_ENGINE == "openai":
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
# Use custom endpoint if not using the official OpenAI API URL
|
||||
if not request.app.state.config.TTS_OPENAI_API_BASE_URL.startswith(
|
||||
"https://api.openai.com"
|
||||
):
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{request.app.state.config.TTS_OPENAI_API_BASE_URL}/audio/voices"
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
voices_list = data.get("voices", [])
|
||||
available_voices = {voice["id"]: voice["name"] for voice in voices_list}
|
||||
except Exception as e:
|
||||
log.error(f"Error fetching voices from custom endpoint: {str(e)}")
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
else:
|
||||
available_voices = {
|
||||
"alloy": "alloy",
|
||||
"echo": "echo",
|
||||
"fable": "fable",
|
||||
"onyx": "onyx",
|
||||
"nova": "nova",
|
||||
"shimmer": "shimmer",
|
||||
}
|
||||
elif request.app.state.config.TTS_ENGINE == "elevenlabs":
|
||||
try:
|
||||
available_voices = get_elevenlabs_voices(
|
||||
|
||||
@@ -252,14 +252,6 @@ async def ldap_auth(request: Request, response: Response, form_data: LdapForm):
|
||||
if not user:
|
||||
try:
|
||||
user_count = Users.get_num_users()
|
||||
if (
|
||||
request.app.state.USER_COUNT
|
||||
and user_count >= request.app.state.USER_COUNT
|
||||
):
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
role = (
|
||||
"admin"
|
||||
@@ -439,11 +431,6 @@ async def signup(request: Request, response: Response, form_data: SignupForm):
|
||||
)
|
||||
|
||||
user_count = Users.get_num_users()
|
||||
if request.app.state.USER_COUNT and user_count >= request.app.state.USER_COUNT:
|
||||
raise HTTPException(
|
||||
status.HTTP_403_FORBIDDEN, detail=ERROR_MESSAGES.ACCESS_PROHIBITED
|
||||
)
|
||||
|
||||
if not validate_email_format(form_data.email.lower()):
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.INVALID_EMAIL_FORMAT
|
||||
|
||||
@@ -356,6 +356,10 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
@@ -399,6 +403,7 @@ async def get_rag_config(request: Request, user=Depends(get_admin_user)):
|
||||
"bing_search_v7_subscription_key": request.app.state.config.BING_SEARCH_V7_SUBSCRIPTION_KEY,
|
||||
"exa_api_key": request.app.state.config.EXA_API_KEY,
|
||||
"result_count": request.app.state.config.RAG_WEB_SEARCH_RESULT_COUNT,
|
||||
"trust_env": request.app.state.config.RAG_WEB_SEARCH_TRUST_ENV,
|
||||
"concurrent_requests": request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
||||
"domain_filter_list": request.app.state.config.RAG_WEB_SEARCH_DOMAIN_FILTER_LIST,
|
||||
},
|
||||
@@ -411,9 +416,15 @@ class FileConfig(BaseModel):
|
||||
max_count: Optional[int] = None
|
||||
|
||||
|
||||
class DocumentIntelligenceConfigForm(BaseModel):
|
||||
endpoint: str
|
||||
key: str
|
||||
|
||||
|
||||
class ContentExtractionConfig(BaseModel):
|
||||
engine: str = ""
|
||||
tika_server_url: Optional[str] = None
|
||||
document_intelligence_config: Optional[DocumentIntelligenceConfigForm] = None
|
||||
|
||||
|
||||
class ChunkParamUpdateForm(BaseModel):
|
||||
@@ -501,13 +512,22 @@ async def update_rag_config(
|
||||
request.app.state.config.FILE_MAX_COUNT = form_data.file.max_count
|
||||
|
||||
if form_data.content_extraction is not None:
|
||||
log.info(f"Updating text settings: {form_data.content_extraction}")
|
||||
log.info(
|
||||
f"Updating content extraction: {request.app.state.config.CONTENT_EXTRACTION_ENGINE} to {form_data.content_extraction.engine}"
|
||||
)
|
||||
request.app.state.config.CONTENT_EXTRACTION_ENGINE = (
|
||||
form_data.content_extraction.engine
|
||||
)
|
||||
request.app.state.config.TIKA_SERVER_URL = (
|
||||
form_data.content_extraction.tika_server_url
|
||||
)
|
||||
if form_data.content_extraction.document_intelligence_config is not None:
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT = (
|
||||
form_data.content_extraction.document_intelligence_config.endpoint
|
||||
)
|
||||
request.app.state.config.DOCUMENT_INTELLIGENCE_KEY = (
|
||||
form_data.content_extraction.document_intelligence_config.key
|
||||
)
|
||||
|
||||
if form_data.chunk is not None:
|
||||
request.app.state.config.TEXT_SPLITTER = form_data.chunk.text_splitter
|
||||
@@ -604,6 +624,10 @@ async def update_rag_config(
|
||||
"content_extraction": {
|
||||
"engine": request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
"tika_server_url": request.app.state.config.TIKA_SERVER_URL,
|
||||
"document_intelligence_config": {
|
||||
"endpoint": request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
"key": request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
},
|
||||
},
|
||||
"chunk": {
|
||||
"text_splitter": request.app.state.config.TEXT_SPLITTER,
|
||||
@@ -937,6 +961,8 @@ def process_file(
|
||||
engine=request.app.state.config.CONTENT_EXTRACTION_ENGINE,
|
||||
TIKA_SERVER_URL=request.app.state.config.TIKA_SERVER_URL,
|
||||
PDF_EXTRACT_IMAGES=request.app.state.config.PDF_EXTRACT_IMAGES,
|
||||
DOCUMENT_INTELLIGENCE_ENDPOINT=request.app.state.config.DOCUMENT_INTELLIGENCE_ENDPOINT,
|
||||
DOCUMENT_INTELLIGENCE_KEY=request.app.state.config.DOCUMENT_INTELLIGENCE_KEY,
|
||||
)
|
||||
docs = loader.load(
|
||||
file.filename, file.meta.get("content_type"), file_path
|
||||
|
||||
@@ -20,6 +20,10 @@ from open_webui.utils.auth import get_admin_user, get_verified_user
|
||||
from open_webui.constants import TASKS
|
||||
|
||||
from open_webui.routers.pipelines import process_pipeline_inlet_filter
|
||||
from open_webui.utils.filter import (
|
||||
get_sorted_filter_ids,
|
||||
process_filter_functions,
|
||||
)
|
||||
from open_webui.utils.task import get_task_model_id
|
||||
|
||||
from open_webui.config import (
|
||||
@@ -221,6 +225,12 @@ async def generate_title(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -290,6 +300,12 @@ async def generate_chat_tags(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -356,6 +372,12 @@ async def generate_image_prompt(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -433,6 +455,12 @@ async def generate_queries(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -514,6 +542,12 @@ async def generate_autocompletion(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -584,6 +618,12 @@ async def generate_emoji(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
@@ -644,6 +684,12 @@ async def generate_moa_response(
|
||||
},
|
||||
}
|
||||
|
||||
# Process the payload through the pipeline
|
||||
try:
|
||||
payload = await process_pipeline_inlet_filter(request, payload, user, models)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
try:
|
||||
return await generate_chat_completion(request, form_data=payload, user=user)
|
||||
except Exception as e:
|
||||
|
||||
@@ -1127,12 +1127,12 @@ async def process_chat_response(
|
||||
|
||||
if reasoning_duration is not None:
|
||||
if raw:
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}\n<details type="reasoning" done="true" duration="{reasoning_duration}">\n<summary>Thought for {reasoning_duration} seconds</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
else:
|
||||
if raw:
|
||||
content = f'{content}\n<{block["tag"]}>{block["content"]}</{block["tag"]}>\n'
|
||||
content = f'{content}\n<{block["start_tag"]}>{block["content"]}<{block["end_tag"]}>\n'
|
||||
else:
|
||||
content = f'{content}\n<details type="reasoning" done="false">\n<summary>Thinking…</summary>\n{reasoning_display_content}\n</details>\n'
|
||||
|
||||
@@ -1228,9 +1228,9 @@ async def process_chat_response(
|
||||
return attributes
|
||||
|
||||
if content_blocks[-1]["type"] == "text":
|
||||
for tag in tags:
|
||||
for start_tag, end_tag in tags:
|
||||
# Match start tag e.g., <tag> or <tag attr="value">
|
||||
start_tag_pattern = rf"<{tag}(\s.*?)?>"
|
||||
start_tag_pattern = rf"<{re.escape(start_tag)}(\s.*?)?>"
|
||||
match = re.search(start_tag_pattern, content)
|
||||
if match:
|
||||
attr_content = (
|
||||
@@ -1263,7 +1263,8 @@ async def process_chat_response(
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": content_type,
|
||||
"tag": tag,
|
||||
"start_tag": start_tag,
|
||||
"end_tag": end_tag,
|
||||
"attributes": attributes,
|
||||
"content": "",
|
||||
"started_at": time.time(),
|
||||
@@ -1275,9 +1276,10 @@ async def process_chat_response(
|
||||
|
||||
break
|
||||
elif content_blocks[-1]["type"] == content_type:
|
||||
tag = content_blocks[-1]["tag"]
|
||||
start_tag = content_blocks[-1]["start_tag"]
|
||||
end_tag = content_blocks[-1]["end_tag"]
|
||||
# Match end tag e.g., </tag>
|
||||
end_tag_pattern = rf"</{tag}>"
|
||||
end_tag_pattern = rf"<{re.escape(end_tag)}>"
|
||||
|
||||
# Check if the content has the end tag
|
||||
if re.search(end_tag_pattern, content):
|
||||
@@ -1285,7 +1287,7 @@ async def process_chat_response(
|
||||
|
||||
block_content = content_blocks[-1]["content"]
|
||||
# Strip start and end tags from the content
|
||||
start_tag_pattern = rf"<{tag}(.*?)>"
|
||||
start_tag_pattern = rf"<{re.escape(start_tag)}(.*?)>"
|
||||
block_content = re.sub(
|
||||
start_tag_pattern, "", block_content
|
||||
).strip()
|
||||
@@ -1350,7 +1352,7 @@ async def process_chat_response(
|
||||
|
||||
# Clean processed content
|
||||
content = re.sub(
|
||||
rf"<{tag}(.*?)>(.|\n)*?</{tag}>",
|
||||
rf"<{re.escape(start_tag)}(.*?)>(.|\n)*?<{re.escape(end_tag)}>",
|
||||
"",
|
||||
content,
|
||||
flags=re.DOTALL,
|
||||
@@ -1388,19 +1390,28 @@ async def process_chat_response(
|
||||
|
||||
# We might want to disable this by default
|
||||
DETECT_REASONING = True
|
||||
DETECT_SOLUTION = True
|
||||
DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
|
||||
"code_interpreter", False
|
||||
)
|
||||
|
||||
reasoning_tags = [
|
||||
"think",
|
||||
"thinking",
|
||||
"reason",
|
||||
"reasoning",
|
||||
"thought",
|
||||
"Thought",
|
||||
("think", "/think"),
|
||||
("thinking", "/thinking"),
|
||||
("reason", "/reason"),
|
||||
("reasoning", "/reasoning"),
|
||||
("thought", "/thought"),
|
||||
("Thought", "/Thought"),
|
||||
("|begin_of_thought|", "|end_of_thought|")
|
||||
]
|
||||
|
||||
code_interpreter_tags = [
|
||||
("code_interpreter", "/code_interpreter")
|
||||
]
|
||||
|
||||
solution_tags = [
|
||||
("|begin_of_solution|", "|end_of_solution|")
|
||||
]
|
||||
code_interpreter_tags = ["code_interpreter"]
|
||||
|
||||
try:
|
||||
for event in events:
|
||||
@@ -1533,6 +1544,16 @@ async def process_chat_response(
|
||||
if end:
|
||||
break
|
||||
|
||||
if DETECT_SOLUTION:
|
||||
content, content_blocks, _ = (
|
||||
tag_content_handler(
|
||||
"solution",
|
||||
solution_tags,
|
||||
content,
|
||||
content_blocks,
|
||||
)
|
||||
)
|
||||
|
||||
if ENABLE_REALTIME_CHAT_SAVE:
|
||||
# Save message in the database
|
||||
Chats.upsert_message_to_chat_by_id_and_message_id(
|
||||
|
||||
@@ -146,7 +146,7 @@ class OAuthManager:
|
||||
nested_claims = oauth_claim.split(".")
|
||||
for nested_claim in nested_claims:
|
||||
claim_data = claim_data.get(nested_claim, {})
|
||||
user_oauth_groups = claim_data if isinstance(claim_data, list) else None
|
||||
user_oauth_groups = claim_data if isinstance(claim_data, list) else []
|
||||
|
||||
user_current_groups: list[GroupModel] = Groups.get_groups_by_member_id(user.id)
|
||||
all_available_groups: list[GroupModel] = Groups.get_groups()
|
||||
@@ -315,15 +315,6 @@ class OAuthManager:
|
||||
if not user:
|
||||
user_count = Users.get_num_users()
|
||||
|
||||
if (
|
||||
request.app.state.USER_COUNT
|
||||
and user_count >= request.app.state.USER_COUNT
|
||||
):
|
||||
raise HTTPException(
|
||||
403,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
# If the user does not exist, check if signups are enabled
|
||||
if auth_manager_config.ENABLE_OAUTH_SIGNUP:
|
||||
# Check if an existing user with the same email already exists
|
||||
|
||||
@@ -124,7 +124,7 @@ def convert_messages_openai_to_ollama(messages: list[dict]) -> list[dict]:
|
||||
tool_call_id = message.get("tool_call_id", None)
|
||||
|
||||
# Check if the content is a string (just a simple message)
|
||||
if isinstance(content, str):
|
||||
if isinstance(content, str) and not tool_calls:
|
||||
# If the content is a string, it's pure text
|
||||
new_message["content"] = content
|
||||
|
||||
@@ -230,6 +230,12 @@ def convert_payload_openai_to_ollama(openai_payload: dict) -> dict:
|
||||
"system"
|
||||
] # To prevent Ollama warning of invalid option provided
|
||||
|
||||
# If there is the "stop" parameter in the openai_payload, remap it to the ollama_payload.options
|
||||
if "stop" in openai_payload:
|
||||
ollama_options = ollama_payload.get("options", {})
|
||||
ollama_options["stop"] = openai_payload.get("stop")
|
||||
ollama_payload["options"] = ollama_options
|
||||
|
||||
if "metadata" in openai_payload:
|
||||
ollama_payload["metadata"] = openai_payload["metadata"]
|
||||
|
||||
|
||||
@@ -104,7 +104,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
data = json.loads(data)
|
||||
|
||||
model = data.get("model", "ollama")
|
||||
message_content = data.get("message", {}).get("content", "")
|
||||
message_content = data.get("message", {}).get("content", None)
|
||||
tool_calls = data.get("message", {}).get("tool_calls", None)
|
||||
openai_tool_calls = None
|
||||
|
||||
@@ -118,7 +118,7 @@ async def convert_streaming_response_ollama_to_openai(ollama_streaming_response)
|
||||
usage = convert_ollama_usage_to_openai(data)
|
||||
|
||||
data = openai_chat_chunk_message_template(
|
||||
model, message_content if not done else None, openai_tool_calls, usage
|
||||
model, message_content, openai_tool_calls, usage
|
||||
)
|
||||
|
||||
line = f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
Reference in New Issue
Block a user