mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
66 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
124684f53f | ||
|
|
455b5d6758 | ||
|
|
c04e2e498b | ||
|
|
da8a45072f | ||
|
|
e1992e2054 | ||
|
|
c17cedd93a | ||
|
|
b6ad8f8790 | ||
|
|
5acc7eebc3 | ||
|
|
941927dfcd | ||
|
|
02933a9c93 | ||
|
|
e537651f29 | ||
|
|
af09fba755 | ||
|
|
04ea9018a3 | ||
|
|
ff7e1be24f | ||
|
|
fc4fd9e61c | ||
|
|
8908c7dcf9 | ||
|
|
b9996e2c1a | ||
|
|
afdc56f37c | ||
|
|
a25cd5dae8 | ||
|
|
447adb9090 | ||
|
|
92fd98d5ad | ||
|
|
c4001b4037 | ||
|
|
970a32287a | ||
|
|
17cd48dada | ||
|
|
ea3b6e955f | ||
|
|
843450bb9b | ||
|
|
e149af58b1 | ||
|
|
604a38035b | ||
|
|
cae38a365b | ||
|
|
e334246b46 | ||
|
|
36e013b40c | ||
|
|
f20cd6536e | ||
|
|
446bd35006 | ||
|
|
a377a7e315 | ||
|
|
3d046ac282 | ||
|
|
a08fa9a0e1 | ||
|
|
5856ed2836 | ||
|
|
d295355d99 | ||
|
|
77350f6119 | ||
|
|
bc2c2ebbfd | ||
|
|
1502e02a1a | ||
|
|
d0e2313a24 | ||
|
|
d8ba1a8ea7 | ||
|
|
ca7937fc4e | ||
|
|
df89bcceef | ||
|
|
cfccbe05c1 | ||
|
|
e352a6a1e7 | ||
|
|
8a3d992aaf | ||
|
|
c37f3d8d5b | ||
|
|
a96870e092 | ||
|
|
6bf1032237 | ||
|
|
3d816c747d | ||
|
|
3f2b96266b | ||
|
|
22b16d12eb | ||
|
|
c55b6f30df | ||
|
|
b7045d3d28 | ||
|
|
e31a404885 | ||
|
|
643588b71a | ||
|
|
a64c4d264d | ||
|
|
567780e188 | ||
|
|
1bc8529d83 | ||
|
|
6b480d7e87 | ||
|
|
083fd315e9 | ||
|
|
ef20e76174 | ||
|
|
8c8910808e | ||
|
|
f6ad379310 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,7 +12,6 @@ test-reports
|
||||
.pytest_cache
|
||||
venv
|
||||
*.noseids
|
||||
build
|
||||
*.egg-info
|
||||
.cache
|
||||
.mypy_cache
|
||||
|
||||
62
README.md
62
README.md
@@ -8,28 +8,43 @@
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://artifacthub.io/packages/search?repo=allegroai)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
<div align="center">
|
||||
|
||||
**v0.16 Upgrade Notice**
|
||||
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
|
||||
|
||||
</div>
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
|
||||
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
|
||||
due to Elasticsearch’s usage of the Java Security Manager.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_es7_migration.html) to migrate existing data.
|
||||
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
|
||||
|
||||
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
|
||||
(which in any case **does not permit access to data within the Elasticsearch cluster**),
|
||||
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
|
||||
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
|
||||
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
|
||||
|
||||
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
|
||||
|
||||
---
|
||||
|
||||
### ClearML Server
|
||||
## ClearML Server
|
||||
#### *Formerly known as Trains Server*
|
||||
|
||||
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
|
||||
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
|
||||
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
|
||||
|
||||
The **ClearML Server** contains the following components:
|
||||
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
|
||||
## System design
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
The **ClearML Server** has two supported configurations:
|
||||
- Single IP (domain) with the following open ports
|
||||
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch The **ClearML Server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_aws_ec2_ami.html)
|
||||
- Pre-built [GCP Custom Image](hhttps://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_gcp.html)
|
||||
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
|
||||
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [macOS](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [Windows 10](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_win.html)
|
||||
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes_helm.html)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes.html)
|
||||
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
|
||||
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
|
||||
|
||||
## Connecting ClearML to your ClearML Server
|
||||
|
||||
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
|
||||
To have the **ClearML** client use your **ClearML Server** instead:
|
||||
In order to set up the **ClearML** client to work with your **ClearML Server**:
|
||||
- Run the `clearml-init` command for an interactive setup.
|
||||
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
|
||||
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
|
||||
|
||||
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
|
||||
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
|
||||
|
||||
## Restarting ClearML Server
|
||||
|
||||
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
```
|
||||
|
||||
1. Configure the ClearML-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
|
||||
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
export CLEARML_HOST_IP=server_host_ip_here
|
||||
export CLEARML_AGENT_GIT_USER=git_username_here
|
||||
export CLEARML_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
|
||||
@@ -205,12 +219,12 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
|
||||
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
|
||||
|
||||
@@ -71,6 +71,8 @@
|
||||
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
|
||||
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
|
||||
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
|
||||
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
|
||||
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
|
||||
@@ -75,11 +75,17 @@ class CreateUserResponse(Base):
|
||||
class Credentials(Base):
|
||||
access_key = StringField(required=True)
|
||||
secret_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
last_used = DateTimeField(default=None)
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Base):
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
@@ -81,14 +81,34 @@ class LogOrderEnum(StringEnum):
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
class TaskEventsRequestBase(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
|
||||
|
||||
class TaskEventsRequest(TaskEventsRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
|
||||
scroll_id: str = StringField()
|
||||
count_total: bool = BoolField(default=True)
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField(default=5000)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField()
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
@@ -115,4 +135,5 @@ class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Sequence
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
@@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
|
||||
metadata: Sequence[MetadataItem] = ListField(
|
||||
[MetadataItem], validators=validators.Length(minimum_value=1)
|
||||
)
|
||||
replace_metadata = BoolField(default=False)
|
||||
|
||||
@@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
|
||||
19
apiserver/apimodels/pipelines.py
Normal file
19
apiserver/apimodels/pipelines.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
|
||||
|
||||
class Arg(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
value = fields.StringField(required=True)
|
||||
|
||||
|
||||
class StartPipelineRequest(models.Base):
|
||||
task = fields.StringField(required=True)
|
||||
queue = fields.StringField(required=True)
|
||||
args = ListField(Arg)
|
||||
|
||||
|
||||
class StartPipelineResponse(models.Base):
|
||||
pipeline = fields.StringField(required=True)
|
||||
enqueued = fields.BoolField(required=True)
|
||||
@@ -1,6 +1,6 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apiserver.apimodels import ListField, ActualEnumField
|
||||
from apiserver.apimodels import ListField, ActualEnumField, DictField
|
||||
from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.database.model import EntityVisibility
|
||||
|
||||
@@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetHyperParamRequest(ProjectOrNoneRequest):
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
@@ -51,8 +51,15 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
|
||||
key = fields.StringField(required=True)
|
||||
allow_public = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
include_stats_filter = DictField()
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
|
||||
@@ -2,7 +2,7 @@ from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apiserver.apimodels import ListField
|
||||
from apiserver.apimodels import ListField, DictField
|
||||
from apiserver.apimodels.metadata import (
|
||||
MetadataItem,
|
||||
DeleteMetadata,
|
||||
@@ -19,13 +19,18 @@ class CreateRequest(Base):
|
||||
name = StringField(required=True)
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class QueueRequest(Base):
|
||||
queue = StringField(required=True)
|
||||
|
||||
|
||||
class GetNextTaskRequest(QueueRequest):
|
||||
queue = StringField(required=True)
|
||||
get_task_info = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(QueueRequest):
|
||||
force = BoolField(default=False)
|
||||
|
||||
@@ -34,7 +39,7 @@ class UpdateRequest(QueueRequest):
|
||||
name = StringField()
|
||||
tags = ListField(items_types=[str])
|
||||
system_tags = ListField(items_types=[str])
|
||||
metadata = ListField(items_types=[MetadataItem])
|
||||
metadata = DictField(value_types=[MetadataItem])
|
||||
|
||||
|
||||
class TaskRequest(QueueRequest):
|
||||
|
||||
@@ -2,7 +2,11 @@ from datetime import datetime
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
||||
from apiserver.apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
CreateUserRequest,
|
||||
Credentials as CredModel,
|
||||
)
|
||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
@@ -145,7 +149,7 @@ class AuthBLL:
|
||||
|
||||
@classmethod
|
||||
def create_credentials(
|
||||
cls, user_id: str, company_id: str, role: str = None
|
||||
cls, user_id: str, company_id: str, role: str = None, label: str = None,
|
||||
) -> CredModel:
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -154,9 +158,11 @@ class AuthBLL:
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
||||
cred = CredModel(
|
||||
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||
)
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
|
||||
)
|
||||
user.save()
|
||||
|
||||
|
||||
@@ -24,13 +24,13 @@ from apiserver.bll.event.event_common import (
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from apiserver.bll.event.event_metrics import EventMetrics
|
||||
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -73,7 +73,7 @@ class EventBLL(object):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
self.events_iterator = EventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
@@ -534,6 +534,7 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
no_scroll: bool = False,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
@@ -611,7 +612,7 @@ class EventBLL(object):
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@@ -680,6 +681,7 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
no_scroll=False,
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@@ -740,7 +742,7 @@ class EventBLL(object):
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
|
||||
@@ -66,12 +66,19 @@ def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
return es.delete_by_query(
|
||||
index=es_index, body=body, conflicts="proceed", **kwargs
|
||||
)
|
||||
|
||||
|
||||
def get_metric_variants_condition(
|
||||
metric_variants: MetricVariants,
|
||||
) -> Sequence:
|
||||
def count_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.count(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
"bool": {
|
||||
|
||||
205
apiserver/bll/event/events_iterator.py
Normal file
205
apiserver/bll/event/events_iterator.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from typing import Optional, Tuple, Sequence, Any
|
||||
|
||||
import attr
|
||||
import jsonmodels.models
|
||||
import jwt
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
count_company_events,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class EventsIterator:
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_key_value: Optional[Any] = None,
|
||||
metric_variants: MetricVariants = None,
|
||||
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
|
||||
**kwargs,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
from_key_value = kwargs.pop("from_timestamp", from_key_value)
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
event_type=event_type,
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=ScalarKey.resolve(key),
|
||||
)
|
||||
return res
|
||||
|
||||
def count_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> int:
|
||||
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
es_req = {
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "count_task_events"):
|
||||
es_result = count_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
|
||||
return es_result["count"]
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
key: ScalarKey,
|
||||
from_key_value: Optional[Any],
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If from_key_field is not set then start either from latest or earliest.
|
||||
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
|
||||
so that events with this value will not be lost between the calls.
|
||||
"""
|
||||
query, must = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": query,
|
||||
"sort": {key.field: "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_key_value:
|
||||
es_req["search_after"] = [from_key_value]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": must + [{"term": {key.field: events[-1][key.field]}}]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_initial_query_and_must(
|
||||
task_id: str, metric_variants: MetricVariants = None
|
||||
) -> Tuple[dict, list]:
|
||||
if not metric_variants:
|
||||
must = [{"term": {"task": task_id}}]
|
||||
query = {"term": {"task": task_id}}
|
||||
else:
|
||||
must = [
|
||||
{"term": {"task": task_id}},
|
||||
get_metric_variants_condition(metric_variants),
|
||||
]
|
||||
query = {"bool": {"must": must}}
|
||||
return query, must
|
||||
|
||||
|
||||
class Scroll(jsonmodels.models.Base):
|
||||
def get_scroll_id(self) -> str:
|
||||
return jwt.encode(
|
||||
self.to_struct(),
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
).decode()
|
||||
|
||||
@classmethod
|
||||
def from_scroll_id(cls, scroll_id: str):
|
||||
try:
|
||||
return cls(
|
||||
**jwt.decode(
|
||||
scroll_id,
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
)
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise ValueError("Invalid Scroll ID")
|
||||
@@ -1,127 +0,0 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = EventType.task_log
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
@@ -4,6 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from typing import Any
|
||||
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
from apiserver.config_repo import config
|
||||
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
|
||||
"""
|
||||
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
|
||||
|
||||
def cast_value(self, value: Any) -> Any:
|
||||
"""Cast value to appropriate type"""
|
||||
return value
|
||||
|
||||
|
||||
class TimestampKey(ScalarKey):
|
||||
"""
|
||||
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class IterKey(ScalarKey):
|
||||
"""
|
||||
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class ISOTimeKey(ScalarKey):
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from apiserver.bll.task.utils import deleted_prefix
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task, TaskStatus
|
||||
from .metadata import Metadata
|
||||
|
||||
|
||||
class ModelBLL:
|
||||
|
||||
111
apiserver/bll/model/metadata.py
Normal file
111
apiserver/bll/model/metadata.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Sequence, Union, Mapping
|
||||
|
||||
from mongoengine import Document
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.utilities.parameter_key_escaper import (
|
||||
ParameterKeyEscaper,
|
||||
mongoengine_safe,
|
||||
)
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Metadata:
|
||||
@staticmethod
|
||||
def metadata_from_api(
|
||||
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
|
||||
) -> dict:
|
||||
if not api_data:
|
||||
return {}
|
||||
|
||||
if isinstance(api_data, dict):
|
||||
return {
|
||||
ParameterKeyEscaper.escape(k): v.to_struct()
|
||||
for k, v in api_data.items()
|
||||
}
|
||||
|
||||
return {
|
||||
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def edit_metadata(
|
||||
cls,
|
||||
obj: Document,
|
||||
items: Sequence[MetadataItem],
|
||||
replace_metadata: bool,
|
||||
**more_updates,
|
||||
) -> int:
|
||||
with TimingContext("mongo", "edit_metadata"):
|
||||
update_cmds = dict()
|
||||
metadata = cls.metadata_from_api(items)
|
||||
if replace_metadata:
|
||||
update_cmds["set__metadata"] = metadata
|
||||
else:
|
||||
for key, value in metadata.items():
|
||||
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
|
||||
|
||||
return obj.update(**update_cmds, **more_updates)
|
||||
|
||||
@classmethod
|
||||
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
|
||||
with TimingContext("mongo", "delete_metadata"):
|
||||
return obj.update(
|
||||
**{
|
||||
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
|
||||
for key in set(keys)
|
||||
},
|
||||
**more_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _process_path(path: str):
|
||||
"""
|
||||
Frontend does a partial escaping on the path so the all '.' in key names are escaped
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
raise errors.bad_request.ValidationError("invalid field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
|
||||
for prefix in (
|
||||
"metadata.",
|
||||
"-metadata.",
|
||||
):
|
||||
paths = [
|
||||
cls._process_path(path) if path.startswith(prefix) else path
|
||||
for path in paths
|
||||
]
|
||||
return paths
|
||||
|
||||
@classmethod
|
||||
def escape_query_parameters(cls, call: APICall) -> dict:
|
||||
if not call.data:
|
||||
return call.data
|
||||
|
||||
keys = list(call.data)
|
||||
call_data = {
|
||||
safe_key: call.data[key]
|
||||
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
|
||||
}
|
||||
|
||||
projection = GetMixin.get_projection(call_data)
|
||||
if projection:
|
||||
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
|
||||
|
||||
ordering = GetMixin.get_ordering(call_data)
|
||||
if ordering:
|
||||
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
|
||||
|
||||
return call_data
|
||||
@@ -1,2 +1,3 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .project_queries import ProjectQueries
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
@@ -14,6 +14,7 @@ from typing import (
|
||||
TypeVar,
|
||||
Callable,
|
||||
Mapping,
|
||||
Any,
|
||||
)
|
||||
|
||||
from mongoengine import Q, Document
|
||||
@@ -22,6 +23,7 @@ from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model import EntityVisibility, AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
|
||||
@@ -57,10 +59,14 @@ class ProjectBLL:
|
||||
with TimingContext("mongo", "move_project"):
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
parent=source_id
|
||||
source=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
destination = Project.get(company, destination_id)
|
||||
if source_id in destination.path:
|
||||
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
|
||||
source=source_id, destination=destination_id
|
||||
)
|
||||
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
@@ -140,7 +146,14 @@ class ProjectBLL:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
|
||||
if (
|
||||
new_parent
|
||||
and project_id == new_parent.id
|
||||
or project_id in new_parent.path
|
||||
):
|
||||
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
|
||||
project=project_id, parent=new_parent.id
|
||||
)
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
@@ -193,6 +206,7 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a new project.
|
||||
@@ -215,7 +229,12 @@ class ProjectBLL:
|
||||
created=now,
|
||||
last_update=now,
|
||||
)
|
||||
parent = _ensure_project(company=company, user=user, name=location)
|
||||
parent = _ensure_project(
|
||||
company=company,
|
||||
user=user,
|
||||
name=location,
|
||||
creation_params=parent_creation_params,
|
||||
)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
@@ -233,13 +252,14 @@ class ProjectBLL:
|
||||
tags: Sequence[str] = None,
|
||||
system_tags: Sequence[str] = None,
|
||||
default_output_destination: str = None,
|
||||
parent_creation_params: dict = None,
|
||||
) -> str:
|
||||
"""
|
||||
Find a project named `project_name` or create a new one.
|
||||
Returns project ID
|
||||
"""
|
||||
if not project_id and not project_name:
|
||||
raise ValueError("project id or name required")
|
||||
raise errors.bad_request.ValidationError("project id or name required")
|
||||
|
||||
if project_id:
|
||||
project = Project.objects(company=company, id=project_id).only("id").first()
|
||||
@@ -260,6 +280,7 @@ class ProjectBLL:
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
default_output_destination=default_output_destination,
|
||||
parent_creation_params=parent_creation_params,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -295,6 +316,7 @@ class ProjectBLL:
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
|
||||
|
||||
@classmethod
|
||||
def make_projects_get_all_pipelines(
|
||||
@@ -302,6 +324,7 @@ class ProjectBLL:
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
) -> Tuple[Sequence, Sequence]:
|
||||
archived = EntityVisibility.archived.value
|
||||
|
||||
@@ -325,10 +348,9 @@ class ProjectBLL:
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company_id, project_ids=project_ids, filter_=filter_
|
||||
)
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
{
|
||||
@@ -356,6 +378,37 @@ class ProjectBLL:
|
||||
},
|
||||
]
|
||||
|
||||
def completed_after_subquery(additional_cond, time_thresh: datetime):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed after the time_thresh
|
||||
"if": {
|
||||
"$and": [
|
||||
"$completed",
|
||||
{"$gt": ["$completed", time_thresh]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
"then": 1,
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def max_started_subquery(condition):
|
||||
return {
|
||||
"$max": {
|
||||
"$cond": {
|
||||
"if": condition,
|
||||
"then": "$started",
|
||||
"else": datetime.min,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
@@ -386,24 +439,36 @@ class ProjectBLL:
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
time_thresh = datetime.utcnow() - timedelta(hours=24)
|
||||
for state in cls.visibility_states:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery(
|
||||
{"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
|
||||
cond = (
|
||||
cls.archived_tasks_cond
|
||||
if state == EntityVisibility.archived
|
||||
else {"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
group_step[state.value] = runtime_subquery(cond)
|
||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||
cond, time_thresh=time_thresh
|
||||
)
|
||||
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
|
||||
|
||||
def get_state_filter() -> dict:
|
||||
if not specific_state:
|
||||
return {}
|
||||
if specific_state == EntityVisibility.archived:
|
||||
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
|
||||
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {"$in": project_ids},
|
||||
**cls.get_match_conditions(
|
||||
company=company_id, project_ids=project_ids, filter_=filter_
|
||||
),
|
||||
**get_state_filter(),
|
||||
}
|
||||
},
|
||||
ensure_valid_fields(),
|
||||
@@ -445,11 +510,17 @@ class ProjectBLL:
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
include_children: bool = True,
|
||||
filter_: Mapping[str, Any] = None,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
|
||||
child_projects = (
|
||||
_get_sub_projects(project_ids, _only=("id", "name"))
|
||||
if include_children
|
||||
else {}
|
||||
)
|
||||
project_ids_with_children = set(project_ids) | {
|
||||
c.id for c in itertools.chain.from_iterable(child_projects.values())
|
||||
}
|
||||
@@ -457,6 +528,7 @@ class ProjectBLL:
|
||||
company,
|
||||
project_ids=list(project_ids_with_children),
|
||||
specific_state=specific_state,
|
||||
filter_=filter_,
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
@@ -483,8 +555,8 @@ class ProjectBLL:
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: {
|
||||
status: nested_get(a, (section, status), 0)
|
||||
+ nested_get(b, (section, status), 0)
|
||||
status: nested_get(a, (section, status), default=0)
|
||||
+ nested_get(b, (section, status), default=0)
|
||||
for status in set(a.get(section, {})) | set(b.get(section, {}))
|
||||
}
|
||||
for section in set(a) | set(b)
|
||||
@@ -507,6 +579,8 @@ class ProjectBLL:
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: a.get(section, 0) + b.get(section, 0)
|
||||
if not section.endswith("max_task_started")
|
||||
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
|
||||
for section in set(a) | set(b)
|
||||
}
|
||||
|
||||
@@ -518,15 +592,30 @@ class ProjectBLL:
|
||||
)
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
project_runtime = runtime.get(project_id, {})
|
||||
project_section_statuses = nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
)
|
||||
|
||||
def get_time_or_none(value):
|
||||
return value if value != datetime.min else None
|
||||
|
||||
return {
|
||||
"total_runtime": nested_get(runtime, (project_id, section), 0),
|
||||
"status_count": nested_get(
|
||||
status_count, (project_id, section), default_counts
|
||||
"status_count": project_section_statuses,
|
||||
"total_tasks": sum(project_section_statuses.values()),
|
||||
"total_runtime": project_runtime.get(section, 0),
|
||||
"completed_tasks_24h": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
"last_task_run": get_time_or_none(
|
||||
project_runtime.get(f"{section}_max_task_started", datetime.min)
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
s
|
||||
for s in cls.visibility_states
|
||||
if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
stats = {
|
||||
@@ -575,6 +664,30 @@ class ProjectBLL:
|
||||
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_project_tags(
|
||||
cls,
|
||||
company_id: str,
|
||||
include_system: bool,
|
||||
projects: Sequence[str] = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> Tuple[Sequence[str], Sequence[str]]:
|
||||
with TimingContext("mongo", "get_tags_from_db"):
|
||||
query = Q(company=company_id)
|
||||
if filter_:
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
|
||||
if projects:
|
||||
query &= Q(id__in=_ids_with_children(projects))
|
||||
|
||||
tags = Project.objects(query).distinct("tags")
|
||||
system_tags = (
|
||||
Project.objects(query).distinct("system_tags") if include_system else []
|
||||
)
|
||||
return tags, system_tags
|
||||
|
||||
@classmethod
|
||||
def get_projects_with_active_user(
|
||||
cls,
|
||||
@@ -631,6 +744,7 @@ class ProjectBLL:
|
||||
if include_subprojects:
|
||||
projects = _ids_with_children(projects)
|
||||
query &= Q(project__in=projects)
|
||||
|
||||
if state == EntityVisibility.archived:
|
||||
query &= Q(system_tags__in=[EntityVisibility.archived.value])
|
||||
elif state == EntityVisibility.active:
|
||||
@@ -658,6 +772,7 @@ class ProjectBLL:
|
||||
if project_ids:
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
query &= Q(project__in=project_ids)
|
||||
|
||||
res = Task.objects(query).distinct(field="type")
|
||||
return set(res).intersection(external_task_types)
|
||||
|
||||
@@ -673,8 +788,35 @@ class ProjectBLL:
|
||||
query &= Q(project__in=project_ids)
|
||||
return Model.objects(query).distinct(field="framework")
|
||||
|
||||
@staticmethod
|
||||
def get_match_conditions(
|
||||
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
|
||||
):
|
||||
conditions = {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
if not filter_:
|
||||
return conditions
|
||||
|
||||
for field in ("tags", "system_tags"):
|
||||
field_filter = filter_.get(field)
|
||||
if not field_filter:
|
||||
continue
|
||||
if not isinstance(field_filter, list) or not all(
|
||||
isinstance(t, str) for t in field_filter
|
||||
):
|
||||
raise errors.bad_request.ValidationError(
|
||||
f"List of strings expected for the field: {field}"
|
||||
)
|
||||
conditions[field] = {"$in": field_filter}
|
||||
|
||||
return conditions
|
||||
|
||||
@classmethod
|
||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
def calc_own_contents(
|
||||
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
|
||||
) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/models per requested project
|
||||
Use separate aggregation calls on Task/Model instead of lookup
|
||||
@@ -685,35 +827,21 @@ class ProjectBLL:
|
||||
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
"$match": cls.get_match_conditions(
|
||||
company=company, project_ids=project_ids, filter_=filter_
|
||||
)
|
||||
},
|
||||
{
|
||||
"$project": {"project": 1}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$project",
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
}
|
||||
{"$project": {"project": 1}},
|
||||
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
|
||||
]
|
||||
|
||||
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
|
||||
return {
|
||||
data["_id"]: data["count"]
|
||||
for data in cls_.aggregate(pipeline)
|
||||
}
|
||||
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
|
||||
|
||||
with TimingContext("mongo", "get_security_groups"):
|
||||
tasks = get_agrregate_res(Task)
|
||||
models = get_agrregate_res(Model)
|
||||
return {
|
||||
pid: {
|
||||
"own_tasks": tasks.get(pid, 0),
|
||||
"own_models": models.get(pid, 0),
|
||||
}
|
||||
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
|
||||
for pid in project_ids
|
||||
}
|
||||
|
||||
370
apiserver/bll/project/project_queries.py
Normal file
370
apiserver/bll/project/project_queries.py
Normal file
@@ -0,0 +1,370 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectQueries:
|
||||
def __init__(self, redis=None):
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
"""
|
||||
If passed projects is None means top level projects
|
||||
If passed projects is empty means no project filtering
|
||||
"""
|
||||
if include_subprojects:
|
||||
if not project_ids:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
if project_ids is None:
|
||||
project_ids = [None]
|
||||
if not project_ids:
|
||||
return {}
|
||||
|
||||
return {"project": {"$in": project_ids}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
if allow_public:
|
||||
return {"company": {"$in": [None, "", company_id]}}
|
||||
|
||||
return {"company": company_id}
|
||||
|
||||
@classmethod
|
||||
def get_aggregated_project_parameters(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "section"))
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "name"))
|
||||
),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
ParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_param_values(
|
||||
self, key: str, last_update: datetime, allowed_delta_sec=0
|
||||
) -> Optional[ParamValues]:
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving params cached values: {str(ex)}")
|
||||
|
||||
def get_task_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key,
|
||||
last_update=last_update,
|
||||
allowed_delta_sec=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
),
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def get_unique_metric_variants(
|
||||
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@classmethod
|
||||
def get_model_metadata_keys(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"metadata": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
|
||||
{"$unwind": "$metadata"},
|
||||
{"$group": {"_id": "$metadata.k"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
ParameterKeyEscaper.unescape(r.get("_id"))
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
def get_model_metadata_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
key: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> ParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
|
||||
last_updated_model = (
|
||||
Model.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_model:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
|
||||
last_update = last_updated_model.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_param_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.models.metadata_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
@@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
|
||||
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
|
||||
|
||||
|
||||
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
def _ensure_project(
|
||||
company: str, user: str, name: str, creation_params: dict = None
|
||||
) -> Optional[Project]:
|
||||
"""
|
||||
Makes sure that the project with the given name exists
|
||||
If needed auto-create the project and all the missing projects in the path to it
|
||||
@@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
|
||||
created=now,
|
||||
last_update=now,
|
||||
name=name,
|
||||
description="",
|
||||
**(creation_params or dict(description="")),
|
||||
)
|
||||
parent = _ensure_project(company, user, location)
|
||||
parent = _ensure_project(company, user, location, creation_params=creation_params)
|
||||
_save_under_parent(project=project, parent=parent)
|
||||
if parent:
|
||||
parent.update(last_update=now)
|
||||
|
||||
@@ -32,7 +32,7 @@ class QueueBLL(object):
|
||||
name: str,
|
||||
tags: Optional[Sequence[str]] = None,
|
||||
system_tags: Optional[Sequence[str]] = None,
|
||||
metadata: Optional[Sequence[dict]] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> Queue:
|
||||
"""Creates a queue"""
|
||||
with translate_errors_context():
|
||||
@@ -126,14 +126,27 @@ class QueueBLL(object):
|
||||
)
|
||||
queue.delete()
|
||||
|
||||
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
return Queue.get_many(
|
||||
company=company_id, parameters=query_dict, query_dict=query_dict
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_queue_infos(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
"""
|
||||
@@ -143,6 +156,7 @@ class QueueBLL(object):
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
override_projection=projection,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
queue_workers = defaultdict(list)
|
||||
@@ -173,13 +187,15 @@ class QueueBLL(object):
|
||||
if any(e.task == task_id for e in queue.entries):
|
||||
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
|
||||
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
entry = Entry(added=datetime.utcnow(), task=task_id)
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
|
||||
push__entries=entry, last_update=datetime.utcnow(), upsert=False
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
|
||||
task=task_id, **query
|
||||
@@ -219,7 +235,6 @@ class QueueBLL(object):
|
||||
queue = self.get_queue_with_task(
|
||||
company_id=company_id, queue_id=queue_id, task_id=task_id
|
||||
)
|
||||
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
|
||||
|
||||
entries_to_remove = [e for e in queue.entries if e.task == task_id]
|
||||
query = dict(id=queue_id, company=company_id)
|
||||
@@ -227,6 +242,9 @@ class QueueBLL(object):
|
||||
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
|
||||
)
|
||||
|
||||
queue.reload()
|
||||
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
|
||||
|
||||
return len(entries_to_remove) if res else 0
|
||||
|
||||
def reposition_task(
|
||||
|
||||
@@ -49,6 +49,21 @@ class RedisCacheManager(Generic[T]):
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
def get_or_create_state_core(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
) -> T:
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
return state
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
@@ -66,12 +81,9 @@ class RedisCacheManager(Generic[T]):
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
state = self.get_or_create_state_core(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
)
|
||||
|
||||
try:
|
||||
yield state
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
@@ -14,7 +11,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
@@ -350,54 +346,6 @@ class TaskBLL:
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(
|
||||
company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company={"$in": [None, "", company_id]}, **project_constraint,
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str],
|
||||
@@ -494,173 +442,6 @@ class TaskBLL:
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**project_constraint,
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
if allow_public:
|
||||
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
||||
else:
|
||||
company_constraint = {"company": company_id}
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
|
||||
@@ -162,6 +162,8 @@ def delete_task(
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
@@ -179,6 +181,17 @@ def delete_task(
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
cleanup_res = cleanup_task(
|
||||
task,
|
||||
force=force,
|
||||
@@ -354,6 +367,7 @@ def stop_task(
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
@@ -113,7 +113,7 @@ class WorkerBLL:
|
||||
res = self.redis.delete(
|
||||
company_id, self._get_worker_key(company_id, user_id, worker)
|
||||
)
|
||||
if not res:
|
||||
if not res and not config.get("apiserver.workers.auto_unregister", False):
|
||||
raise bad_request.WorkerNotRegistered(worker=worker)
|
||||
|
||||
def status_report(
|
||||
@@ -258,7 +258,7 @@ class WorkerBLL:
|
||||
tasks_info = {
|
||||
task.id: task
|
||||
for task in Task.objects(id__in=task_ids).only(
|
||||
"name", "started", "last_iteration"
|
||||
"name", "started", "last_iteration", "active_duration"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -283,11 +283,7 @@ class WorkerBLL:
|
||||
if helper.task_id:
|
||||
task = tasks_info.get(helper.task_id, None)
|
||||
if task:
|
||||
worker.task.running_time = (
|
||||
int((datetime.utcnow() - task.started).total_seconds() * 1000)
|
||||
if task.started
|
||||
else 0
|
||||
)
|
||||
worker.task.running_time = (task.active_duration or 0) * 1000
|
||||
worker.task.last_iteration = task.last_iteration
|
||||
|
||||
update_queue_entries(worker.queue)
|
||||
|
||||
@@ -79,6 +79,8 @@ class BasicConfig:
|
||||
def logger(self, name: str) -> logging.Logger:
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
if name == "__init__" and Path(name).parent.stem:
|
||||
name = Path(name).parent.stem
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
|
||||
@@ -79,6 +79,11 @@
|
||||
max_age: 99999999999
|
||||
}
|
||||
|
||||
# provide a cookie domain override per company
|
||||
# cookies_domain_override {
|
||||
# <company-id>: <domain>
|
||||
# }
|
||||
|
||||
# # A list of fixed users
|
||||
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
|
||||
# fixed_users {
|
||||
@@ -107,6 +112,8 @@
|
||||
workers {
|
||||
# Auto-register unknown workers on status reports and other calls
|
||||
auto_register: true
|
||||
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
|
||||
auto_unregister: true
|
||||
# Timeout in seconds on task status update. If exceeded
|
||||
# then task can be stopped without communicating to the worker
|
||||
task_update_timeout: 600
|
||||
|
||||
4
apiserver/config/default/services/_mongo.conf
Normal file
4
apiserver/config/default/services/_mongo.conf
Normal file
@@ -0,0 +1,4 @@
|
||||
max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
||||
@@ -17,6 +17,10 @@ events_retrieval {
|
||||
|
||||
# the max amount of variants to aggregate on
|
||||
max_variants_count: 100
|
||||
|
||||
max_raw_scalars_size: 200000
|
||||
|
||||
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
|
||||
}
|
||||
|
||||
# if set then plot str will be checked for the valid json on plot add
|
||||
|
||||
7
apiserver/config/default/services/models.conf
Normal file
7
apiserver/config/default/services/models.conf
Normal file
@@ -0,0 +1,7 @@
|
||||
metadata_values {
|
||||
# maximal amount of distinct model values to retrieve
|
||||
max_count: 100
|
||||
|
||||
# cache ttl sec
|
||||
cache_ttl_sec: 86400
|
||||
}
|
||||
@@ -28,6 +28,8 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"MONGODB_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
@@ -47,13 +49,17 @@ class DatabaseFactory:
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
if override_connection_string:
|
||||
log.info(f"Using override mongodb connection string {override_connection_string}")
|
||||
else:
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
@@ -62,11 +68,13 @@ class DatabaseFactory:
|
||||
|
||||
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_connection_string:
|
||||
entry.host = override_connection_string
|
||||
else:
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
|
||||
@@ -48,7 +48,9 @@ class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
label = StringField()
|
||||
last_used = DateTimeField()
|
||||
last_used_from = StringField()
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
|
||||
@@ -1,26 +1,41 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
|
||||
from functools import reduce, partial
|
||||
from typing import (
|
||||
Collection,
|
||||
Sequence,
|
||||
Union,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
Mapping,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from boltons.iterutils import first, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.errors import MakeGetAllQueryError
|
||||
from apiserver.database.projection import project_dict, ProjectionHelper
|
||||
from apiserver.database.props import PropsMixin
|
||||
from apiserver.database.query import RegexQ, RegexWrapper
|
||||
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
|
||||
from apiserver.database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
get_fields_choices,
|
||||
field_does_not_exist,
|
||||
field_exists,
|
||||
)
|
||||
from apiserver.redis_manager import redman
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
@@ -70,6 +85,9 @@ class GetMixin(PropsMixin):
|
||||
_ordering_key = "order_by"
|
||||
_search_text_key = "search_text"
|
||||
|
||||
_start_key = "start"
|
||||
_size_key = "size"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
@@ -77,6 +95,7 @@ class GetMixin(PropsMixin):
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {}
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
@@ -103,46 +122,106 @@ class GetMixin(PropsMixin):
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_mongo_op = "nin"
|
||||
|
||||
_default = "in"
|
||||
default_mongo_op = "in"
|
||||
_ops = {
|
||||
# op -> (mongo_op, sticky)
|
||||
"not": ("nin", False),
|
||||
"nop": (default_mongo_op, False),
|
||||
"all": ("all", True),
|
||||
"and": ("all", True),
|
||||
"any": (default_mongo_op, True),
|
||||
"or": (default_mongo_op, True),
|
||||
}
|
||||
_next = _default
|
||||
_sticky = False
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
self._current_op = None
|
||||
self._sticky = False
|
||||
self._support_legacy = legacy
|
||||
self.allow_empty = False
|
||||
|
||||
def key(self, v) -> Optional[str]:
|
||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||
op = (
|
||||
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
return tup[0] if tup else None
|
||||
return op
|
||||
|
||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"][0]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next, self._sticky = self._ops.get(
|
||||
v[len(self.op_prefix) :], (self._default, self._sticky)
|
||||
)
|
||||
self.allow_empty = True
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
if not self._sticky:
|
||||
self._next = self._default
|
||||
op = self._get_op(v)
|
||||
if op is not None:
|
||||
# operator - set state and return None
|
||||
self._current_op, self._sticky = self._ops.get(
|
||||
op, (self.default_mongo_op, self._sticky)
|
||||
)
|
||||
return None
|
||||
elif self._current_op:
|
||||
current_op = self._current_op
|
||||
if not self._sticky:
|
||||
self._current_op = None
|
||||
return current_op
|
||||
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
|
||||
self._current_op = None
|
||||
return False
|
||||
|
||||
return next_
|
||||
return self.default_mongo_op
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
def get_global_op(self, data: Sequence[str]) -> int:
|
||||
op_to_res = {
|
||||
"in": Q.OR,
|
||||
"all": Q.AND,
|
||||
}
|
||||
data = (x for x in data if x is not None)
|
||||
first_op = (
|
||||
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
||||
)
|
||||
return op_to_res.get(first_op, self.default_mongo_op)
|
||||
|
||||
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
||||
actions = {}
|
||||
|
||||
for val in data:
|
||||
key = self._key(val)
|
||||
if key is None:
|
||||
continue
|
||||
elif self._support_legacy and key is False:
|
||||
key = self._legacy_exclude_mongo_op
|
||||
val = val[len(self._legacy_exclude_prefix) :]
|
||||
actions.setdefault(key, []).append(val)
|
||||
|
||||
return actions
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
class GetManyScrollState(ProperDictMixin, Document):
|
||||
meta = {"db_alias": Database.backend, "strict": False}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
position = IntField(default=0)
|
||||
|
||||
_cache_manager = None
|
||||
|
||||
@classmethod
|
||||
def get_cache_manager(cls):
|
||||
if not cls._cache_manager:
|
||||
cls._cache_manager = RedisCacheManager(
|
||||
state_class=cls.GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
return cls._cache_manager
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Union["GetMixin", Document],
|
||||
@@ -241,7 +320,9 @@ class GetMixin(PropsMixin):
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
||||
|
||||
IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
|
||||
used instead of Q. Otherwise we can and up with some combination that is not processed according to
|
||||
RegexQ rules
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
@@ -278,7 +359,7 @@ class GetMixin(PropsMixin):
|
||||
patterns=opts.fields or [], parameters=parameters
|
||||
).items():
|
||||
if "._" in field or "_." in field:
|
||||
query &= Q(__raw__={field: data})
|
||||
query &= RegexQ(__raw__={field: data})
|
||||
else:
|
||||
dict_query[field.replace(".", "__")] = data
|
||||
|
||||
@@ -312,22 +393,31 @@ class GetMixin(PropsMixin):
|
||||
break
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
),
|
||||
),
|
||||
data.fields,
|
||||
Q()
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Return a range query for the provided field. The data should contain min and max values
|
||||
Both intervals are included. For open range queries either min or max can be None
|
||||
@@ -351,14 +441,14 @@ class GetMixin(PropsMixin):
|
||||
if max_val is not None:
|
||||
query[f"{mongoengine_field}__lte"] = max_val
|
||||
|
||||
q = Q(**query)
|
||||
q = RegexQ(**query)
|
||||
if min_val is None:
|
||||
q |= Q(**{mongoengine_field: None})
|
||||
q |= RegexQ(**{mongoengine_field: None})
|
||||
|
||||
return q
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
@@ -370,30 +460,31 @@ class GetMixin(PropsMixin):
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
data = [data]
|
||||
# raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
global_op = helper.get_global_op(data)
|
||||
actions = helper.get_actions(data)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
|
||||
)
|
||||
queries = [
|
||||
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
|
||||
for action in filter(None, actions)
|
||||
]
|
||||
|
||||
if not allow_empty:
|
||||
if not queries:
|
||||
q = RegexQ()
|
||||
else:
|
||||
q = RegexQCombination(operation=global_op, children=queries)
|
||||
|
||||
if not helper.allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
| RegexQ(**{f"{mongoengine_field}__exists": False})
|
||||
| RegexQ(**{mongoengine_field: []})
|
||||
| RegexQ(**{mongoengine_field: None})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -421,27 +512,41 @@ class GetMixin(PropsMixin):
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
|
||||
"""
|
||||
Validate and extract paging info from from the provided dictionary. Supports default values.
|
||||
If page is specified then it should be non-negative, if page size is specified then it should be positive
|
||||
If page size is specified and page is not then 0 page is assumed
|
||||
If page is specified then page size should be specified too
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
start = parameters.get(cls._start_key)
|
||||
if start is not None:
|
||||
return start, cls.validate_scroll_size(parameters)
|
||||
|
||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
||||
page = parameters.get("page", default_page)
|
||||
if page is not None and page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
|
||||
page_size = parameters.get("page_size", default_page_size or max_page_size)
|
||||
if page_size is not None and page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
if page_size is not None:
|
||||
page = page or 0
|
||||
page_size = min(page_size, max_page_size)
|
||||
return page * page_size, page_size
|
||||
|
||||
if page is not None:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
@@ -485,6 +590,54 @@ class GetMixin(PropsMixin):
|
||||
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
||||
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
||||
|
||||
@classmethod
|
||||
def validate_scroll_size(cls, query_dict: dict) -> int:
|
||||
size = query_dict.get(cls._size_key)
|
||||
if not size or not isinstance(size, int) or size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Integer size parameter greater than 1 should be provided when working with scroll"
|
||||
)
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_data_with_scroll_support(
|
||||
cls,
|
||||
query_dict: dict,
|
||||
data_getter: Callable[[], Sequence[dict]],
|
||||
ret_params: dict,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Retrieves the data by calling the provided data_getter api
|
||||
If scroll parameters are specified then put the query_dict 'start' parameter to the last
|
||||
scroll position and continue retrievals from that position
|
||||
If refresh_scroll is requested then bring once more the data from the beginning
|
||||
till the current scroll position
|
||||
In the end the scroll position is updated and accumulated frames are returned
|
||||
"""
|
||||
query_dict = query_dict or {}
|
||||
state: Optional[cls.GetManyScrollState] = None
|
||||
if "scroll_id" in query_dict:
|
||||
size = cls.validate_scroll_size(query_dict)
|
||||
state = cls.get_cache_manager().get_or_create_state_core(
|
||||
query_dict.get("scroll_id")
|
||||
)
|
||||
if query_dict.get("refresh_scroll"):
|
||||
query_dict[cls._size_key] = max(state.position, size)
|
||||
state.position = 0
|
||||
query_dict[cls._start_key] = state.position
|
||||
|
||||
data = data_getter()
|
||||
if cls._start_key in query_dict:
|
||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||
|
||||
if state:
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.get_cache_manager().set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
@@ -495,6 +648,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
@@ -530,6 +684,7 @@ class GetMixin(PropsMixin):
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
@@ -560,6 +715,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
@@ -605,12 +761,16 @@ class GetMixin(PropsMixin):
|
||||
_query = (q & query) if query else q
|
||||
|
||||
if return_dicts:
|
||||
return cls._get_many_override_none_ordering(
|
||||
data_getter = partial(
|
||||
cls._get_many_override_none_ordering,
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
return cls.get_data_with_scroll_support(
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
@@ -662,7 +822,7 @@ class GetMixin(PropsMixin):
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
if order_by and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_by[0])
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -683,9 +843,9 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
if start is not None and size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
qs = qs.skip(start).limit(size)
|
||||
|
||||
return qs
|
||||
|
||||
@@ -746,7 +906,10 @@ class GetMixin(PropsMixin):
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
if size is not None and size <= 0:
|
||||
return []
|
||||
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -778,25 +941,28 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
if start is None or not size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
last_set = len(query_sets) - 1
|
||||
for i, qs in enumerate(query_sets):
|
||||
last_size = len(ret)
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
for obj in (qs.skip(start) if start else qs).limit(size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
added = len(ret) - last_size
|
||||
|
||||
if added > 0:
|
||||
start = 0
|
||||
size = max(0, size - added)
|
||||
elif i != last_set:
|
||||
start -= min(start, qs.count())
|
||||
|
||||
if size <= 0:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
@@ -13,18 +10,21 @@ from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
class Model(AttributedDocument):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -33,8 +33,6 @@ class Model(DbModelMixin, Document):
|
||||
"project",
|
||||
"task",
|
||||
"last_update",
|
||||
"metadata.key",
|
||||
"metadata.type",
|
||||
("company", "framework"),
|
||||
("company", "name"),
|
||||
("company", "user"),
|
||||
@@ -66,6 +64,7 @@ class Model(DbModelMixin, Document):
|
||||
"project",
|
||||
"task",
|
||||
"parent",
|
||||
"metadata.*",
|
||||
),
|
||||
datetime_fields=("last_update",),
|
||||
)
|
||||
@@ -73,8 +72,6 @@ class Model(DbModelMixin, Document):
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
@@ -91,6 +88,6 @@ class Model(DbModelMixin, Document):
|
||||
default=dict, user_set_allowed=True, exclude_by_default=True
|
||||
)
|
||||
company_origin = StringField(exclude_by_default=True)
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ class Project(AttributedDocument):
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
range_fields=("last_update",),
|
||||
)
|
||||
|
||||
meta = {
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
EmbeddedDocument,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocumentField,
|
||||
)
|
||||
|
||||
from apiserver.database import Database, strict
|
||||
from apiserver.database.fields import StrippedStringField, SafeSortedListField
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.fields import (
|
||||
StrippedStringField,
|
||||
SafeSortedListField,
|
||||
SafeMapField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin, AttributedDocument
|
||||
from apiserver.database.model.base import ProperDictMixin, GetMixin
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
@@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
|
||||
|
||||
class Entry(EmbeddedDocument, ProperDictMixin):
|
||||
""" Entry representing a task waiting in the queue """
|
||||
|
||||
task = StringField(required=True, reference_field=Task)
|
||||
''' Task ID '''
|
||||
""" Task ID """
|
||||
added = DateTimeField(required=True)
|
||||
''' Added to the queue '''
|
||||
""" Added to the queue """
|
||||
|
||||
|
||||
class Queue(DbModelMixin, Document):
|
||||
_field_collation_overrides = {
|
||||
"metadata.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "system_tags", "id"),
|
||||
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
|
||||
)
|
||||
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
"indexes": ["metadata.key", "metadata.type"],
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
@@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
|
||||
)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
created = DateTimeField(required=True)
|
||||
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
|
||||
tags = SafeSortedListField(
|
||||
StringField(required=True), default=list, user_set_allowed=True
|
||||
)
|
||||
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
|
||||
entries = EmbeddedDocumentListField(Entry, default=list)
|
||||
last_update = DateTimeField()
|
||||
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
|
||||
MetadataItem, default=list, user_set_allowed=True
|
||||
metadata = SafeMapField(
|
||||
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
|
||||
)
|
||||
|
||||
@@ -159,11 +159,10 @@ external_task_types = set(get_options(TaskType))
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
|
||||
_field_collation_overrides = {
|
||||
"execution.parameters.": _numeric_locale,
|
||||
"last_metrics.": _numeric_locale,
|
||||
"hyperparams.": _numeric_locale,
|
||||
"execution.parameters.": AttributedDocument._numeric_locale,
|
||||
"last_metrics.": AttributedDocument._numeric_locale,
|
||||
"hyperparams.": AttributedDocument._numeric_locale,
|
||||
}
|
||||
|
||||
meta = {
|
||||
@@ -184,7 +183,10 @@ class Task(AttributedDocument):
|
||||
("company", "type", "system_tags", "status"),
|
||||
("company", "project", "type", "system_tags", "status"),
|
||||
("status", "last_update"), # for maintenance tasks
|
||||
{"fields": ["company", "project"], "collation": _numeric_locale},
|
||||
{
|
||||
"fields": ["company", "project"],
|
||||
"collation": AttributedDocument._numeric_locale,
|
||||
},
|
||||
{
|
||||
"name": "%s.task.main_text_index" % Database.backend,
|
||||
"fields": [
|
||||
|
||||
@@ -4,7 +4,12 @@ from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocument,
|
||||
Document,
|
||||
)
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from apiserver.database.fields import (
|
||||
@@ -25,6 +30,13 @@ class PropsMixin(object):
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
_document_classes = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, (Document, EmbeddedDocument)):
|
||||
PropsMixin._document_classes[cls._class_name] = cls
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
@@ -57,8 +69,14 @@ class PropsMixin(object):
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
|
||||
if v == "self":
|
||||
return cls_.owner_document
|
||||
|
||||
doc_cls = PropsMixin._document_classes.get(v)
|
||||
if doc_cls:
|
||||
return doc_cls
|
||||
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
@@ -72,7 +90,7 @@ class PropsMixin(object):
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
".".join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
@@ -80,10 +98,10 @@ class PropsMixin(object):
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
|
||||
|
||||
return fields
|
||||
|
||||
@@ -94,7 +112,7 @@ class PropsMixin(object):
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
@@ -103,7 +121,7 @@ class PropsMixin(object):
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
raise ValueError("Invalid field path %s" % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
@@ -119,7 +137,7 @@ class PropsMixin(object):
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
translated_parts.append("*")
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
@@ -128,7 +146,7 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
fields = cls._get_fields_with_attr(cls, "reference_field")
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@@ -143,12 +161,12 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
def get_dpath_translated_path(cls, path, separator="."):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
|
||||
@@ -5,7 +5,7 @@ Apply elasticsearch mappings to given hosts.
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
@@ -13,7 +13,7 @@ HERE = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
@@ -30,7 +30,7 @@ def apply_mappings_to_cluster(
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
|
||||
@@ -82,7 +82,11 @@ def check_elastic_empty() -> bool:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
es = Elasticsearch(
|
||||
hosts=cluster_conf.get("hosts", None),
|
||||
http_auth=es_factory.get_credentials("events", cluster_conf),
|
||||
**cluster_conf.get("args", {})
|
||||
)
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
@@ -109,5 +113,7 @@ def init_es_data():
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
http_auth = es_factory.get_credentials(name)
|
||||
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
|
||||
log.info(res)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch, Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
@@ -21,6 +22,10 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"ELASTIC_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_USERNAME_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_USERNAME",)
|
||||
|
||||
OVERRIDE_PASSWORD_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_PASSWORD",)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
log.info(f"Using override elastic host {OVERRIDE_HOST}")
|
||||
@@ -29,6 +34,14 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override elastic port {OVERRIDE_PORT}")
|
||||
|
||||
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
|
||||
if OVERRIDE_USERNAME:
|
||||
log.info(f"Using override elastic username {OVERRIDE_USERNAME}")
|
||||
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
if OVERRIDE_PASSWORD:
|
||||
log.info("Using override elastic password ********")
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
@@ -48,6 +61,10 @@ class InvalidClusterConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingPasswordForElasticUser(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ESFactory:
|
||||
@classmethod
|
||||
def connect(cls, cluster_name):
|
||||
@@ -65,22 +82,45 @@ class ESFactory:
|
||||
if not hosts:
|
||||
raise InvalidClusterConfiguration(cluster_name)
|
||||
|
||||
http_auth = cls.get_credentials(cluster_name)
|
||||
|
||||
args = cluster_config.get("args", {})
|
||||
_instances[cluster_name] = Elasticsearch(
|
||||
hosts=hosts, transport_class=Transport, **args
|
||||
hosts=hosts, http_auth=http_auth, **args
|
||||
)
|
||||
|
||||
return _instances[cluster_name]
|
||||
|
||||
@classmethod
|
||||
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
|
||||
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
|
||||
if not cluster_config.get("secure", True):
|
||||
return None
|
||||
|
||||
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
|
||||
if not elastic_user:
|
||||
return None
|
||||
|
||||
elastic_password = OVERRIDE_PASSWORD or config.get(
|
||||
"secure.elastic.password", None
|
||||
)
|
||||
if not elastic_password:
|
||||
raise MissingPasswordForElasticUser(
|
||||
f"cluster={cluster_name}, username={elastic_user}"
|
||||
)
|
||||
|
||||
return elastic_user, elastic_password
|
||||
|
||||
@classmethod
|
||||
def get_all_cluster_names(cls):
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
@classmethod
|
||||
def get_override(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
def get_override_host(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
return OVERRIDE_HOST, OVERRIDE_PORT
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def get_cluster_config(cls, cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
@@ -97,7 +137,7 @@ class ESFactory:
|
||||
for entry in cluster_config.get("hosts", []):
|
||||
entry[key] = value
|
||||
|
||||
host, port = cls.get_override(cluster_name)
|
||||
host, port = cls.get_override_host(cluster_name)
|
||||
|
||||
if host:
|
||||
set_host_prop("host", host)
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import (
|
||||
Union,
|
||||
Mapping,
|
||||
IO,
|
||||
Callable,
|
||||
)
|
||||
from urllib.parse import unquote, urlparse
|
||||
from zipfile import ZipFile, ZIP_BZIP2
|
||||
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
|
||||
from apiserver.database.utils import get_options
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
class PrePopulate:
|
||||
@@ -744,6 +746,19 @@ class PrePopulate:
|
||||
module = importlib.import_module(module_name)
|
||||
return getattr(module, class_name)
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_model_data(model_data: dict) -> dict:
|
||||
metadata_key = "metadata"
|
||||
metadata = model_data.get(metadata_key)
|
||||
if isinstance(metadata, list):
|
||||
metadata = {
|
||||
ParameterKeyEscaper.escape(item["key"]): item
|
||||
for item in metadata
|
||||
if isinstance(item, dict) and "key" in item
|
||||
}
|
||||
model_data[metadata_key] = metadata
|
||||
return model_data
|
||||
|
||||
@staticmethod
|
||||
def _upgrade_task_data(task_data: dict) -> dict:
|
||||
"""
|
||||
@@ -828,9 +843,14 @@ class PrePopulate:
|
||||
print(f"Writing {cls_.__name__.lower()}s into database")
|
||||
tasks = []
|
||||
override_project_count = 0
|
||||
data_upgrade_funcs: Mapping[Type, Callable] = {
|
||||
cls.task_cls: cls._upgrade_task_data,
|
||||
cls.model_cls: cls._upgrade_model_data,
|
||||
}
|
||||
for item in cls.json_lines(f):
|
||||
if cls_ == cls.task_cls:
|
||||
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
|
||||
upgrade_func = data_upgrade_funcs.get(cls_)
|
||||
if upgrade_func:
|
||||
item = json.dumps(upgrade_func(json.loads(item)))
|
||||
|
||||
doc = cls_.from_json(item, created=True)
|
||||
if hasattr(doc, "user"):
|
||||
|
||||
29
apiserver/mongo/migrations/1_3_0.py
Normal file
29
apiserver/mongo/migrations/1_3_0.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from pymongo.collection import Collection
|
||||
from pymongo.database import Database
|
||||
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .utils import _drop_all_indices_from_collections
|
||||
|
||||
|
||||
def _convert_metadata(db: Database, name):
|
||||
collection: Collection = db[name]
|
||||
|
||||
metadata_field = "metadata"
|
||||
query = {metadata_field: {"$exists": True, "$type": 4}}
|
||||
for doc in collection.find(filter=query, projection=(metadata_field,)):
|
||||
metadata = {
|
||||
ParameterKeyEscaper.escape(item["key"]): item
|
||||
for item in doc.get(metadata_field, [])
|
||||
if isinstance(item, dict) and "key" in item
|
||||
}
|
||||
collection.update_one(
|
||||
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
|
||||
)
|
||||
|
||||
|
||||
def migrate_backend(db: Database):
|
||||
collections = ["model", "queue"]
|
||||
for name in collections:
|
||||
_convert_metadata(db, name)
|
||||
|
||||
_drop_all_indices_from_collections(db, collections)
|
||||
@@ -1,10 +1,8 @@
|
||||
import threading
|
||||
from os import getenv
|
||||
from time import sleep
|
||||
|
||||
from boltons.iterutils import first
|
||||
from redis import StrictRedis
|
||||
from redis.sentinel import Sentinel, SentinelConnectionPool
|
||||
from rediscluster import RedisCluster
|
||||
|
||||
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
|
||||
from apiserver.config_repo import config
|
||||
@@ -21,6 +19,11 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"TRAINS_REDIS_SERVICE_PORT",
|
||||
"REDIS_SERVICE_PORT",
|
||||
)
|
||||
OVERRIDE_PASSWORD_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_PASSWORD",
|
||||
"TRAINS_REDIS_SERVICE_PASSWORD",
|
||||
"REDIS_SERVICE_PASSWORD",
|
||||
)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
@@ -30,99 +33,7 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override redis port {OVERRIDE_PORT}")
|
||||
|
||||
|
||||
class MyPubSubWorkerThread(threading.Thread):
|
||||
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
|
||||
super(MyPubSubWorkerThread, self).__init__()
|
||||
self.daemon = daemon
|
||||
self.sentinel = sentinel
|
||||
self.on_new_master = on_new_master
|
||||
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.msg_sleep_time = msg_sleep_time
|
||||
self._running = False
|
||||
self.pubsub = None
|
||||
|
||||
def subscribe(self):
|
||||
if self.pubsub:
|
||||
try:
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self.pubsub = None
|
||||
|
||||
subscriptions = {"+switch-master": self.on_new_master}
|
||||
|
||||
while not self.pubsub or not self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub = self.sentinel.pubsub()
|
||||
self.pubsub.subscribe(**subscriptions)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
|
||||
)
|
||||
sleep(3)
|
||||
log.info(f"Subscribed to sentinel {self.sentinel_host}")
|
||||
|
||||
def run(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
|
||||
self.subscribe()
|
||||
|
||||
while self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
|
||||
)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
|
||||
)
|
||||
self.subscribe()
|
||||
|
||||
self.pubsub.close()
|
||||
self._running = False
|
||||
|
||||
def stop(self):
|
||||
# stopping simply unsubscribes from all channels and patterns.
|
||||
# the unsubscribe responses that are generated will short circuit
|
||||
# the loop in run(), calling pubsub.close() to clean up the connection
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
|
||||
|
||||
# todo,future - multi master clusters?
|
||||
class RedisCluster(object):
|
||||
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
|
||||
self.service_name = service_name
|
||||
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
|
||||
self.master = None
|
||||
self.master_host_port = None
|
||||
self.reconfigure()
|
||||
self.sentinel_threads = {}
|
||||
self.listen()
|
||||
|
||||
def reconfigure(self):
|
||||
try:
|
||||
self.master_host_port = self.sentinel.discover_master(self.service_name)
|
||||
self.master = self.sentinel.master_for(self.service_name)
|
||||
log.info(f"Reconfigured master to {self.master_host_port}")
|
||||
except Exception as ex:
|
||||
log.error(f"Error while reconfiguring. {ex.args[0]}")
|
||||
|
||||
def listen(self):
|
||||
def on_new_master(workerThread):
|
||||
self.reconfigure()
|
||||
|
||||
for sentinel in self.sentinel.sentinels:
|
||||
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
|
||||
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
|
||||
)
|
||||
self.sentinel_threads[sentinel_host].start()
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
|
||||
|
||||
class RedisManager(object):
|
||||
@@ -131,6 +42,9 @@ class RedisManager(object):
|
||||
for alias, alias_config in redis_config_dict.items():
|
||||
|
||||
alias_config = alias_config.as_plain_ordered_dict()
|
||||
alias_config["password"] = config.get(
|
||||
f"secure.redis.{alias}.password", None
|
||||
)
|
||||
|
||||
is_cluster = alias_config.get("cluster", False)
|
||||
|
||||
@@ -142,34 +56,19 @@ class RedisManager(object):
|
||||
if port:
|
||||
alias_config["port"] = port
|
||||
|
||||
db = alias_config.get("db", 0)
|
||||
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
|
||||
if password:
|
||||
alias_config["password"] = password
|
||||
|
||||
sentinels = alias_config.get("sentinels", None)
|
||||
service_name = alias_config.get("service_name", None)
|
||||
|
||||
if not is_cluster and sentinels:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. mixed regular and cluster mode",
|
||||
alias=alias,
|
||||
)
|
||||
if is_cluster and (not sentinels or not service_name):
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing sentinels or service_name",
|
||||
alias=alias,
|
||||
)
|
||||
if not is_cluster and (not port or not host):
|
||||
if not port or not host:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing port or host", alias=alias
|
||||
)
|
||||
|
||||
if is_cluster:
|
||||
# todo support all redis connection args via sentinel's connection_kwargs
|
||||
del alias_config["sentinels"]
|
||||
del alias_config["cluster"]
|
||||
del alias_config["service_name"]
|
||||
self.aliases[alias] = RedisCluster(
|
||||
sentinels, service_name, **alias_config
|
||||
)
|
||||
del alias_config["db"]
|
||||
self.aliases[alias] = RedisCluster(**alias_config)
|
||||
else:
|
||||
self.aliases[alias] = StrictRedis(**alias_config)
|
||||
|
||||
@@ -177,27 +76,21 @@ class RedisManager(object):
|
||||
obj = self.aliases.get(alias)
|
||||
if not obj:
|
||||
raise GeneralError(f"Invalid Redis alias {alias}")
|
||||
if isinstance(obj, RedisCluster):
|
||||
obj.master.get("health")
|
||||
return obj.master
|
||||
else:
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
def host(self, alias):
|
||||
r = self.connection(alias)
|
||||
pool = r.connection_pool
|
||||
if isinstance(pool, SentinelConnectionPool):
|
||||
connections = pool.connection_kwargs[
|
||||
"connection_pool"
|
||||
]._available_connections
|
||||
if isinstance(r, RedisCluster):
|
||||
connections = first(r.connection_pool._available_connections.values())
|
||||
else:
|
||||
connections = pool._available_connections
|
||||
connections = r.connection_pool._available_connections
|
||||
|
||||
if len(connections) > 0:
|
||||
return connections[0].host
|
||||
else:
|
||||
if not connections:
|
||||
return None
|
||||
|
||||
return connections[0].host
|
||||
|
||||
|
||||
redman = RedisManager(config.get("hosts.redis"))
|
||||
|
||||
@@ -3,7 +3,7 @@ bcrypt>=3.1.4
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
elasticsearch==7.13.3
|
||||
fastjsonschema>=2.8
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
@@ -16,18 +16,19 @@ jinja2==2.11.3
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.10.0
|
||||
mongoengine==0.19.1
|
||||
mongoengine==0.23.1
|
||||
nested_dict>=1.61
|
||||
packaging==20.3
|
||||
psutil>=5.6.5
|
||||
pyhocon>=0.3.35
|
||||
pyjwt<2.0.0
|
||||
pymongo==3.10.1
|
||||
pymongo[srv]==3.12.0
|
||||
python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
redis==3.5.3
|
||||
redis-py-cluster>=2.1.3
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
tqdm
|
||||
validators>=0.12.4
|
||||
validators>=0.12.4
|
||||
|
||||
@@ -26,6 +26,10 @@ credentials {
|
||||
type: string
|
||||
description: Credentials secret key
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
batch_operation {
|
||||
|
||||
@@ -15,11 +15,19 @@ _definitions {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
last_used {
|
||||
type: string
|
||||
description: ""
|
||||
format: "date-time"
|
||||
}
|
||||
last_used_from {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
role {
|
||||
@@ -222,6 +230,12 @@ create_credentials {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.17": ${create_credentials."2.1"} {
|
||||
request.properties.label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_credentials {
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -61,14 +61,14 @@ _definitions {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
||||
@@ -98,9 +98,11 @@ _definitions {
|
||||
additionalProperties: true
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -199,6 +201,29 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -302,6 +327,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_frameworks {
|
||||
"2.8" {
|
||||
@@ -361,7 +409,7 @@ update_for_task {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
override_model_id {
|
||||
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
|
||||
@@ -427,7 +475,7 @@ create {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@@ -483,9 +531,11 @@ create {
|
||||
}
|
||||
"2.13": ${create."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -522,7 +572,7 @@ edit {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
@@ -578,9 +628,11 @@ edit {
|
||||
}
|
||||
"2.13": ${edit."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -611,7 +663,7 @@ update {
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
@@ -661,9 +713,11 @@ update {
|
||||
}
|
||||
"2.13": ${update."2.1"} {
|
||||
metadata {
|
||||
type: array
|
||||
description: "Model metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -672,7 +726,7 @@ publish_many {
|
||||
description: Publish models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to publish"
|
||||
ids.description: "IDs of the models to publish"
|
||||
force_publish_task {
|
||||
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
|
||||
type: boolean
|
||||
@@ -733,7 +787,7 @@ archive_many {
|
||||
description: Archive models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to archive"
|
||||
ids.description: "IDs of the models to archive"
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -769,10 +823,9 @@ delete_many {
|
||||
description: Delete models
|
||||
request {
|
||||
properties {
|
||||
ids.description: "IDs of models to delete"
|
||||
ids.description: "IDs of the models to delete"
|
||||
force {
|
||||
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
|
||||
"""
|
||||
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
@@ -929,6 +982,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
47
apiserver/schema/services/pipelines.conf
Normal file
47
apiserver/schema/services/pipelines.conf
Normal file
@@ -0,0 +1,47 @@
|
||||
_description: "Provides a management API for pipelines in the system."
|
||||
_definitions {
|
||||
}
|
||||
|
||||
start_pipeline {
|
||||
"2.17" {
|
||||
description: "Start a pipeline"
|
||||
request {
|
||||
type: object
|
||||
required: [ task ]
|
||||
properties {
|
||||
task {
|
||||
description: "ID of the task on which the pipeline will be based"
|
||||
type: string
|
||||
}
|
||||
queue {
|
||||
description: "Queue ID in which the created pipeline task will be enqueued"
|
||||
type: string
|
||||
}
|
||||
args {
|
||||
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
name: { type: string }
|
||||
value: { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
pipeline {
|
||||
description: "ID of the new pipeline task"
|
||||
type: string
|
||||
}
|
||||
enqueued {
|
||||
description: "True if the task was successfuly enqueued"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -42,15 +42,20 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -70,6 +75,18 @@ _definitions {
|
||||
description: "Total run time of all tasks in project (in seconds)"
|
||||
type: integer
|
||||
}
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks_24h {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
last_task_run {
|
||||
description: "The most recent started time of a task"
|
||||
type: integer
|
||||
}
|
||||
status_count {
|
||||
description: "Status counts"
|
||||
type: object
|
||||
@@ -78,6 +95,10 @@ _definitions {
|
||||
description: "Number of 'created' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
completed {
|
||||
description: "Number of 'completed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
queued {
|
||||
description: "Number of 'queued' tasks in project"
|
||||
type: integer
|
||||
@@ -152,15 +173,20 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -294,14 +320,14 @@ create {
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -414,7 +440,6 @@ get_all {
|
||||
description: "Projects list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -430,6 +455,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
@@ -488,6 +536,49 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
"2.16": ${get_all_ex."2.15"} {
|
||||
request.properties.stats_with_children {
|
||||
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
"2.17": ${get_all_ex."2.16"} {
|
||||
request.properties.include_stats_filter {
|
||||
description: The filter for selecting entities that participate in statistics calculation
|
||||
type: object
|
||||
properties {
|
||||
system_tags {
|
||||
description: The list of allowed system tags
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@@ -504,23 +595,19 @@ update {
|
||||
description: "Project name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "User-defined tags list"
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
items { type: string }
|
||||
}
|
||||
system_tags {
|
||||
description: "System tags list. This field is reserved for system use, please don't use it."
|
||||
type: array
|
||||
description: "System tags. This field is reserved for system use, please don't use it."
|
||||
items {type: string}
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
@@ -658,7 +745,6 @@ delete {
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
@@ -791,6 +877,7 @@ get_hyper_parameters {
|
||||
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
@@ -839,7 +926,105 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_values {
|
||||
"2.17" {
|
||||
description: """Get a list of distinct values for the chosen model metadata key"""
|
||||
request {
|
||||
type: object
|
||||
required: [key]
|
||||
properties {
|
||||
projects {
|
||||
description: "Project IDs"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
key {
|
||||
description: "Metadata key"
|
||||
type: string
|
||||
}
|
||||
allow_public {
|
||||
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
|
||||
type: boolean
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
total {
|
||||
description: "Total number of distinct values"
|
||||
type: integer
|
||||
}
|
||||
values {
|
||||
description: "The list of the unique values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_model_metadata_keys {
|
||||
"2.17" {
|
||||
description: """Get a list of all metadata keys used in models within the given project."""
|
||||
request {
|
||||
type: object
|
||||
required: [project]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
include_subprojects {
|
||||
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
|
||||
page {
|
||||
description: "Page number"
|
||||
default: 0
|
||||
type: integer
|
||||
}
|
||||
page_size {
|
||||
description: "Page size"
|
||||
default: 500
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
keys {
|
||||
description: "A list of model keys"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
type: integer
|
||||
}
|
||||
total {
|
||||
description: "Total number of results"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_project_tags {
|
||||
"2.17" {
|
||||
description: "Get user and system tags used for the specified projects and their children"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
@@ -847,7 +1032,6 @@ get_task_tags {
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
@@ -969,4 +1153,4 @@ get_task_parents {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,9 +79,11 @@ _definitions {
|
||||
items { "$ref": "#/definitions/entry" }
|
||||
}
|
||||
metadata {
|
||||
type: array
|
||||
description: "Queue metadata"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -115,6 +117,29 @@ get_by_id {
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.4": ${get_all."2.4"}
|
||||
"2.15": ${get_all_ex."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.4" {
|
||||
@@ -178,6 +203,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_default {
|
||||
"2.4" {
|
||||
@@ -235,6 +283,15 @@ create {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${create."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.4" {
|
||||
@@ -276,7 +333,15 @@ update {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.13": ${update."2.4"} {
|
||||
metadata {
|
||||
description: "Queue metadata"
|
||||
type: object
|
||||
additionalProperties {
|
||||
"$ref": "#/definitions/metadata_item"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -586,6 +651,11 @@ add_or_update_metadata {
|
||||
description: "Metadata items to add or update"
|
||||
items {"$ref": "#/definitions/metadata_item"}
|
||||
}
|
||||
replace_metadata {
|
||||
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
|
||||
@@ -534,7 +534,7 @@ _definitions {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
@@ -685,6 +685,29 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -799,6 +822,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_types {
|
||||
"2.8" {
|
||||
@@ -935,7 +981,7 @@ clone {
|
||||
new_task_container {
|
||||
description: "The docker container properties for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1113,7 +1159,7 @@ create {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1202,7 +1248,7 @@ validate {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1364,7 +1410,7 @@ edit {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
|
||||
@@ -4,7 +4,7 @@ from hashlib import md5
|
||||
from flask import Flask
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from semantic_version import Version
|
||||
from packaging.version import Version
|
||||
|
||||
from apiserver.database import db
|
||||
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
||||
@@ -25,6 +25,7 @@ from apiserver.server_init.request_handlers import RequestHandlers
|
||||
from apiserver.service_repo import ServiceRepo
|
||||
from apiserver.sync import distributed_lock
|
||||
from apiserver.updates import check_updates_thread
|
||||
from apiserver.utilities.env import get_bool
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__file__)
|
||||
@@ -46,10 +47,13 @@ class AppSequence:
|
||||
def _attach_request_handlers(self, request_handlers: RequestHandlers):
|
||||
self.app.before_first_request(request_handlers.before_app_first_request)
|
||||
self.app.before_request(request_handlers.before_request)
|
||||
self.app.after_request(request_handlers.after_request)
|
||||
|
||||
def _configure(self):
|
||||
CORS(self.app, **config.get("apiserver.cors"))
|
||||
Compress(self.app)
|
||||
|
||||
if get_bool("CLEARML_COMPRESS_RESP", default=True):
|
||||
Compress(self.app)
|
||||
|
||||
self.app.config["SECRET_KEY"] = config.get(
|
||||
"secure.http.session_secret.apiserver"
|
||||
|
||||
@@ -1,21 +1,24 @@
|
||||
from functools import partial
|
||||
|
||||
from flask import request, Response, redirect
|
||||
from werkzeug.datastructures import ImmutableMultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import ServiceRepo, APICall
|
||||
from apiserver.service_repo.auth import AuthType
|
||||
from apiserver.service_repo.auth import AuthType, Token
|
||||
from apiserver.service_repo.errors import PathParsingError
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class RequestHandlers:
|
||||
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
|
||||
_server_header = config.get("apiserver.response.headers.server", "clearml")
|
||||
|
||||
def before_app_first_request(self):
|
||||
pass
|
||||
@@ -26,10 +29,13 @@ class RequestHandlers:
|
||||
if "/static/" in request.path:
|
||||
return
|
||||
|
||||
if request.content_encoding:
|
||||
return f"Content encoding is not supported ({request.content_encoding})", 415
|
||||
|
||||
try:
|
||||
call = self._create_api_call(request)
|
||||
load_data_callback = partial(self._load_call_data, req=request)
|
||||
content, content_type = ServiceRepo.handle_call(
|
||||
content, content_type, company = ServiceRepo.handle_call(
|
||||
call, load_data_callback=load_data_callback
|
||||
)
|
||||
|
||||
@@ -51,20 +57,53 @@ class RequestHandlers:
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
kwargs = config.get("apiserver.auth.cookies")
|
||||
kwargs = config.get("apiserver.auth.cookies").copy()
|
||||
if value is None:
|
||||
kwargs = kwargs.copy()
|
||||
# Removing a cookie
|
||||
kwargs["max_age"] = 0
|
||||
kwargs["expires"] = 0
|
||||
response.set_cookie(key, "", **kwargs)
|
||||
else:
|
||||
response.set_cookie(key, value, **kwargs)
|
||||
value = ""
|
||||
elif not company:
|
||||
# Setting a cookie, let's try to figure out the company
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
company = Token.decode_identity(value).company
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if company:
|
||||
try:
|
||||
# use no default value to allow setting a null domain as well
|
||||
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response.set_cookie(key, value, **kwargs)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
def after_request(self, response):
|
||||
response.headers["server"] = self._server_header
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
|
||||
def convert_value(v: str):
|
||||
if v.replace(".", "", 1).isdigit():
|
||||
return float(v) if "." in v else int(v)
|
||||
if v in ("true", "True", "TRUE"):
|
||||
return True
|
||||
if v in ("false", "False", "FALSE"):
|
||||
return False
|
||||
return v
|
||||
|
||||
for k, v in md.lists():
|
||||
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
|
||||
nested_set(body, k.rstrip("[]").split("."), v)
|
||||
|
||||
def _update_call_data(self, call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
if req.content_type == "application/json-lines":
|
||||
@@ -82,23 +121,12 @@ class RequestHandlers:
|
||||
req.on_json_loading_failed(msg)
|
||||
call.batched_data = items
|
||||
else:
|
||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
||||
# merge form and args
|
||||
form = req.form.copy()
|
||||
form.update(req.args)
|
||||
form = form.to_dict()
|
||||
# convert string numbers to floats
|
||||
for key in form:
|
||||
if form[key].replace(".", "", 1).isdigit():
|
||||
if "." in form[key]:
|
||||
form[key] = float(form[key])
|
||||
else:
|
||||
form[key] = int(form[key])
|
||||
elif form[key].lower() == "true":
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
call.data = json_body or form or {}
|
||||
body = (req.get_json(force=True, silent=False) if req.data else None) or {}
|
||||
if req.args:
|
||||
self._apply_multi_dict(body, req.args)
|
||||
if req.form:
|
||||
self._apply_multi_dict(body, req.form)
|
||||
call.data = body
|
||||
|
||||
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
|
||||
call = call or APICall(
|
||||
|
||||
@@ -95,8 +95,8 @@ class DataContainer(object):
|
||||
@raw_data.setter
|
||||
def raw_data(self, value):
|
||||
assert isinstance(
|
||||
value, string_types + (types.GeneratorType,)
|
||||
), "Raw data must be a string type or generator"
|
||||
value, string_types + (types.GeneratorType, bytes)
|
||||
), "Raw data must be a string type or bytes or generator"
|
||||
self._raw_data = value
|
||||
|
||||
@property
|
||||
@@ -310,6 +310,12 @@ class APICall(DataContainer):
|
||||
_transaction_headers = _get_headers("Trx")
|
||||
""" Transaction ID """
|
||||
|
||||
_redacted_headers = {
|
||||
HEADER_AUTHORIZATION: " ",
|
||||
"Cookie": "=",
|
||||
}
|
||||
""" Headers whose value should be redacted. Maps header name to partition char """
|
||||
|
||||
@property
|
||||
def HEADER_TRANSACTION(self):
|
||||
return self._transaction_headers[0]
|
||||
@@ -389,6 +395,10 @@ class APICall(DataContainer):
|
||||
self._auth_cookie = auth_cookie
|
||||
self._json_flags = {}
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
return self._files
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
@@ -584,6 +594,10 @@ class APICall(DataContainer):
|
||||
def json_flags(self):
|
||||
return self._json_flags
|
||||
|
||||
@property
|
||||
def extra_meta_fields(self):
|
||||
return {}
|
||||
|
||||
def mark_end(self):
|
||||
self._end_ts = time.time()
|
||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||
@@ -634,6 +648,7 @@ class APICall(DataContainer):
|
||||
"result_msg": self.result.msg,
|
||||
"error_stack": self.result.traceback if include_stack else None,
|
||||
"error_data": self.result.error_data,
|
||||
**self.extra_meta_fields,
|
||||
},
|
||||
"data": self.result.data,
|
||||
}
|
||||
@@ -668,3 +683,15 @@ class APICall(DataContainer):
|
||||
error_data=error_data,
|
||||
cookies=self._result.cookies,
|
||||
)
|
||||
|
||||
def get_redacted_headers(self):
|
||||
headers = self.headers.copy()
|
||||
if not self.requires_authorization or self.auth:
|
||||
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
|
||||
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully
|
||||
# this will allow us to debug authorization errors).
|
||||
for header, sep in self._redacted_headers.items():
|
||||
if header in headers:
|
||||
prefix, _, redact = headers[header].partition(sep)
|
||||
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
|
||||
return headers
|
||||
|
||||
@@ -51,7 +51,7 @@ def authorize_token(jwt_token, *_, **__):
|
||||
)
|
||||
|
||||
|
||||
def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
def authorize_credentials(auth_data, service, action, call):
|
||||
"""Validate credentials against service/action and request data (dicts).
|
||||
Returns a new basic object (auth payload)
|
||||
"""
|
||||
@@ -100,7 +100,12 @@ def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
if not fixed_user:
|
||||
# In case these are proper credentials, update last used time
|
||||
User.objects(id=user.id, credentials__key=access_key).update(
|
||||
**{"set__credentials__$__last_used": datetime.utcnow()}
|
||||
**{
|
||||
"set__credentials__$__last_used": datetime.utcnow(),
|
||||
"set__credentials__$__last_used_from": call.get_worker(
|
||||
default=call.real_ip
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
with TimingContext("mongo", "company_by_id"):
|
||||
|
||||
@@ -12,6 +12,9 @@ from .payload import Payload
|
||||
token_secret = config.get('secure.auth.token_secret')
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Token(Payload):
|
||||
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
|
||||
|
||||
@@ -94,3 +97,14 @@ class Token(Payload):
|
||||
token.exp = now + timedelta(seconds=expiration_sec)
|
||||
|
||||
return token.encode(**extra_payload)
|
||||
|
||||
@classmethod
|
||||
def decode_identity(cls, encoded_token):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from ..auth import Identity
|
||||
|
||||
decoded = cls.decode(encoded_token, verify=False)
|
||||
return Identity.from_dict(decoded.get("identity", {}))
|
||||
except Exception as ex:
|
||||
log.error(f"Failed parsing identity from encoded token: {ex}")
|
||||
|
||||
@@ -10,6 +10,7 @@ from apiserver.apierrors import APIError, errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
from .apicall import APICall
|
||||
from .auth import Identity
|
||||
from .endpoint import Endpoint
|
||||
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
|
||||
from .util import parse_return_stack_on_code
|
||||
@@ -37,7 +38,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.14")
|
||||
_max_version = PartialVersion("2.17")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
@@ -233,19 +234,27 @@ class ServiceRepo(object):
|
||||
return subcode in subcode_list
|
||||
|
||||
@classmethod
|
||||
def _get_company(
|
||||
def _get_identity(
|
||||
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
||||
) -> Optional[str]:
|
||||
) -> Optional[Identity]:
|
||||
authorize = endpoint and endpoint.authorize
|
||||
if ignore_error or not authorize:
|
||||
try:
|
||||
return call.identity.company
|
||||
return call.identity
|
||||
except Exception:
|
||||
return None
|
||||
return call.identity.company
|
||||
return call.identity
|
||||
|
||||
@classmethod
|
||||
def _get_company(
|
||||
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
||||
) -> Optional[str]:
|
||||
identity = cls._get_identity(call, endpoint=endpoint, ignore_error=ignore_error)
|
||||
return None if identity is None else identity.company
|
||||
|
||||
@classmethod
|
||||
def handle_call(cls, call: APICall, load_data_callback: Callable = None):
|
||||
company = None
|
||||
try:
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
@@ -316,4 +325,4 @@ class ServiceRepo(object):
|
||||
else:
|
||||
log.error(console_msg)
|
||||
|
||||
return content, content_type
|
||||
return content, content_type, company
|
||||
|
||||
@@ -69,7 +69,7 @@ def validate_auth(endpoint, call):
|
||||
auth = call.authorization or ""
|
||||
auth_type, _, auth_data = auth.partition(" ")
|
||||
authorize_func = get_auth_func(auth_type)
|
||||
call.auth = authorize_func(auth_data, service, action, call.batched_data)
|
||||
call.auth = authorize_func(auth_data, service, action, call)
|
||||
except Exception:
|
||||
if endpoint.authorize:
|
||||
# if endpoint requires authorization, re-raise exception
|
||||
|
||||
@@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
|
||||
CredentialsResponse,
|
||||
RevokeCredentialsRequest,
|
||||
EditUserReq,
|
||||
CreateCredentialsRequest,
|
||||
)
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.bll.auth import AuthBLL
|
||||
@@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
||||
if call.identity.role not in Role.get_system_roles():
|
||||
if call.identity.role != Role.admin and call.identity.user != request.user:
|
||||
raise errors.bad_request.InvalidUserId("cannot generate token for another user")
|
||||
raise errors.bad_request.InvalidUserId(
|
||||
"cannot generate token for another user"
|
||||
)
|
||||
if call.identity.company != request.company:
|
||||
raise errors.bad_request.InvalidId("cannot generate token in another company")
|
||||
raise errors.bad_request.InvalidId(
|
||||
"cannot generate token in another company"
|
||||
)
|
||||
|
||||
call.result.data_model = AuthBLL.get_token_for_user(
|
||||
user_id=request.user,
|
||||
@@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
|
||||
)
|
||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
""" Create a user from. INTERNAL. """
|
||||
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company:
|
||||
if (
|
||||
call.identity.role not in Role.get_system_roles()
|
||||
and request.company != call.identity.company
|
||||
):
|
||||
raise errors.bad_request.InvalidId("cannot create user in another company")
|
||||
|
||||
user_id = AuthBLL.create_user(request=request, call=call)
|
||||
@@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
|
||||
|
||||
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
|
||||
def create_credentials(call: APICall, _, __):
|
||||
def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
|
||||
if _is_protected_user(call.identity.user):
|
||||
raise errors.bad_request.InvalidUserId("protected identity")
|
||||
|
||||
@@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
|
||||
user_id=call.identity.user,
|
||||
company_id=call.identity.company,
|
||||
role=call.identity.role,
|
||||
label=request.label,
|
||||
)
|
||||
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
|
||||
|
||||
@@ -151,7 +160,12 @@ def get_credentials(call: APICall, _, __):
|
||||
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
|
||||
call.result.data_model = GetCredentialsResponse(
|
||||
credentials=[
|
||||
CredentialsResponse(access_key=c.key, last_used=c.last_used)
|
||||
CredentialsResponse(
|
||||
access_key=c.key,
|
||||
last_used=c.last_used,
|
||||
label=c.label,
|
||||
last_used_from=c.last_used_from,
|
||||
)
|
||||
for c in user.credentials
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import attr
|
||||
from typing import Sequence, Optional
|
||||
import jsonmodels.fields
|
||||
from boltons.iterutils import bucketize
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.events import (
|
||||
@@ -20,12 +23,17 @@ from apiserver.apimodels.events import (
|
||||
NextDebugImageSampleRequest,
|
||||
MetricVariants as ApiMetrics,
|
||||
TaskPlotsRequest,
|
||||
TaskEventsRequest,
|
||||
ScalarMetricsIterRawRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities import json, extract_properties_to_lists
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
@@ -39,7 +47,6 @@ def add(call: APICall, company_id, _):
|
||||
company_id, [data], call.worker, allow_locked_tasks=allow_locked
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = 1
|
||||
|
||||
|
||||
@endpoint("events.add_batch")
|
||||
@@ -50,7 +57,6 @@ def add_batch(call: APICall, company_id, _):
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
@@ -113,7 +119,8 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
res = event_bll.log_events_iterator.get_task_events(
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.task_log,
|
||||
company_id=task.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=request.batch_size,
|
||||
@@ -258,31 +265,94 @@ def vector_metrics_iter_histogram(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", required_fields=["task"])
|
||||
def get_task_events(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
batch_size = call.data.get("batch_size", 500)
|
||||
event_type = call.data.get("event_type")
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
order = call.data.get("order") or "asc"
|
||||
class GetTaskEventsScroll(Scroll):
|
||||
from_key_value = jsonmodels.fields.StringField()
|
||||
total = jsonmodels.fields.IntField()
|
||||
request: TaskEventsRequest = jsonmodels.fields.EmbeddedField(TaskEventsRequest)
|
||||
|
||||
|
||||
def make_response(
|
||||
total: int, returned: int = 0, scroll_id: str = None, **kwargs
|
||||
) -> dict:
|
||||
return {
|
||||
"returned": returned,
|
||||
"total": total,
|
||||
"scroll_id": scroll_id,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
|
||||
def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=EventType(event_type) if event_type else EventType.all,
|
||||
scroll_id=scroll_id,
|
||||
size=batch_size,
|
||||
|
||||
key = ScalarKeyEnum.iter
|
||||
scalar_key = ScalarKey.resolve(key)
|
||||
|
||||
if not request.scroll_id:
|
||||
from_key_value = None if (request.order == LogOrderEnum.desc) else 0
|
||||
total = None
|
||||
else:
|
||||
try:
|
||||
scroll = GetTaskEventsScroll.from_scroll_id(request.scroll_id)
|
||||
except ValueError:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
|
||||
|
||||
if scroll.from_key_value is None:
|
||||
return make_response(
|
||||
scroll_id=request.scroll_id, total=scroll.total, events=[]
|
||||
)
|
||||
|
||||
from_key_value = scalar_key.cast_value(scroll.from_key_value)
|
||||
total = scroll.total
|
||||
|
||||
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
request = scroll.request
|
||||
|
||||
navigate_earlier = request.order == LogOrderEnum.desc
|
||||
metric_variants = _get_metric_variants_from_request(request.metrics)
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
batch_size = min(
|
||||
request.batch_size,
|
||||
int(
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
|
||||
),
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
events=result.events,
|
||||
returned=len(result.events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
key=ScalarKeyEnum.iter,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
scroll = GetTaskEventsScroll(
|
||||
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
|
||||
total=total,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return make_response(
|
||||
returned=len(res.events),
|
||||
total=total,
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
events=res.events,
|
||||
)
|
||||
|
||||
|
||||
@@ -291,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@@ -302,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
@@ -424,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
@@ -445,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
@@ -523,6 +597,7 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
|
||||
@@ -759,3 +834,105 @@ def _get_top_iter_unique_events(events, max_iters):
|
||||
)
|
||||
unique_events.sort(key=lambda e: e["iter"], reverse=True)
|
||||
return unique_events
|
||||
|
||||
|
||||
class ScalarMetricsIterRawScroll(Scroll):
|
||||
from_key_value = jsonmodels.fields.StringField()
|
||||
total = jsonmodels.fields.IntField()
|
||||
request: ScalarMetricsIterRawRequest = jsonmodels.fields.EmbeddedField(
|
||||
ScalarMetricsIterRawRequest
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.scalar_metrics_iter_raw", min_version="2.16")
|
||||
def scalar_metrics_iter_raw(
|
||||
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
|
||||
):
|
||||
key = request.key or ScalarKeyEnum.iter
|
||||
scalar_key = ScalarKey.resolve(key)
|
||||
if request.batch_size and request.batch_size < 0:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"batch_size should be non negative number"
|
||||
)
|
||||
|
||||
if not request.scroll_id:
|
||||
from_key_value = None
|
||||
total = None
|
||||
request.batch_size = request.batch_size or 10_000
|
||||
else:
|
||||
try:
|
||||
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
|
||||
except ValueError:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
|
||||
|
||||
if scroll.from_key_value is None:
|
||||
return make_response(
|
||||
scroll_id=request.scroll_id, total=scroll.total, variants={}
|
||||
)
|
||||
|
||||
from_key_value = scalar_key.cast_value(scroll.from_key_value)
|
||||
total = scroll.total
|
||||
request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
)[0]
|
||||
|
||||
metric_variants = _get_metric_variants_from_request([request.metric])
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
batch_size = min(
|
||||
request.batch_size,
|
||||
int(
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000)
|
||||
),
|
||||
)
|
||||
|
||||
events = []
|
||||
for iteration in range(0, math.ceil(batch_size / 10_000)):
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=min(batch_size, 10_000),
|
||||
navigate_earlier=False,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=key,
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
events.extend(res.events)
|
||||
from_key_value = str(events[-1][scalar_key.field])
|
||||
|
||||
key = str(key)
|
||||
variants = {
|
||||
variant: extract_properties_to_lists(
|
||||
["value", scalar_key.field], events, target_keys=["y", key]
|
||||
)
|
||||
for variant, events in bucketize(events, key=itemgetter("variant")).items()
|
||||
}
|
||||
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
scroll = ScalarMetricsIterRawScroll(
|
||||
from_key_value=str(events[-1][scalar_key.field]) if events else None,
|
||||
total=total,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return make_response(
|
||||
returned=len(events),
|
||||
total=total,
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
variants=variants,
|
||||
)
|
||||
|
||||
@@ -21,16 +21,14 @@ from apiserver.apimodels.models import (
|
||||
ModelsPublishManyRequest,
|
||||
ModelsDeleteManyRequest,
|
||||
)
|
||||
from apiserver.bll.model import ModelBLL
|
||||
from apiserver.bll.model import ModelBLL, Metadata
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import publish_task
|
||||
from apiserver.bll.util import run_batch_operation
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model import validate_id
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.model import Model
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import (
|
||||
@@ -50,8 +48,8 @@ from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
ModelsBackwardsCompatibility,
|
||||
validate_metadata,
|
||||
get_metadata_from_api,
|
||||
unescape_metadata,
|
||||
escape_metadata,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
@@ -64,19 +62,20 @@ project_bll = ProjectBLL()
|
||||
def get_by_id(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
Metadata.escape_query_parameters(call)
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
if not models:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
conform_output_tags(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
conform_output_tags(call, models[0])
|
||||
unescape_metadata(call, models[0])
|
||||
call.result.data = {"model": models[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
@@ -86,25 +85,25 @@ def get_by_task_id(call: APICall, company_id, _):
|
||||
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.models or not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get(_only=["models"], **query)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not task.models or not task.models.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="models.output")
|
||||
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
model_id = task.models.output[-1].model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company_id)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model", id=model_id, company=company_id,
|
||||
)
|
||||
model_dict = model.to_proper_dict()
|
||||
conform_output_tags(call, model_dict)
|
||||
unescape_metadata(call, model_dict)
|
||||
call.result.data = {"model": model_dict}
|
||||
|
||||
|
||||
def _process_include_subprojects(call_data: dict):
|
||||
@@ -121,41 +120,50 @@ def _process_include_subprojects(call_data: dict):
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
_process_include_subprojects(call.data)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
_process_include_subprojects(call.data)
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
||||
def get_by_id_ex(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_by_id_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
Metadata.escape_query_parameters(call)
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
unescape_metadata(call, models)
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
||||
@@ -183,15 +191,22 @@ create_fields = {
|
||||
"metadata": list,
|
||||
}
|
||||
|
||||
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata")
|
||||
last_update_fields = (
|
||||
"uri",
|
||||
"framework",
|
||||
"design",
|
||||
"labels",
|
||||
"ready",
|
||||
"metadata",
|
||||
"system_tags",
|
||||
"tags",
|
||||
)
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
metadata = fields.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
escape_metadata(fields)
|
||||
return fields
|
||||
|
||||
|
||||
@@ -225,82 +240,80 @@ def update_for_task(call: APICall, company_id, _):
|
||||
"exactly one field is required", fields=("uri", "override_model_id")
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
f"model can only be updated for tasks in the {allowed_states} states",
|
||||
**query,
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
call.data["name"] = task.name
|
||||
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.models and task.models.output:
|
||||
# model exists, update
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
# new model, create
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
_only=["models", "execution", "name", "status", "project"],
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.models.input[0].model
|
||||
if task.models and task.models.input
|
||||
else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
**fields,
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
f"model can only be updated for tasks in the {allowed_states} states",
|
||||
**query,
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=override_model_id
|
||||
)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
call.data["name"] = task.name
|
||||
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.models and task.models.output:
|
||||
# model exists, update
|
||||
model_id = task.models.output[-1].model
|
||||
res = _update_model(call, company_id, model_id=model_id).to_struct()
|
||||
res.update({"id": model_id, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
# new model, create
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=now,
|
||||
last_update=now,
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.models.input[0].model
|
||||
if task.models and task.models.input
|
||||
else None,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
models__output=[
|
||||
ModelItem(
|
||||
model=model.id,
|
||||
name=TaskModelNames[TaskModelTypes.output],
|
||||
updated=datetime.utcnow(),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -313,36 +326,33 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
if req_model.public:
|
||||
company_id = ""
|
||||
|
||||
with translate_errors_context():
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company_id, project=project)
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(company_id, req_data)
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
escape_metadata(fields)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
validate_metadata(fields.get("metadata"))
|
||||
|
||||
# create and save model
|
||||
now = datetime.utcnow()
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
|
||||
def prepare_update_fields(call, company_id, fields: dict):
|
||||
@@ -377,6 +387,7 @@ def prepare_update_fields(call, company_id, fields: dict):
|
||||
)
|
||||
|
||||
conform_tag_fields(call, fields, validate=True)
|
||||
escape_metadata(fields)
|
||||
return fields
|
||||
|
||||
|
||||
@@ -388,89 +399,85 @@ def validate_task(company_id, fields: dict):
|
||||
def edit(call: APICall, company_id, _):
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, company_id, fields)
|
||||
if fields:
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get("task")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
if fields:
|
||||
if any(uf in fields for uf in last_update_fields):
|
||||
fields.update(last_update=datetime.utcnow())
|
||||
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call: APICall, company_id, model_id=None):
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id
|
||||
)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
data = prepare_update_fields(call, company_id, call.data)
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
|
||||
metadata = data.get("metadata")
|
||||
if metadata:
|
||||
validate_metadata(metadata)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
if any(uf in updated_fields for uf in last_update_fields):
|
||||
model.update(upsert=False, last_update=datetime.utcnow())
|
||||
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
conform_output_tags(call, updated_fields)
|
||||
unescape_metadata(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -635,26 +642,25 @@ def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
|
||||
updated = metadata_add_or_update(
|
||||
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
|
||||
)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
|
||||
return {
|
||||
"updated": Metadata.edit_metadata(
|
||||
model,
|
||||
items=request.metadata,
|
||||
replace_metadata=request.replace_metadata,
|
||||
last_update=datetime.utcnow(),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("models.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
model_id = request.model
|
||||
ModelBLL.get_company_model_by_id(
|
||||
model = ModelBLL.get_company_model_by_id(
|
||||
company_id=company_id, model_id=model_id, only_fields=("id",)
|
||||
)
|
||||
|
||||
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
|
||||
if updated:
|
||||
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
|
||||
|
||||
return {"updated": updated}
|
||||
return {
|
||||
"updated": Metadata.delete_metadata(
|
||||
model, keys=request.keys, last_update=datetime.utcnow()
|
||||
)
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.database.model import User
|
||||
from apiserver.service_repo import endpoint, APICall
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
|
||||
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
@@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest):
|
||||
for field, vals in tags.items():
|
||||
ret[field] |= vals
|
||||
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint("organization.get_user_companies")
|
||||
def get_user_companies(call: APICall, company_id: str, _):
|
||||
users = [
|
||||
{
|
||||
"id": u.id,
|
||||
"name": u.name,
|
||||
"avatar": u.avatar,
|
||||
}
|
||||
{"id": u.id, "name": u.name, "avatar": u.avatar}
|
||||
for u in User.objects(company=company_id).only("avatar", "name", "company")
|
||||
]
|
||||
|
||||
|
||||
68
apiserver/services/pipelines.py
Normal file
68
apiserver/services/pipelines.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import re
|
||||
|
||||
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
|
||||
from apiserver.bll.organization import OrgBLL
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.task.task_operations import enqueue_task
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
|
||||
org_bll = OrgBLL()
|
||||
project_bll = ProjectBLL()
|
||||
task_bll = TaskBLL()
|
||||
|
||||
|
||||
def _update_task_name(task: Task):
|
||||
if not task or not task.project:
|
||||
return
|
||||
|
||||
project = Project.objects(id=task.project).only("name").first()
|
||||
if not project:
|
||||
return
|
||||
|
||||
_, _, name_prefix = project.name.rpartition("/")
|
||||
name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
|
||||
count = Task.objects(
|
||||
project=task.project, system_tags__in=["pipeline"], name=name_mask
|
||||
).count()
|
||||
new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
|
||||
task.update(name=new_name)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
|
||||
)
|
||||
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
|
||||
hyperparams = None
|
||||
if request.args:
|
||||
hyperparams = {
|
||||
"Args": {
|
||||
str(arg.name): {
|
||||
"section": "Args",
|
||||
"name": str(arg.name),
|
||||
"value": str(arg.value),
|
||||
}
|
||||
for arg in request.args or []
|
||||
}
|
||||
}
|
||||
|
||||
task, _ = task_bll.clone_task(
|
||||
company_id=company_id,
|
||||
user_id=call.identity.user,
|
||||
task_id=request.task,
|
||||
hyperparams=hyperparams,
|
||||
)
|
||||
|
||||
_update_task_name(task)
|
||||
|
||||
queued, res = enqueue_task(
|
||||
task_id=task.id,
|
||||
company_id=company_id,
|
||||
queue_id=request.queue,
|
||||
status_message="Starting pipeline",
|
||||
status_reason="",
|
||||
)
|
||||
|
||||
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))
|
||||
@@ -7,7 +7,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
||||
from apiserver.apimodels.projects import (
|
||||
GetHyperParamRequest,
|
||||
GetParamsRequest,
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
@@ -17,14 +17,14 @@ from apiserver.apimodels.projects import (
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
ProjectRequest,
|
||||
ProjectModelMetadataValuesRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||
from apiserver.bll.project.project_cleanup import (
|
||||
delete_project,
|
||||
validate_project_delete,
|
||||
)
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.utils import (
|
||||
@@ -36,13 +36,13 @@ from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
get_tags_filter_dictionary,
|
||||
get_tags_response,
|
||||
sort_tags_response,
|
||||
)
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
task_bll = TaskBLL()
|
||||
project_bll = ProjectBLL()
|
||||
project_queries = ProjectQueries()
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
@@ -111,8 +111,12 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
_adjust_search_parameters(data, shallow_search=request.shallow_search)
|
||||
|
||||
ret_params = {}
|
||||
projects = Project.get_many_with_join(
|
||||
company=company_id, query_dict=data, allow_public=allow_public,
|
||||
company=company_id,
|
||||
query_dict=data,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
if request.check_own_contents and requested_ids:
|
||||
@@ -121,14 +125,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
}
|
||||
if existing_requested_ids:
|
||||
contents = project_bll.calc_own_contents(
|
||||
company=company_id, project_ids=list(existing_requested_ids)
|
||||
company=company_id,
|
||||
project_ids=list(existing_requested_ids),
|
||||
filter_=request.include_stats_filter,
|
||||
)
|
||||
for project in projects:
|
||||
project.update(**contents.get(project["id"], {}))
|
||||
|
||||
conform_output_tags(call, projects)
|
||||
if not request.include_stats:
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
return
|
||||
|
||||
project_ids = {project["id"] for project in projects}
|
||||
@@ -136,13 +142,15 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
company=company_id,
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
include_children=request.stats_with_children,
|
||||
filter_=request.include_stats_filter,
|
||||
)
|
||||
|
||||
for project in projects:
|
||||
project["stats"] = stats[project["id"]]
|
||||
project["sub_projects"] = children[project["id"]]
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
@@ -151,15 +159,17 @@ def get_all(call: APICall):
|
||||
data = call.data
|
||||
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=data,
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -260,7 +270,7 @@ def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
||||
):
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
metrics = project_queries.get_unique_metric_variants(
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
@@ -269,14 +279,48 @@ def get_unique_metric_variants(
|
||||
call.result.data = {"metrics": metrics}
|
||||
|
||||
|
||||
@endpoint("projects.get_model_metadata_keys",)
|
||||
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
total, remaining, keys = project_queries.get_model_metadata_keys(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"remaining": remaining,
|
||||
"keys": keys,
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_model_metadata_values")
|
||||
def get_model_metadata_values(
|
||||
call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
|
||||
):
|
||||
total, values = project_queries.get_model_metadata_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
key=request.key,
|
||||
include_subprojects=request.include_subprojects,
|
||||
allow_public=request.allow_public,
|
||||
)
|
||||
call.result.data = {
|
||||
"total": total,
|
||||
"values": values,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
request_data_model=GetHyperParamRequest,
|
||||
request_data_model=GetParamsRequest,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
||||
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
@@ -299,7 +343,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
def get_hyperparam_values(
|
||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||
):
|
||||
total, values = task_bll.get_hyperparam_distinct_values(
|
||||
total, values = project_queries.get_task_hyperparam_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
@@ -313,6 +357,17 @@ def get_hyperparam_values(
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.get_project_tags")
|
||||
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
tags, system_tags = project_bll.get_project_tags(
|
||||
company,
|
||||
include_system=request.include_system,
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
@@ -324,7 +379,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -338,7 +393,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
call.result.data = sort_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
|
||||
@@ -13,17 +13,19 @@ from apiserver.apimodels.queues import (
|
||||
QueueMetrics,
|
||||
AddOrUpdateMetadataRequest,
|
||||
DeleteMetadataRequest,
|
||||
GetNextTaskRequest,
|
||||
)
|
||||
from apiserver.bll.model import Metadata
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.workers import WorkerBLL
|
||||
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
|
||||
from apiserver.database.model.queue import Queue
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
conform_tags,
|
||||
get_metadata_from_api,
|
||||
escape_metadata,
|
||||
unescape_metadata,
|
||||
)
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
|
||||
@@ -36,6 +38,7 @@ def get_by_id(call: APICall, company_id, req_model: QueueRequest):
|
||||
queue = queue_bll.get_by_id(company_id, req_model.queue)
|
||||
queue_dict = queue.to_proper_dict()
|
||||
conform_output_tags(call, queue_dict)
|
||||
unescape_metadata(call, queue_dict)
|
||||
call.result.data = {"queue": queue_dict}
|
||||
|
||||
|
||||
@@ -48,21 +51,28 @@ def get_by_id(call: APICall):
|
||||
@endpoint("queues.get_all_ex", min_version="2.4")
|
||||
def get_all_ex(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=call.identity.company, query_dict=call.data
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
call.result.data = {"queues": queues}
|
||||
unescape_metadata(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@endpoint("queues.get_all", min_version="2.4")
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data)
|
||||
ret_params = {}
|
||||
Metadata.escape_query_parameters(call)
|
||||
queues = queue_bll.get_all(
|
||||
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
call.result.data = {"queues": queues}
|
||||
unescape_metadata(call, queues)
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
|
||||
@@ -75,7 +85,7 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
name=request.name,
|
||||
tags=tags,
|
||||
system_tags=system_tags,
|
||||
metadata=get_metadata_from_api(request.metadata),
|
||||
metadata=Metadata.metadata_from_api(request.metadata),
|
||||
)
|
||||
call.result.data = {"id": queue.id}
|
||||
|
||||
@@ -89,10 +99,12 @@ def create(call: APICall, company_id, request: CreateRequest):
|
||||
def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
data = call.data_model_for_partial_update
|
||||
conform_tag_fields(call, data, validate=True)
|
||||
escape_metadata(data)
|
||||
updated, fields = queue_bll.update(
|
||||
company_id=company_id, queue_id=req_model.queue, **data
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
unescape_metadata(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
@@ -113,11 +125,19 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
|
||||
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
|
||||
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
|
||||
if task:
|
||||
call.result.data = {"entry": task.to_proper_dict()}
|
||||
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
|
||||
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
|
||||
entry = queue_bll.get_next_task(
|
||||
company_id=company_id, queue_id=req_model.queue
|
||||
)
|
||||
if entry:
|
||||
data = {"entry": entry.to_proper_dict()}
|
||||
if req_model.get_task_info:
|
||||
task = Task.objects(id=entry.task).first()
|
||||
if task:
|
||||
data["task_info"] = {"company": task.company, "user": task.user}
|
||||
|
||||
call.result.data = data
|
||||
|
||||
|
||||
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
|
||||
@@ -237,21 +257,19 @@ def get_queue_metrics(
|
||||
|
||||
@endpoint("queues.add_or_update_metadata", min_version="2.13")
|
||||
def add_or_update_metadata(
|
||||
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
|
||||
):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return {
|
||||
"updated": metadata_add_or_update(
|
||||
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
|
||||
"updated": Metadata.edit_metadata(
|
||||
queue, items=request.metadata, replace_metadata=request.replace_metadata
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("queues.delete_metadata", min_version="2.13")
|
||||
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
|
||||
queue_id = request.queue
|
||||
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
|
||||
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
|
||||
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
|
||||
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}
|
||||
|
||||
@@ -221,11 +221,15 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
|
||||
@@ -250,14 +254,16 @@ def get_all(call: APICall, company_id, _):
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
ret_params = {}
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
@@ -1050,6 +1056,8 @@ def delete(call: APICall, company_id, request: DeleteRequest):
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
)
|
||||
if deleted:
|
||||
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
|
||||
@@ -1066,6 +1074,8 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
@@ -1163,7 +1173,7 @@ def add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=request.force,
|
||||
force=True,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1179,7 +1189,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=request.force,
|
||||
force=True,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,6 @@ from datetime import datetime
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
|
||||
from apiserver.apimodels.organization import Filter
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
|
||||
@@ -24,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def get_tags_response(ret: dict) -> dict:
|
||||
def sort_tags_response(ret: dict) -> dict:
|
||||
return {field: sorted(vals) for field, vals in ret.items()}
|
||||
|
||||
|
||||
@@ -222,22 +221,38 @@ class DockerCmdBackwardsCompatibility:
|
||||
nested_set(task, cls.field, docker_cmd)
|
||||
|
||||
|
||||
def validate_metadata(metadata: Sequence[dict]):
|
||||
def escape_metadata(document: dict):
|
||||
"""
|
||||
Escape special characters in metadata keys
|
||||
"""
|
||||
metadata = document.get("metadata")
|
||||
if not metadata:
|
||||
return
|
||||
|
||||
keys = [m.get("key") for m in metadata]
|
||||
unique_keys = set(keys)
|
||||
unique_keys.discard(None)
|
||||
if len(keys) != len(set(keys)):
|
||||
raise errors.bad_request.ValidationError("Metadata keys should be unique")
|
||||
document["metadata"] = {
|
||||
ParameterKeyEscaper.escape(k): v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
|
||||
|
||||
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
|
||||
if not api_metadata:
|
||||
return api_metadata
|
||||
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
"""
|
||||
Unescape special characters in metadata keys
|
||||
"""
|
||||
if isinstance(documents, dict):
|
||||
documents = [documents]
|
||||
|
||||
metadata = [m.to_struct() for m in api_metadata]
|
||||
validate_metadata(metadata)
|
||||
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
|
||||
for doc in documents:
|
||||
if old_client and "metadata" in doc:
|
||||
doc["metadata"] = []
|
||||
continue
|
||||
|
||||
return metadata
|
||||
metadata = doc.get("metadata")
|
||||
if not metadata:
|
||||
continue
|
||||
|
||||
doc["metadata"] = {
|
||||
ParameterKeyEscaper.unescape(k): v
|
||||
for k, v in metadata.items()
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ class TestService(TestCase, TestServiceInterface):
|
||||
delete_params=delete_params,
|
||||
)
|
||||
|
||||
def setUp(self, version="1.7"):
|
||||
def setUp(self, version="999.0"):
|
||||
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
|
||||
self._deferred = []
|
||||
self._version = parse(version)
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestEntityOrdering(TestService):
|
||||
self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20)
|
||||
|
||||
field_vals = []
|
||||
page_size = 2
|
||||
page_size = 4
|
||||
num_pages = 5
|
||||
for page in range(num_pages):
|
||||
paged_tasks = self._get_page_tasks(
|
||||
|
||||
@@ -2,9 +2,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestOrganization(TestService):
|
||||
def setUp(self, version="2.12"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_get_user_companies(self):
|
||||
company = self.api.organization.get_user_companies().companies[0]
|
||||
self.assertEqual(len(company.owners), company.allocated)
|
||||
|
||||
80
apiserver/tests/automated/test_paging_and_scrolling.py
Normal file
80
apiserver/tests/automated/test_paging_and_scrolling.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import math
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestPagingAndScrolling(TestService):
|
||||
name_prefix = f"Test paging "
|
||||
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(**kwargs)
|
||||
self.task_ids = self._create_tasks()
|
||||
|
||||
def _create_tasks(self):
|
||||
tasks = [
|
||||
self._temp_task(
|
||||
name=f"{self.name_prefix}{i}",
|
||||
hyperparams={
|
||||
"test": {
|
||||
"param": {
|
||||
"section": "test",
|
||||
"name": "param",
|
||||
"type": "str",
|
||||
"value": str(i),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
for i in range(18)
|
||||
]
|
||||
return tasks
|
||||
|
||||
def test_paging(self):
|
||||
page_size = 10
|
||||
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
|
||||
start = page * page_size
|
||||
expected_size = min(page_size, len(self.task_ids) - start)
|
||||
tasks = self._get_tasks(page=page, page_size=page_size,).tasks
|
||||
self.assertEqual(len(tasks), expected_size)
|
||||
for i, t in enumerate(tasks):
|
||||
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
|
||||
|
||||
def test_scrolling(self):
|
||||
page_size = 10
|
||||
scroll_id = None
|
||||
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
|
||||
start = page * page_size
|
||||
expected_size = min(page_size, len(self.task_ids) - start)
|
||||
res = self._get_tasks(size=page_size, scroll_id=scroll_id,)
|
||||
self.assertTrue(res.scroll_id)
|
||||
scroll_id = res.scroll_id
|
||||
tasks = res.tasks
|
||||
self.assertEqual(len(tasks), expected_size)
|
||||
for i, t in enumerate(tasks):
|
||||
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
|
||||
|
||||
# no more data in this scroll
|
||||
tasks = self._get_tasks(size=page_size, scroll_id=scroll_id,).tasks
|
||||
self.assertFalse(tasks)
|
||||
|
||||
# refresh brings all
|
||||
tasks = self._get_tasks(
|
||||
size=page_size, scroll_id=scroll_id, refresh_scroll=True,
|
||||
).tasks
|
||||
self.assertEqual([t.id for t in tasks], self.task_ids)
|
||||
|
||||
def _get_tasks(self, **page_params):
|
||||
return self.api.tasks.get_all_ex(
|
||||
name="^Test paging ",
|
||||
order_by=["hyperparams.test.param.value"],
|
||||
**page_params,
|
||||
)
|
||||
|
||||
def _temp_task(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
name=name,
|
||||
comment="Test task",
|
||||
type="testing",
|
||||
input=dict(view=dict()),
|
||||
**kwargs,
|
||||
)
|
||||
@@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestProjectTags(TestService):
|
||||
def setUp(self, version="2.12"):
|
||||
super().setUp(version=version)
|
||||
def test_project_own_tags(self):
|
||||
p1_tags = ["Tag 1", "Tag 2"]
|
||||
p1 = self.create_temp(
|
||||
"projects", name="Test project tags1", description="test", tags=p1_tags
|
||||
)
|
||||
p2_tags = ["Tag 1", "Tag 3"]
|
||||
p2 = self.create_temp(
|
||||
"projects",
|
||||
name="Test project tags2",
|
||||
description="test",
|
||||
tags=p2_tags,
|
||||
system_tags=["hidden"],
|
||||
)
|
||||
|
||||
def test_project_tags(self):
|
||||
res = self.api.projects.get_project_tags(projects=[p1, p2])
|
||||
self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags))
|
||||
|
||||
res = self.api.projects.get_project_tags(
|
||||
projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]}
|
||||
)
|
||||
self.assertEqual(res.tags, p1_tags)
|
||||
|
||||
def test_project_entities_tags(self):
|
||||
tags_1 = ["Test tag 1", "Test tag 2"]
|
||||
tags_2 = ["Test tag 3", "Test tag 4"]
|
||||
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
from apiserver.tests.api_client import APIClient
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestQueueAndModelMetadata(TestService):
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||
meta1 = {"test_key": {"key": "test_key", "type": "str", "value": "test_value"}}
|
||||
|
||||
def test_queue_metas(self):
|
||||
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
|
||||
@@ -26,20 +22,51 @@ class TestQueueAndModelMetadata(TestService):
|
||||
)
|
||||
|
||||
model_id = self._temp_model("TestMetadata1")
|
||||
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
|
||||
self.api.models.edit(model=model_id, metadata=self.meta1)
|
||||
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
|
||||
|
||||
def test_project_meta_query(self):
|
||||
self._temp_model("TestMetadata", metadata=self.meta1)
|
||||
project = self.temp_project(name="MetaParent")
|
||||
test_key = "test_key"
|
||||
test_key2 = "test_key2"
|
||||
test_value = "test_value"
|
||||
test_value2 = "test_value2"
|
||||
model_id = self._temp_model(
|
||||
"TestMetadata2",
|
||||
project=project,
|
||||
metadata={
|
||||
test_key: {"key": test_key, "type": "str", "value": test_value},
|
||||
test_key2: {"key": test_key2, "type": "str", "value": test_value2},
|
||||
},
|
||||
)
|
||||
res = self.api.projects.get_model_metadata_keys()
|
||||
self.assertTrue({test_key, test_key2}.issubset(set(res["keys"])))
|
||||
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
|
||||
self.assertTrue(test_key in res["keys"])
|
||||
self.assertFalse(test_key2 in res["keys"])
|
||||
|
||||
model = self.api.models.get_all_ex(
|
||||
id=[model_id], only_fields=["metadata.test_key"]
|
||||
).models[0]
|
||||
self.assertTrue(test_key in model.metadata)
|
||||
self.assertFalse(test_key2 in model.metadata)
|
||||
|
||||
res = self.api.projects.get_model_metadata_values(key=test_key)
|
||||
self.assertEqual(res.total, 1)
|
||||
self.assertEqual(res["values"], [test_value])
|
||||
|
||||
def _test_meta_operations(
|
||||
self, service: APIClient.Service, entity: str, _id: str,
|
||||
):
|
||||
assert_meta = partial(self._assertMeta, service=service, entity=entity)
|
||||
assert_meta(_id=_id, meta=self.meta1)
|
||||
|
||||
meta2 = [
|
||||
{"key": "test1", "type": "str", "value": "data1"},
|
||||
{"key": "test2", "type": "str", "value": "data2"},
|
||||
{"key": "test3", "type": "str", "value": "data3"},
|
||||
]
|
||||
meta2 = {
|
||||
"test1": {"key": "test1", "type": "str", "value": "data1"},
|
||||
"test2": {"key": "test2", "type": "str", "value": "data2"},
|
||||
"test3": {"key": "test3", "type": "str", "value": "data3"},
|
||||
}
|
||||
service.update(**{entity: _id, "metadata": meta2})
|
||||
assert_meta(_id=_id, meta=meta2)
|
||||
|
||||
@@ -51,16 +78,17 @@ class TestQueueAndModelMetadata(TestService):
|
||||
]
|
||||
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=[meta2[0], *updates])
|
||||
assert_meta(_id=_id, meta={**meta2, **{u["key"]: u for u in updates}})
|
||||
|
||||
res = service.delete_metadata(
|
||||
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
|
||||
)
|
||||
self.assertEqual(res.updated, 1)
|
||||
assert_meta(_id=_id, meta=meta2[:1])
|
||||
# noinspection PyTypeChecker
|
||||
assert_meta(_id=_id, meta=dict(list(meta2.items())[:1]))
|
||||
|
||||
def _assertMeta(
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
|
||||
self, service: APIClient.Service, entity: str, _id: str, meta: dict
|
||||
):
|
||||
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
|
||||
self.assertEqual(res.metadata, meta)
|
||||
@@ -72,3 +100,12 @@ class TestQueueAndModelMetadata(TestService):
|
||||
return self.create_temp(
|
||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||
)
|
||||
|
||||
def temp_project(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
name="Test models meta",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
||||
@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestSubProjects(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_project_aggregations(self):
|
||||
"""This test requires user with user_auth_only... credentials in db"""
|
||||
user2_client = APIClient(
|
||||
@@ -202,7 +199,10 @@ class TestSubProjects(TestService):
|
||||
res1 = next(p for p in res if p.id == project1)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0)
|
||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||
self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
|
||||
self.assertEqual(
|
||||
{sp.name for sp in res1.sub_projects},
|
||||
{
|
||||
@@ -214,7 +214,10 @@ class TestSubProjects(TestService):
|
||||
res2 = next(p for p in res if p.id == project2)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
|
||||
self.assertEqual(res2.sub_projects, [])
|
||||
|
||||
def _run_tasks(self, *tasks):
|
||||
|
||||
@@ -133,6 +133,32 @@ class TestTags(TestService):
|
||||
).models
|
||||
self.assertFound(model_id, [], models)
|
||||
|
||||
def testQueueTags(self):
|
||||
q_id = self._temp_queue(system_tags=["default"])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).queues
|
||||
self.assertFound(q_id, ["default"], queues)
|
||||
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertNotFound(q_id, queues)
|
||||
|
||||
self.api.queues.update(queue=q_id, system_tags=[])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertFound(q_id, [], queues)
|
||||
|
||||
# test default queue
|
||||
queues = self.api.queues.get_all(system_tags=["default"]).queues
|
||||
if queues:
|
||||
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
|
||||
else:
|
||||
self.api.queues.update(queue=q_id, system_tags=["default"])
|
||||
self.assertEqual(q_id, self.api.queues.get_default().id)
|
||||
|
||||
def testTaskTags(self):
|
||||
task_id = self._temp_task(
|
||||
name="Test tags", system_tags=["active"]
|
||||
@@ -169,35 +195,11 @@ class TestTags(TestService):
|
||||
task = self.api.tasks.get_by_id(task=task_id).task
|
||||
self.assertEqual(task.status, "stopped")
|
||||
|
||||
def testQueueTags(self):
|
||||
q_id = self._temp_queue(system_tags=["default"])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["default"]
|
||||
).queues
|
||||
self.assertFound(q_id, ["default"], queues)
|
||||
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertNotFound(q_id, queues)
|
||||
|
||||
self.api.queues.update(queue=q_id, system_tags=[])
|
||||
queues = self.api.queues.get_all_ex(
|
||||
name="Test tags", system_tags=["-default"]
|
||||
).queues
|
||||
self.assertFound(q_id, [], queues)
|
||||
|
||||
# test default queue
|
||||
queues = self.api.queues.get_all(system_tags=["default"]).queues
|
||||
if queues:
|
||||
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
|
||||
else:
|
||||
self.api.queues.update(queue=q_id, system_tags=["default"])
|
||||
self.assertEqual(q_id, self.api.queues.get_default().id)
|
||||
|
||||
def assertProjectStats(self, project: AttrDict):
|
||||
self.assertEqual(set(project.stats.keys()), {"active"})
|
||||
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
||||
self.assertEqual(project.stats.active.completed_tasks_24h, 1)
|
||||
self.assertEqual(project.stats.active.total_tasks, 1)
|
||||
for status, count in project.stats.active.status_count.items():
|
||||
self.assertEqual(count, 1 if status == "stopped" else 0)
|
||||
|
||||
|
||||
@@ -82,14 +82,10 @@ class TestTasksArtifacts(TestService):
|
||||
|
||||
# test edit running task
|
||||
self.api.tasks.started(task=task)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
|
||||
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||
self._assertTaskArtifacts(artifacts, res)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
|
||||
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
|
||||
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
|
||||
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
|
||||
|
||||
|
||||
@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
def setUp(self, version="2.9"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||
@@ -257,6 +254,45 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(len(task_data["x"]), iterations)
|
||||
self.assertEqual(len(task_data["y"]), iterations)
|
||||
|
||||
def test_task_metric_raw(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
batch_size = 15
|
||||
metric_param = {"metric": metric, "variants": [variant]}
|
||||
res = self.api.events.scalar_metrics_iter_raw(
|
||||
task=task, batch_size=batch_size, metric=metric_param, count_total=True
|
||||
)
|
||||
self.assertEqual(res.total, len(events))
|
||||
self.assertTrue(res.scroll_id)
|
||||
res_iters = []
|
||||
res_ys = []
|
||||
calls = 0
|
||||
while res.returned or calls > 10:
|
||||
calls += 1
|
||||
res_iters.extend(res.variants[variant]["iter"])
|
||||
res_ys.extend(res.variants[variant]["y"])
|
||||
scroll_id = res.scroll_id
|
||||
res = self.api.events.scalar_metrics_iter_raw(
|
||||
task=task, metric=metric_param, scroll_id=scroll_id
|
||||
)
|
||||
|
||||
self.assertEqual(calls, len(events) // batch_size + 1)
|
||||
self.assertEqual(res_iters, [ev["iter"] for ev in events])
|
||||
self.assertEqual(res_ys, [ev["value"] for ev in events])
|
||||
|
||||
def test_task_metric_value_intervals(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from datetime import timedelta, datetime
|
||||
from threading import Thread
|
||||
from time import sleep
|
||||
from typing import Optional
|
||||
@@ -10,6 +11,7 @@ from semantic_version import Version
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_version
|
||||
from apiserver.database.model.settings import Settings
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__name__)
|
||||
@@ -17,6 +19,8 @@ log = config.logger(__name__)
|
||||
|
||||
class CheckUpdatesThread(Thread):
|
||||
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
|
||||
_lock_name = "check_updates"
|
||||
_redis = redman.connection("apiserver")
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class _VersionResponse:
|
||||
@@ -29,6 +33,19 @@ class CheckUpdatesThread(Thread):
|
||||
target=self._check_updates, daemon=True
|
||||
)
|
||||
|
||||
@property
|
||||
def update_interval(self):
|
||||
return timedelta(
|
||||
seconds=max(
|
||||
float(
|
||||
config.get(
|
||||
"apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24,
|
||||
)
|
||||
),
|
||||
60 * 5,
|
||||
)
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
if not self._enabled:
|
||||
log.info("Checking for updates is disabled")
|
||||
@@ -37,12 +54,13 @@ class CheckUpdatesThread(Thread):
|
||||
|
||||
@property
|
||||
def component_name(self) -> str:
|
||||
return config.get("apiserver.check_for_updates.component_name", "clearml-server")
|
||||
return config.get(
|
||||
"apiserver.check_for_updates.component_name", "clearml-server"
|
||||
)
|
||||
|
||||
def _check_new_version_available(self) -> Optional[_VersionResponse]:
|
||||
url = config.get(
|
||||
"apiserver.check_for_updates.url",
|
||||
"https://updates.clear.ml/updates",
|
||||
"apiserver.check_for_updates.url", "https://updates.clear.ml/updates",
|
||||
)
|
||||
|
||||
uid = Settings.get_by_key("server.uuid")
|
||||
@@ -81,34 +99,31 @@ class CheckUpdatesThread(Thread):
|
||||
)
|
||||
|
||||
def _check_updates(self):
|
||||
update_interval_sec = max(
|
||||
float(
|
||||
config.get(
|
||||
"apiserver.check_for_updates.check_interval_sec",
|
||||
60 * 60 * 24,
|
||||
)
|
||||
),
|
||||
60 * 5,
|
||||
)
|
||||
while not ThreadsManager.terminating:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
response = self._check_new_version_available()
|
||||
if response:
|
||||
if response.patch_upgrade:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
|
||||
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
|
||||
f" is recommended!"
|
||||
)
|
||||
except Exception:
|
||||
log.exception("Failed obtaining updates")
|
||||
if self._redis.set(
|
||||
self._lock_name,
|
||||
value=datetime.utcnow().isoformat(),
|
||||
ex=self.update_interval - timedelta(seconds=60),
|
||||
nx=True,
|
||||
):
|
||||
response = self._check_new_version_available()
|
||||
if response:
|
||||
if response.patch_upgrade:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
|
||||
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
|
||||
f" is recommended!"
|
||||
)
|
||||
except Exception as ex:
|
||||
log.exception("Failed obtaining updates: " + str(ex))
|
||||
|
||||
sleep(update_interval_sec)
|
||||
sleep(self.update_interval.total_seconds())
|
||||
|
||||
|
||||
check_updates_thread = CheckUpdatesThread()
|
||||
|
||||
@@ -37,15 +37,12 @@ def deep_merge(source: dict, override: dict) -> dict:
|
||||
|
||||
def nested_get(
|
||||
dictionary: Mapping,
|
||||
path: Union[Sequence[str], str],
|
||||
path: Sequence[str],
|
||||
default: Optional[Union[Any, Callable]] = None,
|
||||
) -> Any:
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
|
||||
node = dictionary
|
||||
for key in path:
|
||||
if key not in node:
|
||||
if not node or key not in node:
|
||||
if callable(default):
|
||||
return default()
|
||||
return default
|
||||
|
||||
14
apiserver/utilities/env.py
Normal file
14
apiserver/utilities/env.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from distutils.util import strtobool
|
||||
from os import getenv
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_bool(*keys: str, default: bool = None) -> Optional[bool]:
|
||||
try:
|
||||
value = next(env for env in (getenv(key) for key in keys) if env is not None)
|
||||
except StopIteration:
|
||||
return default
|
||||
try:
|
||||
return bool(strtobool(value))
|
||||
except ValueError:
|
||||
return bool(value)
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.1.0"
|
||||
__version__ = "1.3.0"
|
||||
|
||||
35
docker/build/Dockerfile
Normal file
35
docker/build/Dockerfile
Normal file
@@ -0,0 +1,35 @@
|
||||
FROM centos/nodejs-12-centos7 AS webapp
|
||||
|
||||
USER root
|
||||
WORKDIR /opt
|
||||
|
||||
RUN git clone https://github.com/allegroai/clearml-web.git
|
||||
RUN mv clearml-web /opt/open-webapp
|
||||
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
|
||||
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
|
||||
|
||||
FROM centos:7 AS staging_image
|
||||
COPY --chmod=744 docker/build/internal_files/entrypoint.sh /opt/clearml/
|
||||
COPY fileserver /opt/clearml/fileserver/
|
||||
COPY apiserver /opt/clearml/apiserver/
|
||||
|
||||
FROM centos:7
|
||||
COPY --from=staging_image /opt/clearml/ /opt/clearml/
|
||||
|
||||
COPY --chmod=744 docker/build/internal_files/final_image_preparation.sh /tmp/internal_files/
|
||||
COPY docker/build/internal_files/clearml.conf.template /tmp/internal_files/
|
||||
RUN /bin/bash -c '/tmp/internal_files/final_image_preparation.sh'
|
||||
|
||||
COPY --from=webapp /opt/open-webapp/build /usr/share/nginx/html
|
||||
|
||||
EXPOSE 8080
|
||||
EXPOSE 8008
|
||||
EXPOSE 8081
|
||||
|
||||
ARG VERSION
|
||||
ARG BUILD
|
||||
ENV CLEARML_SERVER_VERSION=${VERSION}
|
||||
ENV CLEARML_SERVER_BUILD=${BUILD}
|
||||
|
||||
WORKDIR /opt/clearml/
|
||||
ENTRYPOINT ["/opt/clearml/entrypoint.sh"]
|
||||
9
docker/build/internal_files/build_webapp.sh
Normal file
9
docker/build/internal_files/build_webapp.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
cd /opt/open-webapp/
|
||||
npm ci --unsafe-perm node-sass
|
||||
|
||||
cd /opt/open-webapp/
|
||||
npm run build
|
||||
97
docker/build/internal_files/clearml.conf.template
Normal file
97
docker/build/internal_files/clearml.conf.template
Normal file
@@ -0,0 +1,97 @@
|
||||
# For more information on configuration, see:
|
||||
# * Official English Documentation: http://nginx.org/en/docs/
|
||||
# * Official Russian Documentation: http://nginx.org/ru/docs/
|
||||
|
||||
user nginx;
|
||||
worker_processes auto;
|
||||
error_log /var/log/nginx/error.log;
|
||||
pid /run/nginx.pid;
|
||||
|
||||
# Load dynamic modules. See /usr/share/doc/nginx/README.dynamic.
|
||||
include /usr/share/nginx/modules/*.conf;
|
||||
|
||||
events {
|
||||
worker_connections 1024;
|
||||
}
|
||||
|
||||
http {
|
||||
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
|
||||
'$status $body_bytes_sent "$http_referer" '
|
||||
'"$http_user_agent" "$http_x_forwarded_for"';
|
||||
|
||||
access_log /var/log/nginx/access.log main;
|
||||
|
||||
sendfile on;
|
||||
tcp_nopush on;
|
||||
tcp_nodelay on;
|
||||
keepalive_timeout 65;
|
||||
types_hash_max_size 2048;
|
||||
|
||||
include /etc/nginx/mime.types;
|
||||
default_type application/octet-stream;
|
||||
|
||||
# Load modular configuration files from the /etc/nginx/conf.d directory.
|
||||
# See http://nginx.org/en/docs/ngx_core_module.html#include
|
||||
# for more information.
|
||||
include /etc/nginx/conf.d/*.conf;
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
listen [::]:80 default_server;
|
||||
server_name _;
|
||||
root /usr/share/nginx/html;
|
||||
proxy_http_version 1.1;
|
||||
|
||||
# comppression
|
||||
gzip on;
|
||||
gzip_comp_level 9;
|
||||
gzip_http_version 1.0;
|
||||
gzip_min_length 512;
|
||||
gzip_proxied expired no-cache no-store private auth;
|
||||
gzip_types text/plain
|
||||
text/css
|
||||
application/json
|
||||
application/javascript
|
||||
application/x-javascript
|
||||
text/xml application/xml
|
||||
application/xml+rss
|
||||
text/javascript
|
||||
application/x-font-ttf
|
||||
font/woff2
|
||||
image/svg+xml
|
||||
image/x-icon;
|
||||
|
||||
# Load configuration files for the default server block.
|
||||
include /etc/nginx/default.d/*.conf;
|
||||
|
||||
location / {
|
||||
try_files $uri$args $uri$args/ $uri index.html /index.html;
|
||||
}
|
||||
|
||||
location /version.json {
|
||||
add_header Cache-Control 'no-cache';
|
||||
}
|
||||
|
||||
location /api {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $host;
|
||||
proxy_pass ${NGINX_APISERVER_ADDR};
|
||||
rewrite /api/(.*) /$1 break;
|
||||
}
|
||||
|
||||
location /files {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $host;
|
||||
proxy_pass ${NGINX_FILESERVER_ADDR};
|
||||
rewrite /files/(.*) /$1 break;
|
||||
}
|
||||
|
||||
error_page 404 /404.html;
|
||||
location = /40x.html {
|
||||
}
|
||||
|
||||
error_page 500 502 503 504 /50x.html;
|
||||
location = /50x.html {
|
||||
}
|
||||
}
|
||||
}
|
||||
67
docker/build/internal_files/entrypoint.sh
Normal file
67
docker/build/internal_files/entrypoint.sh
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
mkdir -p /var/log/clearml
|
||||
|
||||
SERVER_TYPE=$1
|
||||
|
||||
if (( $# < 1 )) ; then
|
||||
echo "The server type was not stated. It should be either apiserver, webserver or fileserver."
|
||||
sleep 60
|
||||
exit 1
|
||||
|
||||
elif [[ ${SERVER_TYPE} == "apiserver" ]]; then
|
||||
cd /opt/clearml/
|
||||
python3 -m apiserver.apierrors_generator
|
||||
|
||||
if [[ -n $CLEARML_USE_GUNICORN ]]; then
|
||||
MAX_REQUESTS=
|
||||
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS ]]; then
|
||||
MAX_REQUESTS="--max-requests $CLEARML_GUNICORN_MAX_REQUESTS"
|
||||
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS_JITTER ]]; then
|
||||
MAX_REQUESTS="$MAX_REQUESTS --max-requests-jitter $CLEARML_GUNICORN_MAX_REQUESTS_JITTER"
|
||||
fi
|
||||
fi
|
||||
|
||||
export GUNICORN_CMD_ARGS=${CLEARML_GUNICORN_CMD_ARGS}
|
||||
|
||||
# Note: don't be tempted to "fix" $MAX_REQUESTS with "$MAX_REQUESTS" as this produces an empty arg which fucks up gunicorn
|
||||
gunicorn \
|
||||
-w "${CLEARML_GUNICORN_WORKERS:-8}" \
|
||||
-t "${CLEARML_GUNICORN_TIMEOUT:-600}" --bind="${CLEARML_GUNICORN_BIND:-0.0.0.0:8008}" \
|
||||
$MAX_REQUESTS apiserver.server:app
|
||||
else
|
||||
python3 -m apiserver.server
|
||||
fi
|
||||
|
||||
elif [[ ${SERVER_TYPE} == "webserver" ]]; then
|
||||
|
||||
if [[ "${USER_KEY}" != "" ]] || [[ "${USER_SECRET}" != "" ]] || [[ "${COMPANY_ID}" != "" ]]; then
|
||||
cat << EOF > /usr/share/nginx/html/credentials.json
|
||||
{
|
||||
"userKey": "${USER_KEY}",
|
||||
"userSecret": "${USER_SECRET}",
|
||||
"companyID": "${COMPANY_ID}"
|
||||
}
|
||||
EOF
|
||||
fi
|
||||
|
||||
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
|
||||
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
|
||||
|
||||
envsubst '${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/nginx.conf
|
||||
|
||||
#start the server
|
||||
/usr/sbin/nginx -g "daemon off;"
|
||||
|
||||
elif [[ ${SERVER_TYPE} == "fileserver" ]]; then
|
||||
cd /opt/clearml/fileserver/
|
||||
if [ "$FILESERVER_USE_GUNICORN" = true ] ; then
|
||||
gunicorn -t 600 --bind=0.0.0.0:8081 fileserver:app
|
||||
else
|
||||
python3 fileserver.py
|
||||
fi
|
||||
|
||||
else
|
||||
echo "Server type ${SERVER_TYPE} is invalid. Please choose either apiserver, webserver or fileserver."
|
||||
fi
|
||||
18
docker/build/internal_files/final_image_preparation.sh
Normal file
18
docker/build/internal_files/final_image_preparation.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
set -o errexit
|
||||
set -o nounset
|
||||
set -o pipefail
|
||||
|
||||
yum update -y
|
||||
yum install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
|
||||
yum install -y python36 python36-pip nginx gcc gcc-c++ python3-devel gettext
|
||||
yum -y upgrade
|
||||
python3 -m pip install -r /opt/clearml/fileserver/requirements.txt
|
||||
python3 -m pip install -r /opt/clearml/apiserver/requirements.txt
|
||||
mkdir -p /opt/clearml/log
|
||||
mkdir -p /opt/clearml/config
|
||||
ln -s /dev/stdout /var/log/nginx/access.log
|
||||
ln -s /dev/stderr /var/log/nginx/error.log
|
||||
mv /etc/nginx/nginx.conf /etc/nginx/nginx.conf.orig
|
||||
mv /tmp/internal_files/clearml.conf.template /etc/nginx/clearml.conf.template
|
||||
yum clean all
|
||||
@@ -19,6 +19,7 @@ services:
|
||||
environment:
|
||||
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
|
||||
CLEARML_ELASTIC_SERVICE_PORT: 9200
|
||||
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
CLEARML_MONGODB_SERVICE_HOST: mongo
|
||||
CLEARML_MONGODB_SERVICE_PORT: 27017
|
||||
CLEARML_REDIS_SERVICE_HOST: redis
|
||||
@@ -38,7 +39,8 @@ services:
|
||||
- backend
|
||||
container_name: clearml-elastic
|
||||
environment:
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
|
||||
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
bootstrap.memory_lock: "true"
|
||||
cluster.name: clearml
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
@@ -60,7 +62,7 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
|
||||
@@ -87,12 +89,12 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:3.6.5
|
||||
image: mongo:4.4.9
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
|
||||
volumes:
|
||||
- c:/opt/clearml/data/mongo/db:/data/db
|
||||
- c:/opt/clearml/data/mongo/configdb:/data/configdb
|
||||
- c:/opt/clearml/data/mongo_4/db:/data/db
|
||||
- c:/opt/clearml/data/mongo_4/configdb:/data/configdb
|
||||
|
||||
redis:
|
||||
networks:
|
||||
|
||||
@@ -19,6 +19,7 @@ services:
|
||||
environment:
|
||||
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
|
||||
CLEARML_ELASTIC_SERVICE_PORT: 9200
|
||||
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
CLEARML_MONGODB_SERVICE_HOST: mongo
|
||||
CLEARML_MONGODB_SERVICE_PORT: 27017
|
||||
CLEARML_REDIS_SERVICE_HOST: redis
|
||||
@@ -38,7 +39,8 @@ services:
|
||||
- backend
|
||||
container_name: clearml-elastic
|
||||
environment:
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
|
||||
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
bootstrap.memory_lock: "true"
|
||||
cluster.name: clearml
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
@@ -60,7 +62,7 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
|
||||
@@ -86,12 +88,12 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:3.6.5
|
||||
image: mongo:4.4.9
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
|
||||
volumes:
|
||||
- /opt/clearml/data/mongo/db:/data/db
|
||||
- /opt/clearml/data/mongo/configdb:/data/configdb
|
||||
- /opt/clearml/data/mongo_4/db:/data/db
|
||||
- /opt/clearml/data/mongo_4/configdb:/data/configdb
|
||||
|
||||
redis:
|
||||
networks:
|
||||
@@ -121,7 +123,9 @@ services:
|
||||
- backend
|
||||
container_name: clearml-agent-services
|
||||
image: allegroai/clearml-agent-services:latest
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
privileged: true
|
||||
environment:
|
||||
CLEARML_HOST_IP: ${CLEARML_HOST_IP}
|
||||
@@ -132,7 +136,7 @@ services:
|
||||
CLEARML_API_SECRET_KEY: ${CLEARML_API_SECRET_KEY:-}
|
||||
CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
|
||||
CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
|
||||
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}
|
||||
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:-">=0.17.0"}
|
||||
CLEARML_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
|
||||
@@ -142,6 +146,7 @@ services:
|
||||
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
|
||||
CLEARML_WORKER_ID: "clearml-services"
|
||||
CLEARML_AGENT_DOCKER_HOST_MOUNT: "/opt/clearml/agent:/root/.clearml"
|
||||
SHUTDOWN_IF_NO_ACCESS_KEY: 1
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- /opt/clearml/agent:/root/.clearml
|
||||
|
||||
BIN
docs/ClearML_Server_Diagram.png
Normal file
BIN
docs/ClearML_Server_Diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 155 KiB |
@@ -1,5 +1,6 @@
|
||||
""" A Simple file server for uploading and downloading files """
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
@@ -12,12 +13,15 @@ from werkzeug.exceptions import NotFound
|
||||
from werkzeug.security import safe_join
|
||||
|
||||
from config import config
|
||||
from utils import get_env_bool
|
||||
|
||||
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
|
||||
|
||||
app = Flask(__name__)
|
||||
CORS(app, **config.get("fileserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
if get_env_bool("CLEARML_COMPRESS_RESP", default=True):
|
||||
Compress(app)
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = first(
|
||||
(os.environ.get(f"{prefix}_UPLOAD_FOLDER") for prefix in ("CLEARML", "TRAINS")),
|
||||
@@ -28,6 +32,20 @@ app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get(
|
||||
)
|
||||
|
||||
|
||||
@app.before_request
|
||||
def before_request():
|
||||
if request.content_encoding:
|
||||
return f"Content encoding is not supported ({request.content_encoding})", 415
|
||||
|
||||
|
||||
@app.after_request
|
||||
def after_request(response):
|
||||
response.headers["server"] = config.get(
|
||||
"fileserver.response.headers.server", "clearml"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def upload():
|
||||
results = []
|
||||
@@ -48,8 +66,15 @@ def upload():
|
||||
@app.route("/<path:path>", methods=["GET"])
|
||||
def download(path):
|
||||
as_attachment = "download" in request.args
|
||||
|
||||
_, encoding = mimetypes.guess_type(os.path.basename(path))
|
||||
mimetype = "application/octet-stream" if encoding == "gzip" else None
|
||||
|
||||
response = send_from_directory(
|
||||
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment
|
||||
app.config["UPLOAD_FOLDER"],
|
||||
path,
|
||||
as_attachment=as_attachment,
|
||||
mimetype=mimetype,
|
||||
)
|
||||
if config.get("fileserver.download.disable_browser_caching", False):
|
||||
headers = response.headers
|
||||
@@ -63,12 +88,7 @@ def download(path):
|
||||
|
||||
@app.route("/<path:path>", methods=["DELETE"])
|
||||
def delete(path):
|
||||
real_path = Path(
|
||||
safe_join(
|
||||
os.fspath(app.config["UPLOAD_FOLDER"]),
|
||||
os.fspath(path)
|
||||
)
|
||||
)
|
||||
real_path = Path(safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path)))
|
||||
if not real_path.exists() or not real_path.is_file():
|
||||
abort(Response(f"File {str(path)} not found", 404))
|
||||
|
||||
|
||||
14
fileserver/utils.py
Normal file
14
fileserver/utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from distutils.util import strtobool
|
||||
from os import getenv
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_env_bool(*keys: str, default: bool = None) -> Optional[bool]:
|
||||
try:
|
||||
value = next(env for env in (getenv(key) for key in keys) if env is not None)
|
||||
except StopIteration:
|
||||
return default
|
||||
try:
|
||||
return bool(strtobool(value))
|
||||
except ValueError:
|
||||
return bool(value)
|
||||
Reference in New Issue
Block a user