From 53296e88918b1efddd2d4b39aeb0a47a41c0cae2 Mon Sep 17 00:00:00 2001 From: allegroai <> Date: Sat, 21 Dec 2019 18:13:05 +0200 Subject: [PATCH] Use a single definitive way to obtain server version and build --- server/config/info.py | 56 +++++++++++++----------------- server/services/server/__init__.py | 5 ++- server/updates.py | 6 ++-- 3 files changed, 30 insertions(+), 37 deletions(-) diff --git a/server/config/info.py b/server/config/info.py index 4569159..8ccf634 100644 --- a/server/config/info.py +++ b/server/config/info.py @@ -1,43 +1,37 @@ from functools import lru_cache -from pathlib import Path from os import getenv +from pathlib import Path +from version import __version__ root = Path(__file__).parent.parent -@lru_cache() -def get_build_number(): - try: - return (root / "BUILD").read_text().strip() - except FileNotFoundError: - return "" - - -@lru_cache() -def get_version(): - try: - return (root / "VERSION").read_text().strip() - except FileNotFoundError: - return "" - - -@lru_cache() -def get_commit_number(): - try: - return (root / "COMMIT").read_text().strip() - except FileNotFoundError: - return "" - - -@lru_cache() -def get_deployment_type() -> str: - value = getenv("TRAINS_SERVER_DEPLOYMENT_TYPE") +def _get(prop_name, env_suffix=None, default=""): + value = getenv(f"TRAINS_SERVER_{env_suffix or prop_name}") if value: return value try: - value = (root / "DEPLOY").read_text().strip() + return (root / prop_name).read_text().strip() except FileNotFoundError: - pass + return default - return value or "manual" + +@lru_cache() +def get_build_number(): + return _get("BUILD") + + +@lru_cache() +def get_version(): + return _get("VERSION", default=__version__) + + +@lru_cache() +def get_commit_number(): + return _get("COMMIT") + + +@lru_cache() +def get_deployment_type() -> str: + return _get("DEPLOY", env_suffix="DEPLOYMENT_TYPE", default="manual") diff --git a/server/services/server/__init__.py b/server/services/server/__init__.py index ecd111a..e4f7e58 100644 --- a/server/services/server/__init__.py +++ b/server/services/server/__init__.py @@ -11,7 +11,6 @@ from database.errors import translate_errors_context from database.model import Company from database.model.company import ReportStatsOption from service_repo import ServiceRepo, APICall, endpoint -from version import __version__ as current_version @endpoint("server.get_stats") @@ -79,7 +78,7 @@ def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest) stats_option = ReportStatsOption( enabled=enabled, enabled_time=datetime.utcnow(), - enabled_version=current_version, + enabled_version=get_version(), enabled_user=call.identity.user, ) updated = query.update(defaults__stats_option=stats_option) @@ -88,7 +87,7 @@ def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest) f"Failed setting report_stats to {enabled}" ) data = stats_option.to_mongo() - data["current_version"] = current_version + data["current_version"] = get_version() result = ReportStatsOptionResponse(**data) call.result.data_model = result diff --git a/server/updates.py b/server/updates.py index 3144d7e..322ca04 100644 --- a/server/updates.py +++ b/server/updates.py @@ -8,8 +8,8 @@ import requests from semantic_version import Version from config import config +from config.info import get_version from database.model.settings import Settings -from version import __version__ as current_version log = config.logger(__name__) @@ -48,7 +48,7 @@ class CheckUpdatesThread(Thread): response = requests.get( url, - json={"versions": {self.component_name: str(current_version)}, "uid": uid}, + json={"versions": {self.component_name: str(get_version())}, "uid": uid}, timeout=float( config.get("apiserver.check_for_updates.request_timeout_sec", 3.0) ), @@ -65,7 +65,7 @@ class CheckUpdatesThread(Thread): if not latest_version: return - cur_version = Version(current_version) + cur_version = Version(get_version()) latest_version = Version(latest_version) if cur_version >= latest_version: return