Protect against multiple connects to the update server from different processes

Code cleanup
This commit is contained in:
allegroai 2022-02-13 20:12:12 +02:00
parent afdc56f37c
commit b9996e2c1a
6 changed files with 42 additions and 164 deletions

View File

@ -10,7 +10,6 @@ from typing import (
from redis import StrictRedis from redis import StrictRedis
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get from apiserver.utilities.dicts import nested_get
@ -240,50 +239,3 @@ class ProjectQueries:
result = Task.aggregate(pipeline) result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result] return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$ne": []},
}
},
{"$project": {"metadata": 1}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.key"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [r.get("_id") for r in result.get("results", [])]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@ -226,12 +226,6 @@ create_credentials {
} }
} }
} }
"999.0": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
} }
get_credentials { get_credentials {

View File

@ -929,55 +929,6 @@ get_hyper_parameters {
} }
} }
} }
get_model_metadata_keys {
"999.0" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_task_tags { get_task_tags {
"2.8" { "2.8" {
description: "Get user and system tags used for the tasks under the specified projects" description: "Get user and system tags used for the tasks under the specified projects"

View File

@ -275,23 +275,6 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics} call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"keys": keys,
}
@endpoint( @endpoint(
"projects.get_hyper_parameters", "projects.get_hyper_parameters",
min_version="2.9", min_version="2.9",

View File

@ -26,23 +26,6 @@ class TestQueueAndModelMetadata(TestService):
self.api.models.edit(model=model_id, metadata=[self.meta1[0]]) self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1) self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
self._temp_model(
"TestMetadata2",
project=project,
metadata=[
{"key": "test_key", "type": "str", "value": "test_value"},
{"key": "test_key2", "type": "str", "value": "test_value"},
],
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue("test_key" in res["keys"])
self.assertFalse("test_key2" in res["keys"])
def _test_meta_operations( def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str, self, service: APIClient.Service, entity: str, _id: str,
): ):

View File

@ -1,4 +1,5 @@
import os import os
from datetime import timedelta, datetime
from threading import Thread from threading import Thread
from time import sleep from time import sleep
from typing import Optional from typing import Optional
@ -10,6 +11,7 @@ from semantic_version import Version
from apiserver.config_repo import config from apiserver.config_repo import config
from apiserver.config.info import get_version from apiserver.config.info import get_version
from apiserver.database.model.settings import Settings from apiserver.database.model.settings import Settings
from apiserver.redis_manager import redman
from apiserver.utilities.threads_manager import ThreadsManager from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__name__) log = config.logger(__name__)
@ -17,6 +19,8 @@ log = config.logger(__name__)
class CheckUpdatesThread(Thread): class CheckUpdatesThread(Thread):
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True)) _enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
_lock_name = "check_updates"
_redis = redman.connection("apiserver")
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
class _VersionResponse: class _VersionResponse:
@ -29,6 +33,19 @@ class CheckUpdatesThread(Thread):
target=self._check_updates, daemon=True target=self._check_updates, daemon=True
) )
@property
def update_interval(self):
return timedelta(
seconds=max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24,
)
),
60 * 5,
)
)
def start(self) -> None: def start(self) -> None:
if not self._enabled: if not self._enabled:
log.info("Checking for updates is disabled") log.info("Checking for updates is disabled")
@ -37,12 +54,13 @@ class CheckUpdatesThread(Thread):
@property @property
def component_name(self) -> str: def component_name(self) -> str:
return config.get("apiserver.check_for_updates.component_name", "clearml-server") return config.get(
"apiserver.check_for_updates.component_name", "clearml-server"
)
def _check_new_version_available(self) -> Optional[_VersionResponse]: def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get( url = config.get(
"apiserver.check_for_updates.url", "apiserver.check_for_updates.url", "https://updates.clear.ml/updates",
"https://updates.clear.ml/updates",
) )
uid = Settings.get_by_key("server.uuid") uid = Settings.get_by_key("server.uuid")
@ -81,34 +99,31 @@ class CheckUpdatesThread(Thread):
) )
def _check_updates(self): def _check_updates(self):
update_interval_sec = max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec",
60 * 60 * 24,
)
),
60 * 5,
)
while not ThreadsManager.terminating: while not ThreadsManager.terminating:
# noinspection PyBroadException # noinspection PyBroadException
try: try:
response = self._check_new_version_available() if self._redis.set(
if response: self._lock_name,
if response.patch_upgrade: value=datetime.utcnow().isoformat(),
log.info( ex=self.update_interval - timedelta(seconds=60),
f"{self.component_name.upper()} new package available: upgrade to v{response.version} " nx=True,
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}" ):
) response = self._check_new_version_available()
else: if response:
log.info( if response.patch_upgrade:
f"{self.component_name.upper()} new version available: upgrade to v{response.version}" log.info(
f" is recommended!" f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
) f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
except Exception: )
log.exception("Failed obtaining updates") else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception as ex:
log.exception("Failed obtaining updates: " + str(ex))
sleep(update_interval_sec) sleep(self.update_interval.total_seconds())
check_updates_thread = CheckUpdatesThread() check_updates_thread = CheckUpdatesThread()