This commit is contained in:
Michael Poluektov 2024-08-06 10:15:29 +01:00
parent a140d319fe
commit 831fe9f509
1 changed files with 24 additions and 39 deletions

View File

@ -1,27 +1,21 @@
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, Request,
Response,
HTTPException, HTTPException,
Depends, Depends,
status,
UploadFile, UploadFile,
File, File,
BackgroundTasks,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import os import os
import re import re
import copy
import random import random
import requests import requests
import json import json
import uuid
import aiohttp import aiohttp
import asyncio import asyncio
import logging import logging
@ -32,11 +26,8 @@ from typing import Optional, List, Union
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token,
get_current_user,
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
@ -183,7 +174,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
res = await r.json() res = await r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -238,7 +229,7 @@ async def get_all_models():
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user) url_idx: Optional[int] = None, user=Depends(get_verified_user)
): ):
if url_idx == None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
@ -269,7 +260,7 @@ async def get_ollama_tags(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -282,8 +273,7 @@ async def get_ollama_tags(
@app.get("/api/version/{url_idx}") @app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None): async def get_ollama_versions(url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
if url_idx == None: if url_idx is None:
# returns lowest version # returns lowest version
tasks = [ tasks = [
fetch_url(f"{url}/api/version") fetch_url(f"{url}/api/version")
@ -323,7 +313,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -367,7 +357,7 @@ async def push_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.name in app.state.MODELS: if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0] url_idx = app.state.MODELS[form_data.name]["urls"][0]
else: else:
@ -417,7 +407,7 @@ async def copy_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.source in app.state.MODELS: if form_data.source in app.state.MODELS:
url_idx = app.state.MODELS[form_data.source]["urls"][0] url_idx = app.state.MODELS[form_data.source]["urls"][0]
else: else:
@ -448,7 +438,7 @@ async def copy_model(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -464,7 +454,7 @@ async def delete_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.name in app.state.MODELS: if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0] url_idx = app.state.MODELS[form_data.name]["urls"][0]
else: else:
@ -495,7 +485,7 @@ async def delete_model(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -533,7 +523,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -556,7 +546,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx is None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
@ -590,7 +580,7 @@ async def generate_embeddings(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -603,10 +593,9 @@ def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
): ):
log.info(f"generate_ollama_embeddings {form_data}") log.info(f"generate_ollama_embeddings {form_data}")
if url_idx == None: if url_idx is None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
@ -638,7 +627,7 @@ def generate_ollama_embeddings(
if "embedding" in data: if "embedding" in data:
return data["embedding"] return data["embedding"]
else: else:
raise "Something went wrong :/" raise Exception("Something went wrong :/")
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
@ -647,10 +636,10 @@ def generate_ollama_embeddings(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise error_detail raise Exception(error_detail)
class GenerateCompletionForm(BaseModel): class GenerateCompletionForm(BaseModel):
@ -674,8 +663,7 @@ async def generate_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx is None:
if url_idx == None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
@ -720,7 +708,6 @@ async def generate_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
log.debug( log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format( "form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode() form_data.model_dump_json(exclude_none=True).encode()
@ -906,7 +893,7 @@ async def generate_chat_completion(
system, payload["messages"] system, payload["messages"]
) )
if url_idx == None: if url_idx is None:
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
@ -1016,7 +1003,7 @@ async def generate_openai_chat_completion(
}, },
) )
if url_idx == None: if url_idx is None:
if ":" not in payload["model"]: if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest" payload["model"] = f"{payload['model']}:latest"
@ -1044,7 +1031,7 @@ async def get_openai_models(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
@ -1099,7 +1086,7 @@ async def get_openai_models(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -1125,7 +1112,6 @@ def parse_huggingface_url(hf_url):
path_components = parsed_url.path.split("/") path_components = parsed_url.path.split("/")
# Extract the desired output # Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1] model_file = path_components[-1]
return model_file return model_file
@ -1190,7 +1176,6 @@ async def download_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"] allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
if not any(form_data.url.startswith(host) for host in allowed_hosts): if not any(form_data.url.startswith(host) for host in allowed_hosts):
@ -1199,7 +1184,7 @@ async def download_model(
detail="Invalid file_url. Only URLs from allowed hosts are permitted.", detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
) )
if url_idx == None: if url_idx is None:
url_idx = 0 url_idx = 0
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@ -1222,7 +1207,7 @@ def upload_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
url_idx = 0 url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]