This commit is contained in:
Timothy Jaeryang Baek
2024-12-11 20:39:55 -08:00
parent d9ffcea764
commit 866c3dff11
3 changed files with 36 additions and 22 deletions

View File

@@ -344,7 +344,7 @@ async def get_ollama_tags(
models = []
if url_idx is None:
models = await get_all_models()
models = await get_all_models(request)
else:
url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -565,7 +565,7 @@ async def copy_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models()
await get_all_models(request)
models = request.app.state.OLLAMA_MODELS
if form_data.source in models:
@@ -620,7 +620,7 @@ async def delete_model(
user=Depends(get_admin_user),
):
if url_idx is None:
await get_all_models()
await get_all_models(request)
models = request.app.state.OLLAMA_MODELS
if form_data.name in models:
@@ -670,7 +670,7 @@ async def delete_model(
async def show_model_info(
request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
):
await get_all_models()
await get_all_models(request)
models = request.app.state.OLLAMA_MODELS
if form_data.name not in models:
@@ -734,7 +734,7 @@ async def embed(
log.info(f"generate_ollama_batch_embeddings {form_data}")
if url_idx is None:
await get_all_models()
await get_all_models(request)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -803,7 +803,7 @@ async def embeddings(
log.info(f"generate_ollama_embeddings {form_data}")
if url_idx is None:
await get_all_models()
await get_all_models(request)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -878,8 +878,8 @@ async def generate_completion(
user=Depends(get_verified_user),
):
if url_idx is None:
model_list = await get_all_models()
models = {model["model"]: model for model in model_list["models"]}
await get_all_models(request)
models = request.app.state.OLLAMA_MODELS
model = form_data.model
@@ -1200,7 +1200,7 @@ async def get_openai_models(
models = []
if url_idx is None:
model_list = await get_all_models()
model_list = await get_all_models(request)
models = [
{
"id": model["model"],

View File

@@ -404,7 +404,7 @@ async def get_models(
}
if url_idx is None:
models = await get_all_models()
models = await get_all_models(request)
else:
url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
key = request.app.state.config.OPENAI_API_KEYS[url_idx]