Merge pull request #4402 from michaelpoluektov/remove-ollama

refactor: re-use utils in Ollama
This commit is contained in:
Timothy Jaeryang Baek 2024-08-12 00:45:15 +02:00 committed by GitHub
commit 9c2429ff97
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 139 additions and 320 deletions

View File

@ -1,6 +1,5 @@
import asyncio import asyncio
import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client) import websocket # NOTE: websocket-client (https://github.com/websocket-client/websocket-client)
import uuid
import json import json
import urllib.request import urllib.request
import urllib.parse import urllib.parse
@ -398,7 +397,9 @@ async def comfyui_generate_image(
return None return None
try: try:
images = await asyncio.to_thread(get_images, ws, comfyui_prompt, client_id, base_url) images = await asyncio.to_thread(
get_images, ws, comfyui_prompt, client_id, base_url
)
except Exception as e: except Exception as e:
log.exception(f"Error while receiving images: {e}") log.exception(f"Error while receiving images: {e}")
images = None images = None

View File

@ -1,27 +1,21 @@
from fastapi import ( from fastapi import (
FastAPI, FastAPI,
Request, Request,
Response,
HTTPException, HTTPException,
Depends, Depends,
status,
UploadFile, UploadFile,
File, File,
BackgroundTasks,
) )
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
import os import os
import re import re
import copy
import random import random
import requests import requests
import json import json
import uuid
import aiohttp import aiohttp
import asyncio import asyncio
import logging import logging
@ -32,16 +26,11 @@ from typing import Optional, List, Union
from starlette.background import BackgroundTask from starlette.background import BackgroundTask
from apps.webui.models.models import Models from apps.webui.models.models import Models
from apps.webui.models.users import Users
from constants import ERROR_MESSAGES from constants import ERROR_MESSAGES
from utils.utils import ( from utils.utils import (
decode_token,
get_current_user,
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.task import prompt_template
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
@ -53,7 +42,12 @@ from config import (
UPLOAD_DIR, UPLOAD_DIR,
AppConfig, AppConfig,
) )
from utils.misc import calculate_sha256, add_or_update_system_message from utils.misc import (
apply_model_params_to_body_ollama,
calculate_sha256,
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
log.setLevel(SRC_LOG_LEVELS["OLLAMA"]) log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@ -183,7 +177,7 @@ async def post_streaming_url(url: str, payload: str, stream: bool = True):
res = await r.json() res = await r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -238,7 +232,7 @@ async def get_all_models():
async def get_ollama_tags( async def get_ollama_tags(
url_idx: Optional[int] = None, user=Depends(get_verified_user) url_idx: Optional[int] = None, user=Depends(get_verified_user)
): ):
if url_idx == None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
@ -269,7 +263,7 @@ async def get_ollama_tags(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -282,8 +276,7 @@ async def get_ollama_tags(
@app.get("/api/version/{url_idx}") @app.get("/api/version/{url_idx}")
async def get_ollama_versions(url_idx: Optional[int] = None): async def get_ollama_versions(url_idx: Optional[int] = None):
if app.state.config.ENABLE_OLLAMA_API: if app.state.config.ENABLE_OLLAMA_API:
if url_idx == None: if url_idx is None:
# returns lowest version # returns lowest version
tasks = [ tasks = [
fetch_url(f"{url}/api/version") fetch_url(f"{url}/api/version")
@ -323,7 +316,7 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -346,8 +339,6 @@ async def pull_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = None
# Admin should be able to pull models from any source # Admin should be able to pull models from any source
payload = {**form_data.model_dump(exclude_none=True), "insecure": True} payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
@ -367,7 +358,7 @@ async def push_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.name in app.state.MODELS: if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0] url_idx = app.state.MODELS[form_data.name]["urls"][0]
else: else:
@ -417,7 +408,7 @@ async def copy_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.source in app.state.MODELS: if form_data.source in app.state.MODELS:
url_idx = app.state.MODELS[form_data.source]["urls"][0] url_idx = app.state.MODELS[form_data.source]["urls"][0]
else: else:
@ -428,13 +419,13 @@ async def copy_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/copy",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="POST",
url=f"{url}/api/copy",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
log.debug(f"r.text: {r.text}") log.debug(f"r.text: {r.text}")
@ -448,7 +439,7 @@ async def copy_model(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -464,7 +455,7 @@ async def delete_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
if form_data.name in app.state.MODELS: if form_data.name in app.state.MODELS:
url_idx = app.state.MODELS[form_data.name]["urls"][0] url_idx = app.state.MODELS[form_data.name]["urls"][0]
else: else:
@ -476,12 +467,12 @@ async def delete_model(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="DELETE",
url=f"{url}/api/delete",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
log.debug(f"r.text: {r.text}") log.debug(f"r.text: {r.text}")
@ -495,7 +486,7 @@ async def delete_model(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -516,12 +507,12 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/show",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="POST",
url=f"{url}/api/show",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
@ -533,7 +524,7 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -556,7 +547,7 @@ async def generate_embeddings(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx is None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
@ -573,12 +564,12 @@ async def generate_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
return r.json() return r.json()
@ -590,7 +581,7 @@ async def generate_embeddings(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -603,10 +594,9 @@ def generate_ollama_embeddings(
form_data: GenerateEmbeddingsForm, form_data: GenerateEmbeddingsForm,
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
): ):
log.info(f"generate_ollama_embeddings {form_data}") log.info(f"generate_ollama_embeddings {form_data}")
if url_idx == None: if url_idx is None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
@ -623,12 +613,12 @@ def generate_ollama_embeddings(
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
try: try:
r = requests.request(
method="POST",
url=f"{url}/api/embeddings",
data=form_data.model_dump_json(exclude_none=True).encode(),
)
r.raise_for_status() r.raise_for_status()
data = r.json() data = r.json()
@ -638,7 +628,7 @@ def generate_ollama_embeddings(
if "embedding" in data: if "embedding" in data:
return data["embedding"] return data["embedding"]
else: else:
raise "Something went wrong :/" raise Exception("Something went wrong :/")
except Exception as e: except Exception as e:
log.exception(e) log.exception(e)
error_detail = "Open WebUI: Server Connection Error" error_detail = "Open WebUI: Server Connection Error"
@ -647,10 +637,10 @@ def generate_ollama_embeddings(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise error_detail raise Exception(error_detail)
class GenerateCompletionForm(BaseModel): class GenerateCompletionForm(BaseModel):
@ -674,8 +664,7 @@ async def generate_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx is None:
if url_idx == None:
model = form_data.model model = form_data.model
if ":" not in model: if ":" not in model:
@ -713,6 +702,18 @@ class GenerateChatCompletionForm(BaseModel):
keep_alive: Optional[Union[int, str]] = None keep_alive: Optional[Union[int, str]] = None
def get_ollama_url(url_idx: Optional[int], model: str):
if url_idx is None:
if model not in app.state.MODELS:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
)
url_idx = random.choice(app.state.MODELS[model]["urls"])
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
return url
@app.post("/api/chat") @app.post("/api/chat")
@app.post("/api/chat/{url_idx}") @app.post("/api/chat/{url_idx}")
async def generate_chat_completion( async def generate_chat_completion(
@ -720,12 +721,7 @@ async def generate_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
log.debug(
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
form_data.model_dump_json(exclude_none=True).encode()
)
)
payload = { payload = {
**form_data.model_dump(exclude_none=True, exclude=["metadata"]), **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
@ -740,185 +736,21 @@ async def generate_chat_completion(
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() params = model_info.params.model_dump()
if model_info.params: if params:
if payload.get("options") is None: if payload.get("options") is None:
payload["options"] = {} payload["options"] = {}
if ( payload["options"] = apply_model_params_to_body_ollama(
model_info.params.get("mirostat", None) params, payload["options"]
and payload["options"].get("mirostat") is None
):
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
if (
model_info.params.get("mirostat_eta", None)
and payload["options"].get("mirostat_eta") is None
):
payload["options"]["mirostat_eta"] = model_info.params.get(
"mirostat_eta", None
)
if (
model_info.params.get("mirostat_tau", None)
and payload["options"].get("mirostat_tau") is None
):
payload["options"]["mirostat_tau"] = model_info.params.get(
"mirostat_tau", None
)
if (
model_info.params.get("num_ctx", None)
and payload["options"].get("num_ctx") is None
):
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
if (
model_info.params.get("num_batch", None)
and payload["options"].get("num_batch") is None
):
payload["options"]["num_batch"] = model_info.params.get(
"num_batch", None
)
if (
model_info.params.get("num_keep", None)
and payload["options"].get("num_keep") is None
):
payload["options"]["num_keep"] = model_info.params.get("num_keep", None)
if (
model_info.params.get("repeat_last_n", None)
and payload["options"].get("repeat_last_n") is None
):
payload["options"]["repeat_last_n"] = model_info.params.get(
"repeat_last_n", None
)
if (
model_info.params.get("frequency_penalty", None)
and payload["options"].get("frequency_penalty") is None
):
payload["options"]["repeat_penalty"] = model_info.params.get(
"frequency_penalty", None
)
if (
model_info.params.get("temperature", None) is not None
and payload["options"].get("temperature") is None
):
payload["options"]["temperature"] = model_info.params.get(
"temperature", None
)
if (
model_info.params.get("seed", None) is not None
and payload["options"].get("seed") is None
):
payload["options"]["seed"] = model_info.params.get("seed", None)
if (
model_info.params.get("stop", None)
and payload["options"].get("stop") is None
):
payload["options"]["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
if (
model_info.params.get("tfs_z", None)
and payload["options"].get("tfs_z") is None
):
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
if (
model_info.params.get("max_tokens", None)
and payload["options"].get("max_tokens") is None
):
payload["options"]["num_predict"] = model_info.params.get(
"max_tokens", None
)
if (
model_info.params.get("top_k", None)
and payload["options"].get("top_k") is None
):
payload["options"]["top_k"] = model_info.params.get("top_k", None)
if (
model_info.params.get("top_p", None)
and payload["options"].get("top_p") is None
):
payload["options"]["top_p"] = model_info.params.get("top_p", None)
if (
model_info.params.get("min_p", None)
and payload["options"].get("min_p") is None
):
payload["options"]["min_p"] = model_info.params.get("min_p", None)
if (
model_info.params.get("use_mmap", None)
and payload["options"].get("use_mmap") is None
):
payload["options"]["use_mmap"] = model_info.params.get("use_mmap", None)
if (
model_info.params.get("use_mlock", None)
and payload["options"].get("use_mlock") is None
):
payload["options"]["use_mlock"] = model_info.params.get(
"use_mlock", None
)
if (
model_info.params.get("num_thread", None)
and payload["options"].get("num_thread") is None
):
payload["options"]["num_thread"] = model_info.params.get(
"num_thread", None
)
system = model_info.params.get("system", None)
if system:
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
) )
payload = apply_model_system_prompt_to_body(params, payload, user)
if payload.get("messages"): if ":" not in payload["model"]:
payload["messages"] = add_or_update_system_message( payload["model"] = f"{payload['model']}:latest"
system, payload["messages"]
)
if url_idx == None: url = get_ollama_url(url_idx, payload["model"])
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
log.debug(payload) log.debug(payload)
@ -952,83 +784,28 @@ async def generate_openai_chat_completion(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
form_data = OpenAIChatCompletionForm(**form_data) completion_form = OpenAIChatCompletionForm(**form_data)
payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])} payload = {**completion_form.model_dump(exclude_none=True, exclude=["metadata"])}
if "metadata" in payload: if "metadata" in payload:
del payload["metadata"] del payload["metadata"]
model_id = form_data.model model_id = completion_form.model
model_info = Models.get_model_by_id(model_id) model_info = Models.get_model_by_id(model_id)
if model_info: if model_info:
if model_info.base_model_id: if model_info.base_model_id:
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
model_info.params = model_info.params.model_dump() params = model_info.params.model_dump()
if model_info.params: if params:
payload["temperature"] = model_info.params.get("temperature", None) payload = apply_model_params_to_body_openai(params, payload)
payload["top_p"] = model_info.params.get("top_p", None) payload = apply_model_system_prompt_to_body(params, payload, user)
payload["max_tokens"] = model_info.params.get("max_tokens", None)
payload["frequency_penalty"] = model_info.params.get(
"frequency_penalty", None
)
payload["seed"] = model_info.params.get("seed", None)
payload["stop"] = (
[
bytes(stop, "utf-8").decode("unicode_escape")
for stop in model_info.params["stop"]
]
if model_info.params.get("stop", None)
else None
)
system = model_info.params.get("system", None) if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if system: url = get_ollama_url(url_idx, payload["model"])
system = prompt_template(
system,
**(
{
"user_name": user.name,
"user_location": (
user.info.get("location") if user.info else None
),
}
if user
else {}
),
)
# Check if the payload already has a system message
# If not, add a system message to the payload
if payload.get("messages"):
for message in payload["messages"]:
if message.get("role") == "system":
message["content"] = system + message["content"]
break
else:
payload["messages"].insert(
0,
{
"role": "system",
"content": system,
},
)
if url_idx == None:
if ":" not in payload["model"]:
payload["model"] = f"{payload['model']}:latest"
if payload["model"] in app.state.MODELS:
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
else:
raise HTTPException(
status_code=400,
detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
)
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
log.info(f"url: {url}") log.info(f"url: {url}")
return await post_streaming_url( return await post_streaming_url(
@ -1044,7 +821,7 @@ async def get_openai_models(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_verified_user), user=Depends(get_verified_user),
): ):
if url_idx == None: if url_idx is None:
models = await get_all_models() models = await get_all_models()
if app.state.config.ENABLE_MODEL_FILTER: if app.state.config.ENABLE_MODEL_FILTER:
@ -1099,7 +876,7 @@ async def get_openai_models(
res = r.json() res = r.json()
if "error" in res: if "error" in res:
error_detail = f"Ollama: {res['error']}" error_detail = f"Ollama: {res['error']}"
except: except Exception:
error_detail = f"Ollama: {e}" error_detail = f"Ollama: {e}"
raise HTTPException( raise HTTPException(
@ -1125,7 +902,6 @@ def parse_huggingface_url(hf_url):
path_components = parsed_url.path.split("/") path_components = parsed_url.path.split("/")
# Extract the desired output # Extract the desired output
user_repo = "/".join(path_components[1:3])
model_file = path_components[-1] model_file = path_components[-1]
return model_file return model_file
@ -1190,7 +966,6 @@ async def download_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
allowed_hosts = ["https://huggingface.co/", "https://github.com/"] allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
if not any(form_data.url.startswith(host) for host in allowed_hosts): if not any(form_data.url.startswith(host) for host in allowed_hosts):
@ -1199,7 +974,7 @@ async def download_model(
detail="Invalid file_url. Only URLs from allowed hosts are permitted.", detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
) )
if url_idx == None: if url_idx is None:
url_idx = 0 url_idx = 0
url = app.state.config.OLLAMA_BASE_URLS[url_idx] url = app.state.config.OLLAMA_BASE_URLS[url_idx]
@ -1222,7 +997,7 @@ def upload_model(
url_idx: Optional[int] = None, url_idx: Optional[int] = None,
user=Depends(get_admin_user), user=Depends(get_admin_user),
): ):
if url_idx == None: if url_idx is None:
url_idx = 0 url_idx = 0
ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx] ollama_url = app.state.config.OLLAMA_BASE_URLS[url_idx]

View File

@ -17,7 +17,10 @@ from utils.utils import (
get_verified_user, get_verified_user,
get_admin_user, get_admin_user,
) )
from utils.misc import apply_model_params_to_body, apply_model_system_prompt_to_body from utils.misc import (
apply_model_params_to_body_openai,
apply_model_system_prompt_to_body,
)
from config import ( from config import (
SRC_LOG_LEVELS, SRC_LOG_LEVELS,
@ -368,7 +371,7 @@ async def generate_chat_completion(
payload["model"] = model_info.base_model_id payload["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
payload = apply_model_params_to_body(params, payload) payload = apply_model_params_to_body_openai(params, payload)
payload = apply_model_system_prompt_to_body(params, payload, user) payload = apply_model_system_prompt_to_body(params, payload, user)
model = app.state.MODELS[payload.get("model")] model = app.state.MODELS[payload.get("model")]

View File

@ -22,7 +22,7 @@ from apps.webui.utils import load_function_module_by_id
from utils.misc import ( from utils.misc import (
openai_chat_chunk_message_template, openai_chat_chunk_message_template,
openai_chat_completion_message_template, openai_chat_completion_message_template,
apply_model_params_to_body, apply_model_params_to_body_openai,
apply_model_system_prompt_to_body, apply_model_system_prompt_to_body,
) )
@ -291,7 +291,7 @@ async def generate_function_chat_completion(form_data, user):
form_data["model"] = model_info.base_model_id form_data["model"] = model_info.base_model_id
params = model_info.params.model_dump() params = model_info.params.model_dump()
form_data = apply_model_params_to_body(params, form_data) form_data = apply_model_params_to_body_openai(params, form_data)
form_data = apply_model_system_prompt_to_body(params, form_data, user) form_data = apply_model_system_prompt_to_body(params, form_data, user)
pipe_id = get_pipe_id(form_data) pipe_id = get_pipe_id(form_data)

View File

@ -2,7 +2,7 @@ from pathlib import Path
import hashlib import hashlib
import re import re
from datetime import timedelta from datetime import timedelta
from typing import Optional, List, Tuple from typing import Optional, List, Tuple, Callable
import uuid import uuid
import time import time
@ -135,10 +135,21 @@ def apply_model_system_prompt_to_body(params: dict, form_data: dict, user) -> di
# inplace function: form_data is modified # inplace function: form_data is modified
def apply_model_params_to_body(params: dict, form_data: dict) -> dict: def apply_model_params_to_body(
params: dict, form_data: dict, mappings: dict[str, Callable]
) -> dict:
if not params: if not params:
return form_data return form_data
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None:
form_data[key] = cast_func(value)
return form_data
# inplace function: form_data is modified
def apply_model_params_to_body_openai(params: dict, form_data: dict) -> dict:
mappings = { mappings = {
"temperature": float, "temperature": float,
"top_p": int, "top_p": int,
@ -147,10 +158,39 @@ def apply_model_params_to_body(params: dict, form_data: dict) -> dict:
"seed": lambda x: x, "seed": lambda x: x,
"stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x], "stop": lambda x: [bytes(s, "utf-8").decode("unicode_escape") for s in x],
} }
return apply_model_params_to_body(params, form_data, mappings)
for key, cast_func in mappings.items():
if (value := params.get(key)) is not None: def apply_model_params_to_body_ollama(params: dict, form_data: dict) -> dict:
form_data[key] = cast_func(value) opts = [
"temperature",
"top_p",
"seed",
"mirostat",
"mirostat_eta",
"mirostat_tau",
"num_ctx",
"num_batch",
"num_keep",
"repeat_last_n",
"tfs_z",
"top_k",
"min_p",
"use_mmap",
"use_mlock",
"num_thread",
]
mappings = {i: lambda x: x for i in opts}
form_data = apply_model_params_to_body(params, form_data, mappings)
name_differences = {
"max_tokens": "num_predict",
"frequency_penalty": "repeat_penalty",
}
for key, value in name_differences.items():
if (param := params.get(key, None)) is not None:
form_data[value] = param
return form_data return form_data