2024-06-18 13:03:31 +00:00
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import time
|
|
|
|
|
|
|
|
import docker
|
|
|
|
import pytest
|
|
|
|
from docker import DockerClient
|
|
|
|
from pytest_docker.plugin import get_docker_ip
|
|
|
|
from fastapi.testclient import TestClient
|
|
|
|
from sqlalchemy import text, create_engine
|
|
|
|
|
2024-06-24 11:06:15 +00:00
|
|
|
|
2024-06-18 13:03:31 +00:00
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def get_fast_api_client():
|
|
|
|
from main import app
|
|
|
|
|
|
|
|
with TestClient(app) as c:
|
|
|
|
return c
|
|
|
|
|
|
|
|
|
|
|
|
class AbstractIntegrationTest:
|
|
|
|
BASE_PATH = None
|
|
|
|
|
2024-06-25 06:29:18 +00:00
|
|
|
def create_url(self, path="", query_params=None):
|
2024-06-18 13:03:31 +00:00
|
|
|
if self.BASE_PATH is None:
|
|
|
|
raise Exception("BASE_PATH is not set")
|
|
|
|
parts = self.BASE_PATH.split("/")
|
|
|
|
parts = [part.strip() for part in parts if part.strip() != ""]
|
|
|
|
path_parts = path.split("/")
|
|
|
|
path_parts = [part.strip() for part in path_parts if part.strip() != ""]
|
2024-06-25 06:29:18 +00:00
|
|
|
query_parts = ""
|
|
|
|
if query_params:
|
|
|
|
query_parts = "&".join(
|
|
|
|
[f"{key}={value}" for key, value in query_params.items()]
|
|
|
|
)
|
|
|
|
query_parts = f"?{query_parts}"
|
|
|
|
return "/".join(parts + path_parts) + query_parts
|
2024-06-18 13:03:31 +00:00
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setup_class(cls):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def setup_method(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def teardown_class(cls):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def teardown_method(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
class AbstractPostgresTest(AbstractIntegrationTest):
|
|
|
|
DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
|
|
|
|
docker_client: DockerClient
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _create_db_url(cls, env_vars_postgres: dict) -> str:
|
|
|
|
host = get_docker_ip()
|
|
|
|
user = env_vars_postgres["POSTGRES_USER"]
|
|
|
|
pw = env_vars_postgres["POSTGRES_PASSWORD"]
|
|
|
|
port = 8081
|
|
|
|
db = env_vars_postgres["POSTGRES_DB"]
|
|
|
|
return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setup_class(cls):
|
|
|
|
super().setup_class()
|
|
|
|
try:
|
|
|
|
env_vars_postgres = {
|
|
|
|
"POSTGRES_USER": "user",
|
|
|
|
"POSTGRES_PASSWORD": "example",
|
|
|
|
"POSTGRES_DB": "openwebui",
|
|
|
|
}
|
|
|
|
cls.docker_client = docker.from_env()
|
|
|
|
cls.docker_client.containers.run(
|
|
|
|
"postgres:16.2",
|
|
|
|
detach=True,
|
|
|
|
environment=env_vars_postgres,
|
|
|
|
name=cls.DOCKER_CONTAINER_NAME,
|
|
|
|
ports={5432: ("0.0.0.0", 8081)},
|
|
|
|
command="postgres -c log_statement=all",
|
|
|
|
)
|
|
|
|
time.sleep(0.5)
|
|
|
|
|
|
|
|
database_url = cls._create_db_url(env_vars_postgres)
|
|
|
|
os.environ["DATABASE_URL"] = database_url
|
|
|
|
retries = 10
|
|
|
|
db = None
|
|
|
|
while retries > 0:
|
|
|
|
try:
|
|
|
|
from config import BACKEND_DIR
|
2024-06-24 07:57:08 +00:00
|
|
|
|
2024-06-18 13:03:31 +00:00
|
|
|
db = create_engine(database_url, pool_pre_ping=True)
|
|
|
|
db = db.connect()
|
|
|
|
log.info("postgres is ready!")
|
|
|
|
break
|
|
|
|
except Exception as e:
|
|
|
|
log.warning(e)
|
|
|
|
time.sleep(3)
|
|
|
|
retries -= 1
|
|
|
|
|
|
|
|
if db:
|
|
|
|
# import must be after setting env!
|
|
|
|
cls.fast_api_client = get_fast_api_client()
|
|
|
|
db.close()
|
|
|
|
else:
|
|
|
|
raise Exception("Could not connect to Postgres")
|
|
|
|
except Exception as ex:
|
|
|
|
log.error(ex)
|
|
|
|
cls.teardown_class()
|
|
|
|
pytest.fail(f"Could not setup test environment: {ex}")
|
|
|
|
|
|
|
|
def _check_db_connection(self):
|
2024-06-24 11:06:15 +00:00
|
|
|
from apps.webui.internal.db import Session
|
2024-06-24 11:55:18 +00:00
|
|
|
|
2024-06-18 13:03:31 +00:00
|
|
|
retries = 10
|
|
|
|
while retries > 0:
|
|
|
|
try:
|
2024-06-24 11:06:15 +00:00
|
|
|
Session.execute(text("SELECT 1"))
|
|
|
|
Session.commit()
|
2024-06-18 13:03:31 +00:00
|
|
|
break
|
|
|
|
except Exception as e:
|
2024-06-24 11:06:15 +00:00
|
|
|
Session.rollback()
|
2024-06-18 13:03:31 +00:00
|
|
|
log.warning(e)
|
|
|
|
time.sleep(3)
|
|
|
|
retries -= 1
|
|
|
|
|
|
|
|
def setup_method(self):
|
|
|
|
super().setup_method()
|
|
|
|
self._check_db_connection()
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def teardown_class(cls) -> None:
|
|
|
|
super().teardown_class()
|
|
|
|
cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
|
|
|
|
|
|
|
|
def teardown_method(self):
|
2024-06-24 11:06:15 +00:00
|
|
|
from apps.webui.internal.db import Session
|
2024-06-24 11:55:18 +00:00
|
|
|
|
2024-06-18 13:03:31 +00:00
|
|
|
# rollback everything not yet committed
|
2024-06-24 11:06:15 +00:00
|
|
|
Session.commit()
|
2024-06-18 13:03:31 +00:00
|
|
|
|
|
|
|
# truncate all tables
|
|
|
|
tables = [
|
|
|
|
"auth",
|
|
|
|
"chat",
|
|
|
|
"chatidtag",
|
|
|
|
"document",
|
|
|
|
"memory",
|
|
|
|
"model",
|
|
|
|
"prompt",
|
|
|
|
"tag",
|
|
|
|
'"user"',
|
|
|
|
]
|
|
|
|
for table in tables:
|
2024-06-24 11:06:15 +00:00
|
|
|
Session.execute(text(f"TRUNCATE TABLE {table}"))
|
|
|
|
Session.commit()
|