From e91a49c455580456bf35a674ec96c007ccd8d611 Mon Sep 17 00:00:00 2001 From: Jun Siang Cheah Date: Wed, 24 Apr 2024 18:10:18 +0100 Subject: [PATCH] feat: add support for using postgres for the backend DB --- backend/apps/litellm/main.py | 13 +- backend/apps/web/internal/db.py | 9 +- .../internal/migrations/001_initial_schema.py | 105 ++++++++++++++ .../internal/migrations/005_add_updated_at.py | 53 +++++++ .../006_migrate_timestamps_and_charfields.py | 130 ++++++++++++++++++ backend/apps/web/models/auths.py | 2 +- backend/apps/web/models/chats.py | 6 +- backend/apps/web/models/documents.py | 6 +- backend/apps/web/models/modelfiles.py | 2 +- backend/apps/web/models/prompts.py | 4 +- backend/apps/web/models/tags.py | 2 +- backend/apps/web/models/users.py | 4 +- backend/apps/web/routers/auths.py | 2 + backend/config.py | 7 + backend/requirements.txt | 2 + 15 files changed, 329 insertions(+), 18 deletions(-) create mode 100644 backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py diff --git a/backend/apps/litellm/main.py b/backend/apps/litellm/main.py index 547bd80ee..f17a1bbca 100644 --- a/backend/apps/litellm/main.py +++ b/backend/apps/litellm/main.py @@ -21,6 +21,8 @@ from utils.utils import get_verified_user, get_current_user, get_admin_user from config import SRC_LOG_LEVELS, ENV from constants import MESSAGES +import os + log = logging.getLogger(__name__) log.setLevel(SRC_LOG_LEVELS["LITELLM"]) @@ -62,6 +64,13 @@ app.state.CONFIG = litellm_config # Global variable to store the subprocess reference background_process = None +CONFLICT_ENV_VARS = [ + # Uvicorn uses PORT, so LiteLLM might use it as well + "PORT", + # LiteLLM uses DATABASE_URL for Prisma connections + "DATABASE_URL", +] + async def run_background_process(command): global background_process @@ -70,9 +79,11 @@ async def run_background_process(command): try: # Log the command to be executed log.info(f"Executing command: {command}") + # Filter environment variables known to conflict with litellm + env = {k: v for k, v in os.environ.items() if k not in CONFLICT_ENV_VARS} # Execute the command and create a subprocess process = await asyncio.create_subprocess_exec( - *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE + *command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=env ) background_process = process log.info("Subprocess started successfully.") diff --git a/backend/apps/web/internal/db.py b/backend/apps/web/internal/db.py index fad566ce9..136e3fafc 100644 --- a/backend/apps/web/internal/db.py +++ b/backend/apps/web/internal/db.py @@ -1,6 +1,7 @@ from peewee import * from peewee_migrate import Router -from config import SRC_LOG_LEVELS, DATA_DIR +from playhouse.db_url import connect +from config import SRC_LOG_LEVELS, DATA_DIR, DATABASE_URL import os import logging @@ -11,12 +12,12 @@ log.setLevel(SRC_LOG_LEVELS["DB"]) if os.path.exists(f"{DATA_DIR}/ollama.db"): # Rename the file os.rename(f"{DATA_DIR}/ollama.db", f"{DATA_DIR}/webui.db") - log.info("File renamed successfully.") + log.info("Database migrated from Ollama-WebUI successfully.") else: pass - -DB = SqliteDatabase(f"{DATA_DIR}/webui.db") +DB = connect(DATABASE_URL) +log.info(f"Connected to a {DB.__class__.__name__} database.") router = Router(DB, migrate_dir="apps/web/internal/migrations", logger=log) router.run() DB.connect(reuse_if_open=True) diff --git a/backend/apps/web/internal/migrations/001_initial_schema.py b/backend/apps/web/internal/migrations/001_initial_schema.py index 24ea6d39f..77788fff9 100644 --- a/backend/apps/web/internal/migrations/001_initial_schema.py +++ b/backend/apps/web/internal/migrations/001_initial_schema.py @@ -37,6 +37,18 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" + # We perform different migrations for SQLite and other databases + # This is because SQLite is very loose with enforcing its schema, and trying to migrate other databases like SQLite + # will require per-database SQL queries. + # Instead, we assume that because external DB support was added at a later date, it is safe to assume a newer base + # schema instead of trying to migrate from an older schema. + if isinstance(database, pw.SqliteDatabase): + migrate_sqlite(migrator, database, fake=fake) + else: + migrate_external(migrator, database, fake=fake) + + +def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): @migrator.create_model class Auth(pw.Model): id = pw.CharField(max_length=255, unique=True) @@ -129,6 +141,99 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): table_name = "user" +def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): + @migrator.create_model + class Auth(pw.Model): + id = pw.CharField(max_length=255, unique=True) + email = pw.CharField(max_length=255) + password = pw.TextField() + active = pw.BooleanField() + + class Meta: + table_name = "auth" + + @migrator.create_model + class Chat(pw.Model): + id = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.TextField() + chat = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "chat" + + @migrator.create_model + class ChatIdTag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + tag_name = pw.CharField(max_length=255) + chat_id = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "chatidtag" + + @migrator.create_model + class Document(pw.Model): + id = pw.AutoField() + collection_name = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255, unique=True) + title = pw.TextField() + filename = pw.TextField() + content = pw.TextField(null=True) + user_id = pw.CharField(max_length=255) + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "document" + + @migrator.create_model + class Modelfile(pw.Model): + id = pw.AutoField() + tag_name = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + modelfile = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "modelfile" + + @migrator.create_model + class Prompt(pw.Model): + id = pw.AutoField() + command = pw.CharField(max_length=255, unique=True) + user_id = pw.CharField(max_length=255) + title = pw.TextField() + content = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "prompt" + + @migrator.create_model + class Tag(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + user_id = pw.CharField(max_length=255) + data = pw.TextField(null=True) + + class Meta: + table_name = "tag" + + @migrator.create_model + class User(pw.Model): + id = pw.CharField(max_length=255, unique=True) + name = pw.CharField(max_length=255) + email = pw.CharField(max_length=255) + role = pw.CharField(max_length=255) + profile_image_url = pw.TextField() + timestamp = pw.BigIntegerField() + + class Meta: + table_name = "user" + + def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" diff --git a/backend/apps/web/internal/migrations/005_add_updated_at.py b/backend/apps/web/internal/migrations/005_add_updated_at.py index 63a023cdb..950866ef0 100644 --- a/backend/apps/web/internal/migrations/005_add_updated_at.py +++ b/backend/apps/web/internal/migrations/005_add_updated_at.py @@ -37,6 +37,13 @@ with suppress(ImportError): def migrate(migrator: Migrator, database: pw.Database, *, fake=False): """Write your migrations here.""" + if isinstance(database, pw.SqliteDatabase): + migrate_sqlite(migrator, database, fake=fake) + else: + migrate_external(migrator, database, fake=fake) + + +def migrate_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): # Adding fields created_at and updated_at to the 'chat' table migrator.add_fields( "chat", @@ -60,9 +67,40 @@ def migrate(migrator: Migrator, database: pw.Database, *, fake=False): ) +def migrate_external(migrator: Migrator, database: pw.Database, *, fake=False): + # Adding fields created_at and updated_at to the 'chat' table + migrator.add_fields( + "chat", + created_at=pw.BigIntegerField(null=True), # Allow null for transition + updated_at=pw.BigIntegerField(null=True), # Allow null for transition + ) + + # Populate the new fields from an existing 'timestamp' field + migrator.sql( + "UPDATE chat SET created_at = timestamp, updated_at = timestamp WHERE timestamp IS NOT NULL" + ) + + # Now that the data has been copied, remove the original 'timestamp' field + migrator.remove_fields("chat", "timestamp") + + # Update the fields to be not null now that they are populated + migrator.change_fields( + "chat", + created_at=pw.BigIntegerField(null=False), + updated_at=pw.BigIntegerField(null=False), + ) + + def rollback(migrator: Migrator, database: pw.Database, *, fake=False): """Write your rollback migrations here.""" + if isinstance(database, pw.SqliteDatabase): + rollback_sqlite(migrator, database, fake=fake) + else: + rollback_external(migrator, database, fake=fake) + + +def rollback_sqlite(migrator: Migrator, database: pw.Database, *, fake=False): # Recreate the timestamp field initially allowing null values for safe transition migrator.add_fields("chat", timestamp=pw.DateTimeField(null=True)) @@ -75,3 +113,18 @@ def rollback(migrator: Migrator, database: pw.Database, *, fake=False): # Finally, alter the timestamp field to not allow nulls if that was the original setting migrator.change_fields("chat", timestamp=pw.DateTimeField(null=False)) + + +def rollback_external(migrator: Migrator, database: pw.Database, *, fake=False): + # Recreate the timestamp field initially allowing null values for safe transition + migrator.add_fields("chat", timestamp=pw.BigIntegerField(null=True)) + + # Copy the earliest created_at date back into the new timestamp field + # This assumes created_at was originally a copy of timestamp + migrator.sql("UPDATE chat SET timestamp = created_at") + + # Remove the created_at and updated_at fields + migrator.remove_fields("chat", "created_at", "updated_at") + + # Finally, alter the timestamp field to not allow nulls if that was the original setting + migrator.change_fields("chat", timestamp=pw.BigIntegerField(null=False)) diff --git a/backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py b/backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py new file mode 100644 index 000000000..caca14d32 --- /dev/null +++ b/backend/apps/web/internal/migrations/006_migrate_timestamps_and_charfields.py @@ -0,0 +1,130 @@ +"""Peewee migrations -- 006_migrate_timestamps_and_charfields.py. + +Some examples (model - class or model name):: + + > Model = migrator.orm['table_name'] # Return model in current state by name + > Model = migrator.ModelClass # Return model in current state by name + + > migrator.sql(sql) # Run custom SQL + > migrator.run(func, *args, **kwargs) # Run python function with the given args + > migrator.create_model(Model) # Create a model (could be used as decorator) + > migrator.remove_model(model, cascade=True) # Remove a model + > migrator.add_fields(model, **fields) # Add fields to a model + > migrator.change_fields(model, **fields) # Change fields + > migrator.remove_fields(model, *field_names, cascade=True) + > migrator.rename_field(model, old_field_name, new_field_name) + > migrator.rename_table(model, new_table_name) + > migrator.add_index(model, *col_names, unique=False) + > migrator.add_not_null(model, *field_names) + > migrator.add_default(model, field_name, default) + > migrator.add_constraint(model, name, sql) + > migrator.drop_index(model, *col_names) + > migrator.drop_not_null(model, *field_names) + > migrator.drop_constraints(model, *constraints) + +""" + +from contextlib import suppress + +import peewee as pw +from peewee_migrate import Migrator + + +with suppress(ImportError): + import playhouse.postgres_ext as pw_pext + + +def migrate(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your migrations here.""" + + # Alter the tables with timestamps + migrator.change_fields( + "chatidtag", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "document", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "modelfile", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "prompt", + timestamp=pw.BigIntegerField(), + ) + migrator.change_fields( + "user", + timestamp=pw.BigIntegerField(), + ) + # Alter the tables with varchar to text where necessary + migrator.change_fields( + "auth", + password=pw.TextField(), + ) + migrator.change_fields( + "chat", + title=pw.TextField(), + ) + migrator.change_fields( + "document", + title=pw.TextField(), + filename=pw.TextField(), + ) + migrator.change_fields( + "prompt", + title=pw.TextField(), + ) + migrator.change_fields( + "user", + profile_image_url=pw.TextField(), + ) + + +def rollback(migrator: Migrator, database: pw.Database, *, fake=False): + """Write your rollback migrations here.""" + + if isinstance(database, pw.SqliteDatabase): + # Alter the tables with timestamps + migrator.change_fields( + "chatidtag", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "document", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "modelfile", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "prompt", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "user", + timestamp=pw.DateField(), + ) + migrator.change_fields( + "auth", + password=pw.CharField(max_length=255), + ) + migrator.change_fields( + "chat", + title=pw.CharField(), + ) + migrator.change_fields( + "document", + title=pw.CharField(), + filename=pw.CharField(), + ) + migrator.change_fields( + "prompt", + title=pw.CharField(), + ) + migrator.change_fields( + "user", + profile_image_url=pw.CharField(), + ) diff --git a/backend/apps/web/models/auths.py b/backend/apps/web/models/auths.py index a97312ff9..9c4e5ffed 100644 --- a/backend/apps/web/models/auths.py +++ b/backend/apps/web/models/auths.py @@ -23,7 +23,7 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) class Auth(Model): id = CharField(unique=True) email = CharField() - password = CharField() + password = TextField() active = BooleanField() class Meta: diff --git a/backend/apps/web/models/chats.py b/backend/apps/web/models/chats.py index ea7fb355d..a2ea7becc 100644 --- a/backend/apps/web/models/chats.py +++ b/backend/apps/web/models/chats.py @@ -17,11 +17,11 @@ from apps.web.internal.db import DB class Chat(Model): id = CharField(unique=True) user_id = CharField() - title = CharField() + title = TextField() chat = TextField() # Save Chat JSON as Text - created_at = DateTimeField() - updated_at = DateTimeField() + created_at = BigIntegerField() + updated_at = BigIntegerField() share_id = CharField(null=True, unique=True) archived = BooleanField(default=False) diff --git a/backend/apps/web/models/documents.py b/backend/apps/web/models/documents.py index 91e721a48..42b99596c 100644 --- a/backend/apps/web/models/documents.py +++ b/backend/apps/web/models/documents.py @@ -25,11 +25,11 @@ log.setLevel(SRC_LOG_LEVELS["MODELS"]) class Document(Model): collection_name = CharField(unique=True) name = CharField(unique=True) - title = CharField() - filename = CharField() + title = TextField() + filename = TextField() content = TextField(null=True) user_id = CharField() - timestamp = DateField() + timestamp = BigIntegerField() class Meta: database = DB diff --git a/backend/apps/web/models/modelfiles.py b/backend/apps/web/models/modelfiles.py index 50439a808..1d60d7c55 100644 --- a/backend/apps/web/models/modelfiles.py +++ b/backend/apps/web/models/modelfiles.py @@ -20,7 +20,7 @@ class Modelfile(Model): tag_name = CharField(unique=True) user_id = CharField() modelfile = TextField() - timestamp = DateField() + timestamp = BigIntegerField() class Meta: database = DB diff --git a/backend/apps/web/models/prompts.py b/backend/apps/web/models/prompts.py index e6b663c04..bc4e3e58b 100644 --- a/backend/apps/web/models/prompts.py +++ b/backend/apps/web/models/prompts.py @@ -19,9 +19,9 @@ import json class Prompt(Model): command = CharField(unique=True) user_id = CharField() - title = CharField() + title = TextField() content = TextField() - timestamp = DateField() + timestamp = BigIntegerField() class Meta: database = DB diff --git a/backend/apps/web/models/tags.py b/backend/apps/web/models/tags.py index 02de5b9d7..d9a967ff7 100644 --- a/backend/apps/web/models/tags.py +++ b/backend/apps/web/models/tags.py @@ -35,7 +35,7 @@ class ChatIdTag(Model): tag_name = CharField() chat_id = CharField() user_id = CharField() - timestamp = DateField() + timestamp = BigIntegerField() class Meta: database = DB diff --git a/backend/apps/web/models/users.py b/backend/apps/web/models/users.py index 7d1e182da..2d228d020 100644 --- a/backend/apps/web/models/users.py +++ b/backend/apps/web/models/users.py @@ -18,8 +18,8 @@ class User(Model): name = CharField() email = CharField() role = CharField() - profile_image_url = CharField() - timestamp = DateField() + profile_image_url = TextField() + timestamp = BigIntegerField() api_key = CharField(null=True, unique=True) class Meta: diff --git a/backend/apps/web/routers/auths.py b/backend/apps/web/routers/auths.py index 89d8c1c8f..321b26034 100644 --- a/backend/apps/web/routers/auths.py +++ b/backend/apps/web/routers/auths.py @@ -1,3 +1,5 @@ +import logging + from fastapi import Request from fastapi import Depends, HTTPException, status diff --git a/backend/config.py b/backend/config.py index 37433d518..5450d4c7d 100644 --- a/backend/config.py +++ b/backend/config.py @@ -534,3 +534,10 @@ LITELLM_PROXY_PORT = int(os.getenv("LITELLM_PROXY_PORT", "14365")) if LITELLM_PROXY_PORT < 0 or LITELLM_PROXY_PORT > 65535: raise ValueError("Invalid port number for LITELLM_PROXY_PORT") LITELLM_PROXY_HOST = os.getenv("LITELLM_PROXY_HOST", "127.0.0.1") + + +#################################### +# Database +#################################### + +DATABASE_URL = os.environ.get("DATABASE_URL", f"sqlite:///{DATA_DIR}/webui.db") diff --git a/backend/requirements.txt b/backend/requirements.txt index 10bcc3b69..336cae17a 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -15,6 +15,8 @@ requests aiohttp peewee peewee-migrate +psycopg2-binary +pymysql bcrypt litellm==1.35.17