Rename migration script

Support refresh flag in debug image samples
Remove silent_dequeue_fail param to prevent status change in case task wasn't queued
Add organizations.get_user_companies
Fix reset should also reset active_duration
Add api_version to server.info
This commit is contained in:
allegroai 2021-01-05 18:09:34 +02:00
parent 618a0b9473
commit 3272d0f31f
12 changed files with 164 additions and 24 deletions

View File

@ -60,6 +60,7 @@ class TaskMetricVariant(Base):
class GetDebugImageSampleRequest(TaskMetricVariant): class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField() iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField() scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False)
class NextDebugImageSampleRequest(Base): class NextDebugImageSampleRequest(Base):

View File

@ -211,6 +211,7 @@ class DebugSampleHistory:
metric: str, metric: str,
variant: str, variant: str,
iteration: Optional[int] = None, iteration: Optional[int] = None,
refresh: bool = False,
state_id: str = None, state_id: str = None,
) -> DebugSampleHistoryResult: ) -> DebugSampleHistoryResult:
""" """
@ -225,15 +226,7 @@ class DebugSampleHistory:
def init_state(state_: DebugSampleHistoryState): def init_state(state_: DebugSampleHistoryState):
state_.task = task state_.task = task
state_.metric = metric state_.metric = metric
variant_iterations = self._get_variant_iterations( self._reset_variant_states(es_index, state=state_)
es_index=es_index, task=task, metric=metric
)
state_.variant_states = [
VariantState(
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
for var_name, min_iter, max_iter in variant_iterations
]
def validate_state(state_: DebugSampleHistoryState): def validate_state(state_: DebugSampleHistoryState):
if state_.task != task or state_.metric != metric: if state_.task != task or state_.metric != metric:
@ -241,6 +234,8 @@ class DebugSampleHistory:
"Task and metric stored in the state do not match the passed ones", "Task and metric stored in the state do not match the passed ones",
scroll_id=state_.id, scroll_id=state_.id,
) )
if refresh:
self._reset_variant_states(es_index, state=state_)
state: DebugSampleHistoryState state: DebugSampleHistoryState
with self.cache_manager.get_or_create_state( with self.cache_manager.get_or_create_state(
@ -291,6 +286,17 @@ class DebugSampleHistory:
) )
return res return res
def _reset_variant_states(self, es_index, state: DebugSampleHistoryState):
variant_iterations = self._get_variant_iterations(
es_index=es_index, task=state.task, metric=state.metric
)
state.variant_states = [
VariantState(
name=var_name, min_iteration=min_iter, max_iteration=max_iter
)
for var_name, min_iter, max_iter in variant_iterations
]
def _get_variant_iterations( def _get_variant_iterations(
self, self,
es_index: str, es_index: str,

View File

@ -651,9 +651,8 @@ class TaskBLL:
company_id: str, company_id: str,
status_message: str, status_message: str,
status_reason: str, status_reason: str,
silent_dequeue_fail=False,
): ):
cls.dequeue(task, company_id, silent_dequeue_fail) cls.dequeue(task, company_id)
return ChangeStatusRequest( return ChangeStatusRequest(
task=task, task=task,

View File

@ -5,19 +5,25 @@ from pymongo.database import Database
def _add_active_duration(db: Database): def _add_active_duration(db: Database):
active_duration = "active_duration" active_duration_key = "active_duration"
query = {active_duration: {"$eq": None}} query = {"$or": [{active_duration_key: {"$eq": None}}, {active_duration_key: {"$eq": 0}}]}
collection = db["task"] collection = db["task"]
for doc in collection.find( for doc in collection.find(
filter=query, projection=[active_duration, "status", "started", "completed"] filter=query, projection=[active_duration_key, "status", "started", "completed"]
): ):
started = doc.get("started") started = doc.get("started")
completed = doc.get("completed") completed = doc.get("completed")
running = doc.get("status") == "running" running = doc.get("status") == "running"
if started and doc.get(active_duration) is None: active_duration_value = doc.get(active_duration_key)
if active_duration_value == 0:
collection.update_one( collection.update_one(
{"_id": doc["_id"]}, {"_id": doc["_id"]},
{"$set": {active_duration: _get_active_duration(completed, running, started)}}, {"$set": {active_duration_key: None}},
)
elif started and active_duration_value is None:
collection.update_one(
{"_id": doc["_id"]},
{"$set": {active_duration_key: _get_active_duration(completed, running, started)}},
) )

View File

@ -439,6 +439,10 @@
description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved" description: "The iteration to bring debug image from. If not specified then the latest reported image is retrieved"
type: integer type: integer
} }
refresh {
description: "If set then scroll state will be refreshed to reflect the latest changes in the debug images"
type: boolean
}
scroll_id { scroll_id {
type: string type: string
description: "Scroll ID from the previous call to get_debug_image_sample or empty" description: "Scroll ID from the previous call to get_debug_image_sample or empty"

View File

@ -45,4 +45,61 @@ get_tags {
} }
} }
} }
}
get_user_companies {
"2.12" {
description: "Get details for all companies associated with the current user"
request {
type: object
properties {}
additionalProperties: false
}
response {
type: object
properties {
companies {
description: "List of company information entries. First company is the user's own company"
type: array
items {
type: object
properties {
id {
description: "Company ID"
type: string
}
name {
description: "Company name"
type: string
}
allocated {
description: "Number of users allocated for company"
type: integer
}
owners {
description: "Company owners"
type: array
items {
type: object
properties {
id {
description: "User ID"
type: string
}
name {
description: "User Name"
type: string
}
avatar {
description: "User avatar (URL or base64-encoded data)"
type: string
}
}
}
}
}
}
}
}
}
}
} }

