This commit is contained in:
Timothy Jaeryang Baek 2024-12-11 20:15:23 -08:00
parent fe5519e0a2
commit a07ff56c50
2 changed files with 55 additions and 50 deletions

View File

@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models):
class ChatCompletionMiddleware(BaseHTTPMiddleware): class ChatCompletionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if not request.method == "POST" and any( if not (
endpoint in request.url.path request.method == "POST"
for endpoint in ["/ollama/api/chat", "/chat/completions"] and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
): ):
return await call_next(request) return await call_next(request)
log.debug(f"request.url.path: {request.url.path}") log.debug(f"request.url.path: {request.url.path}")
@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware)
class PipelineMiddleware(BaseHTTPMiddleware): class PipelineMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
if not request.method == "POST" and any( if not (
endpoint in request.url.path request.method == "POST"
for endpoint in ["/ollama/api/chat", "/chat/completions"] and any(
endpoint in request.url.path
for endpoint in ["/ollama/api/chat", "/chat/completions"]
)
): ):
return await call_next(request) return await call_next(request)
@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}):
return openai_chat_completion_message_template(form_data["model"], message) return openai_chat_completion_message_template(form_data["model"], message)
async def get_all_base_models(): async def get_all_base_models(request):
function_models = [] function_models = []
openai_models = [] openai_models = []
ollama_models = [] ollama_models = []
if app.state.config.ENABLE_OPENAI_API: if app.state.config.ENABLE_OPENAI_API:
openai_models = await openai.get_all_models() openai_models = await openai.get_all_models(request)
openai_models = openai_models["data"] openai_models = openai_models["data"]
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
ollama_models = await ollama.get_all_models() ollama_models = await ollama.get_all_models(request)
ollama_models = [ ollama_models = [
{ {
"id": model["model"], "id": model["model"],
@ -1729,8 +1735,8 @@ async def get_all_base_models():
@cached(ttl=3) @cached(ttl=3)
async def get_all_models(): async def get_all_models(request):
models = await get_all_base_models() models = await get_all_base_models(request)
# If there are no models, return an empty list # If there are no models, return an empty list
if len([model for model in models if not model.get("arena", False)]) == 0: if len([model for model in models if not model.get("arena", False)]) == 0:
@ -1859,8 +1865,8 @@ async def get_all_models():
@app.get("/api/models") @app.get("/api/models")
async def get_models(user=Depends(get_verified_user)): async def get_models(request: Request, user=Depends(get_verified_user)):
models = await get_all_models() models = await get_all_models(request)
# Filter out filter pipelines # Filter out filter pipelines
models = [ models = [
@ -2042,7 +2048,7 @@ async def generate_chat_completions(
async def chat_completed( async def chat_completed(
request: Request, form_data: dict, user=Depends(get_verified_user) request: Request, form_data: dict, user=Depends(get_verified_user)
): ):
model_list = await get_all_models() model_list = await get_all_models(request)
models = {model["id"]: model for model in model_list} models = {model["id"]: model for model in model_list}
data = form_data data = form_data

View File

@ -245,41 +245,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND) raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
return merged_list
async def get_all_models_responses(request: Request) -> list: async def get_all_models_responses(request: Request) -> list:
if not request.app.state.config.ENABLE_OPENAI_API: if not request.app.state.config.ENABLE_OPENAI_API:
return [] return []
@ -379,7 +344,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
if not request.app.state.config.ENABLE_OPENAI_API: if not request.app.state.config.ENABLE_OPENAI_API:
return {"data": []} return {"data": []}
responses = await get_all_models_responses() responses = await get_all_models_responses(request)
def extract_data(response): def extract_data(response):
if response and "data" in response: if response and "data" in response:
@ -388,6 +353,40 @@ async def get_all_models(request: Request) -> dict[str, list]:
return response return response
return None return None
def merge_models_lists(model_lists):
log.debug(f"merge_models_lists {model_lists}")
merged_list = []
for idx, models in enumerate(model_lists):
if models is not None and "error" not in models:
merged_list.extend(
[
{
**model,
"name": model.get("name", model["id"]),
"owned_by": "openai",
"openai": model,
"urlIdx": idx,
}
for model in models
if "api.openai.com"
not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
or not any(
name in model["id"]
for name in [
"babbage",
"dall-e",
"davinci",
"embedding",
"tts",
"whisper",
]
)
]
)
return merged_list
models = {"data": merge_models_lists(map(extract_data, responses))} models = {"data": merge_models_lists(map(extract_data, responses))}
log.debug(f"models: {models}") log.debug(f"models: {models}")