mirror of
https://github.com/open-webui/open-webui
synced 2024-12-29 15:25:29 +00:00
wip
This commit is contained in:
parent
fe5519e0a2
commit
a07ff56c50
@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models):
|
|||||||
|
|
||||||
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if not request.method == "POST" and any(
|
if not (
|
||||||
endpoint in request.url.path
|
request.method == "POST"
|
||||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
and any(
|
||||||
|
endpoint in request.url.path
|
||||||
|
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
log.debug(f"request.url.path: {request.url.path}")
|
log.debug(f"request.url.path: {request.url.path}")
|
||||||
@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware)
|
|||||||
|
|
||||||
class PipelineMiddleware(BaseHTTPMiddleware):
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
||||||
async def dispatch(self, request: Request, call_next):
|
async def dispatch(self, request: Request, call_next):
|
||||||
if not request.method == "POST" and any(
|
if not (
|
||||||
endpoint in request.url.path
|
request.method == "POST"
|
||||||
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
and any(
|
||||||
|
endpoint in request.url.path
|
||||||
|
for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
||||||
|
)
|
||||||
):
|
):
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}):
|
|||||||
return openai_chat_completion_message_template(form_data["model"], message)
|
return openai_chat_completion_message_template(form_data["model"], message)
|
||||||
|
|
||||||
|
|
||||||
async def get_all_base_models():
|
async def get_all_base_models(request):
|
||||||
function_models = []
|
function_models = []
|
||||||
openai_models = []
|
openai_models = []
|
||||||
ollama_models = []
|
ollama_models = []
|
||||||
|
|
||||||
if app.state.config.ENABLE_OPENAI_API:
|
if app.state.config.ENABLE_OPENAI_API:
|
||||||
openai_models = await openai.get_all_models()
|
openai_models = await openai.get_all_models(request)
|
||||||
openai_models = openai_models["data"]
|
openai_models = openai_models["data"]
|
||||||
|
|
||||||
if app.state.config.ENABLE_OLLAMA_API:
|
if app.state.config.ENABLE_OLLAMA_API:
|
||||||
ollama_models = await ollama.get_all_models()
|
ollama_models = await ollama.get_all_models(request)
|
||||||
ollama_models = [
|
ollama_models = [
|
||||||
{
|
{
|
||||||
"id": model["model"],
|
"id": model["model"],
|
||||||
@ -1729,8 +1735,8 @@ async def get_all_base_models():
|
|||||||
|
|
||||||
|
|
||||||
@cached(ttl=3)
|
@cached(ttl=3)
|
||||||
async def get_all_models():
|
async def get_all_models(request):
|
||||||
models = await get_all_base_models()
|
models = await get_all_base_models(request)
|
||||||
|
|
||||||
# If there are no models, return an empty list
|
# If there are no models, return an empty list
|
||||||
if len([model for model in models if not model.get("arena", False)]) == 0:
|
if len([model for model in models if not model.get("arena", False)]) == 0:
|
||||||
@ -1859,8 +1865,8 @@ async def get_all_models():
|
|||||||
|
|
||||||
|
|
||||||
@app.get("/api/models")
|
@app.get("/api/models")
|
||||||
async def get_models(user=Depends(get_verified_user)):
|
async def get_models(request: Request, user=Depends(get_verified_user)):
|
||||||
models = await get_all_models()
|
models = await get_all_models(request)
|
||||||
|
|
||||||
# Filter out filter pipelines
|
# Filter out filter pipelines
|
||||||
models = [
|
models = [
|
||||||
@ -2042,7 +2048,7 @@ async def generate_chat_completions(
|
|||||||
async def chat_completed(
|
async def chat_completed(
|
||||||
request: Request, form_data: dict, user=Depends(get_verified_user)
|
request: Request, form_data: dict, user=Depends(get_verified_user)
|
||||||
):
|
):
|
||||||
model_list = await get_all_models()
|
model_list = await get_all_models(request)
|
||||||
models = {model["id"]: model for model in model_list}
|
models = {model["id"]: model for model in model_list}
|
||||||
|
|
||||||
data = form_data
|
data = form_data
|
||||||
|
@ -245,41 +245,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
|
|||||||
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
|
||||||
|
|
||||||
|
|
||||||
def merge_models_lists(model_lists):
|
|
||||||
log.debug(f"merge_models_lists {model_lists}")
|
|
||||||
merged_list = []
|
|
||||||
|
|
||||||
for idx, models in enumerate(model_lists):
|
|
||||||
if models is not None and "error" not in models:
|
|
||||||
merged_list.extend(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
**model,
|
|
||||||
"name": model.get("name", model["id"]),
|
|
||||||
"owned_by": "openai",
|
|
||||||
"openai": model,
|
|
||||||
"urlIdx": idx,
|
|
||||||
}
|
|
||||||
for model in models
|
|
||||||
if "api.openai.com"
|
|
||||||
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
|
||||||
or not any(
|
|
||||||
name in model["id"]
|
|
||||||
for name in [
|
|
||||||
"babbage",
|
|
||||||
"dall-e",
|
|
||||||
"davinci",
|
|
||||||
"embedding",
|
|
||||||
"tts",
|
|
||||||
"whisper",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return merged_list
|
|
||||||
|
|
||||||
|
|
||||||
async def get_all_models_responses(request: Request) -> list:
|
async def get_all_models_responses(request: Request) -> list:
|
||||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||||
return []
|
return []
|
||||||
@ -379,7 +344,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
|||||||
if not request.app.state.config.ENABLE_OPENAI_API:
|
if not request.app.state.config.ENABLE_OPENAI_API:
|
||||||
return {"data": []}
|
return {"data": []}
|
||||||
|
|
||||||
responses = await get_all_models_responses()
|
responses = await get_all_models_responses(request)
|
||||||
|
|
||||||
def extract_data(response):
|
def extract_data(response):
|
||||||
if response and "data" in response:
|
if response and "data" in response:
|
||||||
@ -388,6 +353,40 @@ async def get_all_models(request: Request) -> dict[str, list]:
|
|||||||
return response
|
return response
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def merge_models_lists(model_lists):
|
||||||
|
log.debug(f"merge_models_lists {model_lists}")
|
||||||
|
merged_list = []
|
||||||
|
|
||||||
|
for idx, models in enumerate(model_lists):
|
||||||
|
if models is not None and "error" not in models:
|
||||||
|
merged_list.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
**model,
|
||||||
|
"name": model.get("name", model["id"]),
|
||||||
|
"owned_by": "openai",
|
||||||
|
"openai": model,
|
||||||
|
"urlIdx": idx,
|
||||||
|
}
|
||||||
|
for model in models
|
||||||
|
if "api.openai.com"
|
||||||
|
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||||
|
or not any(
|
||||||
|
name in model["id"]
|
||||||
|
for name in [
|
||||||
|
"babbage",
|
||||||
|
"dall-e",
|
||||||
|
"davinci",
|
||||||
|
"embedding",
|
||||||
|
"tts",
|
||||||
|
"whisper",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return merged_list
|
||||||
|
|
||||||
models = {"data": merge_models_lists(map(extract_data, responses))}
|
models = {"data": merge_models_lists(map(extract_data, responses))}
|
||||||
log.debug(f"models: {models}")
|
log.debug(f"models: {models}")
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user