View File

@ -77,6 +77,21 @@ info {
description: "Server UID" description: "Server UID"
type: string type: string
} }
api_version {
description: "Max API version supported"
type: string
}
}
}
}
"2.12": ${info."2.8"} {
response {
type: object
properties {
api_version {
description: "Max API version supported"
type: string
}
} }
} }
} }

View File

@ -643,6 +643,7 @@ def get_debug_image_sample(call, company_id, request: GetDebugImageSampleRequest
metric=request.metric, metric=request.metric,
variant=request.variant, variant=request.variant,
iteration=request.iteration, iteration=request.iteration,
refresh=request.refresh,
state_id=request.scroll_id, state_id=request.scroll_id,
) )
call.result.data = attr.asdict(res, recurse=False) call.result.data = attr.asdict(res, recurse=False)

View File

@ -1,7 +1,9 @@
from collections import defaultdict from collections import defaultdict
from operator import itemgetter
from apiserver.apimodels.organization import TagsRequest from apiserver.apimodels.organization import TagsRequest
from apiserver.bll.organization import OrgBLL, Tags from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User
from apiserver.service_repo import endpoint, APICall from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
@ -20,3 +22,26 @@ def get_tags(call: APICall, company, request: TagsRequest):
ret[field] |= vals ret[field] |= vals
call.result.data = get_tags_response(ret) call.result.data = get_tags_response(ret)
@endpoint("organization.get_user_companies")
def get_user_companies(call: APICall, company_id: str, _):
users = [
{
"id": u.id,
"name": u.name,
"avatar": u.avatar,
}
for u in User.objects(company=company_id).only("avatar", "name", "company")
]
call.result.data = {
"companies": [
{
"id": company_id,
"name": call.identity.company_name,
"allocated": len(users),
"owners": sorted(users, key=itemgetter("name")),
}
]
}

View File

@ -67,6 +67,12 @@ def info_2_8(call: APICall):
call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid) call.result.data["uid"] = Settings.get_by_key(SettingKeys.server__uuid)
@endpoint("server.info", min_version="2.12")
def info_2_8(call: APICall):
info(call)
call.result.data["api_version"] = str(ServiceRepo.max_endpoint_version())
@endpoint( @endpoint(
"server.report_stats_option", "server.report_stats_option",
request_data_model=ReportStatsOptionRequest, request_data_model=ReportStatsOptionRequest,

View File

@ -879,7 +879,7 @@ def reset(call: APICall, company_id, request: ResetRequest):
force=force, force=force,
status_reason="reset", status_reason="reset",
status_message="reset", status_message="reset",
).execute(started=None, completed=None, published=None, **updates) ).execute(started=None, completed=None, published=None, active_duration=None, **updates)
) )
# do not return artifacts since they are not serializable # do not return artifacts since they are not serializable
@ -904,13 +904,16 @@ def archive(call: APICall, company_id, request: ArchiveRequest):
only=("id", "execution", "status", "project", "system_tags"), only=("id", "execution", "status", "project", "system_tags"),
) )
for task in tasks: for task in tasks:
TaskBLL.dequeue_and_change_status( try:
task, TaskBLL.dequeue_and_change_status(
company_id, task,
request.status_message, company_id,
request.status_reason, request.status_message,
silent_dequeue_fail=True, request.status_reason,
) )
except APIError:
# dequeue may fail if the task was not enqueued
pass
task.update( task.update(
status_message=request.status_message, status_message=request.status_message,
status_reason=request.status_reason, status_reason=request.status_reason,

View File

@ -0,0 +1,17 @@
from apiserver.tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_get_user_companies(self):
company = self.api.organization.get_user_companies().companies[0]
self.assertEqual(len(company.owners), company.allocated)
users = company.owners
self.assertTrue(users)
self.assertTrue(u1.name < u2.name for u1, u2 in zip(users, users[1:]))
for user in company.owners:
self.assertTrue(user.id)
self.assertTrue(user.name)
self.assertIn("avatar", user)