mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Rename migration script
Support refresh flag in debug image samples Remove silent_dequeue_fail param to prevent status change in case task wasn't queued Add organizations.get_user_companies Fix reset should also reset active_duration Add api_version to server.info
This commit is contained in:
parent
618a0b9473
commit
3272d0f31f
@ -60,6 +60,7 @@ class TaskMetricVariant(Base):
|
|||||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||||
iteration: Optional[int] = IntField()
|
iteration: Optional[int] = IntField()
|
||||||
scroll_id: Optional[str] = StringField()
|
scroll_id: Optional[str] = StringField()
|
||||||
|
refresh: bool = BoolField(default=False)
|
||||||
|
|
||||||
|
|
||||||
class NextDebugImageSampleRequest(Base):
|
class NextDebugImageSampleRequest(Base):
|
||||||
|
@ -211,6 +211,7 @@ class DebugSampleHistory:
|
|||||||
metric: str,
|
metric: str,
|
||||||
variant: str,
|
variant: str,
|
||||||
iteration: Optional[int] = None,
|
iteration: Optional[int] = None,
|
||||||
|
refresh: bool = False,
|
||||||
state_id: str = None,
|
state_id: str = None,
|
||||||
) -> DebugSampleHistoryResult:
|
) -> DebugSampleHistoryResult:
|
||||||
"""
|
"""
|
||||||
@ -225,15 +226,7 @@ class DebugSampleHistory:
|
|||||||
def init_state(state_: DebugSampleHistoryState):
|
def init_state(state_: DebugSampleHistoryState):
|
||||||
state_.task = task
|
state_.task = task
|
||||||
state_.metric = metric
|
state_.metric = metric
|
||||||
variant_iterations = self._get_variant_iterations(
|
self._reset_variant_states(es_index, state=state_)
|
||||||
es_index=es_index, task=task, metric=metric
|
|
||||||
)
|
|
||||||
state_.variant_states = [
|
|
||||||
VariantState(
|
|
||||||
name=var_name, min_iteration=min_iter, max_iteration=max_iter
|
|
||||||
)
|
|
||||||
for var_name, min_iter, max_iter in variant_iterations
|
|
||||||
]
|
|
||||||
|
|
||||||
def validate_state(state_: DebugSampleHistoryState):
|
def validate_state(state_: DebugSampleHistoryState):
|
||||||
if state_.task != task or state_.metric != metric:
|
if state_.task != task or state_.metric != metric:
|
||||||
@ -241,6 +234,8 @@ class DebugSampleHistory:
|
|||||||
"Task and metric stored in the state do not match the passed ones",
|
"Task and metric stored in the state do not match the passed ones",
|
||||||
scroll_id=state_.id,
|
scroll_id=state_.id,
|
||||||
)
|
)
|
||||||
|
if refresh:
|
||||||
|
self._reset_variant_states(es_index, state=state_)
|
||||||
|
|
||||||
state: DebugSampleHistoryState
|
state: DebugSampleHistoryState
|
||||||
with self.cache_manager.get_or_create_state(
|
with self.cache_manager.get_or_create_state(
|
||||||
@ -291,6 +286,17 @@ class DebugSampleHistory:
|
|||||||
)
|
)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState):
|
||||||
|
variant_iterations = self._get_variant_iterations(
|
||||||
|
es_index=es_index, task=state.task, metric=state.metric
|
||||||
|
)
|
||||||
|
state.variant_states = [
|
||||||
|
VariantState(
|
||||||
|
name=var_name, min_iteration=min_iter, max_iteration=max_iter
|
||||||
|
)
|
||||||
|
for var_name, min_iter, max_iter in variant_iterations
|
||||||
|
]
|
||||||
|
|
||||||
def _get_variant_iterations(
|
def _get_variant_iterations(
|
||||||
self,
|
self,
|
||||||
es_index: str,
|
es_index: str,
|
||||||
|
@ -651,9 +651,8 @@ class TaskBLL:
|
|||||||
company_id: str,
|
company_id: str,
|
||||||
status_message: str,
|
status_message: str,
|
||||||
status_reason: str,
|
status_reason: str,
|
||||||
silent_dequeue_fail=False,
|
|
||||||
):
|
):
|
||||||
cls.dequeue(task, company_id, silent_dequeue_fail)
|
cls.dequeue(task, company_id)
|
||||||
|
|
||||||
return ChangeStatusRequest(
|
return ChangeStatusRequest(
|
||||||
task=task,
|
task=task,
|
||||||
|
@ -5,19 +5,25 @@ from pymongo.database import Database
|
|||||||
|
|
||||||
|
|
||||||
def _add_active_duration(db: Database):
|
def _add_active_duration(db: Database):
|
||||||
active_duration = "active_duration"
|
active_duration_key = "active_duration"
|
||||||
query = {active_duration: {"$eq": None}}
|
query = {"$or": [{active_duration_key: {"$eq": None}}, {active_duration_key: {"$eq": 0}}]}
|
||||||
collection = db["task"]
|
collection = db["task"]
|
||||||
for doc in collection.find(
|
for doc in collection.find(
|
||||||
filter=query, projection=[active_duration, "status", "started", "completed"]
|
filter=query, projection=[active_duration_key, "status", "started", "completed"]
|
||||||
):
|
):
|
||||||
started = doc.get("started")
|
started = doc.get("started")
|
||||||
completed = doc.get("completed")
|
completed = doc.get("completed")
|
||||||
running = doc.get("status") == "running"
|
running = doc.get("status") == "running"
|
||||||
if started and doc.get(active_duration) is None:
|
active_duration_value = doc.get(active_duration_key)
|
||||||
|
if active_duration_value == 0:
|
||||||
collection.update_one(
|
collection.update_one(
|
||||||
{"_id": doc["_id"]},
|
{"_id": doc["_id"]},
|
||||||
{"$set": {active_duration: _get_active_duration(completed, running, started)}},
|
{"$set": {active_duration_key: None}},
|
||||||
|
)
|
||||||
|
elif started and active_duration_value is None:
|
||||||
|
collection.update_one(
|
||||||
|
{"_id": doc["_id"]},
|
||||||
|
{"$set": {active_duration_key: _get_active_duration(completed, running, started)}},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -439,6 +439,10 @@
|
|||||||
description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved"
|
description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved"
|
||||||
type: integer
|
type: integer
|
||||||
}
|
}
|
||||||
|
refresh {
|
||||||
|
description: "If set then scroll state will be refreshed to reflect the latest changes in the debug images"
|
||||||
|
type: boolean
|
||||||
|
}
|
||||||
scroll_id {
|
scroll_id {
|
||||||
type: string
|
type: string
|
||||||
description: "Scroll ID from the previous call to get_debug_image_sample or empty"
|
description: "Scroll ID from the previous call to get_debug_image_sample or empty"
|
||||||
|
@ -45,4 +45,61 @@ get_tags {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
get_user_companies {
|
||||||
|
"2.12" {
|
||||||
|
description: "Get details for all companies associated with the current user"
|
||||||
|
request {
|
||||||
|
type: object
|
||||||
|
properties {}
|
||||||
|
additionalProperties: false
|
||||||
|
}
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
companies {
|
||||||
|
description: "List of company information entries. First company is the user's own company"
|
||||||
|
type: array
|
||||||
|
items {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
id {
|
||||||
|
description: "Company ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
name {
|
||||||
|
description: "Company name"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
allocated {
|
||||||
|
description: "Number of users allocated for company"
|
||||||
|
type: integer
|
||||||
|
}
|
||||||
|
owners {
|
||||||
|
description: "Company owners"
|
||||||
|
type: array
|
||||||
|
items {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
id {
|
||||||
|
description: "User ID"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
name {
|
||||||
|
description: "User Name"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
avatar {
|
||||||
|
description: "User avatar (URL or base64-encoded data)"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
@ -77,6 +77,21 @@ info {
|
|||||||
description: "Server UID"
|
description: "Server UID"
|
||||||
type: string
|
type: string
|
||||||
}
|
}
|
||||||
|
api_version {
|
||||||
|
description: "Max API version supported"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"2.12": ${info."2.8"} {
|
||||||
|
response {
|
||||||
|
type: object
|
||||||
|
properties {
|
||||||
|
api_version {
|
||||||
|
description: "Max API version supported"
|
||||||
|
type: string
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -643,6 +643,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
|
|||||||
metric=request.metric,
|
metric=request.metric,
|
||||||
variant=request.variant,
|
variant=request.variant,
|
||||||
iteration=request.iteration,
|
iteration=request.iteration,
|
||||||
|
refresh=request.refresh,
|
||||||
state_id=request.scroll_id,
|
state_id=request.scroll_id,
|
||||||
)
|
)
|
||||||
call.result.data = attr.asdict(res, recurse=False)
|
call.result.data = attr.asdict(res, recurse=False)
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from operator import itemgetter
|
||||||
|
|
||||||
from apiserver.apimodels.organization import TagsRequest
|
from apiserver.apimodels.organization import TagsRequest
|
||||||
from apiserver.bll.organization import OrgBLL, Tags
|
from apiserver.bll.organization import OrgBLL, Tags
|
||||||
|
from apiserver.database.model import User
|
||||||
from apiserver.service_repo import endpoint, APICall
|
from apiserver.service_repo import endpoint, APICall
|
||||||
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
|
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
|
||||||
|
|
||||||
@ -20,3 +22,26 @@ def get_tags(call: APICall, company, request: TagsRequest):
|
|||||||
ret[field] |= vals
|
ret[field] |= vals
|
||||||
|
|
||||||
call.result.data = get_tags_response(ret)
|
call.result.data = get_tags_response(ret)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("organization.get_user_companies")
|
||||||
|
def get_user_companies(call: APICall, company_id: str, _):
|
||||||
|
users = [
|
||||||
|
{
|
||||||
|
"id": u.id,
|
||||||
|
"name": u.name,
|
||||||
|
"avatar": u.avatar,
|
||||||
|
}
|
||||||
|
for u in User.objects(company=company_id).only("avatar", "name", "company")
|
||||||
|
]
|
||||||
|
|
||||||
|
call.result.data = {
|
||||||
|
"companies": [
|
||||||
|
{
|
||||||
|
"id": company_id,
|
||||||
|
"name": call.identity.company_name,
|
||||||
|
"allocated": len(users),
|
||||||
|
"owners": sorted(users, key=itemgetter("name")),
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
@ -67,6 +67,12 @@ def info_2_8(call: APICall):
|
|||||||
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
|
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
|
||||||
|
|
||||||
|
|
||||||
|
@endpoint("server.info", min_version="2.12")
|
||||||
|
def info_2_8(call: APICall):
|
||||||
|
info(call)
|
||||||
|
call.result.data["api_version"] = str(ServiceRepo.max_endpoint_version())
|
||||||
|
|
||||||
|
|
||||||
@endpoint(
|
@endpoint(
|
||||||
"server.report_stats_option",
|
"server.report_stats_option",
|
||||||
request_data_model=ReportStatsOptionRequest,
|
request_data_model=ReportStatsOptionRequest,
|
||||||
|
@ -879,7 +879,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
|||||||
force=force,
|
force=force,
|
||||||
status_reason="reset",
|
status_reason="reset",
|
||||||
status_message="reset",
|
status_message="reset",
|
||||||
).execute(started=None, completed=None, published=None, **updates)
|
).execute(started=None, completed=None, published=None, active_duration=None, **updates)
|
||||||
)
|
)
|
||||||
|
|
||||||
# do not return artifacts since they are not serializable
|
# do not return artifacts since they are not serializable
|
||||||
@ -904,13 +904,16 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
|||||||
only=("id", "execution", "status", "project", "system_tags"),
|
only=("id", "execution", "status", "project", "system_tags"),
|
||||||
)
|
)
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
TaskBLL.dequeue_and_change_status(
|
try:
|
||||||
task,
|
TaskBLL.dequeue_and_change_status(
|
||||||
company_id,
|
task,
|
||||||
request.status_message,
|
company_id,
|
||||||
request.status_reason,
|
request.status_message,
|
||||||
silent_dequeue_fail=True,
|
request.status_reason,
|
||||||
)
|
)
|
||||||
|
except APIError:
|
||||||
|
# dequeue may fail if the task was not enqueued
|
||||||
|
pass
|
||||||
task.update(
|
task.update(
|
||||||
status_message=request.status_message,
|
status_message=request.status_message,
|
||||||
status_reason=request.status_reason,
|
status_reason=request.status_reason,
|
||||||
|
17
apiserver/tests/automated/test_organization.py
Normal file
17
apiserver/tests/automated/test_organization.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
from apiserver.tests.automated import TestService
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrganization(TestService):
|
||||||
|
def setUp(self, version="2.12"):
|
||||||
|
super().setUp(version=version)
|
||||||
|
|
||||||
|
def test_get_user_companies(self):
|
||||||
|
company = self.api.organization.get_user_companies().companies[0]
|
||||||
|
self.assertEqual(len(company.owners), company.allocated)
|
||||||
|
users = company.owners
|
||||||
|
self.assertTrue(users)
|
||||||
|
self.assertTrue(u1.name < u2.name for u1, u2 in zip(users, users[1:]))
|
||||||
|
for user in company.owners:
|
||||||
|
self.assertTrue(user.id)
|
||||||
|
self.assertTrue(user.name)
|
||||||
|
self.assertIn("avatar", user)
|
Loading…
Reference in New Issue
Block a user