mirror of
https://github.com/clearml/clearml-server
synced 2025-04-08 23:14:44 +00:00
Rename migration script
Support refresh flag in debug image samples Remove silent_dequeue_fail param to prevent status change in case task wasn't queued Add organizations.get_user_companies Fix reset should also reset active_duration Add api_version to server.info
This commit is contained in:
parent
618a0b9473
commit
3272d0f31f
@ -60,6 +60,7 @@ class TaskMetricVariant(Base):
|
||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
scroll_id: Optional[str] = StringField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
|
||||
|
||||
class NextDebugImageSampleRequest(Base):
|
||||
|
@ -211,6 +211,7 @@ class DebugSampleHistory:
|
||||
metric: str,
|
||||
variant: str,
|
||||
iteration: Optional[int] = None,
|
||||
refresh: bool = False,
|
||||
state_id: str = None,
|
||||
) -> DebugSampleHistoryResult:
|
||||
"""
|
||||
@ -225,15 +226,7 @@ class DebugSampleHistory:
|
||||
def init_state(state_: DebugSampleHistoryState):
|
||||
state_.task = task
|
||||
state_.metric = metric
|
||||
variant_iterations = self._get_variant_iterations(
|
||||
es_index=es_index, task=task, metric=metric
|
||||
)
|
||||
state_.variant_states = [
|
||||
VariantState(
|
||||
name=var_name, min_iteration=min_iter, max_iteration=max_iter
|
||||
)
|
||||
for var_name, min_iter, max_iter in variant_iterations
|
||||
]
|
||||
self._reset_variant_states(es_index, state=state_)
|
||||
|
||||
def validate_state(state_: DebugSampleHistoryState):
|
||||
if state_.task != task or state_.metric != metric:
|
||||
@ -241,6 +234,8 @@ class DebugSampleHistory:
|
||||
"Task and metric stored in the state do not match the passed ones",
|
||||
scroll_id=state_.id,
|
||||
)
|
||||
if refresh:
|
||||
self._reset_variant_states(es_index, state=state_)
|
||||
|
||||
state: DebugSampleHistoryState
|
||||
with self.cache_manager.get_or_create_state(
|
||||
@ -291,6 +286,17 @@ class DebugSampleHistory:
|
||||
)
|
||||
return res
|
||||
|
||||
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState):
|
||||
variant_iterations = self._get_variant_iterations(
|
||||
es_index=es_index, task=state.task, metric=state.metric
|
||||
)
|
||||
state.variant_states = [
|
||||
VariantState(
|
||||
name=var_name, min_iteration=min_iter, max_iteration=max_iter
|
||||
)
|
||||
for var_name, min_iter, max_iter in variant_iterations
|
||||
]
|
||||
|
||||
def _get_variant_iterations(
|
||||
self,
|
||||
es_index: str,
|
||||
|
@ -651,9 +651,8 @@ class TaskBLL:
|
||||
company_id: str,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
silent_dequeue_fail=False,
|
||||
):
|
||||
cls.dequeue(task, company_id, silent_dequeue_fail)
|
||||
cls.dequeue(task, company_id)
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
|
@ -5,19 +5,25 @@ from pymongo.database import Database
|
||||
|
||||
|
||||
def _add_active_duration(db: Database):
|
||||
active_duration = "active_duration"
|
||||
query = {active_duration: {"$eq": None}}
|
||||
active_duration_key = "active_duration"
|
||||
query = {"$or": [{active_duration_key: {"$eq": None}}, {active_duration_key: {"$eq": 0}}]}
|
||||
collection = db["task"]
|
||||
for doc in collection.find(
|
||||
filter=query, projection=[active_duration, "status", "started", "completed"]
|
||||
filter=query, projection=[active_duration_key, "status", "started", "completed"]
|
||||
):
|
||||
started = doc.get("started")
|
||||
completed = doc.get("completed")
|
||||
running = doc.get("status") == "running"
|
||||
if started and doc.get(active_duration) is None:
|
||||
active_duration_value = doc.get(active_duration_key)
|
||||
if active_duration_value == 0:
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {active_duration: _get_active_duration(completed, running, started)}},
|
||||
{"$set": {active_duration_key: None}},
|
||||
)
|
||||
elif started and active_duration_value is None:
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]},
|
||||
{"$set": {active_duration_key: _get_active_duration(completed, running, started)}},
|
||||
)
|
||||
|
||||
|
@ -439,6 +439,10 @@
|
||||
description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved"
|
||||
type: integer
|
||||
}
|
||||
refresh {
|
||||
description: "If set then scroll state will be refreshed to reflect the latest changes in the debug images"
|
||||
type: boolean
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID from the previous call to get_debug_image_sample or empty"
|
||||
|
@ -45,4 +45,61 @@ get_tags {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_user_companies {
|
||||
"2.12" {
|
||||
description: "Get details for all companies associated with the current user"
|
||||
request {
|
||||
type: object
|
||||
properties {}
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
companies {
|
||||
description: "List of company information entries. First company is the user's own company"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Company ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Company name"
|
||||
type: string
|
||||
}
|
||||
allocated {
|
||||
description: "Number of users allocated for company"
|
||||
type: integer
|
||||
}
|
||||
owners {
|
||||
description: "Company owners"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "User ID"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "User Name"
|
||||
type: string
|
||||
}
|
||||
avatar {
|
||||
description: "User avatar (URL or base64-encoded data)"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -77,6 +77,21 @@ info {
|
||||
description: "Server UID"
|
||||
type: string
|
||||
}
|
||||
api_version {
|
||||
description: "Max API version supported"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.12": ${info."2.8"} {
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
api_version {
|
||||
description: "Max API version supported"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -643,6 +643,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
|
||||
metric=request.metric,
|
||||
variant=request.variant,
|
||||
iteration=request.iteration,
|
||||
refresh=request.refresh,
|
||||
state_id=request.scroll_id,
|
||||
)
|
||||
call.result.data = attr.asdict(res, recurse=False)
|
||||
|
@ -1,7 +1,9 @@
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.database.model import User
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
|
||||
|
||||
@ -20,3 +22,26 @@ def get_tags(call: APICall, company, request: TagsRequest):
|
||||
ret[field] |= vals
|
||||
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint("organization.get_user_companies")
|
||||
def get_user_companies(call: APICall, company_id: str, _):
|
||||
users = [
|
||||
{
|
||||
"id": u.id,
|
||||
"name": u.name,
|
||||
"avatar": u.avatar,
|
||||
}
|
||||
for u in User.objects(company=company_id).only("avatar", "name", "company")
|
||||
]
|
||||
|
||||
call.result.data = {
|
||||
"companies": [
|
||||
{
|
||||
"id": company_id,
|
||||
"name": call.identity.company_name,
|
||||
"allocated": len(users),
|
||||
"owners": sorted(users, key=itemgetter("name")),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -67,6 +67,12 @@ def info_2_8(call: APICall):
|
||||
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
|
||||
|
||||
|
||||
@endpoint("server.info", min_version="2.12")
|
||||
def info_2_8(call: APICall):
|
||||
info(call)
|
||||
call.result.data["api_version"] = str(ServiceRepo.max_endpoint_version())
|
||||
|
||||
|
||||
@endpoint(
|
||||
"server.report_stats_option",
|
||||
request_data_model=ReportStatsOptionRequest,
|
||||
|
@ -879,7 +879,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
force=force,
|
||||
status_reason="reset",
|
||||
status_message="reset",
|
||||
).execute(started=None, completed=None, published=None, **updates)
|
||||
).execute(started=None, completed=None, published=None, active_duration=None, **updates)
|
||||
)
|
||||
|
||||
# do not return artifacts since they are not serializable
|
||||
@ -904,13 +904,16 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
|
||||
only=("id", "execution", "status", "project", "system_tags"),
|
||||
)
|
||||
for task in tasks:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id,
|
||||
request.status_message,
|
||||
request.status_reason,
|
||||
silent_dequeue_fail=True,
|
||||
)
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id,
|
||||
request.status_message,
|
||||
request.status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
task.update(
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
|
17
apiserver/tests/automated/test_organization.py
Normal file
17
apiserver/tests/automated/test_organization.py
Normal file
@ -0,0 +1,17 @@
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestOrganization(TestService):
|
||||
def setUp(self, version="2.12"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_get_user_companies(self):
|
||||
company = self.api.organization.get_user_companies().companies[0]
|
||||
self.assertEqual(len(company.owners), company.allocated)
|
||||
users = company.owners
|
||||
self.assertTrue(users)
|
||||
self.assertTrue(u1.name < u2.name for u1, u2 in zip(users, users[1:]))
|
||||
for user in company.owners:
|
||||
self.assertTrue(user.id)
|
||||
self.assertTrue(user.name)
|
||||
self.assertIn("avatar", user)
|
Loading…
Reference in New Issue
Block a user