Protect against multiple connects to the update server from different processes

Code cleanup
This commit is contained in:
allegroai 2022-02-13 20:12:12 +02:00
parent afdc56f37c
commit b9996e2c1a
6 changed files with 42 additions and 164 deletions

View File

@ -10,7 +10,6 @@ from typing import (
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
@ -240,50 +239,3 @@ class ProjectQueries:
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$ne": []},
}
},
{"$project": {"metadata": 1}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.key"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [r.get("_id") for r in result.get("results", [])]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@ -226,12 +226,6 @@ create_credentials {
}
}
}
"999.0": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
}
get_credentials {

View File

@ -929,55 +929,6 @@ get_hyper_parameters {
}
}
}
get_model_metadata_keys {
"999.0" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"

View File

@ -275,23 +275,6 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"keys": keys,
}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",

View File

@ -26,23 +26,6 @@ class TestQueueAndModelMetadata(TestService):
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
self._temp_model(
"TestMetadata2",
project=project,
metadata=[
{"key": "test_key", "type": "str", "value": "test_value"},
{"key": "test_key2", "type": "str", "value": "test_value"},
],
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({"test_key", "test_key2"}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue("test_key" in res["keys"])
self.assertFalse("test_key2" in res["keys"])
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,
):

View File

@ -1,4 +1,5 @@
import os
from datetime import timedelta, datetime
from threading import Thread
from time import sleep
from typing import Optional
@ -10,6 +11,7 @@ from semantic_version import Version
from apiserver.config_repo import config
from apiserver.config.info import get_version
from apiserver.database.model.settings import Settings
from apiserver.redis_manager import redman
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__name__)
@ -17,6 +19,8 @@ log = config.logger(__name__)
class CheckUpdatesThread(Thread):
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
_lock_name = "check_updates"
_redis = redman.connection("apiserver")
@attr.s(auto_attribs=True)
class _VersionResponse:
@ -29,6 +33,19 @@ class CheckUpdatesThread(Thread):
target=self._check_updates, daemon=True
)
@property
def update_interval(self):
return timedelta(
seconds=max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24,
)
),
60 * 5,
)
)
def start(self) -> None:
if not self._enabled:
log.info("Checking for updates is disabled")
@ -37,12 +54,13 @@ class CheckUpdatesThread(Thread):
@property
def component_name(self) -> str:
return config.get("apiserver.check_for_updates.component_name", "clearml-server")
return config.get(
"apiserver.check_for_updates.component_name", "clearml-server"
)
def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get(
"apiserver.check_for_updates.url",
"https://updates.clear.ml/updates",
"apiserver.check_for_updates.url", "https://updates.clear.ml/updates",
)
uid = Settings.get_by_key("server.uuid")
@ -81,34 +99,31 @@ class CheckUpdatesThread(Thread):
)
def _check_updates(self):
update_interval_sec = max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec",
60 * 60 * 24,
)
),
60 * 5,
)
while not ThreadsManager.terminating:
# noinspection PyBroadException
try:
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception:
log.exception("Failed obtaining updates")
if self._redis.set(
self._lock_name,
value=datetime.utcnow().isoformat(),
ex=self.update_interval - timedelta(seconds=60),
nx=True,
):
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception as ex:
log.exception("Failed obtaining updates: " + str(ex))
sleep(update_interval_sec)
sleep(self.update_interval.total_seconds())
check_updates_thread = CheckUpdatesThread()