mirror of
https://github.com/open-webui/open-webui
synced 2025-06-26 18:26:48 +00:00
Merge remote-tracking branch 'upstream/dev' into feat/backend-web-search
This commit is contained in:
@@ -18,8 +18,9 @@ import requests
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from typing import Optional, List
|
||||
|
||||
from apps.web.models.models import Models
|
||||
from utils.utils import get_verified_user, get_current_user, get_admin_user
|
||||
from config import SRC_LOG_LEVELS, ENV
|
||||
from config import SRC_LOG_LEVELS
|
||||
from constants import MESSAGES
|
||||
|
||||
import os
|
||||
@@ -77,7 +78,7 @@ with open(LITELLM_CONFIG_DIR, "r") as file:
|
||||
|
||||
app.state.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER.value
|
||||
app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST.value
|
||||
|
||||
app.state.MODEL_CONFIG = Models.get_all_models()
|
||||
|
||||
app.state.ENABLE = ENABLE_LITELLM
|
||||
app.state.CONFIG = litellm_config
|
||||
@@ -261,6 +262,14 @@ async def get_models(user=Depends(get_current_user)):
|
||||
"object": "model",
|
||||
"created": int(time.time()),
|
||||
"owned_by": "openai",
|
||||
"custom_info": next(
|
||||
(
|
||||
item
|
||||
for item in app.state.MODEL_CONFIG
|
||||
if item.id == model["model_name"]
|
||||
),
|
||||
None,
|
||||
),
|
||||
}
|
||||
for model in app.state.CONFIG["model_list"]
|
||||
],
|
||||
|
||||
@@ -29,7 +29,7 @@ import time
|
||||
from urllib.parse import urlparse
|
||||
from typing import Optional, List, Union
|
||||
|
||||
|
||||
from apps.web.models.models import Models
|
||||
from apps.web.models.users import Users
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.utils import (
|
||||
@@ -39,6 +39,8 @@ from utils.utils import (
|
||||
get_admin_user,
|
||||
)
|
||||
|
||||
from utils.models import get_model_id_from_custom_model_id
|
||||
|
||||
|
||||
from config import (
|
||||
SRC_LOG_LEVELS,
|
||||
@@ -68,7 +70,6 @@ app.state.config = AppConfig()
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
||||
app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
|
||||
app.state.MODELS = {}
|
||||
@@ -875,14 +876,93 @@ async def generate_chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
|
||||
log.debug(
|
||||
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
||||
form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
)
|
||||
|
||||
payload = {
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
model_id = form_data.model
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
payload["options"] = {}
|
||||
|
||||
payload["options"]["mirostat"] = model_info.params.get("mirostat", None)
|
||||
payload["options"]["mirostat_eta"] = model_info.params.get(
|
||||
"mirostat_eta", None
|
||||
)
|
||||
payload["options"]["mirostat_tau"] = model_info.params.get(
|
||||
"mirostat_tau", None
|
||||
)
|
||||
payload["options"]["num_ctx"] = model_info.params.get("num_ctx", None)
|
||||
|
||||
payload["options"]["repeat_last_n"] = model_info.params.get(
|
||||
"repeat_last_n", None
|
||||
)
|
||||
payload["options"]["repeat_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
|
||||
payload["options"]["temperature"] = model_info.params.get(
|
||||
"temperature", None
|
||||
)
|
||||
payload["options"]["seed"] = model_info.params.get("seed", None)
|
||||
|
||||
payload["options"]["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
payload["options"]["tfs_z"] = model_info.params.get("tfs_z", None)
|
||||
|
||||
payload["options"]["num_predict"] = model_info.params.get(
|
||||
"max_tokens", None
|
||||
)
|
||||
payload["options"]["top_k"] = model_info.params.get("top_k", None)
|
||||
|
||||
payload["options"]["top_p"] = model_info.params.get("top_p", None)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
|
||||
if url_idx == None:
|
||||
model = form_data.model
|
||||
if ":" not in payload["model"]:
|
||||
payload["model"] = f"{payload['model']}:latest"
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
if payload["model"] in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -892,16 +972,12 @@ async def generate_chat_completion(
|
||||
url = app.state.config.OLLAMA_BASE_URLS[url_idx]
|
||||
log.info(f"url: {url}")
|
||||
|
||||
print(payload)
|
||||
|
||||
r = None
|
||||
|
||||
log.debug(
|
||||
"form_data.model_dump_json(exclude_none=True).encode(): {0} ".format(
|
||||
form_data.model_dump_json(exclude_none=True).encode()
|
||||
)
|
||||
)
|
||||
|
||||
def get_request():
|
||||
nonlocal form_data
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
@@ -910,7 +986,7 @@ async def generate_chat_completion(
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if form_data.stream:
|
||||
if payload.get("stream", None):
|
||||
yield json.dumps({"id": request_id, "done": False}) + "\n"
|
||||
|
||||
for chunk in r.iter_content(chunk_size=8192):
|
||||
@@ -928,7 +1004,7 @@ async def generate_chat_completion(
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/api/chat",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -984,14 +1060,62 @@ async def generate_openai_chat_completion(
|
||||
user=Depends(get_verified_user),
|
||||
):
|
||||
|
||||
payload = {
|
||||
**form_data.model_dump(exclude_none=True),
|
||||
}
|
||||
|
||||
model_id = form_data.model
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
if model_info.params:
|
||||
payload["temperature"] = model_info.params.get("temperature", None)
|
||||
payload["top_p"] = model_info.params.get("top_p", None)
|
||||
payload["max_tokens"] = model_info.params.get("max_tokens", None)
|
||||
payload["frequency_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
payload["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None) + message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
|
||||
if url_idx == None:
|
||||
model = form_data.model
|
||||
if ":" not in payload["model"]:
|
||||
payload["model"] = f"{payload['model']}:latest"
|
||||
|
||||
if ":" not in model:
|
||||
model = f"{model}:latest"
|
||||
|
||||
if model in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[model]["urls"])
|
||||
if payload["model"] in app.state.MODELS:
|
||||
url_idx = random.choice(app.state.MODELS[payload["model"]]["urls"])
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -1004,7 +1128,7 @@ async def generate_openai_chat_completion(
|
||||
r = None
|
||||
|
||||
def get_request():
|
||||
nonlocal form_data
|
||||
nonlocal payload
|
||||
nonlocal r
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
@@ -1013,7 +1137,7 @@ async def generate_openai_chat_completion(
|
||||
|
||||
def stream_content():
|
||||
try:
|
||||
if form_data.stream:
|
||||
if payload.get("stream"):
|
||||
yield json.dumps(
|
||||
{"request_id": request_id, "done": False}
|
||||
) + "\n"
|
||||
@@ -1033,7 +1157,7 @@ async def generate_openai_chat_completion(
|
||||
r = requests.request(
|
||||
method="POST",
|
||||
url=f"{url}/v1/chat/completions",
|
||||
data=form_data.model_dump_json(exclude_none=True).encode(),
|
||||
data=json.dumps(payload),
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ import logging
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
from apps.web.models.models import Models
|
||||
from apps.web.models.users import Users
|
||||
from constants import ERROR_MESSAGES
|
||||
from utils.utils import (
|
||||
@@ -53,7 +53,6 @@ app.state.config = AppConfig()
|
||||
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
||||
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
||||
|
||||
|
||||
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
||||
app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
|
||||
app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
|
||||
@@ -206,7 +205,13 @@ def merge_models_lists(model_lists):
|
||||
if models is not None and "error" not in models:
|
||||
merged_list.extend(
|
||||
[
|
||||
{**model, "urlIdx": idx}
|
||||
{
|
||||
**model,
|
||||
"name": model.get("name", model["id"]),
|
||||
"owned_by": "openai",
|
||||
"openai": model,
|
||||
"urlIdx": idx,
|
||||
}
|
||||
for model in models
|
||||
if "api.openai.com"
|
||||
not in app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
@@ -252,7 +257,7 @@ async def get_all_models():
|
||||
log.info(f"models: {models}")
|
||||
app.state.MODELS = {model["id"]: model for model in models["data"]}
|
||||
|
||||
return models
|
||||
return models
|
||||
|
||||
|
||||
@app.get("/models")
|
||||
@@ -306,44 +311,97 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
|
||||
@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
|
||||
async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
idx = 0
|
||||
pipeline = False
|
||||
|
||||
body = await request.body()
|
||||
# TODO: Remove below after gpt-4-vision fix from Open AI
|
||||
# Try to decode the body of the request from bytes to a UTF-8 string (Require add max_token to fix gpt-4-vision)
|
||||
|
||||
payload = None
|
||||
|
||||
try:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
if "chat/completions" in path:
|
||||
body = body.decode("utf-8")
|
||||
body = json.loads(body)
|
||||
|
||||
model = app.state.MODELS[body.get("model")]
|
||||
payload = {**body}
|
||||
|
||||
idx = model["urlIdx"]
|
||||
model_id = body.get("model")
|
||||
model_info = Models.get_model_by_id(model_id)
|
||||
|
||||
if "pipeline" in model:
|
||||
pipeline = model.get("pipeline")
|
||||
if model_info:
|
||||
print(model_info)
|
||||
if model_info.base_model_id:
|
||||
payload["model"] = model_info.base_model_id
|
||||
|
||||
if pipeline:
|
||||
body["user"] = {"name": user.name, "id": user.id}
|
||||
model_info.params = model_info.params.model_dump()
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
if body.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in body:
|
||||
body["max_tokens"] = 4000
|
||||
log.debug("Modified body_dict:", body)
|
||||
if model_info.params:
|
||||
payload["temperature"] = model_info.params.get("temperature", None)
|
||||
payload["top_p"] = model_info.params.get("top_p", None)
|
||||
payload["max_tokens"] = model_info.params.get("max_tokens", None)
|
||||
payload["frequency_penalty"] = model_info.params.get(
|
||||
"frequency_penalty", None
|
||||
)
|
||||
payload["seed"] = model_info.params.get("seed", None)
|
||||
payload["stop"] = (
|
||||
[
|
||||
bytes(stop, "utf-8").decode("unicode_escape")
|
||||
for stop in model_info.params["stop"]
|
||||
]
|
||||
if model_info.params.get("stop", None)
|
||||
else None
|
||||
)
|
||||
|
||||
# Fix for ChatGPT calls failing because the num_ctx key is in body
|
||||
if "num_ctx" in body:
|
||||
# If 'num_ctx' is in the dictionary, delete it
|
||||
# Leaving it there generates an error with the
|
||||
# OpenAI API (Feb 2024)
|
||||
del body["num_ctx"]
|
||||
if model_info.params.get("system", None):
|
||||
# Check if the payload already has a system message
|
||||
# If not, add a system message to the payload
|
||||
if payload.get("messages"):
|
||||
for message in payload["messages"]:
|
||||
if message.get("role") == "system":
|
||||
message["content"] = (
|
||||
model_info.params.get("system", None)
|
||||
+ message["content"]
|
||||
)
|
||||
break
|
||||
else:
|
||||
payload["messages"].insert(
|
||||
0,
|
||||
{
|
||||
"role": "system",
|
||||
"content": model_info.params.get("system", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
pass
|
||||
|
||||
print(app.state.MODELS)
|
||||
model = app.state.MODELS[payload.get("model")]
|
||||
|
||||
idx = model["urlIdx"]
|
||||
|
||||
if "pipeline" in model and model.get("pipeline"):
|
||||
payload["user"] = {"name": user.name, "id": user.id}
|
||||
payload["title"] = (
|
||||
True
|
||||
if payload["stream"] == False and payload["max_tokens"] == 50
|
||||
else False
|
||||
)
|
||||
|
||||
# Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
|
||||
# This is a workaround until OpenAI fixes the issue with this model
|
||||
if payload.get("model") == "gpt-4-vision-preview":
|
||||
if "max_tokens" not in payload:
|
||||
payload["max_tokens"] = 4000
|
||||
log.debug("Modified payload:", payload)
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
payload = json.dumps(payload)
|
||||
|
||||
# Convert the modified body back to JSON
|
||||
body = json.dumps(body)
|
||||
except json.JSONDecodeError as e:
|
||||
log.error("Error loading request body into a dictionary:", e)
|
||||
|
||||
print(payload)
|
||||
|
||||
url = app.state.config.OPENAI_API_BASE_URLS[idx]
|
||||
key = app.state.config.OPENAI_API_KEYS[idx]
|
||||
|
||||
@@ -362,7 +420,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
|
||||
r = requests.request(
|
||||
method=request.method,
|
||||
url=target_url,
|
||||
data=body,
|
||||
data=payload if payload else body,
|
||||
headers=headers,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
from peewee import *
|
||||
from peewee_migrate import Router
|
||||
from playhouse.db_url import connect
|
||||
@@ -8,6 +10,16 @@ import logging
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["DB"])
|
||||
|
||||
|
||||
class JSONField(TextField):
|
||||
def db_value(self, value):
|
||||
return json.dumps(value)
|
||||
|
||||
def python_value(self, value):
|
||||
if value is not None:
|
||||
return json.loads(value)
|
||||
|
||||
|
||||
# Check if the file exists
|
||||
if os.path.exists(f"{DATA_DIR}/ollama.db"):
|
||||
# Rename the file
|
||||
|
||||
61
backend/apps/web/internal/migrations/009_add_models.py
Normal file
61
backend/apps/web/internal/migrations/009_add_models.py
Normal file
@@ -0,0 +1,61 @@
|
||||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
@migrator.create_model
|
||||
class Model(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
user_id = pw.TextField()
|
||||
base_model_id = pw.TextField(null=True)
|
||||
|
||||
name = pw.TextField()
|
||||
|
||||
meta = pw.TextField()
|
||||
params = pw.TextField()
|
||||
|
||||
created_at = pw.BigIntegerField(null=False)
|
||||
updated_at = pw.BigIntegerField(null=False)
|
||||
|
||||
class Meta:
|
||||
table_name = "model"
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
migrator.remove_model("model")
|
||||
@@ -0,0 +1,130 @@
|
||||
"""Peewee migrations -- 009_add_models.py.
|
||||
|
||||
Some examples (model - class or model name)::
|
||||
|
||||
> Model = migrator.orm['table_name'] # Return model in current state by name
|
||||
> Model = migrator.ModelClass # Return model in current state by name
|
||||
|
||||
> migrator.sql(sql) # Run custom SQL
|
||||
> migrator.run(func, *args, **kwargs) # Run python function with the given args
|
||||
> migrator.create_model(Model) # Create a model (could be used as decorator)
|
||||
> migrator.remove_model(model, cascade=True) # Remove a model
|
||||
> migrator.add_fields(model, **fields) # Add fields to a model
|
||||
> migrator.change_fields(model, **fields) # Change fields
|
||||
> migrator.remove_fields(model, *field_names, cascade=True)
|
||||
> migrator.rename_field(model, old_field_name, new_field_name)
|
||||
> migrator.rename_table(model, new_table_name)
|
||||
> migrator.add_index(model, *col_names, unique=False)
|
||||
> migrator.add_not_null(model, *field_names)
|
||||
> migrator.add_default(model, field_name, default)
|
||||
> migrator.add_constraint(model, name, sql)
|
||||
> migrator.drop_index(model, *col_names)
|
||||
> migrator.drop_not_null(model, *field_names)
|
||||
> migrator.drop_constraints(model, *constraints)
|
||||
|
||||
"""
|
||||
|
||||
from contextlib import suppress
|
||||
|
||||
import peewee as pw
|
||||
from peewee_migrate import Migrator
|
||||
import json
|
||||
|
||||
from utils.misc import parse_ollama_modelfile
|
||||
|
||||
with suppress(ImportError):
|
||||
import playhouse.postgres_ext as pw_pext
|
||||
|
||||
|
||||
def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your migrations here."""
|
||||
|
||||
# Fetch data from 'modelfile' table and insert into 'model' table
|
||||
migrate_modelfile_to_model(migrator, database)
|
||||
# Drop the 'modelfile' table
|
||||
migrator.remove_model("modelfile")
|
||||
|
||||
|
||||
def migrate_modelfile_to_model(migrator: Migrator, database: pw.Database):
|
||||
ModelFile = migrator.orm["modelfile"]
|
||||
Model = migrator.orm["model"]
|
||||
|
||||
modelfiles = ModelFile.select()
|
||||
|
||||
for modelfile in modelfiles:
|
||||
# Extract and transform data in Python
|
||||
|
||||
modelfile.modelfile = json.loads(modelfile.modelfile)
|
||||
meta = json.dumps(
|
||||
{
|
||||
"description": modelfile.modelfile.get("desc"),
|
||||
"profile_image_url": modelfile.modelfile.get("imageUrl"),
|
||||
"ollama": {"modelfile": modelfile.modelfile.get("content")},
|
||||
"suggestion_prompts": modelfile.modelfile.get("suggestionPrompts"),
|
||||
"categories": modelfile.modelfile.get("categories"),
|
||||
"user": {**modelfile.modelfile.get("user", {}), "community": True},
|
||||
}
|
||||
)
|
||||
|
||||
info = parse_ollama_modelfile(modelfile.modelfile.get("content"))
|
||||
|
||||
# Insert the processed data into the 'model' table
|
||||
Model.create(
|
||||
id=f"ollama-{modelfile.tag_name}",
|
||||
user_id=modelfile.user_id,
|
||||
base_model_id=info.get("base_model_id"),
|
||||
name=modelfile.modelfile.get("title"),
|
||||
meta=meta,
|
||||
params=json.dumps(info.get("params", {})),
|
||||
created_at=modelfile.timestamp,
|
||||
updated_at=modelfile.timestamp,
|
||||
)
|
||||
|
||||
|
||||
def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
|
||||
"""Write your rollback migrations here."""
|
||||
|
||||
recreate_modelfile_table(migrator, database)
|
||||
move_data_back_to_modelfile(migrator, database)
|
||||
migrator.remove_model("model")
|
||||
|
||||
|
||||
def recreate_modelfile_table(migrator: Migrator, database: pw.Database):
|
||||
query = """
|
||||
CREATE TABLE IF NOT EXISTS modelfile (
|
||||
user_id TEXT,
|
||||
tag_name TEXT,
|
||||
modelfile JSON,
|
||||
timestamp BIGINT
|
||||
)
|
||||
"""
|
||||
migrator.sql(query)
|
||||
|
||||
|
||||
def move_data_back_to_modelfile(migrator: Migrator, database: pw.Database):
|
||||
Model = migrator.orm["model"]
|
||||
Modelfile = migrator.orm["modelfile"]
|
||||
|
||||
models = Model.select()
|
||||
|
||||
for model in models:
|
||||
# Extract and transform data in Python
|
||||
meta = json.loads(model.meta)
|
||||
|
||||
modelfile_data = {
|
||||
"title": model.name,
|
||||
"desc": meta.get("description"),
|
||||
"imageUrl": meta.get("profile_image_url"),
|
||||
"content": meta.get("ollama", {}).get("modelfile"),
|
||||
"suggestionPrompts": meta.get("suggestion_prompts"),
|
||||
"categories": meta.get("categories"),
|
||||
"user": {k: v for k, v in meta.get("user", {}).items() if k != "community"},
|
||||
}
|
||||
|
||||
# Insert the processed data back into the 'modelfile' table
|
||||
Modelfile.create(
|
||||
user_id=model.user_id,
|
||||
tag_name=model.id,
|
||||
modelfile=modelfile_data,
|
||||
timestamp=model.created_at,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ from apps.web.routers import (
|
||||
users,
|
||||
chats,
|
||||
documents,
|
||||
modelfiles,
|
||||
models,
|
||||
prompts,
|
||||
configs,
|
||||
memories,
|
||||
@@ -40,6 +40,9 @@ app.state.config.DEFAULT_PROMPT_SUGGESTIONS = DEFAULT_PROMPT_SUGGESTIONS
|
||||
app.state.config.DEFAULT_USER_ROLE = DEFAULT_USER_ROLE
|
||||
app.state.config.USER_PERMISSIONS = USER_PERMISSIONS
|
||||
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
||||
|
||||
|
||||
app.state.MODELS = {}
|
||||
app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
|
||||
|
||||
|
||||
@@ -56,11 +59,10 @@ app.include_router(users.router, prefix="/users", tags=["users"])
|
||||
app.include_router(chats.router, prefix="/chats", tags=["chats"])
|
||||
|
||||
app.include_router(documents.router, prefix="/documents", tags=["documents"])
|
||||
app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
|
||||
app.include_router(models.router, prefix="/models", tags=["models"])
|
||||
app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
|
||||
app.include_router(memories.router, prefix="/memories", tags=["memories"])
|
||||
|
||||
|
||||
app.include_router(configs.router, prefix="/configs", tags=["configs"])
|
||||
app.include_router(utils.router, prefix="/utils", tags=["utils"])
|
||||
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
################################################################################
|
||||
# DEPRECATION NOTICE #
|
||||
# #
|
||||
# This file has been deprecated since version 0.2.0. #
|
||||
# #
|
||||
################################################################################
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from peewee import *
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
|
||||
179
backend/apps/web/models/models.py
Normal file
179
backend/apps/web/models/models.py
Normal file
@@ -0,0 +1,179 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import peewee as pw
|
||||
from peewee import *
|
||||
|
||||
from playhouse.shortcuts import model_to_dict
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from apps.web.internal.db import DB, JSONField
|
||||
|
||||
from typing import List, Union, Optional
|
||||
from config import SRC_LOG_LEVELS
|
||||
|
||||
import time
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
log.setLevel(SRC_LOG_LEVELS["MODELS"])
|
||||
|
||||
|
||||
####################
|
||||
# Models DB Schema
|
||||
####################
|
||||
|
||||
|
||||
# ModelParams is a model for the data stored in the params field of the Model table
|
||||
class ModelParams(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
pass
|
||||
|
||||
|
||||
# ModelMeta is a model for the data stored in the meta field of the Model table
|
||||
class ModelMeta(BaseModel):
|
||||
profile_image_url: Optional[str] = "/favicon.png"
|
||||
|
||||
description: Optional[str] = None
|
||||
"""
|
||||
User-facing description of the model.
|
||||
"""
|
||||
|
||||
capabilities: Optional[dict] = None
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class Model(pw.Model):
|
||||
id = pw.TextField(unique=True)
|
||||
"""
|
||||
The model's id as used in the API. If set to an existing model, it will override the model.
|
||||
"""
|
||||
user_id = pw.TextField()
|
||||
|
||||
base_model_id = pw.TextField(null=True)
|
||||
"""
|
||||
An optional pointer to the actual model that should be used when proxying requests.
|
||||
"""
|
||||
|
||||
name = pw.TextField()
|
||||
"""
|
||||
The human-readable display name of the model.
|
||||
"""
|
||||
|
||||
params = JSONField()
|
||||
"""
|
||||
Holds a JSON encoded blob of parameters, see `ModelParams`.
|
||||
"""
|
||||
|
||||
meta = JSONField()
|
||||
"""
|
||||
Holds a JSON encoded blob of metadata, see `ModelMeta`.
|
||||
"""
|
||||
|
||||
updated_at = BigIntegerField()
|
||||
created_at = BigIntegerField()
|
||||
|
||||
class Meta:
|
||||
database = DB
|
||||
|
||||
|
||||
class ModelModel(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
base_model_id: Optional[str] = None
|
||||
|
||||
name: str
|
||||
params: ModelParams
|
||||
meta: ModelMeta
|
||||
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
####################
|
||||
# Forms
|
||||
####################
|
||||
|
||||
|
||||
class ModelResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
meta: ModelMeta
|
||||
updated_at: int # timestamp in epoch
|
||||
created_at: int # timestamp in epoch
|
||||
|
||||
|
||||
class ModelForm(BaseModel):
|
||||
id: str
|
||||
base_model_id: Optional[str] = None
|
||||
name: str
|
||||
meta: ModelMeta
|
||||
params: ModelParams
|
||||
|
||||
|
||||
class ModelsTable:
|
||||
def __init__(
|
||||
self,
|
||||
db: pw.SqliteDatabase | pw.PostgresqlDatabase,
|
||||
):
|
||||
self.db = db
|
||||
self.db.create_tables([Model])
|
||||
|
||||
def insert_new_model(
|
||||
self, form_data: ModelForm, user_id: str
|
||||
) -> Optional[ModelModel]:
|
||||
model = ModelModel(
|
||||
**{
|
||||
**form_data.model_dump(),
|
||||
"user_id": user_id,
|
||||
"created_at": int(time.time()),
|
||||
"updated_at": int(time.time()),
|
||||
}
|
||||
)
|
||||
try:
|
||||
result = Model.create(**model.model_dump())
|
||||
|
||||
if result:
|
||||
return model
|
||||
else:
|
||||
return None
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
def get_all_models(self) -> List[ModelModel]:
|
||||
return [ModelModel(**model_to_dict(model)) for model in Model.select()]
|
||||
|
||||
def get_model_by_id(self, id: str) -> Optional[ModelModel]:
|
||||
try:
|
||||
model = Model.get(Model.id == id)
|
||||
return ModelModel(**model_to_dict(model))
|
||||
except:
|
||||
return None
|
||||
|
||||
def update_model_by_id(self, id: str, model: ModelForm) -> Optional[ModelModel]:
|
||||
try:
|
||||
# update only the fields that are present in the model
|
||||
query = Model.update(**model.model_dump()).where(Model.id == id)
|
||||
query.execute()
|
||||
|
||||
model = Model.get(Model.id == id)
|
||||
return ModelModel(**model_to_dict(model))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
return None
|
||||
|
||||
def delete_model_by_id(self, id: str) -> bool:
|
||||
try:
|
||||
query = Model.delete().where(Model.id == id)
|
||||
query.execute()
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
Models = ModelsTable(DB)
|
||||
@@ -1,124 +0,0 @@
|
||||
from fastapi import Depends, FastAPI, HTTPException, status
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
from apps.web.models.modelfiles import (
|
||||
Modelfiles,
|
||||
ModelfileForm,
|
||||
ModelfileTagNameForm,
|
||||
ModelfileUpdateForm,
|
||||
ModelfileResponse,
|
||||
)
|
||||
|
||||
from utils.utils import get_current_user, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
############################
|
||||
# GetModelfiles
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[ModelfileResponse])
|
||||
async def get_modelfiles(
|
||||
skip: int = 0, limit: int = 50, user=Depends(get_current_user)
|
||||
):
|
||||
return Modelfiles.get_modelfiles(skip, limit)
|
||||
|
||||
|
||||
############################
|
||||
# CreateNewModelfile
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/create", response_model=Optional[ModelfileResponse])
|
||||
async def create_new_modelfile(form_data: ModelfileForm, user=Depends(get_admin_user)):
|
||||
modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
|
||||
|
||||
if modelfile:
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetModelfileByTagName
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/", response_model=Optional[ModelfileResponse])
|
||||
async def get_modelfile_by_tag_name(
|
||||
form_data: ModelfileTagNameForm, user=Depends(get_current_user)
|
||||
):
|
||||
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
|
||||
|
||||
if modelfile:
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateModelfileByTagName
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/update", response_model=Optional[ModelfileResponse])
|
||||
async def update_modelfile_by_tag_name(
|
||||
form_data: ModelfileUpdateForm, user=Depends(get_admin_user)
|
||||
):
|
||||
modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
|
||||
if modelfile:
|
||||
updated_modelfile = {
|
||||
**json.loads(modelfile.modelfile),
|
||||
**form_data.modelfile,
|
||||
}
|
||||
|
||||
modelfile = Modelfiles.update_modelfile_by_tag_name(
|
||||
form_data.tag_name, updated_modelfile
|
||||
)
|
||||
|
||||
return ModelfileResponse(
|
||||
**{
|
||||
**modelfile.model_dump(),
|
||||
"modelfile": json.loads(modelfile.modelfile),
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteModelfileByTagName
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/delete", response_model=bool)
|
||||
async def delete_modelfile_by_tag_name(
|
||||
form_data: ModelfileTagNameForm, user=Depends(get_admin_user)
|
||||
):
|
||||
result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
|
||||
return result
|
||||
108
backend/apps/web/routers/models.py
Normal file
108
backend/apps/web/routers/models.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from fastapi import Depends, FastAPI, HTTPException, status, Request
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Union, Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
import json
|
||||
from apps.web.models.models import Models, ModelModel, ModelForm, ModelResponse
|
||||
|
||||
from utils.utils import get_verified_user, get_admin_user
|
||||
from constants import ERROR_MESSAGES
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
###########################
|
||||
# getModels
|
||||
###########################
|
||||
|
||||
|
||||
@router.get("/", response_model=List[ModelResponse])
|
||||
async def get_models(user=Depends(get_verified_user)):
|
||||
return Models.get_all_models()
|
||||
|
||||
|
||||
############################
|
||||
# AddNewModel
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/add", response_model=Optional[ModelModel])
|
||||
async def add_new_model(
|
||||
request: Request, form_data: ModelForm, user=Depends(get_admin_user)
|
||||
):
|
||||
if form_data.id in request.app.state.MODELS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.MODEL_ID_TAKEN,
|
||||
)
|
||||
else:
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# GetModelById
|
||||
############################
|
||||
|
||||
|
||||
@router.get("/{id}", response_model=Optional[ModelModel])
|
||||
async def get_model_by_id(id: str, user=Depends(get_verified_user)):
|
||||
model = Models.get_model_by_id(id)
|
||||
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.NOT_FOUND,
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# UpdateModelById
|
||||
############################
|
||||
|
||||
|
||||
@router.post("/{id}/update", response_model=Optional[ModelModel])
|
||||
async def update_model_by_id(
|
||||
request: Request, id: str, form_data: ModelForm, user=Depends(get_admin_user)
|
||||
):
|
||||
model = Models.get_model_by_id(id)
|
||||
if model:
|
||||
model = Models.update_model_by_id(id, form_data)
|
||||
return model
|
||||
else:
|
||||
if form_data.id in request.app.state.MODELS:
|
||||
model = Models.insert_new_model(form_data, user.id)
|
||||
print(model)
|
||||
if model:
|
||||
return model
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=ERROR_MESSAGES.DEFAULT(),
|
||||
)
|
||||
|
||||
|
||||
############################
|
||||
# DeleteModelById
|
||||
############################
|
||||
|
||||
|
||||
@router.delete("/{id}/delete", response_model=bool)
|
||||
async def delete_model_by_id(id: str, user=Depends(get_admin_user)):
|
||||
result = Models.delete_model_by_id(id)
|
||||
return result
|
||||
Reference in New Issue
Block a user