Compare commits

66 Commits

Author SHA1 Message Date
allegroai
124684f53f Version bump to v1.3.0 2022-03-15 16:34:35 +02:00
allegroai
455b5d6758 Fix pre-populate to convert model metadata from the old format 2022-03-15 16:30:14 +02:00
allegroai
c04e2e498b Support credentials label and last_used_from fields 2022-03-15 16:29:37 +02:00
allegroai
da8a45072f Add pipelines support 2022-03-15 16:28:59 +02:00
allegroai
e1992e2054 Fix queue metrics calculation 2022-03-15 16:28:49 +02:00
allegroai
c17cedd93a Support disabling response compression in fileserver 2022-03-15 16:27:31 +02:00
allegroai
b6ad8f8790 Add support for worker auto-unregister (instead of raising an error) 2022-03-15 16:25:14 +02:00
allegroai
5acc7eebc3 Set API version to 2.17 2022-03-15 16:22:51 +02:00
allegroai
941927dfcd Return fixed fileserver header 2022-03-15 16:21:52 +02:00
allegroai
02933a9c93 Support disabling response compression
Return fixed server header
2022-03-15 16:21:14 +02:00
allegroai
e537651f29 Better support for assets upload/download 2022-03-15 16:19:52 +02:00
allegroai
af09fba755 Add metadata dict support for models, queues
Add more info for projects
2022-03-15 16:18:57 +02:00
Reuben Morais
04ea9018a3 Add missing g++ dep to server build (#111) 2022-02-21 22:14:22 +02:00
allegroai
ff7e1be24f Updated docker-compose files for v1.2.0 2022-02-14 15:27:23 +02:00
allegroai
fc4fd9e61c Version bump to v1.2.0 2022-02-14 15:26:27 +02:00
allegroai
8908c7dcf9 Update driver requirements
Refactor ES initialization
2022-02-13 20:27:12 +02:00
allegroai
b9996e2c1a Protect against multiple connects to the update server from different processes
Code cleanup
2022-02-13 20:12:12 +02:00
allegroai
afdc56f37c Use task active duration for worker task running time 2022-02-13 20:01:47 +02:00
allegroai
a25cd5dae8 Fix version conflicts when deleting task events cause an error 2022-02-13 20:01:25 +02:00
allegroai
447adb9090 Add support for credentials label
Support no_scroll in events.get_task_plots
Support better project stats
Fix Redis required on mongodb initialization
Update tests
2022-02-13 19:59:58 +02:00
allegroai
92fd98d5ad Add support for lists and nested fields in URL args and form 2022-02-13 19:52:05 +02:00
allegroai
c4001b4037 Add Redis cluster support
Fix for lru_cache usage
2022-02-13 19:48:26 +02:00
allegroai
970a32287a Add Redis password support 2022-02-13 19:37:52 +02:00
allegroai
17cd48dada Add support for override cookie domains
Support for community invitation alarms
Remove duplicate property
Add query optimizations
2022-02-13 19:35:35 +02:00
allegroai
ea3b6e955f Optimize nested_get() 2022-02-13 19:32:22 +02:00
allegroai
843450bb9b Fix add_or_update_artifacts should always be allowed on in_progress tasks
Fix delete_artifacts should always be allowed on in_progress tasks
Fix query code
2022-02-13 19:31:54 +02:00
allegroai
e149af58b1 Support for additional mata data in api call response 2022-02-13 19:30:36 +02:00
allegroai
604a38035b Add organization.update_company_name
Fix unit-tests
2022-02-13 19:29:46 +02:00
allegroai
cae38a365b Fix base query building
Fix schema
Improve events.scalar_metrics_iter_raw implementation
2022-02-13 19:28:23 +02:00
allegroai
e334246b46 Add support for project stats with children flag 2022-02-13 19:26:47 +02:00
allegroai
36e013b40c Add support for events.scalar_metrics_iter_raw 2022-02-13 19:26:03 +02:00
allegroai
f20cd6536e Add scroll support to *.get_* 2022-02-13 19:23:29 +02:00
allegroai
446bd35006 Refactor debug images response, model ORM 2022-02-13 19:21:07 +02:00
allegroai
a377a7e315 Support status_message and status_reason in tasks.delete 2022-02-13 19:20:31 +02:00
allegroai
3d046ac282 Fix project should not be merged into itself 2022-02-13 19:18:08 +02:00
allegroai
a08fa9a0e1 Add missing API Errors 2022-02-13 19:16:58 +02:00
allegroai
5856ed2836 Update Model.last_update on changes to tags and system tags 2022-02-13 19:15:37 +02:00
allegroai
d295355d99 Better logger name if called from __init__.py 2022-02-13 19:15:10 +02:00
pollfly
77350f6119 Fix link (#104) 2022-01-27 12:15:55 +02:00
Niels ten Boom
bc2c2ebbfd Add connection string functionality for MongoDB access (#102) 2022-01-08 12:07:59 +02:00
allegroai
1502e02a1a Update ES version to 7.16.2 2021-12-22 13:53:34 +02:00
allegroai
d0e2313a24 Update README regarding CVE-2021-45046 2021-12-15 15:51:18 +02:00
allegroai
d8ba1a8ea7 Fix README 2021-12-14 15:52:53 +02:00
allegroai
ca7937fc4e Fix README 2021-12-14 15:50:30 +02:00
allegroai
df89bcceef Update README with a note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:48:54 +02:00
allegroai
cfccbe05c1 Add precautionary mitigation for Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:15:11 +02:00
Théo Mathieu
e352a6a1e7 Fix elasticsearch authentication when initializing (#98) 2021-12-05 09:55:06 +02:00
Théo Mathieu
8a3d992aaf Support MongoDB SRV endpoints (#96) 2021-12-02 10:07:33 +02:00
allegroai
c37f3d8d5b Fix set() not supported in ConfigTree()
Add user/pass config support
2021-11-15 18:33:49 +02:00
allegroai
a96870e092 Add admonition in case only username or password were provided 2021-11-15 15:19:07 +02:00
allegroai
6bf1032237 Rename back to docker-compose.yml 2021-11-15 15:13:09 +02:00
Weixiao Huang
3d816c747d Add ES http_auth credentials support (#93)
Also update ES and MongoDB versions and fix nginx configuration bug

Co-authored-by: huangweixiao <huangweixiao@megvii.com>
2021-11-15 15:01:27 +02:00
Jake Henning
3f2b96266b Merge pull request #91 from valeriano-manassero/fix-dockerfile-chmod
Fix chmod for file copy in Dockerfile
2021-10-19 11:35:00 +03:00
Valeriano Manassero
22b16d12eb fix chmod for file copy 2021-10-19 09:06:51 +02:00
allegroai
c55b6f30df Add Dockerfile 2021-10-18 16:52:17 +03:00
allegroai
b7045d3d28 Fix docker-compose escaping 2021-10-18 16:49:51 +03:00
Jake Henning
e31a404885 Remove README mentions of demo server (#90) 2021-10-10 16:32:51 +03:00
Revital
643588b71a edit README mention of demo server 2021-10-10 11:27:44 +03:00
Jake Henning
a64c4d264d Merge pull request #82 from IgorKasianenko/IgorKasianenko-patch-1
Fix typo TRAINS > CLEARML for env variables in README
2021-08-12 11:47:20 +03:00
Igor Kasianenko
567780e188 Fix typo TRAINS > CLEARML for env variables 2021-08-11 16:21:02 +03:00
allegroai
1bc8529d83 Version bump 2021-08-05 16:46:29 +03:00
allegroai
6b480d7e87 Fix file server GET response for gzipped data-files contains Content-Encoding: gz header, causing clients to automatically decompress the file 2021-08-05 16:46:25 +03:00
allegroai
083fd315e9 Fix server error when running with non-migrated v0.16 ElasticSearch data 2021-08-05 16:46:05 +03:00
Jake Henning
ef20e76174 Update README with artifact.io badge 2021-07-27 19:53:41 +03:00
Jake Henning
8c8910808e Merge pull request #80 from pollfly/master
Fix README links
2021-07-27 12:58:45 +03:00
Revital
f6ad379310 link to clear.ml docs in readme, add image 2021-07-27 12:54:41 +03:00
98 changed files with 4729 additions and 2455 deletions

1
.gitignore vendored
View File

@@ -12,7 +12,6 @@ test-reports
.pytest_cache
venv
*.noseids
build
*.egg-info
.cache
.mypy_cache

View File

@@ -8,28 +8,43 @@
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/allegroai)](https://artifacthub.io/packages/search?repo=allegroai)
</div>
---
<div align="center">
**v0.16 Upgrade Notice**
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
</div>
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
due to Elasticsearchs usage of the Java Security Manager.
Follow [this procedure](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_es7_migration.html) to migrate existing data.
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
(which in any case **does not permit access to data within the Elasticsearch cluster**),
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
---
### ClearML Server
## ClearML Server
#### *Formerly known as Trains Server*
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
It allows multiple users to collaborate and manage their experiments.
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
The **ClearML Server** contains the following components:
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
## System design
![Alt Text](https://allegro.ai/clearml/docs/_images/ClearML_Server_Diagram.png)
![Alt Text](docs/ClearML_Server_Diagram.png)
The **ClearML Server** has two supported configurations:
- Single IP (domain) with the following open ports
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
Launch The **ClearML Server** in any of the following formats:
- Pre-built [AWS EC2 AMI](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_aws_ec2_ami.html)
- Pre-built [GCP Custom Image](hhttps://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_gcp.html)
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
- Pre-built Docker Image
- [Linux](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
- [macOS](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
- [Windows 10](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_win.html)
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
- Kubernetes
- [Kubernetes Helm](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes_helm.html)
- Manual [Kubernetes installation](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes.html)
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
## Connecting ClearML to your ClearML Server
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
To have the **ClearML** client use your **ClearML Server** instead:
In order to set up the **ClearML** client to work with your **ClearML Server**:
- Run the `clearml-init` command for an interactive setup.
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
## Restarting ClearML Server
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
```
1. Configure the ClearML-Agent Services (not supported on Windows installation).
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
```bash
export TRAINS_HOST_IP=server_host_ip_here
export TRAINS_AGENT_GIT_USER=git_username_here
export TRAINS_AGENT_GIT_PASS=git_password_here
export CLEARML_HOST_IP=server_host_ip_here
export CLEARML_AGENT_GIT_USER=git_username_here
export CLEARML_AGENT_GIT_PASS=git_password_here
```
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
@@ -205,12 +219,12 @@ To upgrade your existing **ClearML Server** deployment:
docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
## Community & Support
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).

View File

@@ -71,6 +71,8 @@
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]

View File

@@ -75,11 +75,17 @@ class CreateUserResponse(Base):
class Credentials(Base):
access_key = StringField(required=True)
secret_key = StringField(required=True)
label = StringField()
class CredentialsResponse(Credentials):
secret_key = StringField()
last_used = DateTimeField(default=None)
last_used_from = StringField()
class CreateCredentialsRequest(Base):
label = StringField()
class CreateCredentialsResponse(Base):

View File

@@ -2,7 +2,7 @@ from enum import auto
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.fields import StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Length, Min, Max
@@ -81,14 +81,34 @@ class LogOrderEnum(StringEnum):
desc = auto()
class LogEventsRequest(Base):
class TaskEventsRequestBase(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
class TaskEventsRequest(TaskEventsRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField()
count_total: bool = BoolField(default=True)
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
batch_size: int = IntField()
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)
@@ -115,4 +135,5 @@ class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)

View File

@@ -1,7 +1,7 @@
from typing import Sequence
from jsonmodels import validators
from jsonmodels.fields import StringField
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
@@ -21,3 +21,4 @@ class AddOrUpdateMetadata(Base):
metadata: Sequence[MetadataItem] = ListField(
[MetadataItem], validators=validators.Length(minimum_value=1)
)
replace_metadata = BoolField(default=False)

View File

@@ -30,7 +30,7 @@ class CreateModelRequest(models.Base):
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class CreateModelResponse(models.Base):

View File

@@ -0,0 +1,19 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField
class Arg(models.Base):
name = fields.StringField(required=True)
value = fields.StringField(required=True)
class StartPipelineRequest(models.Base):
task = fields.StringField(required=True)
queue = fields.StringField(required=True)
args = ListField(Arg)
class StartPipelineResponse(models.Base):
pipeline = fields.StringField(required=True)
enqueued = fields.BoolField(required=True)

View File

@@ -1,6 +1,6 @@
from jsonmodels import models, fields
from apiserver.apimodels import ListField, ActualEnumField
from apiserver.apimodels import ListField, ActualEnumField, DictField
from apiserver.apimodels.organization import TagsRequest
from apiserver.database.model import EntityVisibility
@@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True)
class GetHyperParamRequest(ProjectOrNoneRequest):
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@@ -51,8 +51,15 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
allow_public = fields.BoolField(default=True)
class ProjectModelMetadataValuesRequest(MultiProjectRequest):
key = fields.StringField(required=True)
allow_public = fields.BoolField(default=True)
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
include_stats_filter = DictField()
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)

View File

@@ -2,7 +2,7 @@ from jsonmodels import validators
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apiserver.apimodels import ListField
from apiserver.apimodels import ListField, DictField
from apiserver.apimodels.metadata import (
MetadataItem,
DeleteMetadata,
@@ -19,13 +19,18 @@ class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class QueueRequest(Base):
queue = StringField(required=True)
class GetNextTaskRequest(QueueRequest):
queue = StringField(required=True)
get_task_info = BoolField(default=False)
class DeleteRequest(QueueRequest):
force = BoolField(default=False)
@@ -34,7 +39,7 @@ class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
metadata = ListField(items_types=[MetadataItem])
metadata = DictField(value_types=[MetadataItem])
class TaskRequest(QueueRequest):

View File

@@ -2,7 +2,11 @@ from datetime import datetime
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
from apiserver.apimodels.auth import (
GetTokenResponse,
CreateUserRequest,
Credentials as CredModel,
)
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
@@ -145,7 +149,7 @@ class AuthBLL:
@classmethod
def create_credentials(
cls, user_id: str, company_id: str, role: str = None
cls, user_id: str, company_id: str, role: str = None, label: str = None,
) -> CredModel:
with translate_errors_context():
@@ -154,9 +158,11 @@ class AuthBLL:
if not user:
raise errors.bad_request.InvalidUserId(**query)
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
cred = CredModel(
access_key=get_client_id(), secret_key=get_secret_key(), label=label
)
user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key)
Credentials(key=cred.access_key, secret=cred.secret_key, label=label)
)
user.save()

View File

@@ -24,13 +24,13 @@ from apiserver.bll.event.event_common import (
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@@ -73,7 +73,7 @@ class EventBLL(object):
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
self.events_iterator = EventsIterator(es=self.es)
@property
def metrics(self) -> EventMetrics:
@@ -534,6 +534,7 @@ class EventBLL(object):
sort=None,
size: int = 500,
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
):
if scroll_id == self.empty_scroll:
@@ -611,7 +612,7 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@@ -680,6 +681,7 @@ class EventBLL(object):
sort=None,
size=500,
scroll_id=None,
no_scroll=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@@ -740,7 +742,7 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)

View File

@@ -66,12 +66,19 @@ def delete_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs)
return es.delete_by_query(
index=es_index, body=body, conflicts="proceed", **kwargs
)
def get_metric_variants_condition(
metric_variants: MetricVariants,
) -> Sequence:
def count_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.count(index=es_index, body=body, **kwargs)
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
conditions = [
{
"bool": {

View File

@@ -0,0 +1,205 @@
from typing import Optional, Tuple, Sequence, Any
import attr
import jsonmodels.models
import jwt
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
MetricVariants,
get_metric_variants_condition,
count_company_events,
)
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class EventsIterator:
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_key_value: Optional[Any] = None,
metric_variants: MetricVariants = None,
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
**kwargs,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
from_key_value = kwargs.pop("from_timestamp", from_key_value)
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
event_type=event_type,
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=ScalarKey.resolve(key),
)
return res
def count_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
metric_variants: MetricVariants = None,
) -> int:
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
es_req = {
"query": query,
}
with translate_errors_context(), TimingContext("es", "count_task_events"):
es_result = count_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
return es_result["count"]
def _get_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
key: ScalarKey,
from_key_value: Optional[Any],
metric_variants: MetricVariants = None,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If from_key_field is not set then start either from latest or earliest.
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must(task_id, metric_variants)
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": query,
"sort": {key.field: "desc" if navigate_earlier else "asc"},
}
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": must + [{"term": {key.field: events[-1][key.field]}}]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)
@staticmethod
def _get_initial_query_and_must(
task_id: str, metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
must = [{"term": {"task": task_id}}]
query = {"term": {"task": task_id}}
else:
must = [
{"term": {"task": task_id}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}
return query, must
class Scroll(jsonmodels.models.Base):
def get_scroll_id(self) -> str:
return jwt.encode(
self.to_struct(),
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
).decode()
@classmethod
def from_scroll_id(cls, scroll_id: str):
try:
return cls(
**jwt.decode(
scroll_id,
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
)
)
except jwt.PyJWTError:
raise ValueError("Invalid Scroll ID")

View File

@@ -1,127 +0,0 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
)
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = EventType.task_log
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return TaskEventsResult()
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)

View File

@@ -4,6 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from typing import Any
from apiserver.utilities import extract_properties_to_lists
from apiserver.utilities.stringenum import StringEnum
from apiserver.config_repo import config
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
"""
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
def cast_value(self, value: Any) -> Any:
"""Cast value to appropriate type"""
return value
class TimestampKey(ScalarKey):
"""
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class IterKey(ScalarKey):
"""
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class ISOTimeKey(ScalarKey):
"""

View File

@@ -7,6 +7,7 @@ from apiserver.bll.task.utils import deleted_prefix
from apiserver.database.model import EntityVisibility
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task, TaskStatus
from .metadata import Metadata
class ModelBLL:

View File

@@ -0,0 +1,111 @@
from typing import Sequence, Union, Mapping
from mongoengine import Document
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem
from apiserver.database.model.base import GetMixin
from apiserver.service_repo import APICall
from apiserver.utilities.parameter_key_escaper import (
ParameterKeyEscaper,
mongoengine_safe,
)
from apiserver.config_repo import config
from apiserver.timing_context import TimingContext
log = config.logger(__file__)
class Metadata:
@staticmethod
def metadata_from_api(
api_data: Union[Mapping[str, MetadataItem], Sequence[MetadataItem]]
) -> dict:
if not api_data:
return {}
if isinstance(api_data, dict):
return {
ParameterKeyEscaper.escape(k): v.to_struct()
for k, v in api_data.items()
}
return {
ParameterKeyEscaper.escape(item.key): item.to_struct() for item in api_data
}
@classmethod
def edit_metadata(
cls,
obj: Document,
items: Sequence[MetadataItem],
replace_metadata: bool,
**more_updates,
) -> int:
with TimingContext("mongo", "edit_metadata"):
update_cmds = dict()
metadata = cls.metadata_from_api(items)
if replace_metadata:
update_cmds["set__metadata"] = metadata
else:
for key, value in metadata.items():
update_cmds[f"set__metadata__{mongoengine_safe(key)}"] = value
return obj.update(**update_cmds, **more_updates)
@classmethod
def delete_metadata(cls, obj: Document, keys: Sequence[str], **more_updates) -> int:
with TimingContext("mongo", "delete_metadata"):
return obj.update(
**{
f"unset__metadata__{ParameterKeyEscaper.escape(key)}": 1
for key in set(keys)
},
**more_updates,
)
@staticmethod
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
@classmethod
def escape_paths(cls, paths: Sequence[str]) -> Sequence[str]:
for prefix in (
"metadata.",
"-metadata.",
):
paths = [
cls._process_path(path) if path.startswith(prefix) else path
for path in paths
]
return paths
@classmethod
def escape_query_parameters(cls, call: APICall) -> dict:
if not call.data:
return call.data
keys = list(call.data)
call_data = {
safe_key: call.data[key]
for key, safe_key in zip(keys, Metadata.escape_paths(keys))
}
projection = GetMixin.get_projection(call_data)
if projection:
GetMixin.set_projection(call_data, Metadata.escape_paths(projection))
ordering = GetMixin.get_ordering(call_data)
if ordering:
GetMixin.set_ordering(call_data, Metadata.escape_paths(ordering))
return call_data

View File

@@ -1,2 +1,3 @@
from .project_bll import ProjectBLL
from .project_queries import ProjectQueries
from .sub_projects import _ids_with_children as project_ids_with_children

View File

@@ -1,6 +1,6 @@
import itertools
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timedelta
from functools import reduce
from itertools import groupby
from operator import itemgetter
@@ -14,6 +14,7 @@ from typing import (
TypeVar,
Callable,
Mapping,
Any,
)
from mongoengine import Q, Document
@@ -22,6 +23,7 @@ from apiserver import database
from apiserver.apierrors import errors
from apiserver.config_repo import config
from apiserver.database.model import EntityVisibility, AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task, TaskStatus, external_task_types
@@ -57,10 +59,14 @@ class ProjectBLL:
with TimingContext("mongo", "move_project"):
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
parent=source_id
source=source_id
)
source = Project.get(company, source_id)
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
@@ -140,7 +146,14 @@ class ProjectBLL:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
if (
new_parent
and project_id == new_parent.id
or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
@@ -193,6 +206,7 @@ class ProjectBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Create a new project.
@@ -215,7 +229,12 @@ class ProjectBLL:
created=now,
last_update=now,
)
parent = _ensure_project(company=company, user=user, name=location)
parent = _ensure_project(
company=company,
user=user,
name=location,
creation_params=parent_creation_params,
)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)
@@ -233,13 +252,14 @@ class ProjectBLL:
tags: Sequence[str] = None,
system_tags: Sequence[str] = None,
default_output_destination: str = None,
parent_creation_params: dict = None,
) -> str:
"""
Find a project named `project_name` or create a new one.
Returns project ID
"""
if not project_id and not project_name:
raise ValueError("project id or name required")
raise errors.bad_request.ValidationError("project id or name required")
if project_id:
project = Project.objects(company=company, id=project_id).only("id").first()
@@ -260,6 +280,7 @@ class ProjectBLL:
tags=tags,
system_tags=system_tags,
default_output_destination=default_output_destination,
parent_creation_params=parent_creation_params,
)
@classmethod
@@ -295,6 +316,7 @@ class ProjectBLL:
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
@classmethod
def make_projects_get_all_pipelines(
@@ -302,6 +324,7 @@ class ProjectBLL:
company_id: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
filter_: Mapping[str, Any] = None,
) -> Tuple[Sequence, Sequence]:
archived = EntityVisibility.archived.value
@@ -325,10 +348,9 @@ class ProjectBLL:
status_count_pipeline = [
# count tasks per project per status
{
"$match": {
"company": {"$in": [None, "", company_id]},
"project": {"$in": project_ids},
}
"$match": cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
)
},
ensure_valid_fields(),
{
@@ -356,6 +378,37 @@ class ProjectBLL:
},
]
def completed_after_subquery(additional_cond, time_thresh: datetime):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed after the time_thresh
"if": {
"$and": [
"$completed",
{"$gt": ["$completed", time_thresh]},
additional_cond,
]
},
"then": 1,
"else": 0,
}
}
}
def max_started_subquery(condition):
return {
"$max": {
"$cond": {
"if": condition,
"then": "$started",
"else": datetime.min,
}
}
}
def runtime_subquery(additional_cond):
return {
# the sum of
@@ -386,24 +439,36 @@ class ProjectBLL:
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
time_thresh = datetime.utcnow() - timedelta(hours=24)
for state in cls.visibility_states:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery(
{"$not": cls.archived_tasks_cond}
)
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
cond = (
cls.archived_tasks_cond
if state == EntityVisibility.archived
else {"$not": cls.archived_tasks_cond}
)
group_step[state.value] = runtime_subquery(cond)
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
cond, time_thresh=time_thresh
)
group_step[f"{state.value}_max_task_started"] = max_started_subquery(cond)
def get_state_filter() -> dict:
if not specific_state:
return {}
if specific_state == EntityVisibility.archived:
return {"system_tags": {"$eq": EntityVisibility.archived.value}}
return {"system_tags": {"$ne": EntityVisibility.archived.value}}
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"company": {"$in": [None, "", company_id]},
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
**cls.get_match_conditions(
company=company_id, project_ids=project_ids, filter_=filter_
),
**get_state_filter(),
}
},
ensure_valid_fields(),
@@ -445,11 +510,17 @@ class ProjectBLL:
company: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
filter_: Mapping[str, Any] = None,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
child_projects = (
_get_sub_projects(project_ids, _only=("id", "name"))
if include_children
else {}
)
project_ids_with_children = set(project_ids) | {
c.id for c in itertools.chain.from_iterable(child_projects.values())
}
@@ -457,6 +528,7 @@ class ProjectBLL:
company,
project_ids=list(project_ids_with_children),
specific_state=specific_state,
filter_=filter_,
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
@@ -483,8 +555,8 @@ class ProjectBLL:
) -> Dict[str, dict]:
return {
section: {
status: nested_get(a, (section, status), 0)
+ nested_get(b, (section, status), 0)
status: nested_get(a, (section, status), default=0)
+ nested_get(b, (section, status), default=0)
for status in set(a.get(section, {})) | set(b.get(section, {}))
}
for section in set(a) | set(b)
@@ -507,6 +579,8 @@ class ProjectBLL:
) -> Dict[str, dict]:
return {
section: a.get(section, 0) + b.get(section, 0)
if not section.endswith("max_task_started")
else max(a.get(section) or datetime.min, b.get(section) or datetime.min)
for section in set(a) | set(b)
}
@@ -518,15 +592,30 @@ class ProjectBLL:
)
def get_status_counts(project_id, section):
project_runtime = runtime.get(project_id, {})
project_section_statuses = nested_get(
status_count, (project_id, section), default=default_counts
)
def get_time_or_none(value):
return value if value != datetime.min else None
return {
"total_runtime": nested_get(runtime, (project_id, section), 0),
"status_count": nested_get(
status_count, (project_id, section), default_counts
"status_count": project_section_statuses,
"total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0),
"completed_tasks_24h": project_runtime.get(
f"{section}_recently_completed", 0
),
"last_task_run": get_time_or_none(
project_runtime.get(f"{section}_max_task_started", datetime.min)
),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
s
for s in cls.visibility_states
if not specific_state or specific_state == s
]
stats = {
@@ -575,6 +664,30 @@ class ProjectBLL:
return res
@classmethod
def get_project_tags(
cls,
company_id: str,
include_system: bool,
projects: Sequence[str] = None,
filter_: Dict[str, Sequence[str]] = None,
) -> Tuple[Sequence[str], Sequence[str]]:
with TimingContext("mongo", "get_tags_from_db"):
query = Q(company=company_id)
if filter_:
for name, vals in filter_.items():
if vals:
query &= GetMixin.get_list_field_query(name, vals)
if projects:
query &= Q(id__in=_ids_with_children(projects))
tags = Project.objects(query).distinct("tags")
system_tags = (
Project.objects(query).distinct("system_tags") if include_system else []
)
return tags, system_tags
@classmethod
def get_projects_with_active_user(
cls,
@@ -631,6 +744,7 @@ class ProjectBLL:
if include_subprojects:
projects = _ids_with_children(projects)
query &= Q(project__in=projects)
if state == EntityVisibility.archived:
query &= Q(system_tags__in=[EntityVisibility.archived.value])
elif state == EntityVisibility.active:
@@ -658,6 +772,7 @@ class ProjectBLL:
if project_ids:
project_ids = _ids_with_children(project_ids)
query &= Q(project__in=project_ids)
res = Task.objects(query).distinct(field="type")
return set(res).intersection(external_task_types)
@@ -673,8 +788,35 @@ class ProjectBLL:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")
@staticmethod
def get_match_conditions(
company: str, project_ids: Sequence[str], filter_: Mapping[str, Any]
):
conditions = {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
if not filter_:
return conditions
for field in ("tags", "system_tags"):
field_filter = filter_.get(field)
if not field_filter:
continue
if not isinstance(field_filter, list) or not all(
isinstance(t, str) for t in field_filter
):
raise errors.bad_request.ValidationError(
f"List of strings expected for the field: {field}"
)
conditions[field] = {"$in": field_filter}
return conditions
@classmethod
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
def calc_own_contents(
cls, company: str, project_ids: Sequence[str], filter_: Mapping[str, Any] = None
) -> Dict[str, dict]:
"""
Returns the amount of task/models per requested project
Use separate aggregation calls on Task/Model instead of lookup
@@ -685,35 +827,21 @@ class ProjectBLL:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company]},
"project": {"$in": project_ids},
}
"$match": cls.get_match_conditions(
company=company, project_ids=project_ids, filter_=filter_
)
},
{
"$project": {"project": 1}
},
{
"$group": {
"_id": "$project",
"count": {"$sum": 1},
}
}
{"$project": {"project": 1}},
{"$group": {"_id": "$project", "count": {"$sum": 1}}},
]
def get_agrregate_res(cls_: Type[AttributedDocument]) -> dict:
return {
data["_id"]: data["count"]
for data in cls_.aggregate(pipeline)
}
return {data["_id"]: data["count"] for data in cls_.aggregate(pipeline)}
with TimingContext("mongo", "get_security_groups"):
tasks = get_agrregate_res(Task)
models = get_agrregate_res(Model)
return {
pid: {
"own_tasks": tasks.get(pid, 0),
"own_models": models.get(pid, 0),
}
pid: {"own_tasks": tasks.get(pid, 0), "own_models": models.get(pid, 0)}
for pid in project_ids
}

View File

@@ -0,0 +1,370 @@
import json
from collections import OrderedDict
from datetime import datetime
from typing import (
Sequence,
Optional,
Tuple,
)
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.database.model.model import Model
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .sub_projects import _ids_with_children
log = config.logger(__file__)
class ProjectQueries:
def __init__(self, redis=None):
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
"""
If passed projects is None means top level projects
If passed projects is empty means no project filtering
"""
if include_subprojects:
if not project_ids:
return {}
project_ids = _ids_with_children(project_ids)
if project_ids is None:
project_ids = [None]
if not project_ids:
return {}
return {"project": {"$in": project_ids}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": [None, "", company_id]}}
return {"company": company_id}
@classmethod
def get_aggregated_project_parameters(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"hyperparams": {"$exists": True, "$gt": {}},
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "section"))
),
"name": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "name"))
),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
ParamValues = Tuple[int, Sequence[str]]
def _get_cached_param_values(
self, key: str, last_update: datetime, allowed_delta_sec=0
) -> Optional[ParamValues]:
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update).total_seconds() <= allowed_delta_sec:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving params cached values: {str(ex)}")
def get_task_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key,
last_update=last_update,
allowed_delta_sec=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
),
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@classmethod
def get_model_metadata_keys(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"metadata": {"$exists": True, "$gt": {}},
}
},
{"$project": {"metadata": {"$objectToArray": "$metadata"}}},
{"$unwind": "$metadata"},
{"$group": {"_id": "$metadata.k"}},
{"$sort": {"_id": 1}},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Model.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r.get("_id"))
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
def get_model_metadata_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
key: str,
include_subprojects: bool,
allow_public: bool = True,
) -> ParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"metadata.{ParameterKeyEscaper.escape(key)}"
last_updated_model = (
Model.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_model:
return 0, []
redis_key = f"modelmetadata_values_{company_id}_{'_'.join(project_ids)}_{key}_{allow_public}"
last_update = last_updated_model.last_update or datetime.utcnow()
cached_res = self._get_cached_param_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.models.metadata_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Model.aggregate(pipeline, collation=Model._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.models.metadata_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values

View File

@@ -25,7 +25,9 @@ def _validate_project_name(project_name: str) -> Tuple[str, str]:
return name_separator.join(name_parts), name_separator.join(name_parts[:-1])
def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
def _ensure_project(
company: str, user: str, name: str, creation_params: dict = None
) -> Optional[Project]:
"""
Makes sure that the project with the given name exists
If needed auto-create the project and all the missing projects in the path to it
@@ -48,9 +50,9 @@ def _ensure_project(company: str, user: str, name: str) -> Optional[Project]:
created=now,
last_update=now,
name=name,
description="",
**(creation_params or dict(description="")),
)
parent = _ensure_project(company, user, location)
parent = _ensure_project(company, user, location, creation_params=creation_params)
_save_under_parent(project=project, parent=parent)
if parent:
parent.update(last_update=now)

View File

@@ -32,7 +32,7 @@ class QueueBLL(object):
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
metadata: Optional[Sequence[dict]] = None,
metadata: Optional[dict] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
@@ -126,14 +126,27 @@ class QueueBLL(object):
)
queue.delete()
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_all(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
return Queue.get_many(
company=company_id, parameters=query_dict, query_dict=query_dict
company=company_id,
parameters=query_dict,
query_dict=query_dict,
ret_params=ret_params,
)
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_queue_infos(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
"""
@@ -143,6 +156,7 @@ class QueueBLL(object):
company=company_id,
query_dict=query_dict,
override_projection=projection,
ret_params=ret_params,
)
queue_workers = defaultdict(list)
@@ -173,13 +187,15 @@ class QueueBLL(object):
if any(e.task == task_id for e in queue.entries):
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
entry = Entry(added=datetime.utcnow(), task=task_id)
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
push__entries=entry, last_update=datetime.utcnow(), upsert=False
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
if not res:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
@@ -219,7 +235,6 @@ class QueueBLL(object):
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
entries_to_remove = [e for e in queue.entries if e.task == task_id]
query = dict(id=queue_id, company=company_id)
@@ -227,6 +242,9 @@ class QueueBLL(object):
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
queue.reload()
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
return len(entries_to_remove) if res else 0
def reposition_task(

View File

@@ -49,6 +49,21 @@ class RedisCacheManager(Generic[T]):
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
def get_or_create_state_core(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
) -> T:
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
return state
@contextmanager
def get_or_create_state(
self,
@@ -66,12 +81,9 @@ class RedisCacheManager(Generic[T]):
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
state = self.get_or_create_state_core(
state_id=state_id, init_state=init_state, validate_state=validate_state
)
try:
yield state

View File

@@ -1,9 +1,6 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
@@ -14,7 +11,7 @@ from apiserver.apierrors import errors
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@@ -350,54 +346,6 @@ class TaskBLL:
if validate_models:
cls.validate_input_models(task)
@staticmethod
def get_unique_metric_variants(
company_id, project_ids: Sequence[str], include_subprojects: bool
):
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]}, **project_constraint,
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
def set_last_update(
task_ids: Collection[str],
@@ -494,173 +442,6 @@ class TaskBLL:
**extra_updates,
)
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**project_constraint,
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
company_constraint = {"company": {"$in": [None, "", company_id]}}
else:
company_constraint = {"company": company_id}
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,

View File

@@ -162,6 +162,8 @@ def delete_task(
force: bool,
return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
@@ -179,6 +181,17 @@ def delete_task(
current=task.status,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
status_message=status_message,
status_reason=status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task(
task,
force=force,
@@ -354,6 +367,7 @@ def stop_task(
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
requires_write_access=True,
)

View File

@@ -113,7 +113,7 @@ class WorkerBLL:
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res:
if not res and not config.get("apiserver.workers.auto_unregister", False):
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
@@ -258,7 +258,7 @@ class WorkerBLL:
tasks_info = {
task.id: task
for task in Task.objects(id__in=task_ids).only(
"name", "started", "last_iteration"
"name", "started", "last_iteration", "active_duration"
)
}
@@ -283,11 +283,7 @@ class WorkerBLL:
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (
int((datetime.utcnow() - task.started).total_seconds() * 1000)
if task.started
else 0
)
worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration
update_queue_entries(worker.queue)

View File

@@ -79,6 +79,8 @@ class BasicConfig:
def logger(self, name: str) -> logging.Logger:
if Path(name).is_file():
name = Path(name).stem
if name == "__init__" and Path(name).parent.stem:
name = Path(name).parent.stem
path = ".".join((self.prefix, name))
return logging.getLogger(path)

View File

@@ -79,6 +79,11 @@
max_age: 99999999999
}
# provide a cookie domain override per company
# cookies_domain_override {
# <company-id>: <domain>
# }
# # A list of fixed users
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
# fixed_users {
@@ -107,6 +112,8 @@
workers {
# Auto-register unknown workers on status reports and other calls
auto_register: true
# Assume unknow workers have unregistered (i.e. do not raise unregistered error)
auto_unregister: true
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600

View File

@@ -0,0 +1,4 @@
max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600

View File

@@ -17,6 +17,10 @@ events_retrieval {
# the max amount of variants to aggregate on
max_variants_count: 100
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
}
# if set then plot str will be checked for the valid json on plot add

View File

@@ -0,0 +1,7 @@
metadata_values {
# maximal amount of distinct model values to retrieve
max_count: 100
# cache ttl sec
cache_ttl_sec: 86400
}

View File

@@ -28,6 +28,8 @@ OVERRIDE_PORT_ENV_KEY = (
"MONGODB_SERVICE_PORT",
)
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
class DatabaseEntry(models.Base):
host = StringField(required=True)
@@ -47,13 +49,17 @@ class DatabaseFactory:
missing = []
log.info("Initializing database connections")
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
if override_port:
log.info(f"Using override mongodb port {override_port}")
if override_connection_string:
log.info(f"Using override mongodb connection string {override_connection_string}")
else:
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
if override_port:
log.info(f"Using override mongodb port {override_port}")
for key, alias in get_items(Database).items():
if key not in db_entries:
@@ -62,11 +68,13 @@ class DatabaseFactory:
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
if override_connection_string:
entry.host = override_connection_string
else:
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
try:
entry.validate()

View File

@@ -48,7 +48,9 @@ class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
label = StringField()
last_used = DateTimeField()
last_used_from = StringField()
class User(DbModelMixin, AuthDocument):

View File

@@ -1,26 +1,41 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
from functools import reduce, partial
from typing import (
Collection,
Sequence,
Union,
Optional,
Type,
Tuple,
Mapping,
Any,
Callable,
Dict,
List,
)
from boltons.iterutils import first, bucketize, partition
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from mongoengine import Q, Document, ListField, StringField, IntField
from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors
from apiserver.apierrors.base import BaseError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.props import PropsMixin
from apiserver.database.query import RegexQ, RegexWrapper
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
from apiserver.database.utils import (
get_company_or_none_constraint,
get_fields_choices,
field_does_not_exist,
field_exists,
)
from apiserver.redis_manager import redman
log = config.logger("dbmodel")
@@ -70,6 +85,9 @@ class GetMixin(PropsMixin):
_ordering_key = "order_by"
_search_text_key = "search_text"
_start_key = "start"
_size_key = "size"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
("_any_", "_or_"): lambda a, b: a | b,
@@ -77,6 +95,7 @@ class GetMixin(PropsMixin):
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {}
class QueryParameterOptions(object):
@@ -103,46 +122,106 @@ class GetMixin(PropsMixin):
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
_default = "in"
default_mongo_op = "in"
_ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
}
_next = _default
_sticky = False
def __init__(self, legacy=False):
self._legacy = legacy
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
def key(self, v) -> Optional[str]:
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"][0]
elif v.startswith(self.op_prefix):
self._next, self._sticky = self._ops.get(
v[len(self.op_prefix) :], (self._default, self._sticky)
)
self.allow_empty = True
return None
next_ = self._next
if not self._sticky:
self._next = self._default
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op:
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
return next_
return self.default_mongo_op
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
get_all_query_options = QueryParameterOptions()
class GetManyScrollState(ProperDictMixin, Document):
meta = {"db_alias": Database.backend, "strict": False}
id = StringField(primary_key=True)
position = IntField(default=0)
_cache_manager = None
@classmethod
def get_cache_manager(cls):
if not cls._cache_manager:
cls._cache_manager = RedisCacheManager(
state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
),
)
return cls._cache_manager
@classmethod
def get(
cls: Union["GetMixin", Document],
@@ -241,7 +320,9 @@ class GetMixin(PropsMixin):
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
used instead of Q. Otherwise we can and up with some combination that is not processed according to
RegexQ rules
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
@@ -278,7 +359,7 @@ class GetMixin(PropsMixin):
patterns=opts.fields or [], parameters=parameters
).items():
if "._" in field or "_." in field:
query &= Q(__raw__={field: data})
query &= RegexQ(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
@@ -312,22 +393,31 @@ class GetMixin(PropsMixin):
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
data.fields,
Q()
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
query = query & q
return query & RegexQ(**dict_query)
@classmethod
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Return a range query for the provided field. The data should contain min and max values
Both intervals are included. For open range queries either min or max can be None
@@ -351,14 +441,14 @@ class GetMixin(PropsMixin):
if max_val is not None:
query[f"{mongoengine_field}__lte"] = max_val
q = Q(**query)
q = RegexQ(**query)
if min_val is None:
q |= Q(**{mongoengine_field: None})
q |= RegexQ(**{mongoengine_field: None})
return q
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
@@ -370,30 +460,31 @@ class GetMixin(PropsMixin):
"""
if not isinstance(data, (list, tuple)):
data = [data]
# raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
)
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
]
if not allow_empty:
if not queries:
q = RegexQ()
else:
q = RegexQCombination(operation=global_op, children=queries)
if not helper.allow_empty:
return q
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
| RegexQ(**{f"{mongoengine_field}__exists": False})
| RegexQ(**{mongoengine_field: []})
| RegexQ(**{mongoengine_field: None})
)
@classmethod
@@ -421,27 +512,41 @@ class GetMixin(PropsMixin):
return order_by
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
):
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
if parameters is None:
parameters = {}
default_page = parameters.get("page", default_page)
if default_page is None:
return None, None
default_page_size = parameters.get("page_size", default_page_size)
if not default_page_size:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
elif default_page < 0:
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
"""
Validate and extract paging info from from the provided dictionary. Supports default values.
If page is specified then it should be non-negative, if page size is specified then it should be positive
If page size is specified and page is not then 0 page is assumed
If page is specified then page size should be specified too
"""
parameters = parameters or {}
start = parameters.get(cls._start_key)
if start is not None:
return start, cls.validate_scroll_size(parameters)
max_page_size = config.get("services._mongo.max_page_size", 500)
page = parameters.get("page", default_page)
if page is not None and page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
elif default_page_size < 1:
page_size = parameters.get("page_size", default_page_size or max_page_size)
if page_size is not None and page_size < 1:
raise errors.bad_request.ValidationError(
"page_size must be >0", field="page_size"
)
return default_page, default_page_size
if page_size is not None:
page = page or 0
page_size = min(page_size, max_page_size)
return page * page_size, page_size
if page is not None:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
return None, None
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
@@ -485,6 +590,54 @@ class GetMixin(PropsMixin):
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
@classmethod
def validate_scroll_size(cls, query_dict: dict) -> int:
size = query_dict.get(cls._size_key)
if not size or not isinstance(size, int) or size < 1:
raise errors.bad_request.ValidationError(
"Integer size parameter greater than 1 should be provided when working with scroll"
)
return size
@classmethod
def get_data_with_scroll_support(
cls,
query_dict: dict,
data_getter: Callable[[], Sequence[dict]],
ret_params: dict,
) -> Sequence[dict]:
"""
Retrieves the data by calling the provided data_getter api
If scroll parameters are specified then put the query_dict 'start' parameter to the last
scroll position and continue retrievals from that position
If refresh_scroll is requested then bring once more the data from the beginning
till the current scroll position
In the end the scroll position is updated and accumulated frames are returned
"""
query_dict = query_dict or {}
state: Optional[cls.GetManyScrollState] = None
if "scroll_id" in query_dict:
size = cls.validate_scroll_size(query_dict)
state = cls.get_cache_manager().get_or_create_state_core(
query_dict.get("scroll_id")
)
if query_dict.get("refresh_scroll"):
query_dict[cls._size_key] = max(state.position, size)
state.position = 0
query_dict[cls._start_key] = state.position
data = data_getter()
if cls._start_key in query_dict:
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
if state:
state.position = query_dict[cls._start_key]
cls.get_cache_manager().set_state(state)
if ret_params is not None:
ret_params["scroll_id"] = state.id
return data
@classmethod
def get_many_with_join(
cls,
@@ -495,6 +648,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
@@ -530,6 +684,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
ret_params=ret_params,
)
def projection_func(doc_type, projection, ids):
@@ -560,6 +715,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
@@ -605,12 +761,16 @@ class GetMixin(PropsMixin):
_query = (q & query) if query else q
if return_dicts:
return cls._get_many_override_none_ordering(
data_getter = partial(
cls._get_many_override_none_ordering,
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
return cls.get_data_with_scroll_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
)
return cls._get_many_no_company(
query=_query,
@@ -662,7 +822,7 @@ class GetMixin(PropsMixin):
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
if order_by and not override_collation:
override_collation = cls._get_collation_override(order_by[0])
page, page_size = cls.validate_paging(parameters=parameters)
start, size = cls.validate_paging(parameters=parameters)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
@@ -683,9 +843,9 @@ class GetMixin(PropsMixin):
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
if start is not None and size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
qs = qs.skip(start).limit(size)
return qs
@@ -746,7 +906,10 @@ class GetMixin(PropsMixin):
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
start, size = cls.validate_paging(parameters=parameters)
if size is not None and size <= 0:
return []
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
@@ -778,25 +941,28 @@ class GetMixin(PropsMixin):
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
start = page * page_size
for qs in query_sets:
qs_size = qs.count()
if qs_size < start:
start -= qs_size
continue
last_set = len(query_sets) - 1
for i, qs in enumerate(query_sets):
last_size = len(ret)
ret.extend(
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
for obj in (qs.skip(start) if start else qs).limit(size)
)
if len(ret) >= page_size:
added = len(ret) - last_size
if added > 0:
start = 0
size = max(0, size - added)
elif i != last_set:
start -= min(start, qs.count())
if size <= 0:
break
start = 0
page_size -= len(ret)
return ret

View File

@@ -1,11 +1,8 @@
from typing import Sequence
from mongoengine import (
Document,
StringField,
DateTimeField,
BooleanField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
@@ -13,18 +10,21 @@ from apiserver.database.fields import (
StrippedStringField,
SafeDictField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import DbModelMixin
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.company import Company
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.user import User
class Model(DbModelMixin, Document):
class Model(AttributedDocument):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -33,8 +33,6 @@ class Model(DbModelMixin, Document):
"project",
"task",
"last_update",
"metadata.key",
"metadata.type",
("company", "framework"),
("company", "name"),
("company", "user"),
@@ -66,6 +64,7 @@ class Model(DbModelMixin, Document):
"project",
"task",
"parent",
"metadata.*",
),
datetime_fields=("last_update",),
)
@@ -73,8 +72,6 @@ class Model(DbModelMixin, Document):
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field="Model", required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
@@ -91,6 +88,6 @@ class Model(DbModelMixin, Document):
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@@ -11,6 +11,7 @@ class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
range_fields=("last_update",),
)
meta = {

View File

@@ -1,16 +1,19 @@
from typing import Sequence
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
EmbeddedDocumentField,
)
from apiserver.database import Database, strict
from apiserver.database.fields import StrippedStringField, SafeSortedListField
from apiserver.database.model import DbModelMixin
from apiserver.database.fields import (
StrippedStringField,
SafeSortedListField,
SafeMapField,
)
from apiserver.database.model import DbModelMixin, AttributedDocument
from apiserver.database.model.base import ProperDictMixin, GetMixin
from apiserver.database.model.company import Company
from apiserver.database.model.metadata import MetadataItem
@@ -19,23 +22,25 @@ from apiserver.database.model.task.task import Task
class Entry(EmbeddedDocument, ProperDictMixin):
""" Entry representing a task waiting in the queue """
task = StringField(required=True, reference_field=Task)
''' Task ID '''
""" Task ID """
added = DateTimeField(required=True)
''' Added to the queue '''
""" Added to the queue """
class Queue(DbModelMixin, Document):
_field_collation_overrides = {
"metadata.": AttributedDocument._numeric_locale,
}
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name",),
list_fields=("tags", "system_tags", "id"),
pattern_fields=("name",), list_fields=("tags", "system_tags", "id", "metadata.*"),
)
meta = {
'db_alias': Database.backend,
'strict': strict,
"indexes": ["metadata.key", "metadata.type"],
"db_alias": Database.backend,
"strict": strict,
}
id = StringField(primary_key=True)
@@ -44,10 +49,12 @@ class Queue(DbModelMixin, Document):
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = SafeSortedListField(StringField(required=True), default=list, user_set_allowed=True)
tags = SafeSortedListField(
StringField(required=True), default=list, user_set_allowed=True
)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()
metadata: Sequence[MetadataItem] = EmbeddedDocumentListField(
MetadataItem, default=list, user_set_allowed=True
metadata = SafeMapField(
field=EmbeddedDocumentField(MetadataItem), user_set_allowed=True
)

View File

@@ -159,11 +159,10 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"execution.parameters.": AttributedDocument._numeric_locale,
"last_metrics.": AttributedDocument._numeric_locale,
"hyperparams.": AttributedDocument._numeric_locale,
}
meta = {
@@ -184,7 +183,10 @@ class Task(AttributedDocument):
("company", "type", "system_tags", "status"),
("company", "project", "type", "system_tags", "status"),
("status", "last_update"), # for maintenance tasks
{"fields": ["company", "project"], "collation": _numeric_locale},
{
"fields": ["company", "project"],
"collation": AttributedDocument._numeric_locale,
},
{
"name": "%s.task.main_text_index" % Database.backend,
"fields": [

View File

@@ -4,7 +4,12 @@ from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine import (
EmbeddedDocumentField,
EmbeddedDocumentListField,
EmbeddedDocument,
Document,
)
from mongoengine.base import get_document
from apiserver.database.fields import (
@@ -25,6 +30,13 @@ class PropsMixin(object):
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
_document_classes = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, (Document, EmbeddedDocument)):
PropsMixin._document_classes[cls._class_name] = cls
@classmethod
def get_fields(cls):
if cls.__cached_fields is None:
@@ -57,8 +69,14 @@ class PropsMixin(object):
def resolve_doc(v):
if not isinstance(v, six.string_types):
return v
if v == 'self':
if v == "self":
return cls_.owner_document
doc_cls = PropsMixin._document_classes.get(v)
if doc_cls:
return doc_cls
return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()}
@@ -72,7 +90,7 @@ class PropsMixin(object):
).document_type
fields.update(
{
'.'.join((field, subfield)): doc
".".join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr
).items()
@@ -80,10 +98,10 @@ class PropsMixin(object):
)
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
return fields
@@ -94,7 +112,7 @@ class PropsMixin(object):
for depth, part in enumerate(parts):
if current_cls is None:
raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
)
try:
field_name, field = next(
@@ -103,7 +121,7 @@ class PropsMixin(object):
if k == part
)
except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth])
raise ValueError("Invalid field path %s" % parts[:depth])
translated_parts.append(part)
@@ -119,7 +137,7 @@ class PropsMixin(object):
),
):
current_cls = field.field.document_type
translated_parts.append('*')
translated_parts.append("*")
else:
current_cls = None
@@ -128,7 +146,7 @@ class PropsMixin(object):
@classmethod
def get_reference_fields(cls):
if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field')
fields = cls._get_fields_with_attr(cls, "reference_field")
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@@ -143,12 +161,12 @@ class PropsMixin(object):
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields
@classmethod
def get_dpath_translated_path(cls, path, separator='.'):
def get_dpath_translated_path(cls, path, separator="."):
if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields:

View File

@@ -5,7 +5,7 @@ Apply elasticsearch mappings to given hosts.
import argparse
import json
from pathlib import Path
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple
from elasticsearch import Elasticsearch
@@ -13,7 +13,7 @@ HERE = Path(__file__).resolve().parent
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
@@ -30,7 +30,7 @@ def apply_mappings_to_cluster(
else:
files = p.glob("**/*.json")
es = Elasticsearch(hosts=hosts, **(es_args or {}))
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
return [_send_template(f) for f in files]

View File

@@ -82,7 +82,11 @@ def check_elastic_empty() -> bool:
es_logger.addFilter(log_filter)
for retry in range(max_retries):
try:
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
es = Elasticsearch(
hosts=cluster_conf.get("hosts", None),
http_auth=es_factory.get_credentials("events", cluster_conf),
**cluster_conf.get("args", {})
)
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
@@ -109,5 +113,7 @@ def init_es_data():
log.info(f"Applying mappings to ES host: {hosts_config}")
args = cluster_conf.get("args", {})
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
http_auth = es_factory.get_credentials(name)
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
log.info(res)

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,9 +1,10 @@
from datetime import datetime
from functools import lru_cache
from os import getenv
from typing import Tuple
from typing import Tuple, Optional
from boltons.iterutils import first
from elasticsearch import Elasticsearch, Transport
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
@@ -21,6 +22,10 @@ OVERRIDE_PORT_ENV_KEY = (
"ELASTIC_SERVICE_PORT",
)
OVERRIDE_USERNAME_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_USERNAME",)
OVERRIDE_PASSWORD_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_PASSWORD",)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
log.info(f"Using override elastic host {OVERRIDE_HOST}")
@@ -29,6 +34,14 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override elastic port {OVERRIDE_PORT}")
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
if OVERRIDE_USERNAME:
log.info(f"Using override elastic username {OVERRIDE_USERNAME}")
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
if OVERRIDE_PASSWORD:
log.info("Using override elastic password ********")
_instances = {}
@@ -48,6 +61,10 @@ class InvalidClusterConfiguration(Exception):
pass
class MissingPasswordForElasticUser(Exception):
pass
class ESFactory:
@classmethod
def connect(cls, cluster_name):
@@ -65,22 +82,45 @@ class ESFactory:
if not hosts:
raise InvalidClusterConfiguration(cluster_name)
http_auth = cls.get_credentials(cluster_name)
args = cluster_config.get("args", {})
_instances[cluster_name] = Elasticsearch(
hosts=hosts, transport_class=Transport, **args
hosts=hosts, http_auth=http_auth, **args
)
return _instances[cluster_name]
@classmethod
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
if not cluster_config.get("secure", True):
return None
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
if not elastic_user:
return None
elastic_password = OVERRIDE_PASSWORD or config.get(
"secure.elastic.password", None
)
if not elastic_password:
raise MissingPasswordForElasticUser(
f"cluster={cluster_name}, username={elastic_user}"
)
return elastic_user, elastic_password
@classmethod
def get_all_cluster_names(cls):
return list(config.get("hosts.elastic"))
@classmethod
def get_override(cls, cluster_name: str) -> Tuple[str, str]:
def get_override_host(cls, cluster_name: str) -> Tuple[str, str]:
return OVERRIDE_HOST, OVERRIDE_PORT
@classmethod
@lru_cache()
def get_cluster_config(cls, cluster_name):
"""
Returns cluster config for the specified cluster path
@@ -97,7 +137,7 @@ class ESFactory:
for entry in cluster_config.get("hosts", []):
entry[key] = value
host, port = cls.get_override(cluster_name)
host, port = cls.get_override_host(cluster_name)
if host:
set_host_prop("host", host)

View File

@@ -21,6 +21,7 @@ from typing import (
Union,
Mapping,
IO,
Callable,
)
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
@@ -54,6 +55,7 @@ from apiserver.database.model.task.task import (
from apiserver.database.utils import get_options
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_get, nested_set, nested_delete
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
class PrePopulate:
@@ -744,6 +746,19 @@ class PrePopulate:
module = importlib.import_module(module_name)
return getattr(module, class_name)
@staticmethod
def _upgrade_model_data(model_data: dict) -> dict:
metadata_key = "metadata"
metadata = model_data.get(metadata_key)
if isinstance(metadata, list):
metadata = {
ParameterKeyEscaper.escape(item["key"]): item
for item in metadata
if isinstance(item, dict) and "key" in item
}
model_data[metadata_key] = metadata
return model_data
@staticmethod
def _upgrade_task_data(task_data: dict) -> dict:
"""
@@ -828,9 +843,14 @@ class PrePopulate:
print(f"Writing {cls_.__name__.lower()}s into database")
tasks = []
override_project_count = 0
data_upgrade_funcs: Mapping[Type, Callable] = {
cls.task_cls: cls._upgrade_task_data,
cls.model_cls: cls._upgrade_model_data,
}
for item in cls.json_lines(f):
if cls_ == cls.task_cls:
item = json.dumps(cls._upgrade_task_data(task_data=json.loads(item)))
upgrade_func = data_upgrade_funcs.get(cls_)
if upgrade_func:
item = json.dumps(upgrade_func(json.loads(item)))
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):

View File

@@ -0,0 +1,29 @@
from pymongo.collection import Collection
from pymongo.database import Database
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .utils import _drop_all_indices_from_collections
def _convert_metadata(db: Database, name):
collection: Collection = db[name]
metadata_field = "metadata"
query = {metadata_field: {"$exists": True, "$type": 4}}
for doc in collection.find(filter=query, projection=(metadata_field,)):
metadata = {
ParameterKeyEscaper.escape(item["key"]): item
for item in doc.get(metadata_field, [])
if isinstance(item, dict) and "key" in item
}
collection.update_one(
{"_id": doc["_id"]}, {"$set": {"metadata": metadata}},
)
def migrate_backend(db: Database):
collections = ["model", "queue"]
for name in collections:
_convert_metadata(db, name)
_drop_all_indices_from_collections(db, collections)

View File

@@ -1,10 +1,8 @@
import threading
from os import getenv
from time import sleep
from boltons.iterutils import first
from redis import StrictRedis
from redis.sentinel import Sentinel, SentinelConnectionPool
from rediscluster import RedisCluster
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
from apiserver.config_repo import config
@@ -21,6 +19,11 @@ OVERRIDE_PORT_ENV_KEY = (
"TRAINS_REDIS_SERVICE_PORT",
"REDIS_SERVICE_PORT",
)
OVERRIDE_PASSWORD_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PASSWORD",
"TRAINS_REDIS_SERVICE_PASSWORD",
"REDIS_SERVICE_PASSWORD",
)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
@@ -30,99 +33,7 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override redis port {OVERRIDE_PORT}")
class MyPubSubWorkerThread(threading.Thread):
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
super(MyPubSubWorkerThread, self).__init__()
self.daemon = daemon
self.sentinel = sentinel
self.on_new_master = on_new_master
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.msg_sleep_time = msg_sleep_time
self._running = False
self.pubsub = None
def subscribe(self):
if self.pubsub:
try:
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
except Exception:
pass
finally:
self.pubsub = None
subscriptions = {"+switch-master": self.on_new_master}
while not self.pubsub or not self.pubsub.subscribed:
try:
self.pubsub = self.sentinel.pubsub()
self.pubsub.subscribe(**subscriptions)
except Exception as ex:
log.warn(
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
)
sleep(3)
log.info(f"Subscribed to sentinel {self.sentinel_host}")
def run(self):
if self._running:
return
self._running = True
self.subscribe()
while self.pubsub.subscribed:
try:
self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
)
except Exception as ex:
log.warn(
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
)
self.subscribe()
self.pubsub.close()
self._running = False
def stop(self):
# stopping simply unsubscribes from all channels and patterns.
# the unsubscribe responses that are generated will short circuit
# the loop in run(), calling pubsub.close() to clean up the connection
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
# todo,future - multi master clusters?
class RedisCluster(object):
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
self.service_name = service_name
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
self.master = None
self.master_host_port = None
self.reconfigure()
self.sentinel_threads = {}
self.listen()
def reconfigure(self):
try:
self.master_host_port = self.sentinel.discover_master(self.service_name)
self.master = self.sentinel.master_for(self.service_name)
log.info(f"Reconfigured master to {self.master_host_port}")
except Exception as ex:
log.error(f"Error while reconfiguring. {ex.args[0]}")
def listen(self):
def on_new_master(workerThread):
self.reconfigure()
for sentinel in self.sentinel.sentinels:
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
)
self.sentinel_threads[sentinel_host].start()
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
class RedisManager(object):
@@ -131,6 +42,9 @@ class RedisManager(object):
for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict()
alias_config["password"] = config.get(
f"secure.redis.{alias}.password", None
)
is_cluster = alias_config.get("cluster", False)
@@ -142,34 +56,19 @@ class RedisManager(object):
if port:
alias_config["port"] = port
db = alias_config.get("db", 0)
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
if password:
alias_config["password"] = password
sentinels = alias_config.get("sentinels", None)
service_name = alias_config.get("service_name", None)
if not is_cluster and sentinels:
raise ConfigError(
"Redis configuration is invalid. mixed regular and cluster mode",
alias=alias,
)
if is_cluster and (not sentinels or not service_name):
raise ConfigError(
"Redis configuration is invalid. missing sentinels or service_name",
alias=alias,
)
if not is_cluster and (not port or not host):
if not port or not host:
raise ConfigError(
"Redis configuration is invalid. missing port or host", alias=alias
)
if is_cluster:
# todo support all redis connection args via sentinel's connection_kwargs
del alias_config["sentinels"]
del alias_config["cluster"]
del alias_config["service_name"]
self.aliases[alias] = RedisCluster(
sentinels, service_name, **alias_config
)
del alias_config["db"]
self.aliases[alias] = RedisCluster(**alias_config)
else:
self.aliases[alias] = StrictRedis(**alias_config)
@@ -177,27 +76,21 @@ class RedisManager(object):
obj = self.aliases.get(alias)
if not obj:
raise GeneralError(f"Invalid Redis alias {alias}")
if isinstance(obj, RedisCluster):
obj.master.get("health")
return obj.master
else:
obj.get("health")
return obj
obj.get("health")
return obj
def host(self, alias):
r = self.connection(alias)
pool = r.connection_pool
if isinstance(pool, SentinelConnectionPool):
connections = pool.connection_kwargs[
"connection_pool"
]._available_connections
if isinstance(r, RedisCluster):
connections = first(r.connection_pool._available_connections.values())
else:
connections = pool._available_connections
connections = r.connection_pool._available_connections
if len(connections) > 0:
return connections[0].host
else:
if not connections:
return None
return connections[0].host
redman = RedisManager(config.get("hosts.redis"))

View File

@@ -3,7 +3,7 @@ bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
dpath>=1.4.2,<2.0
elasticsearch>=7.0.0,<8.0.0
elasticsearch==7.13.3
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
@@ -16,18 +16,19 @@ jinja2==2.11.3
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.19.1
mongoengine==0.23.1
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt<2.0.0
pymongo==3.10.1
pymongo[srv]==3.12.0
python-rapidjson>=0.6.3
redis>=2.10.5
redis==3.5.3
redis-py-cluster>=2.1.3
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.3,<3
six
tqdm
validators>=0.12.4
validators>=0.12.4

View File

@@ -26,6 +26,10 @@ credentials {
type: string
description: Credentials secret key
}
label {
type: string
description: Optional credentials label
}
}
}
batch_operation {

View File

@@ -15,11 +15,19 @@ _definitions {
type: string
description: ""
}
label {
type: string
description: Optional credentials label
}
last_used {
type: string
description: ""
format: "date-time"
}
last_used_from {
type: string
description: ""
}
}
}
role {
@@ -222,6 +230,12 @@ create_credentials {
}
}
}
"2.17": ${create_credentials."2.1"} {
request.properties.label {
type: string
description: Optional credentials label
}
}
}
get_credentials {

File diff suppressed because it is too large Load Diff

View File

@@ -61,14 +61,14 @@ _definitions {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
description: "System tags. This field is reserved for system use, please don't use it."
items { type: string }
}
framework {
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
@@ -98,9 +98,11 @@ _definitions {
additionalProperties: true
}
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -199,6 +201,29 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@@ -302,6 +327,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_frameworks {
"2.8" {
@@ -361,7 +409,7 @@ update_for_task {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
override_model_id {
description: "Override model ID. If provided, this model is updated in the task. Exactly one of override_model_id or uri is required."
@@ -427,7 +475,7 @@ create {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@@ -483,9 +531,11 @@ create {
}
"2.13": ${create."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -522,7 +572,7 @@ edit {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
@@ -578,9 +628,11 @@ edit {
}
"2.13": ${edit."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -611,7 +663,7 @@ update {
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items {type: string}
items { type: string }
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
@@ -661,9 +713,11 @@ update {
}
"2.13": ${update."2.1"} {
metadata {
type: array
description: "Model metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -672,7 +726,7 @@ publish_many {
description: Publish models
request {
properties {
ids.description: "IDs of models to publish"
ids.description: "IDs of the models to publish"
force_publish_task {
description: "Publish the associated tasks (if exist) even if they are not in the 'stopped' state. Optional, the default value is False."
type: boolean
@@ -733,7 +787,7 @@ archive_many {
description: Archive models
request {
properties {
ids.description: "IDs of models to archive"
ids.description: "IDs of the models to archive"
}
}
response {
@@ -769,10 +823,9 @@ delete_many {
description: Delete models
request {
properties {
ids.description: "IDs of models to delete"
ids.description: "IDs of the models to delete"
force {
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
"""
description: "Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published."
type: boolean
}
}
@@ -929,6 +982,11 @@ add_or_update_metadata {
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
replace_metadata {
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
type: boolean
default: false
}
}
}
response {

View File

@@ -0,0 +1,47 @@
_description: "Provides a management API for pipelines in the system."
_definitions {
}
start_pipeline {
"2.17" {
description: "Start a pipeline"
request {
type: object
required: [ task ]
properties {
task {
description: "ID of the task on which the pipeline will be based"
type: string
}
queue {
description: "Queue ID in which the created pipeline task will be enqueued"
type: string
}
args {
description: "Task arguments, name/value to be placed in the hyperparameters Args section"
type: array
items {
type: object
properties {
name: { type: string }
value: { type: [string, null] }
}
}
}
}
}
response {
type: object
properties {
pipeline {
description: "ID of the new pipeline task"
type: string
}
enqueued {
description: "True if the task was successfuly enqueued"
type: boolean
}
}
}
}
}

View File

@@ -42,15 +42,20 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -70,6 +75,18 @@ _definitions {
description: "Total run time of all tasks in project (in seconds)"
type: integer
}
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks_24h {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
last_task_run {
description: "The most recent started time of a task"
type: integer
}
status_count {
description: "Status counts"
type: object
@@ -78,6 +95,10 @@ _definitions {
description: "Number of 'created' tasks in project"
type: integer
}
completed {
description: "Number of 'completed' tasks in project"
type: integer
}
queued {
description: "Number of 'queued' tasks in project"
type: integer
@@ -152,15 +173,20 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -294,14 +320,14 @@ create {
type: string
}
tags {
type: array
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -414,7 +440,6 @@ get_all {
description: "Projects list"
type: array
items { "$ref": "#/definitions/projects_get_all_response_single" }
}
}
}
@@ -430,6 +455,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all_ex {
internal: true
@@ -488,6 +536,49 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"2.16": ${get_all_ex."2.15"} {
request.properties.stats_with_children {
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
type: boolean
default: true
}
}
"2.17": ${get_all_ex."2.16"} {
request.properties.include_stats_filter {
description: The filter for selecting entities that participate in statistics calculation
type: object
properties {
system_tags {
description: The list of allowed system tags
type: array
items { type: string }
}
}
}
}
}
update {
"2.1" {
@@ -504,23 +595,19 @@ update {
description: "Project name. Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
description {
description: "Project description"
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
@@ -658,7 +745,6 @@ delete {
type: boolean
default: false
}
}
}
response {
@@ -791,6 +877,7 @@ get_hyper_parameters {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
@@ -839,7 +926,105 @@ get_hyper_parameters {
}
}
}
get_model_metadata_values {
"2.17" {
description: """Get a list of distinct values for the chosen model metadata key"""
request {
type: object
required: [key]
properties {
projects {
description: "Project IDs"
type: array
items {type: string}
}
key {
description: "Metadata key"
type: string
}
allow_public {
description: "If set to 'true' then collect values from both company and public models otherwise company modeels only. The default is 'true'"
type: boolean
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadata values from the subproject models"
type: boolean
default: true
}
}
}
response {
type: object
properties {
total {
description: "Total number of distinct values"
type: integer
}
values {
description: "The list of the unique values"
type: array
items {type: string}
}
}
}
}
}
get_model_metadata_keys {
"2.17" {
description: """Get a list of all metadata keys used in models within the given project."""
request {
type: object
required: [project]
properties {
project {
description: "Project ID"
type: string
}
include_subprojects {
description: "If set to 'true' and the project field is set then the result includes metadate keys from the subproject models"
type: boolean
default: true
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
keys {
description: "A list of model keys"
type: array
items {type: string}
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}
}
get_project_tags {
"2.17" {
description: "Get user and system tags used for the specified projects and their children"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"
@@ -847,7 +1032,6 @@ get_task_tags {
response = ${_definitions.tags_response}
}
}
get_model_tags {
"2.8" {
description: "Get user and system tags used for the models under the specified projects"
@@ -969,4 +1153,4 @@ get_task_parents {
}
}
}
}
}

View File

@@ -79,9 +79,11 @@ _definitions {
items { "$ref": "#/definitions/entry" }
}
metadata {
type: array
description: "Queue metadata"
items {"$ref": "#/definitions/metadata_item"}
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
@@ -115,6 +117,29 @@ get_by_id {
get_all_ex {
internal: true
"2.4": ${get_all."2.4"}
"2.15": ${get_all_ex."2.4"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of queues to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.4" {
@@ -178,6 +203,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.4"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of queues to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_default {
"2.4" {
@@ -235,6 +283,15 @@ create {
}
}
}
"2.13": ${create."2.4"} {
metadata {
description: "Queue metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
}
update {
"2.4" {
@@ -276,7 +333,15 @@ update {
type: object
additionalProperties: true
}
}
}
}
"2.13": ${update."2.4"} {
metadata {
description: "Queue metadata"
type: object
additionalProperties {
"$ref": "#/definitions/metadata_item"
}
}
}
@@ -586,6 +651,11 @@ add_or_update_metadata {
description: "Metadata items to add or update"
items {"$ref": "#/definitions/metadata_item"}
}
replace_metadata {
description: "If set then the all the metadata items will be replaced with the provided ones. Otherwise only the provided metadata items will be updated or added"
type: boolean
default: false
}
}
}
response {

View File

@@ -534,7 +534,7 @@ _definitions {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
models {
description: "Task models"
@@ -685,6 +685,29 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@@ -799,6 +822,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_types {
"2.8" {
@@ -935,7 +981,7 @@ clone {
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@@ -1113,7 +1159,7 @@ create {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@@ -1202,7 +1248,7 @@ validate {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@@ -1364,7 +1410,7 @@ edit {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
runtime {
description: "Task runtime mapping"

View File

@@ -4,7 +4,7 @@ from hashlib import md5
from flask import Flask
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from packaging.version import Version
from apiserver.database import db
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
@@ -25,6 +25,7 @@ from apiserver.server_init.request_handlers import RequestHandlers
from apiserver.service_repo import ServiceRepo
from apiserver.sync import distributed_lock
from apiserver.updates import check_updates_thread
from apiserver.utilities.env import get_bool
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__file__)
@@ -46,10 +47,13 @@ class AppSequence:
def _attach_request_handlers(self, request_handlers: RequestHandlers):
self.app.before_first_request(request_handlers.before_app_first_request)
self.app.before_request(request_handlers.before_request)
self.app.after_request(request_handlers.after_request)
def _configure(self):
CORS(self.app, **config.get("apiserver.cors"))
Compress(self.app)
if get_bool("CLEARML_COMPRESS_RESP", default=True):
Compress(self.app)
self.app.config["SECRET_KEY"] = config.get(
"secure.http.session_secret.apiserver"

View File

@@ -1,21 +1,24 @@
from functools import partial
from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.auth import AuthType, Token
from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
class RequestHandlers:
_request_strip_prefix = config.get("apiserver.request.strip_prefix", None)
_server_header = config.get("apiserver.response.headers.server", "clearml")
def before_app_first_request(self):
pass
@@ -26,10 +29,13 @@ class RequestHandlers:
if "/static/" in request.path:
return
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
try:
call = self._create_api_call(request)
load_data_callback = partial(self._load_call_data, req=request)
content, content_type = ServiceRepo.handle_call(
content, content_type, company = ServiceRepo.handle_call(
call, load_data_callback=load_data_callback
)
@@ -51,20 +57,53 @@ class RequestHandlers:
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies")
kwargs = config.get("apiserver.auth.cookies").copy()
if value is None:
kwargs = kwargs.copy()
# Removing a cookie
kwargs["max_age"] = 0
kwargs["expires"] = 0
response.set_cookie(key, "", **kwargs)
else:
response.set_cookie(key, value, **kwargs)
value = ""
elif not company:
# Setting a cookie, let's try to figure out the company
# noinspection PyBroadException
try:
company = Token.decode_identity(value).company
except Exception:
pass
if company:
try:
# use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
except KeyError:
pass
response.set_cookie(key, value, **kwargs)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def after_request(self, response):
response.headers["server"] = self._server_header
return response
@staticmethod
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
def convert_value(v: str):
if v.replace(".", "", 1).isdigit():
return float(v) if "." in v else int(v)
if v in ("true", "True", "TRUE"):
return True
if v in ("false", "False", "FALSE"):
return False
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
@@ -82,23 +121,12 @@ class RequestHandlers:
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
body = (req.get_json(force=True, silent=False) if req.data else None) or {}
if req.args:
self._apply_multi_dict(body, req.args)
if req.form:
self._apply_multi_dict(body, req.form)
call.data = body
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or APICall(

View File

@@ -95,8 +95,8 @@ class DataContainer(object):
@raw_data.setter
def raw_data(self, value):
assert isinstance(
value, string_types + (types.GeneratorType,)
), "Raw data must be a string type or generator"
value, string_types + (types.GeneratorType, bytes)
), "Raw data must be a string type or bytes or generator"
self._raw_data = value
@property
@@ -310,6 +310,12 @@ class APICall(DataContainer):
_transaction_headers = _get_headers("Trx")
""" Transaction ID """
_redacted_headers = {
HEADER_AUTHORIZATION: " ",
"Cookie": "=",
}
""" Headers whose value should be redacted. Maps header name to partition char """
@property
def HEADER_TRANSACTION(self):
return self._transaction_headers[0]
@@ -389,6 +395,10 @@ class APICall(DataContainer):
self._auth_cookie = auth_cookie
self._json_flags = {}
@property
def files(self):
return self._files
@property
def id(self):
return self._id
@@ -584,6 +594,10 @@ class APICall(DataContainer):
def json_flags(self):
return self._json_flags
@property
def extra_meta_fields(self):
return {}
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
@@ -634,6 +648,7 @@ class APICall(DataContainer):
"result_msg": self.result.msg,
"error_stack": self.result.traceback if include_stack else None,
"error_data": self.result.error_data,
**self.extra_meta_fields,
},
"data": self.result.data,
}
@@ -668,3 +683,15 @@ class APICall(DataContainer):
error_data=error_data,
cookies=self._result.cookies,
)
def get_redacted_headers(self):
headers = self.headers.copy()
if not self.requires_authorization or self.auth:
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully
# this will allow us to debug authorization errors).
for header, sep in self._redacted_headers.items():
if header in headers:
prefix, _, redact = headers[header].partition(sep)
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
return headers

View File

@@ -51,7 +51,7 @@ def authorize_token(jwt_token, *_, **__):
)
def authorize_credentials(auth_data, service, action, call_data_items):
def authorize_credentials(auth_data, service, action, call):
"""Validate credentials against service/action and request data (dicts).
Returns a new basic object (auth payload)
"""
@@ -100,7 +100,12 @@ def authorize_credentials(auth_data, service, action, call_data_items):
if not fixed_user:
# In case these are proper credentials, update last used time
User.objects(id=user.id, credentials__key=access_key).update(
**{"set__credentials__$__last_used": datetime.utcnow()}
**{
"set__credentials__$__last_used": datetime.utcnow(),
"set__credentials__$__last_used_from": call.get_worker(
default=call.real_ip
),
}
)
with TimingContext("mongo", "company_by_id"):

View File

@@ -12,6 +12,9 @@ from .payload import Payload
token_secret = config.get('secure.auth.token_secret')
log = config.logger(__file__)
class Token(Payload):
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
@@ -94,3 +97,14 @@ class Token(Payload):
token.exp = now + timedelta(seconds=expiration_sec)
return token.encode(**extra_payload)
@classmethod
def decode_identity(cls, encoded_token):
# noinspection PyBroadException
try:
from ..auth import Identity
decoded = cls.decode(encoded_token, verify=False)
return Identity.from_dict(decoded.get("identity", {}))
except Exception as ex:
log.error(f"Failed parsing identity from encoded token: {ex}")

View File

@@ -10,6 +10,7 @@ from apiserver.apierrors import APIError, errors
from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall
from .auth import Identity
from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
from .util import parse_return_stack_on_code
@@ -37,7 +38,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.14")
_max_version = PartialVersion("2.17")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -233,19 +234,27 @@ class ServiceRepo(object):
return subcode in subcode_list
@classmethod
def _get_company(
def _get_identity(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]:
) -> Optional[Identity]:
authorize = endpoint and endpoint.authorize
if ignore_error or not authorize:
try:
return call.identity.company
return call.identity
except Exception:
return None
return call.identity.company
return call.identity
@classmethod
def _get_company(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]:
identity = cls._get_identity(call, endpoint=endpoint, ignore_error=ignore_error)
return None if identity is None else identity.company
@classmethod
def handle_call(cls, call: APICall, load_data_callback: Callable = None):
company = None
try:
if call.failed:
raise CallFailedError()
@@ -316,4 +325,4 @@ class ServiceRepo(object):
else:
log.error(console_msg)
return content, content_type
return content, content_type, company

View File

@@ -69,7 +69,7 @@ def validate_auth(endpoint, call):
auth = call.authorization or ""
auth_type, _, auth_data = auth.partition(" ")
authorize_func = get_auth_func(auth_type)
call.auth = authorize_func(auth_data, service, action, call.batched_data)
call.auth = authorize_func(auth_data, service, action, call)
except Exception:
if endpoint.authorize:
# if endpoint requires authorization, re-raise exception

View File

@@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
CredentialsResponse,
RevokeCredentialsRequest,
EditUserReq,
CreateCredentialsRequest,
)
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.auth import AuthBLL
@@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
""" Generates a token based on a requested user and company. INTERNAL. """
if call.identity.role not in Role.get_system_roles():
if call.identity.role != Role.admin and call.identity.user != request.user:
raise errors.bad_request.InvalidUserId("cannot generate token for another user")
raise errors.bad_request.InvalidUserId(
"cannot generate token for another user"
)
if call.identity.company != request.company:
raise errors.bad_request.InvalidId("cannot generate token in another company")
raise errors.bad_request.InvalidId(
"cannot generate token in another company"
)
call.result.data_model = AuthBLL.get_token_for_user(
user_id=request.user,
@@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
)
def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company:
if (
call.identity.role not in Role.get_system_roles()
and request.company != call.identity.company
):
raise errors.bad_request.InvalidId("cannot create user in another company")
user_id = AuthBLL.create_user(request=request, call=call)
@@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
def create_credentials(call: APICall, _, __):
def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
if _is_protected_user(call.identity.user):
raise errors.bad_request.InvalidUserId("protected identity")
@@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
user_id=call.identity.user,
company_id=call.identity.company,
role=call.identity.role,
label=request.label,
)
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@@ -151,7 +160,12 @@ def get_credentials(call: APICall, _, __):
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
call.result.data_model = GetCredentialsResponse(
credentials=[
CredentialsResponse(access_key=c.key, last_used=c.last_used)
CredentialsResponse(
access_key=c.key,
last_used=c.last_used,
label=c.label,
last_used_from=c.last_used_from,
)
for c in user.credentials
]
)

View File

@@ -1,9 +1,12 @@
import itertools
import math
from collections import defaultdict
from operator import itemgetter
from typing import Sequence, Optional
import attr
from typing import Sequence, Optional
import jsonmodels.fields
from boltons.iterutils import bucketize
from apiserver.apierrors import errors
from apiserver.apimodels.events import (
@@ -20,12 +23,17 @@ from apiserver.apimodels.events import (
NextDebugImageSampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
TaskEventsRequest,
ScalarMetricsIterRawRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL()
event_bll = EventBLL()
@@ -39,7 +47,6 @@ def add(call: APICall, company_id, _):
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1
@endpoint("events.add_batch")
@@ -50,7 +57,6 @@ def add_batch(call: APICall, company_id, _):
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
@@ -113,7 +119,8 @@ def get_task_log(call, company_id, request: LogEventsRequest):
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
res = event_bll.log_events_iterator.get_task_events(
res = event_bll.events_iterator.get_task_events(
event_type=EventType.task_log,
company_id=task.get_index_company(),
task_id=task_id,
batch_size=request.batch_size,
@@ -258,31 +265,94 @@ def vector_metrics_iter_histogram(call, company_id, _):
)
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, _):
task_id = call.data["task"]
batch_size = call.data.get("batch_size", 500)
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
class GetTaskEventsScroll(Scroll):
from_key_value = jsonmodels.fields.StringField()
total = jsonmodels.fields.IntField()
request: TaskEventsRequest = jsonmodels.fields.EmbeddedField(TaskEventsRequest)
def make_response(
total: int, returned: int = 0, scroll_id: str = None, **kwargs
) -> dict:
return {
"returned": returned,
"total": total,
"scroll_id": scroll_id,
**kwargs,
}
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
def get_task_events(call, company_id, request: TaskEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
company_id, task_id, allow_public=True, only=("company",),
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=EventType(event_type) if event_type else EventType.all,
scroll_id=scroll_id,
size=batch_size,
key = ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key)
if not request.scroll_id:
from_key_value = None if (request.order == LogOrderEnum.desc) else 0
total = None
else:
try:
scroll = GetTaskEventsScroll.from_scroll_id(request.scroll_id)
except ValueError:
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
if scroll.from_key_value is None:
return make_response(
scroll_id=request.scroll_id, total=scroll.total, events=[]
)
from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
request = scroll.request
navigate_earlier = request.order == LogOrderEnum.desc
metric_variants = _get_metric_variants_from_request(request.metrics)
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=request.event_type,
company_id=task.company,
task_id=task_id,
metric_variants=metric_variants,
)
batch_size = min(
request.batch_size,
int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
),
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
res = event_bll.events_iterator.get_task_events(
event_type=request.event_type,
company_id=task.company,
task_id=task_id,
batch_size=batch_size,
key=ScalarKeyEnum.iter,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
)
scroll = GetTaskEventsScroll(
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
total=total,
request=request,
)
return make_response(
returned=len(res.events),
total=total,
scroll_id=scroll.get_scroll_id(),
events=res.events,
)
@@ -291,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -302,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id,
no_scroll=no_scroll,
)
call.result.data = dict(
@@ -424,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
tasks = task_bll.assert_exists(
company_id=call.identity.company,
@@ -445,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
)
tasks = {t.id: t.name for t in tasks}
@@ -523,6 +597,7 @@ def get_task_plots(call, company_id, request: TaskPlotsRequest):
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
@@ -759,3 +834,105 @@ def _get_top_iter_unique_events(events, max_iters):
)
unique_events.sort(key=lambda e: e["iter"], reverse=True)
return unique_events
class ScalarMetricsIterRawScroll(Scroll):
from_key_value = jsonmodels.fields.StringField()
total = jsonmodels.fields.IntField()
request: ScalarMetricsIterRawRequest = jsonmodels.fields.EmbeddedField(
ScalarMetricsIterRawRequest
)
@endpoint("events.scalar_metrics_iter_raw", min_version="2.16")
def scalar_metrics_iter_raw(
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
):
key = request.key or ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key)
if request.batch_size and request.batch_size < 0:
raise errors.bad_request.ValidationError(
"batch_size should be non negative number"
)
if not request.scroll_id:
from_key_value = None
total = None
request.batch_size = request.batch_size or 10_000
else:
try:
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
except ValueError:
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
if scroll.from_key_value is None:
return make_response(
scroll_id=request.scroll_id, total=scroll.total, variants={}
)
from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total
request.batch_size = request.batch_size or scroll.request.batch_size
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
)[0]
metric_variants = _get_metric_variants_from_request([request.metric])
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
metric_variants=metric_variants,
)
batch_size = min(
request.batch_size,
int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000)
),
)
events = []
for iteration in range(0, math.ceil(batch_size / 10_000)):
res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
batch_size=min(batch_size, 10_000),
navigate_earlier=False,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=key,
)
if not res.events:
break
events.extend(res.events)
from_key_value = str(events[-1][scalar_key.field])
key = str(key)
variants = {
variant: extract_properties_to_lists(
["value", scalar_key.field], events, target_keys=["y", key]
)
for variant, events in bucketize(events, key=itemgetter("variant")).items()
}
call.kpis["events"] = len(events)
scroll = ScalarMetricsIterRawScroll(
from_key_value=str(events[-1][scalar_key.field]) if events else None,
total=total,
request=request,
)
return make_response(
returned=len(events),
total=total,
scroll_id=scroll.get_scroll_id(),
variants=variants,
)

View File

@@ -21,16 +21,14 @@ from apiserver.apimodels.models import (
ModelsPublishManyRequest,
ModelsDeleteManyRequest,
)
from apiserver.bll.model import ModelBLL
from apiserver.bll.model import ModelBLL, Metadata
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import publish_task
from apiserver.bll.util import run_batch_operation
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model import validate_id
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.model import Model
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import (
@@ -50,8 +48,8 @@ from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
ModelsBackwardsCompatibility,
validate_metadata,
get_metadata_from_api,
unescape_metadata,
escape_metadata,
)
from apiserver.timing_context import TimingContext
@@ -64,19 +62,20 @@ project_bll = ProjectBLL()
def get_by_id(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
models = Model.get_many(
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
Metadata.escape_query_parameters(call)
models = Model.get_many(
company=company_id,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
conform_output_tags(call, models[0])
call.result.data = {"model": models[0]}
conform_output_tags(call, models[0])
unescape_metadata(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
@@ -86,25 +85,25 @@ def get_by_task_id(call: APICall, company_id, _):
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
query = dict(id=task_id, company=company_id)
task = Task.get(_only=["models"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not task.models or not task.models.output:
raise errors.bad_request.MissingTaskFields(field="models.output")
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
call.result.data = {"model": model_dict}
model_id = task.models.output[-1].model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company_id)
).first()
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model", id=model_id, company=company_id,
)
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
unescape_metadata(call, model_dict)
call.result.data = {"model": model_dict}
def _process_include_subprojects(call_data: dict):
@@ -121,41 +120,50 @@ def _process_include_subprojects(call_data: dict):
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
_process_include_subprojects(call.data)
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_by_id_ex", required_fields=["id"])
def get_by_id_ex(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
call.result.data = {"models": models}
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_by_id_ex"):
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
Metadata.escape_query_parameters(call)
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
unescape_metadata(call, models)
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
@@ -183,15 +191,22 @@ create_fields = {
"metadata": list,
}
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata")
last_update_fields = (
"uri",
"framework",
"design",
"labels",
"ready",
"metadata",
"system_tags",
"tags",
)
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
conform_tag_fields(call, fields, validate=True)
metadata = fields.get("metadata")
if metadata:
validate_metadata(metadata)
escape_metadata(fields)
return fields
@@ -225,82 +240,80 @@ def update_for_task(call: APICall, company_id, _):
"exactly one field is required", fields=("uri", "override_model_id")
)
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
)
if override_model_id:
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models and task.models.output:
# model exists, update
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=now,
last_update=now,
user=call.identity.user,
company=company_id,
_only=["models", "execution", "name", "status", "project"],
project=task.project,
framework=task.execution.framework,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)
if override_model_id:
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=override_model_id
)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.models and task.models.output:
# model exists, update
model_id = task.models.output[-1].model
res = _update_model(call, company_id, model_id=model_id).to_struct()
res.update({"id": model_id, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
created=now,
last_update=now,
user=call.identity.user,
company=company_id,
project=task.project,
framework=task.execution.framework,
parent=task.models.input[0].model
if task.models and task.models.input
else None,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
models__output=[
ModelItem(
model=model.id,
name=TaskModelNames[TaskModelTypes.output],
updated=datetime.utcnow(),
)
],
)
call.result.data = {"id": model.id, "created": True}
call.result.data = {"id": model.id, "created": True}
@endpoint(
@@ -313,36 +326,33 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
if req_model.public:
company_id = ""
with translate_errors_context():
project = req_model.project
if project:
validate_id(Project, company=company_id, project=project)
project = req_model.project
if project:
validate_id(Project, company=company_id, project=project)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(company_id, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields, validate=True)
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=now,
last_update=now,
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
validate_metadata(fields.get("metadata"))
# create and save model
now = datetime.utcnow()
model = Model(
id=database.utils.id(),
user=call.identity.user,
company=company_id,
created=now,
last_update=now,
**fields,
)
model.save()
_update_cached_tags(company_id, project=model.project, fields=fields)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
call.result.data_model = CreateModelResponse(id=model.id, created=True)
def prepare_update_fields(call, company_id, fields: dict):
@@ -377,6 +387,7 @@ def prepare_update_fields(call, company_id, fields: dict):
)
conform_tag_fields(call, fields, validate=True)
escape_metadata(fields)
return fields
@@ -388,89 +399,85 @@ def validate_task(company_id, fields: dict):
def edit(call: APICall, company_id, _):
model_id = call.data["model"]
with translate_errors_context():
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, company_id, fields)
if fields:
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get("task")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
if fields:
if any(uf in fields for uf in last_update_fields):
fields.update(last_update=datetime.utcnow())
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
updated = model.update(upsert=False, **fields)
if updated:
new_project = fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(
company_id, projects=[new_project, model.project]
)
else:
_update_cached_tags(
company_id, project=model.project, fields=fields
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call: APICall, company_id, model_id=None):
model_id = model_id or call.data["model"]
with translate_errors_context():
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id
)
data = prepare_update_fields(call, company_id, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
)
data = prepare_update_fields(call, company_id, call.data)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id, company_id=company_id, last_iteration_max=iteration,
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
metadata = data.get("metadata")
if metadata:
validate_metadata(metadata)
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
if updated_count:
if any(uf in updated_fields for uf in last_update_fields):
model.update(upsert=False, last_update=datetime.utcnow())
new_project = updated_fields.get("project", model.project)
if new_project != model.project:
_reset_cached_tags(company_id, projects=[new_project, model.project])
else:
_update_cached_tags(
company_id, project=model.project, fields=updated_fields
)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
conform_output_tags(call, updated_fields)
unescape_metadata(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
@@ -635,26 +642,25 @@ def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
model_id = request.model
ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
updated = metadata_add_or_update(
cls=Model, _id=model_id, items=get_metadata_from_api(request.metadata),
)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}
model = ModelBLL.get_company_model_by_id(company_id=company_id, model_id=model_id)
return {
"updated": Metadata.edit_metadata(
model,
items=request.metadata,
replace_metadata=request.replace_metadata,
last_update=datetime.utcnow(),
)
}
@endpoint("models.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
model_id = request.model
ModelBLL.get_company_model_by_id(
model = ModelBLL.get_company_model_by_id(
company_id=company_id, model_id=model_id, only_fields=("id",)
)
updated = metadata_delete(cls=Model, _id=model_id, keys=request.keys)
if updated:
Model.objects(id=model_id).update_one(last_update=datetime.utcnow())
return {"updated": updated}
return {
"updated": Metadata.delete_metadata(
model, keys=request.keys, last_update=datetime.utcnow()
)
}

View File

@@ -5,7 +5,7 @@ from apiserver.apimodels.organization import TagsRequest
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.database.model import User
from apiserver.service_repo import endpoint, APICall
from apiserver.services.utils import get_tags_filter_dictionary, get_tags_response
from apiserver.services.utils import get_tags_filter_dictionary, sort_tags_response
org_bll = OrgBLL()
@@ -21,17 +21,13 @@ def get_tags(call: APICall, company, request: TagsRequest):
for field, vals in tags.items():
ret[field] |= vals
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint("organization.get_user_companies")
def get_user_companies(call: APICall, company_id: str, _):
users = [
{
"id": u.id,
"name": u.name,
"avatar": u.avatar,
}
{"id": u.id, "name": u.name, "avatar": u.avatar}
for u in User.objects(company=company_id).only("avatar", "name", "company")
]

View File

@@ -0,0 +1,68 @@
import re
from apiserver.apimodels.pipelines import StartPipelineResponse, StartPipelineRequest
from apiserver.bll.organization import OrgBLL
from apiserver.bll.project import ProjectBLL
from apiserver.bll.task import TaskBLL
from apiserver.bll.task.task_operations import enqueue_task
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
org_bll = OrgBLL()
project_bll = ProjectBLL()
task_bll = TaskBLL()
def _update_task_name(task: Task):
if not task or not task.project:
return
project = Project.objects(id=task.project).only("name").first()
if not project:
return
_, _, name_prefix = project.name.rpartition("/")
name_mask = re.compile(rf"{re.escape(name_prefix)}( #\d+)?$")
count = Task.objects(
project=task.project, system_tags__in=["pipeline"], name=name_mask
).count()
new_name = f"{name_prefix} #{count}" if count > 0 else name_prefix
task.update(name=new_name)
@endpoint(
"pipelines.start_pipeline", response_data_model=StartPipelineResponse,
)
def start_pipeline(call: APICall, company_id: str, request: StartPipelineRequest):
hyperparams = None
if request.args:
hyperparams = {
"Args": {
str(arg.name): {
"section": "Args",
"name": str(arg.name),
"value": str(arg.value),
}
for arg in request.args or []
}
}
task, _ = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
task_id=request.task,
hyperparams=hyperparams,
)
_update_task_name(task)
queued, res = enqueue_task(
task_id=task.id,
company_id=company_id,
queue_id=request.queue,
status_message="Starting pipeline",
status_reason="",
)
return StartPipelineResponse(pipeline=task.id, enqueued=bool(queued))

View File

@@ -7,7 +7,7 @@ from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
GetHyperParamRequest,
GetParamsRequest,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
@@ -17,14 +17,14 @@ from apiserver.apimodels.projects import (
MergeRequest,
ProjectOrNoneRequest,
ProjectRequest,
ProjectModelMetadataValuesRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_cleanup import (
delete_project,
validate_project_delete,
)
from apiserver.bll.task import TaskBLL
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.utils import (
@@ -36,13 +36,13 @@ from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
get_tags_filter_dictionary,
get_tags_response,
sort_tags_response,
)
from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
task_bll = TaskBLL()
project_bll = ProjectBLL()
project_queries = ProjectQueries()
create_fields = {
"name": None,
@@ -111,8 +111,12 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
_adjust_search_parameters(data, shallow_search=request.shallow_search)
ret_params = {}
projects = Project.get_many_with_join(
company=company_id, query_dict=data, allow_public=allow_public,
company=company_id,
query_dict=data,
allow_public=allow_public,
ret_params=ret_params,
)
if request.check_own_contents and requested_ids:
@@ -121,14 +125,16 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
}
if existing_requested_ids:
contents = project_bll.calc_own_contents(
company=company_id, project_ids=list(existing_requested_ids)
company=company_id,
project_ids=list(existing_requested_ids),
filter_=request.include_stats_filter,
)
for project in projects:
project.update(**contents.get(project["id"], {}))
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
return
project_ids = {project["id"] for project in projects}
@@ -136,13 +142,15 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
filter_=request.include_stats_filter,
)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
@endpoint("projects.get_all")
@@ -151,15 +159,17 @@ def get_all(call: APICall):
data = call.data
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
@endpoint(
@@ -260,7 +270,7 @@ def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
):
metrics = task_bll.get_unique_metric_variants(
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
@@ -269,14 +279,48 @@ def get_unique_metric_variants(
call.result.data = {"metrics": metrics}
@endpoint("projects.get_model_metadata_keys",)
def get_model_metadata_keys(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, keys = project_queries.get_model_metadata_keys(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"keys": keys,
}
@endpoint("projects.get_model_metadata_values")
def get_model_metadata_values(
call: APICall, company_id: str, request: ProjectModelMetadataValuesRequest
):
total, values = project_queries.get_model_metadata_distinct_values(
company_id,
project_ids=request.projects,
key=request.key,
include_subprojects=request.include_subprojects,
allow_public=request.allow_public,
)
call.result.data = {
"total": total,
"values": values,
}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",
request_data_model=GetHyperParamRequest,
request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
@@ -299,7 +343,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
total, values = task_bll.get_hyperparam_distinct_values(
total, values = project_queries.get_task_hyperparam_distinct_values(
company_id,
project_ids=request.projects,
section=request.section,
@@ -313,6 +357,17 @@ def get_hyperparam_values(
}
@endpoint("projects.get_project_tags")
def get_tags(call: APICall, company, request: ProjectTagsRequest):
tags, system_tags = project_bll.get_project_tags(
company,
include_system=request.include_system,
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = sort_tags_response({"tags": tags, "system_tags": system_tags})
@endpoint(
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
)
@@ -324,7 +379,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint(
@@ -338,7 +393,7 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
filter_=get_tags_filter_dictionary(request.filter),
projects=request.projects,
)
call.result.data = get_tags_response(ret)
call.result.data = sort_tags_response(ret)
@endpoint(

View File

@@ -13,17 +13,19 @@ from apiserver.apimodels.queues import (
QueueMetrics,
AddOrUpdateMetadataRequest,
DeleteMetadataRequest,
GetNextTaskRequest,
)
from apiserver.bll.model import Metadata
from apiserver.bll.queue import QueueBLL
from apiserver.bll.workers import WorkerBLL
from apiserver.database.model.metadata import metadata_add_or_update, metadata_delete
from apiserver.database.model.queue import Queue
from apiserver.database.model.task.task import Task
from apiserver.service_repo import APICall, endpoint
from apiserver.services.utils import (
conform_tag_fields,
conform_output_tags,
conform_tags,
get_metadata_from_api,
escape_metadata,
unescape_metadata,
)
from apiserver.utilities import extract_properties_to_lists
@@ -36,6 +38,7 @@ def get_by_id(call: APICall, company_id, req_model: QueueRequest):
queue = queue_bll.get_by_id(company_id, req_model.queue)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
unescape_metadata(call, queue_dict)
call.result.data = {"queue": queue_dict}
@@ -48,21 +51,28 @@ def get_by_id(call: APICall):
@endpoint("queues.get_all_ex", min_version="2.4")
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company, query_dict=call.data
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.get_all", min_version="2.4")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data)
ret_params = {}
Metadata.escape_query_parameters(call)
queues = queue_bll.get_all(
company_id=call.identity.company, query_dict=call.data, ret_params=ret_params,
)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
unescape_metadata(call, queues)
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
@@ -75,7 +85,7 @@ def create(call: APICall, company_id, request: CreateRequest):
name=request.name,
tags=tags,
system_tags=system_tags,
metadata=get_metadata_from_api(request.metadata),
metadata=Metadata.metadata_from_api(request.metadata),
)
call.result.data = {"id": queue.id}
@@ -89,10 +99,12 @@ def create(call: APICall, company_id, request: CreateRequest):
def update(call: APICall, company_id, req_model: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data, validate=True)
escape_metadata(data)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
conform_output_tags(call, fields)
unescape_metadata(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -113,11 +125,19 @@ def add_task(call: APICall, company_id, req_model: TaskRequest):
}
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
if task:
call.result.data = {"entry": task.to_proper_dict()}
@endpoint("queues.get_next_task", request_data_model=GetNextTaskRequest)
def get_next_task(call: APICall, company_id, req_model: GetNextTaskRequest):
entry = queue_bll.get_next_task(
company_id=company_id, queue_id=req_model.queue
)
if entry:
data = {"entry": entry.to_proper_dict()}
if req_model.get_task_info:
task = Task.objects(id=entry.task).first()
if task:
data["task_info"] = {"company": task.company, "user": task.user}
call.result.data = data
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
@@ -237,21 +257,19 @@ def get_queue_metrics(
@endpoint("queues.add_or_update_metadata", min_version="2.13")
def add_or_update_metadata(
_: APICall, company_id: str, request: AddOrUpdateMetadataRequest
call: APICall, company_id: str, request: AddOrUpdateMetadataRequest
):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {
"updated": metadata_add_or_update(
cls=Queue, _id=queue_id, items=get_metadata_from_api(request.metadata),
"updated": Metadata.edit_metadata(
queue, items=request.metadata, replace_metadata=request.replace_metadata
)
}
@endpoint("queues.delete_metadata", min_version="2.13")
def delete_metadata(_: APICall, company_id: str, request: DeleteMetadataRequest):
def delete_metadata(call: APICall, company_id: str, request: DeleteMetadataRequest):
queue_id = request.queue
queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": metadata_delete(cls=Queue, _id=queue_id, keys=request.keys)}
queue = queue_bll.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return {"updated": Metadata.delete_metadata(queue, keys=request.keys)}

View File

@@ -221,11 +221,15 @@ def get_all_ex(call: APICall, company_id, _):
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
company=company_id,
query_dict=call_data,
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
call.result.data = {"tasks": tasks, **ret_params}
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
@@ -250,14 +254,16 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with TimingContext("mongo", "task_get_all"):
ret_params = {}
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
call.result.data = {"tasks": tasks, **ret_params}
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
@@ -1050,6 +1056,8 @@ def delete(call: APICall, company_id, request: DeleteRequest):
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
)
if deleted:
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
@@ -1066,6 +1074,8 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
),
ids=request.ids,
)
@@ -1163,7 +1173,7 @@ def add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
force=True,
)
}
@@ -1179,7 +1189,7 @@ def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest)
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
force=True,
)
}

View File

@@ -2,7 +2,6 @@ from datetime import datetime
from typing import Union, Sequence, Tuple
from apiserver.apierrors import errors
from apiserver.apimodels.metadata import MetadataItem as ApiMetadataItem
from apiserver.apimodels.organization import Filter
from apiserver.database.model.base import GetMixin
from apiserver.database.model.task.task import TaskModelTypes, TaskModelNames
@@ -24,7 +23,7 @@ def get_tags_filter_dictionary(input_: Filter) -> dict:
}
def get_tags_response(ret: dict) -> dict:
def sort_tags_response(ret: dict) -> dict:
return {field: sorted(vals) for field, vals in ret.items()}
@@ -222,22 +221,38 @@ class DockerCmdBackwardsCompatibility:
nested_set(task, cls.field, docker_cmd)
def validate_metadata(metadata: Sequence[dict]):
def escape_metadata(document: dict):
"""
Escape special characters in metadata keys
"""
metadata = document.get("metadata")
if not metadata:
return
keys = [m.get("key") for m in metadata]
unique_keys = set(keys)
unique_keys.discard(None)
if len(keys) != len(set(keys)):
raise errors.bad_request.ValidationError("Metadata keys should be unique")
document["metadata"] = {
ParameterKeyEscaper.escape(k): v
for k, v in metadata.items()
}
def get_metadata_from_api(api_metadata: Sequence[ApiMetadataItem]) -> Sequence:
if not api_metadata:
return api_metadata
def unescape_metadata(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
Unescape special characters in metadata keys
"""
if isinstance(documents, dict):
documents = [documents]
metadata = [m.to_struct() for m in api_metadata]
validate_metadata(metadata)
old_client = call.requested_endpoint_version <= PartialVersion("2.16")
for doc in documents:
if old_client and "metadata" in doc:
doc["metadata"] = []
continue
return metadata
metadata = doc.get("metadata")
if not metadata:
continue
doc["metadata"] = {
ParameterKeyEscaper.unescape(k): v
for k, v in metadata.items()
}

View File

@@ -71,7 +71,7 @@ class TestService(TestCase, TestServiceInterface):
delete_params=delete_params,
)
def setUp(self, version="1.7"):
def setUp(self, version="999.0"):
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
self._deferred = []
self._version = parse(version)

View File

@@ -38,7 +38,7 @@ class TestEntityOrdering(TestService):
self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20)
field_vals = []
page_size = 2
page_size = 4
num_pages = 5
for page in range(num_pages):
paged_tasks = self._get_page_tasks(

View File

@@ -2,9 +2,6 @@ from apiserver.tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_get_user_companies(self):
company = self.api.organization.get_user_companies().companies[0]
self.assertEqual(len(company.owners), company.allocated)

View File

@@ -0,0 +1,80 @@
import math
from apiserver.tests.automated import TestService
class TestPagingAndScrolling(TestService):
name_prefix = f"Test paging "
def setUp(self, **kwargs):
super().setUp(**kwargs)
self.task_ids = self._create_tasks()
def _create_tasks(self):
tasks = [
self._temp_task(
name=f"{self.name_prefix}{i}",
hyperparams={
"test": {
"param": {
"section": "test",
"name": "param",
"type": "str",
"value": str(i),
}
}
},
)
for i in range(18)
]
return tasks
def test_paging(self):
page_size = 10
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
tasks = self._get_tasks(page=page, page_size=page_size,).tasks
self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
def test_scrolling(self):
page_size = 10
scroll_id = None
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
res = self._get_tasks(size=page_size, scroll_id=scroll_id,)
self.assertTrue(res.scroll_id)
scroll_id = res.scroll_id
tasks = res.tasks
self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
# no more data in this scroll
tasks = self._get_tasks(size=page_size, scroll_id=scroll_id,).tasks
self.assertFalse(tasks)
# refresh brings all
tasks = self._get_tasks(
size=page_size, scroll_id=scroll_id, refresh_scroll=True,
).tasks
self.assertEqual([t.id for t in tasks], self.task_ids)
def _get_tasks(self, **page_params):
return self.api.tasks.get_all_ex(
name="^Test paging ",
order_by=["hyperparams.test.param.value"],
**page_params,
)
def _temp_task(self, name, **kwargs):
return self.create_temp(
"tasks",
name=name,
comment="Test task",
type="testing",
input=dict(view=dict()),
**kwargs,
)

View File

@@ -4,10 +4,29 @@ from apiserver.tests.automated import TestService
class TestProjectTags(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_project_own_tags(self):
p1_tags = ["Tag 1", "Tag 2"]
p1 = self.create_temp(
"projects", name="Test project tags1", description="test", tags=p1_tags
)
p2_tags = ["Tag 1", "Tag 3"]
p2 = self.create_temp(
"projects",
name="Test project tags2",
description="test",
tags=p2_tags,
system_tags=["hidden"],
)
def test_project_tags(self):
res = self.api.projects.get_project_tags(projects=[p1, p2])
self.assertEqual(set(res.tags), set(p1_tags) | set(p2_tags))
res = self.api.projects.get_project_tags(
projects=[p1, p2], filter={"system_tags": ["__$not", "hidden"]}
)
self.assertEqual(res.tags, p1_tags)
def test_project_entities_tags(self):
tags_1 = ["Test tag 1", "Test tag 2"]
tags_2 = ["Test tag 3", "Test tag 4"]

View File

@@ -1,15 +1,11 @@
from functools import partial
from typing import Sequence
from apiserver.tests.api_client import APIClient
from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
meta1 = {"test_key": {"key": "test_key", "type": "str", "value": "test_value"}}
def test_queue_metas(self):
queue_id = self._temp_queue("TestMetadata", metadata=self.meta1)
@@ -26,20 +22,51 @@ class TestQueueAndModelMetadata(TestService):
)
model_id = self._temp_model("TestMetadata1")
self.api.models.edit(model=model_id, metadata=[self.meta1[0]])
self.api.models.edit(model=model_id, metadata=self.meta1)
self._assertMeta(service=service, entity=entity, _id=model_id, meta=self.meta1)
def test_project_meta_query(self):
self._temp_model("TestMetadata", metadata=self.meta1)
project = self.temp_project(name="MetaParent")
test_key = "test_key"
test_key2 = "test_key2"
test_value = "test_value"
test_value2 = "test_value2"
model_id = self._temp_model(
"TestMetadata2",
project=project,
metadata={
test_key: {"key": test_key, "type": "str", "value": test_value},
test_key2: {"key": test_key2, "type": "str", "value": test_value2},
},
)
res = self.api.projects.get_model_metadata_keys()
self.assertTrue({test_key, test_key2}.issubset(set(res["keys"])))
res = self.api.projects.get_model_metadata_keys(include_subprojects=False)
self.assertTrue(test_key in res["keys"])
self.assertFalse(test_key2 in res["keys"])
model = self.api.models.get_all_ex(
id=[model_id], only_fields=["metadata.test_key"]
).models[0]
self.assertTrue(test_key in model.metadata)
self.assertFalse(test_key2 in model.metadata)
res = self.api.projects.get_model_metadata_values(key=test_key)
self.assertEqual(res.total, 1)
self.assertEqual(res["values"], [test_value])
def _test_meta_operations(
self, service: APIClient.Service, entity: str, _id: str,
):
assert_meta = partial(self._assertMeta, service=service, entity=entity)
assert_meta(_id=_id, meta=self.meta1)
meta2 = [
{"key": "test1", "type": "str", "value": "data1"},
{"key": "test2", "type": "str", "value": "data2"},
{"key": "test3", "type": "str", "value": "data3"},
]
meta2 = {
"test1": {"key": "test1", "type": "str", "value": "data1"},
"test2": {"key": "test2", "type": "str", "value": "data2"},
"test3": {"key": "test3", "type": "str", "value": "data3"},
}
service.update(**{entity: _id, "metadata": meta2})
assert_meta(_id=_id, meta=meta2)
@@ -51,16 +78,17 @@ class TestQueueAndModelMetadata(TestService):
]
res = service.add_or_update_metadata(**{entity: _id, "metadata": updates})
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=[meta2[0], *updates])
assert_meta(_id=_id, meta={**meta2, **{u["key"]: u for u in updates}})
res = service.delete_metadata(
**{entity: _id, "keys": [f"test{idx}" for idx in range(2, 6)]}
)
self.assertEqual(res.updated, 1)
assert_meta(_id=_id, meta=meta2[:1])
# noinspection PyTypeChecker
assert_meta(_id=_id, meta=dict(list(meta2.items())[:1]))
def _assertMeta(
self, service: APIClient.Service, entity: str, _id: str, meta: Sequence[dict]
self, service: APIClient.Service, entity: str, _id: str, meta: dict
):
res = service.get_all_ex(id=[_id])[f"{entity}s"][0]
self.assertEqual(res.metadata, meta)
@@ -72,3 +100,12 @@ class TestQueueAndModelMetadata(TestService):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test models meta",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)

View File

@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestSubProjects(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
@@ -202,7 +199,10 @@ class TestSubProjects(TestService):
res1 = next(p for p in res if p.id == project1)
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks_24h"], 2)
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
self.assertEqual(
{sp.name for sp in res1.sub_projects},
{
@@ -214,7 +214,10 @@ class TestSubProjects(TestService):
res2 = next(p for p in res if p.id == project2)
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["in_progress"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["completed"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks):

View File

@@ -133,6 +133,32 @@ class TestTags(TestService):
).models
self.assertFound(model_id, [], models)
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def testTaskTags(self):
task_id = self._temp_task(
name="Test tags", system_tags=["active"]
@@ -169,35 +195,11 @@ class TestTags(TestService):
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "stopped")
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
self.assertEqual(project.stats.active.completed_tasks_24h, 1)
self.assertEqual(project.stats.active.total_tasks, 1)
for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0)

View File

@@ -82,14 +82,10 @@ class TestTasksArtifacts(TestService):
# test edit running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)

View File

@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestTaskEvents(TestService):
def setUp(self, version="2.9"):
super().setUp(version=version)
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
@@ -257,6 +254,45 @@ class TestTaskEvents(TestService):
self.assertEqual(len(task_data["x"]), iterations)
self.assertEqual(len(task_data["y"]), iterations)
def test_task_metric_raw(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
batch_size = 15
metric_param = {"metric": metric, "variants": [variant]}
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=batch_size, metric=metric_param, count_total=True
)
self.assertEqual(res.total, len(events))
self.assertTrue(res.scroll_id)
res_iters = []
res_ys = []
calls = 0
while res.returned or calls > 10:
calls += 1
res_iters.extend(res.variants[variant]["iter"])
res_ys.extend(res.variants[variant]["y"])
scroll_id = res.scroll_id
res = self.api.events.scalar_metrics_iter_raw(
task=task, metric=metric_param, scroll_id=scroll_id
)
self.assertEqual(calls, len(events) // batch_size + 1)
self.assertEqual(res_iters, [ev["iter"] for ev in events])
self.assertEqual(res_ys, [ev["value"] for ev in events])
def test_task_metric_value_intervals(self):
metric = "Metric1"
variant = "Variant1"

View File

@@ -1,4 +1,5 @@
import os
from datetime import timedelta, datetime
from threading import Thread
from time import sleep
from typing import Optional
@@ -10,6 +11,7 @@ from semantic_version import Version
from apiserver.config_repo import config
from apiserver.config.info import get_version
from apiserver.database.model.settings import Settings
from apiserver.redis_manager import redman
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__name__)
@@ -17,6 +19,8 @@ log = config.logger(__name__)
class CheckUpdatesThread(Thread):
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
_lock_name = "check_updates"
_redis = redman.connection("apiserver")
@attr.s(auto_attribs=True)
class _VersionResponse:
@@ -29,6 +33,19 @@ class CheckUpdatesThread(Thread):
target=self._check_updates, daemon=True
)
@property
def update_interval(self):
return timedelta(
seconds=max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24,
)
),
60 * 5,
)
)
def start(self) -> None:
if not self._enabled:
log.info("Checking for updates is disabled")
@@ -37,12 +54,13 @@ class CheckUpdatesThread(Thread):
@property
def component_name(self) -> str:
return config.get("apiserver.check_for_updates.component_name", "clearml-server")
return config.get(
"apiserver.check_for_updates.component_name", "clearml-server"
)
def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get(
"apiserver.check_for_updates.url",
"https://updates.clear.ml/updates",
"apiserver.check_for_updates.url", "https://updates.clear.ml/updates",
)
uid = Settings.get_by_key("server.uuid")
@@ -81,34 +99,31 @@ class CheckUpdatesThread(Thread):
)
def _check_updates(self):
update_interval_sec = max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec",
60 * 60 * 24,
)
),
60 * 5,
)
while not ThreadsManager.terminating:
# noinspection PyBroadException
try:
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception:
log.exception("Failed obtaining updates")
if self._redis.set(
self._lock_name,
value=datetime.utcnow().isoformat(),
ex=self.update_interval - timedelta(seconds=60),
nx=True,
):
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception as ex:
log.exception("Failed obtaining updates: " + str(ex))
sleep(update_interval_sec)
sleep(self.update_interval.total_seconds())
check_updates_thread = CheckUpdatesThread()

View File

@@ -37,15 +37,12 @@ def deep_merge(source: dict, override: dict) -> dict:
def nested_get(
dictionary: Mapping,
path: Union[Sequence[str], str],
path: Sequence[str],
default: Optional[Union[Any, Callable]] = None,
) -> Any:
if isinstance(path, str):
path = [path]
node = dictionary
for key in path:
if key not in node:
if not node or key not in node:
if callable(default):
return default()
return default

View File

@@ -0,0 +1,14 @@
from distutils.util import strtobool
from os import getenv
from typing import Optional
def get_bool(*keys: str, default: bool = None) -> Optional[bool]:
try:
value = next(env for env in (getenv(key) for key in keys) if env is not None)
except StopIteration:
return default
try:
return bool(strtobool(value))
except ValueError:
return bool(value)

View File

@@ -1 +1 @@
__version__ = "1.1.0"
__version__ = "1.3.0"

35
docker/build/Dockerfile Normal file
View File

@@ -0,0 +1,35 @@
FROM centos/nodejs-12-centos7 AS webapp
USER root
WORKDIR /opt
RUN git clone https://github.com/allegroai/clearml-web.git
RUN mv clearml-web /opt/open-webapp
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
FROM centos:7 AS staging_image
COPY --chmod=744 docker/build/internal_files/entrypoint.sh /opt/clearml/
COPY fileserver /opt/clearml/fileserver/
COPY apiserver /opt/clearml/apiserver/
FROM centos:7
COPY --from=staging_image /opt/clearml/ /opt/clearml/
COPY --chmod=744 docker/build/internal_files/final_image_preparation.sh /tmp/internal_files/
COPY docker/build/internal_files/clearml.conf.template /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/final_image_preparation.sh'
COPY --from=webapp /opt/open-webapp/build /usr/share/nginx/html
EXPOSE 8080
EXPOSE 8008
EXPOSE 8081
ARG VERSION
ARG BUILD
ENV CLEARML_SERVER_VERSION=${VERSION}
ENV CLEARML_SERVER_BUILD=${BUILD}
WORKDIR /opt/clearml/
ENTRYPOINT ["/opt/clearml/entrypoint.sh"]

View File

@@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -x
set -e
cd /opt/open-webapp/
npm ci --unsafe-perm node-sass
cd /opt/open-webapp/
npm run build

View File

@@ -0,0 +1,97 @@
# For more information on configuration, see:
# * Official English Documentation: http://nginx.org/en/docs/
# * Official Russian Documentation: http://nginx.org/ru/docs/
user nginx;
worker_processes auto;
error_log /var/log/nginx/error.log;
pid /run/nginx.pid;
# Load dynamic modules. See /usr/share/doc/nginx/README.dynamic.
include /usr/share/nginx/modules/*.conf;
events {
worker_connections 1024;
}
http {
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
'$status $body_bytes_sent "$http_referer" '
'"$http_user_agent" "$http_x_forwarded_for"';
access_log /var/log/nginx/access.log main;
sendfile on;
tcp_nopush on;
tcp_nodelay on;
keepalive_timeout 65;
types_hash_max_size 2048;
include /etc/nginx/mime.types;
default_type application/octet-stream;
# Load modular configuration files from the /etc/nginx/conf.d directory.
# See http://nginx.org/en/docs/ngx_core_module.html#include
# for more information.
include /etc/nginx/conf.d/*.conf;
server {
listen 80 default_server;
listen [::]:80 default_server;
server_name _;
root /usr/share/nginx/html;
proxy_http_version 1.1;
# comppression
gzip on;
gzip_comp_level 9;
gzip_http_version 1.0;
gzip_min_length 512;
gzip_proxied expired no-cache no-store private auth;
gzip_types text/plain
text/css
application/json
application/javascript
application/x-javascript
text/xml application/xml
application/xml+rss
text/javascript
application/x-font-ttf
font/woff2
image/svg+xml
image/x-icon;
# Load configuration files for the default server block.
include /etc/nginx/default.d/*.conf;
location / {
try_files $uri$args $uri$args/ $uri index.html /index.html;
}
location /version.json {
add_header Cache-Control 'no-cache';
}
location /api {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass ${NGINX_APISERVER_ADDR};
rewrite /api/(.*) /$1 break;
}
location /files {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass ${NGINX_FILESERVER_ADDR};
rewrite /files/(.*) /$1 break;
}
error_page 404 /404.html;
location = /40x.html {
}
error_page 500 502 503 504 /50x.html;
location = /50x.html {
}
}
}

View File

@@ -0,0 +1,67 @@
#!/usr/bin/env bash
set -e
mkdir -p /var/log/clearml
SERVER_TYPE=$1
if (( $# < 1 )) ; then
echo "The server type was not stated. It should be either apiserver, webserver or fileserver."
sleep 60
exit 1
elif [[ ${SERVER_TYPE} == "apiserver" ]]; then
cd /opt/clearml/
python3 -m apiserver.apierrors_generator
if [[ -n $CLEARML_USE_GUNICORN ]]; then
MAX_REQUESTS=
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS ]]; then
MAX_REQUESTS="--max-requests $CLEARML_GUNICORN_MAX_REQUESTS"
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS_JITTER ]]; then
MAX_REQUESTS="$MAX_REQUESTS --max-requests-jitter $CLEARML_GUNICORN_MAX_REQUESTS_JITTER"
fi
fi
export GUNICORN_CMD_ARGS=${CLEARML_GUNICORN_CMD_ARGS}
# Note: don't be tempted to "fix" $MAX_REQUESTS with "$MAX_REQUESTS" as this produces an empty arg which fucks up gunicorn
gunicorn \
-w "${CLEARML_GUNICORN_WORKERS:-8}" \
-t "${CLEARML_GUNICORN_TIMEOUT:-600}" --bind="${CLEARML_GUNICORN_BIND:-0.0.0.0:8008}" \
$MAX_REQUESTS apiserver.server:app
else
python3 -m apiserver.server
fi
elif [[ ${SERVER_TYPE} == "webserver" ]]; then
if [[ "${USER_KEY}" != "" ]] || [[ "${USER_SECRET}" != "" ]] || [[ "${COMPANY_ID}" != "" ]]; then
cat << EOF > /usr/share/nginx/html/credentials.json
{
"userKey": "${USER_KEY}",
"userSecret": "${USER_SECRET}",
"companyID": "${COMPANY_ID}"
}
EOF
fi
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
envsubst '${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/nginx.conf
#start the server
/usr/sbin/nginx -g "daemon off;"
elif [[ ${SERVER_TYPE} == "fileserver" ]]; then
cd /opt/clearml/fileserver/
if [ "$FILESERVER_USE_GUNICORN" = true ] ; then
gunicorn -t 600 --bind=0.0.0.0:8081 fileserver:app
else
python3 fileserver.py
fi
else
echo "Server type ${SERVER_TYPE} is invalid. Please choose either apiserver, webserver or fileserver."
fi

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env bash
set -o errexit
set -o nounset
set -o pipefail
yum update -y
yum install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
yum install -y python36 python36-pip nginx gcc gcc-c++ python3-devel gettext
yum -y upgrade
python3 -m pip install -r /opt/clearml/fileserver/requirements.txt
python3 -m pip install -r /opt/clearml/apiserver/requirements.txt
mkdir -p /opt/clearml/log
mkdir -p /opt/clearml/config
ln -s /dev/stdout /var/log/nginx/access.log
ln -s /dev/stderr /var/log/nginx/error.log
mv /etc/nginx/nginx.conf /etc/nginx/nginx.conf.orig
mv /tmp/internal_files/clearml.conf.template /etc/nginx/clearml.conf.template
yum clean all

View File

@@ -19,6 +19,7 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
@@ -38,7 +39,8 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
@@ -60,7 +62,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
restart: unless-stopped
volumes:
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -87,12 +89,12 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:3.6.5
image: mongo:4.4.9
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
- c:/opt/clearml/data/mongo/db:/data/db
- c:/opt/clearml/data/mongo/configdb:/data/configdb
- c:/opt/clearml/data/mongo_4/db:/data/db
- c:/opt/clearml/data/mongo_4/configdb:/data/configdb
redis:
networks:

View File

@@ -19,6 +19,7 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
@@ -38,7 +39,8 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
@@ -60,7 +62,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
restart: unless-stopped
volumes:
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -86,12 +88,12 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:3.6.5
image: mongo:4.4.9
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
command: --setParameter internalQueryMaxBlockingSortMemoryUsageBytes=196100200
volumes:
- /opt/clearml/data/mongo/db:/data/db
- /opt/clearml/data/mongo/configdb:/data/configdb
- /opt/clearml/data/mongo_4/db:/data/db
- /opt/clearml/data/mongo_4/configdb:/data/configdb
redis:
networks:
@@ -121,7 +123,9 @@ services:
- backend
container_name: clearml-agent-services
image: allegroai/clearml-agent-services:latest
restart: unless-stopped
deploy:
restart_policy:
condition: on-failure
privileged: true
environment:
CLEARML_HOST_IP: ${CLEARML_HOST_IP}
@@ -132,7 +136,7 @@ services:
CLEARML_API_SECRET_KEY: ${CLEARML_API_SECRET_KEY:-}
CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:-">=0.17.0"}
CLEARML_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
@@ -142,6 +146,7 @@ services:
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
CLEARML_WORKER_ID: "clearml-services"
CLEARML_AGENT_DOCKER_HOST_MOUNT: "/opt/clearml/agent:/root/.clearml"
SHUTDOWN_IF_NO_ACCESS_KEY: 1
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/clearml/agent:/root/.clearml

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

View File

@@ -1,5 +1,6 @@
""" A Simple file server for uploading and downloading files """
import json
import mimetypes
import os
from argparse import ArgumentParser
from pathlib import Path
@@ -12,12 +13,15 @@ from werkzeug.exceptions import NotFound
from werkzeug.security import safe_join
from config import config
from utils import get_env_bool
DEFAULT_UPLOAD_FOLDER = "/mnt/fileserver"
app = Flask(__name__)
CORS(app, **config.get("fileserver.cors"))
Compress(app)
if get_env_bool("CLEARML_COMPRESS_RESP", default=True):
Compress(app)
app.config["UPLOAD_FOLDER"] = first(
(os.environ.get(f"{prefix}_UPLOAD_FOLDER") for prefix in ("CLEARML", "TRAINS")),
@@ -28,6 +32,20 @@ app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get(
)
@app.before_request
def before_request():
if request.content_encoding:
return f"Content encoding is not supported ({request.content_encoding})", 415
@app.after_request
def after_request(response):
response.headers["server"] = config.get(
"fileserver.response.headers.server", "clearml"
)
return response
@app.route("/", methods=["POST"])
def upload():
results = []
@@ -48,8 +66,15 @@ def upload():
@app.route("/<path:path>", methods=["GET"])
def download(path):
as_attachment = "download" in request.args
_, encoding = mimetypes.guess_type(os.path.basename(path))
mimetype = "application/octet-stream" if encoding == "gzip" else None
response = send_from_directory(
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment
app.config["UPLOAD_FOLDER"],
path,
as_attachment=as_attachment,
mimetype=mimetype,
)
if config.get("fileserver.download.disable_browser_caching", False):
headers = response.headers
@@ -63,12 +88,7 @@ def download(path):
@app.route("/<path:path>", methods=["DELETE"])
def delete(path):
real_path = Path(
safe_join(
os.fspath(app.config["UPLOAD_FOLDER"]),
os.fspath(path)
)
)
real_path = Path(safe_join(os.fspath(app.config["UPLOAD_FOLDER"]), os.fspath(path)))
if not real_path.exists() or not real_path.is_file():
abort(Response(f"File {str(path)} not found", 404))

14
fileserver/utils.py Normal file
View File

@@ -0,0 +1,14 @@
from distutils.util import strtobool
from os import getenv
from typing import Optional
def get_env_bool(*keys: str, default: bool = None) -> Optional[bool]:
try:
value = next(env for env in (getenv(key) for key in keys) if env is not None)
except StopIteration:
return default
try:
return bool(strtobool(value))
except ValueError:
return bool(value)