mirror of
https://github.com/open-webui/open-webui
synced 2024-12-28 14:52:23 +00:00
wip
This commit is contained in:
parent
fe5519e0a2
commit
a07ff56c50
@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models):
|
||||
|
||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
if not (
|
||||
request.method == "POST"
|
||||
and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
)
|
||||
):
|
||||
return await call_next(request)
|
||||
log.debug(f"request.url.path: {request.url.path}")
|
||||
@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware)
|
||||
|
||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
if not request.method == "POST" and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
if not (
|
||||
request.method == "POST"
|
||||
and any(
|
||||
endpoint in request.url.path
|
||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||
)
|
||||
):
|
||||
return await call_next(request)
|
||||
|
||||
@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}):
|
||||
return openai_chat_completion_message_template(form_data["model"], message)
|
||||
|
||||
|
||||
async def get_all_base_models():
|
||||
async def get_all_base_models(request):
|
||||
function_models = []
|
||||
openai_models = []
|
||||
ollama_models = []
|
||||
|
||||
if app.state.config.ENABLE_OPENAI_API:
|
||||
openai_models = await openai.get_all_models()
|
||||
openai_models = await openai.get_all_models(request)
|
||||
openai_models = openai_models["data"]
|
||||
|
||||
if app.state.config.ENABLE_OLLAMA_API:
|
||||
ollama_models = await ollama.get_all_models()
|
||||
ollama_models = await ollama.get_all_models(request)
|
||||
ollama_models = [
|
||||
{
|
||||
"id": model["model"],
|
||||
@ -1729,8 +1735,8 @@ async def get_all_base_models():
|
||||
|
||||
|
||||
@cached(ttl=3)
|
||||
async def get_all_models():
|
||||
models = await get_all_base_models()
|
||||
async def get_all_models(request):
|
||||
models = await get_all_base_models(request)
|
||||
|
||||
# If there are no models, return an empty list
|
||||
if len([model for model in models if not model.get("arena", False)]) == 0:
|
||||
@ -1859,8 +1865,8 @@ async def get_all_models():
|
||||
|
||||
|
||||
@app.get("/api/models")
|
||||
async def get_models(user=Depends(get_verified_user)):
|
||||
models = await get_all_models()
|
||||
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||
models = await get_all_models(request)
|
||||
|
||||
# Filter out filter pipelines
|
||||
models = [
|
||||
@ -2042,7 +2048,7 @@ async def generate_chat_completions(
|
||||
async def chat_completed(
|
||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||
):
|
||||
model_list = await get_all_models()
|
||||
model_list = await get_all_models(request)
|
||||
models = {model["id"]: model for model in model_list}
|
||||
|
||||
data = form_data
|
||||
|
@ -245,41 +245,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
log.debug(f"merge_models_lists {model_lists}")
|
||||
merged_list = []
|
||||
|
||||
for idx, models in enumerate(model_lists):
|
||||
if models is not None and "error" not in models:
|
||||
merged_list.extend(
|
||||
[
|
||||
{
|
||||
**model,
|
||||
"name": model.get("name", model["id"]),
|
||||
"owned_by": "openai",
|
||||
"openai": model,
|
||||
"urlIdx": idx,
|
||||
}
|
||||
for model in models
|
||||
if "api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return merged_list
|
||||
|
||||
|
||||
async def get_all_models_responses(request: Request) -> list:
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return []
|
||||
@ -379,7 +344,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||
return {"data": []}
|
||||
|
||||
responses = await get_all_models_responses()
|
||||
responses = await get_all_models_responses(request)
|
||||
|
||||
def extract_data(response):
|
||||
if response and "data" in response:
|
||||
@ -388,6 +353,40 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
||||
return response
|
||||
return None
|
||||
|
||||
def merge_models_lists(model_lists):
|
||||
log.debug(f"merge_models_lists {model_lists}")
|
||||
merged_list = []
|
||||
|
||||
for idx, models in enumerate(model_lists):
|
||||
if models is not None and "error" not in models:
|
||||
merged_list.extend(
|
||||
[
|
||||
{
|
||||
**model,
|
||||
"name": model.get("name", model["id"]),
|
||||
"owned_by": "openai",
|
||||
"openai": model,
|
||||
"urlIdx": idx,
|
||||
}
|
||||
for model in models
|
||||
if "api.openai.com"
|
||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
or not any(
|
||||
name in model["id"]
|
||||
for name in [
|
||||
"babbage",
|
||||
"dall-e",
|
||||
"davinci",
|
||||
"embedding",
|
||||
"tts",
|
||||
"whisper",
|
||||
]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return merged_list
|
||||
|
||||
models = {"data": merge_models_lists(map(extract_data, responses))}
|
||||
log.debug(f"models: {models}")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user