diff --git a/backend/config.py b/backend/config.py index 60d0c563c..4dde80eda 100644 --- a/backend/config.py +++ b/backend/config.py @@ -94,7 +94,7 @@ class Config(Base): updated_at = Column(DateTime, nullable=True, onupdate=func.now()) -def load_initial_config(): +def load_json_config(): with open(f"{DATA_DIR}/config.json", "r") as file: return json.load(file) @@ -107,12 +107,14 @@ def save_to_db(data): db.add(new_config) else: existing_config.data = data + existing_config.updated_at = datetime.now() + db.add(existing_config) db.commit() # When initializing, check if config.json exists and migrate it to the database if os.path.exists(f"{DATA_DIR}/config.json"): - data = load_initial_config() + data = load_json_config() save_to_db(data) os.rename(f"{DATA_DIR}/config.json", f"{DATA_DIR}/old_config.json") @@ -125,6 +127,15 @@ def save_config(): log.exception(e) +def get_config(): + with get_db() as db: + config_entry = db.query(Config).order_by(Config.id.desc()).first() + return config_entry.data if config_entry else {} + + +CONFIG_DATA = get_config() + + def get_config_value(config_path: str): path_parts = config_path.split(".") cur_config = CONFIG_DATA @@ -144,7 +155,7 @@ class PersistentConfig(Generic[T]): self.env_name = env_name self.config_path = config_path self.env_value = env_value - self.config_value = self.load_latest_config_value(config_path) + self.config_value = get_config_value(config_path) if self.config_value is not None: log.info(f"'{env_name}' loaded from the latest database entry") self.value = self.config_value @@ -154,43 +165,29 @@ class PersistentConfig(Generic[T]): def __str__(self): return str(self.value) - def load_latest_config_value(self, config_path: str): - with get_db() as db: - config_entry = db.query(Config).order_by(Config.id.desc()).first() - if config_entry: - try: - path_parts = config_path.split(".") - config_value = config_entry.data - for key in path_parts: - config_value = config_value[key] - return config_value - except KeyError: - return None + @property + def __dict__(self): + raise TypeError( + "PersistentConfig object cannot be converted to dict, use config_get or .value instead." + ) + + def __getattribute__(self, item): + if item == "__dict__": + raise TypeError( + "PersistentConfig object cannot be converted to dict, use config_get or .value instead." + ) + return super().__getattribute__(item) def save(self): - if self.env_value == self.value and self.config_value == self.value: - return log.info(f"Saving '{self.env_name}' to the database") path_parts = self.config_path.split(".") - with get_db() as db: - existing_config = db.query(Config).first() - if existing_config: - config = existing_config.data - for key in path_parts[:-1]: - if key not in config: - config[key] = {} - config = config[key] - config[path_parts[-1]] = self.value - else: # This case should not actually occur as there should always be at least one entry - new_data = {} - config = new_data - for key in path_parts[:-1]: - config[key] = {} - config = config[key] - config[path_parts[-1]] = self.value - new_config = Config(data=new_data, version=0) - db.add(new_config) - db.commit() + sub_config = CONFIG_DATA + for key in path_parts[:-1]: + if key not in sub_config: + sub_config[key] = {} + sub_config = sub_config[key] + sub_config[path_parts[-1]] = self.value + save_to_db(CONFIG_DATA) self.config_value = self.value