Compare commits

29 Commits

Author SHA1 Message Date
allegroai
b93591ec32 Improve startup sequence 2020-08-24 14:05:48 +03:00
allegroai
0abfd8da0d Version bump to v0.16.1 2020-08-23 15:43:38 +03:00
allegroai
a9cc4e36c6 Update docs 2020-08-23 15:41:05 +03:00
allegroai
fe1c963eec Fix internal export utility 2020-08-23 15:40:57 +03:00
allegroai
111d80e88d Add migration to verify correct project ordering 2020-08-23 15:39:36 +03:00
allegroai
6718862dbe Update fixed user name if user already exists 2020-08-23 15:38:53 +03:00
allegroai
0fe1bf8a61 Add elasticsearch log filtering while trying to connect 2020-08-23 15:38:22 +03:00
allegroai
10f326eda9 Fix KeyError when accessing log results in events.get_task_logs 2020-08-23 15:36:43 +03:00
allegroai
cd0d6c1a3d Fix max buckets calculation for iters histogram 2020-08-23 15:34:59 +03:00
allegroai
3205f2df97 Add services.tasks.multi_task_histogram_limit configuration option 2020-08-23 15:30:32 +03:00
allegroai
5bdbcfcd8d Update README and docker-compose files for v0.16.0 2020-08-10 23:48:38 +03:00
allegroai
a2e2052b30 Version bump 2020-08-10 08:56:50 +03:00
allegroai
0146ded4f4 Fix empty projection handling 2020-08-10 08:56:43 +03:00
allegroai
dccf9dd8f8 Fix incorrect formatted timestamp in events.download_task_log 2020-08-10 08:55:01 +03:00
allegroai
7816b402bb Enhance ES7 initialization and migration support
Support older task hyper-parameter migration on pre-population
2020-08-10 08:53:41 +03:00
allegroai
cd4ce30f7c Add support for field exclusion in get_all endpoints
Add support for ephemeral worker tags (valid while worker has not timed out)
2020-08-10 08:48:48 +03:00
allegroai
8c7e230898 Add support for Task hyper-parameter sections and meta-data
Add new Task configuration section
2020-08-10 08:45:25 +03:00
allegroai
42ba696518 Support order parameter in events.get_task_log 2020-08-10 08:37:41 +03:00
allegroai
3f84e60a1f Add debug.ping endpoint
Optimize exhausted scrolls by using a fixed empty scroll
2020-08-10 08:35:34 +03:00
allegroai
baba8b5b73 Move to ElasticSearch 7
Add initial support for project ordering
Add support for sortable task duration (used by the UI in the experiment's table)
Add support for project name in worker's current task info
Add support for results and artifacts in pre-populates examples
Add demo server features
2020-08-10 08:30:40 +03:00
Allegro AI
77397c4f21 Update docker-compose.yml 2020-07-09 13:21:44 +03:00
allegroai
8678091d8f Fix documentation, remove sudo from docker-compose up (issue #48) 2020-07-06 22:07:59 +03:00
allegroai
aa22170ab4 Fix support for example projects and experiments in demo server 2020-07-06 22:06:42 +03:00
allegroai
901ec37290 Improve pre-populate on server startup (including sync lock) 2020-07-06 22:05:36 +03:00
allegroai
21f2ea8b17 Add events.get_task_log for improved log retrieval support 2020-07-06 21:54:25 +03:00
allegroai
8219e3d4e2 Fix trains-agent-services default ubuntu docker to support unicode in tty 2020-07-06 21:52:32 +03:00
allegroai
3ed71a61d5 Add models.get_frameworks endpoint 2020-07-06 21:50:43 +03:00
allegroai
18a88a8e8f Update AWS AMIs 2020-06-24 23:15:47 +03:00
allegroai
318a72987c Update GCP images for v0.15.1 2020-06-22 13:00:30 +03:00
96 changed files with 4614 additions and 1801 deletions

View File

@@ -11,6 +11,12 @@
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
## v0.16 Upgrade Notice
In v0.16, the Elasticsearch subsystem of Trains Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
Follow [this procedure](https://allegro.ai/docs/deploying_trains/trains_server_es7_migration/) to migrate existing data.
## Introduction
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
@@ -64,15 +70,15 @@ For example, to see if port `8080` is in use:
Launch **trains-server** in any of the following formats:
- Pre-built [AWS EC2 AMI](https://github.com/allegroai/trains-server/blob/master/docs/install_aws.md)
- Pre-built [GCP Custom Image](https://github.com/allegroai/trains-server/blob/master/docs/install_gcp.md)
- Pre-built [AWS EC2 AMI](https://allegro.ai/docs/deploying_trains/trains_server_aws_ec2_ami/)
- Pre-built [GCP Custom Image](https://allegro.ai/docs/deploying_trains/trains_server_gcp/)
- Pre-built Docker Image
- [Linux](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
- [macOS](https://github.com/allegroai/trains-server/blob/master/docs/install_linux_mac.md)
- [Windows 10](https://github.com/allegroai/trains-server/blob/master/docs/install_win.md)
- [Linux](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [macOS](https://allegro.ai/docs/deploying_trains/trains_server_linux_mac/)
- [Windows 10](https://allegro.ai/docs/deploying_trains/trains_server_win/)
- Kubernetes
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#prerequisites)
- Manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#prerequisites)
- [Kubernetes Helm](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes_helm/)
- Manual [Kubernetes installation](https://allegro.ai/docs/deploying_trains/trains_server_kubernetes/)
## Connecting Trains to your trains-server
@@ -124,8 +130,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
**trains-server** provides a few additional useful features, which can be manually enabled:
* [Web login authentication](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#web-auth)
* [Non-responsive experiments watchdog](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#watchdog-the-non-responsive-task-watchdog-settings)
* [Web login authentication](https://allegro.ai/docs/faq/faq/#web-auth)
* [Non-responsive experiments watchdog](https://allegro.ai/docs/faq/faq/#watchdog)
## Restarting trains-server
@@ -191,12 +197,12 @@ To upgrade your existing **trains-server** deployment:
docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#common-docker-upgrade-errors).**
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/docs/faq/faq/#common-docker-upgrade-errors).**
## Community & Support
If you have any questions, look to the Trains server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
If you have any questions, look to the Trains [FAQ](https://allegro.ai/docs/faq/faq/), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).

View File

@@ -40,15 +40,11 @@ services:
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
@@ -58,10 +54,10 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
ports:
- "9200:9200"
mongo:

View File

@@ -40,15 +40,11 @@ services:
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
@@ -58,10 +54,10 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- c:/opt/trains/data/elastic:/usr/share/elasticsearch/data
- c:/opt/trains/data/elastic_7:/usr/share/elasticsearch/data
ports:
- "9200:9200"

View File

@@ -10,6 +10,7 @@ services:
volumes:
- /opt/trains/logs:/var/log/trains
- /opt/trains/config:/opt/trains/config
- /opt/trains/data/fileserver:/mnt/fileserver
depends_on:
- redis
- mongo
@@ -23,8 +24,9 @@ services:
TRAINS_REDIS_SERVICE_HOST: redis
TRAINS_REDIS_SERVICE_PORT: 6379
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
TRAINS__apiserver__pre_populate__enabled: "true"
TRAINS__apiserver__pre_populate__zip_files: "/opt/trains/db-pre-populate"
TRAINS__apiserver__pre_populate__artifacts_path: "/mnt/fileserver"
ports:
- "8008:8008"
networks:
@@ -40,15 +42,11 @@ services:
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
discovery.type: "single-node"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
@@ -58,10 +56,10 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
restart: unless-stopped
volumes:
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
- /opt/trains/data/elastic_7:/usr/share/elasticsearch/data
ports:
- "9200:9200"

View File

@@ -54,41 +54,41 @@ The following sections contain lists of AMI Image IDs, per region, for each rele
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
* **eu-north-1** : ami-0f63429f8e5d57315
* **ap-south-1** : ami-058a2a70b7fb8ec87
* **eu-west-3** : ami-0fc9f9e8e986f39c4
* **eu-west-2** : ami-0b0bc1ff2f0239bd9
* **eu-west-1** : ami-0056ec5d22b0fac91
* **ap-northeast-2** : ami-0898c9aa7f580fec7
* **ap-northeast-1** : ami-011036ddcc9398871
* **sa-east-1** : ami-04feeded12192438c
* **ca-central-1** : ami-02c717776c9e75025
* **ap-southeast-1** : ami-05b5866e7029bb9f1
* **ap-southeast-2** : ami-0384bd2b69467fff8
* **eu-central-1** : ami-01f15be85297d6f06
* **us-east-2** : ami-094070ca8aa110180
* **us-west-1** : ami-0d08ec5bc29eddb29
* **us-west-2** : ami-04715cceedaf6eae7
* **us-east-1** : ami-071dbaa1847585c4c
* **eu-north-1** : ami-0f30c84b905d354b9
* **ap-south-1** : ami-050e7acec52c8c74e
* **eu-west-3** : ami-03911c5b5bc77ef75
* **eu-west-2** : ami-0a5ed8aa2573ccc70
* **eu-west-1** : ami-0a53c65e922ec0611
* **ap-northeast-2** : ami-08cd017a37b8e8aab
* **ap-northeast-1** : ami-056b3ca1ad5af9322
* **sa-east-1** : ami-01ddc9325bafb400c
* **ca-central-1** : ami-0fc3cbbd982b18b45
* **ap-southeast-1** : ami-04c7a358df7002ef5
* **ap-southeast-2** : ami-0eeaf54231b4ae22a
* **eu-central-1** : ami-00b8e44041f8175fd
* **us-east-2** : ami-0ac7deebb3f738f6d
* **us-west-1** : ami-06bc07deb8b8c44d6
* **us-west-2** : ami-01ba85ffe79a422f1
* **us-east-1** : ami-04cf5a66cb4928ac3
### v0.15.1 (static update)
* **eu-north-1** : ami-0bb36c4dbe61f8c46
* **ap-south-1** : ami-0ac93ff85a5c770f9
* **eu-west-3** : ami-015ebfa846b8de5bb
* **eu-west-2** : ami-082aacd59408713d9
* **eu-west-1** : ami-066aad8c6b9b9991b
* **ap-northeast-2** : ami-0cb47f1c8591c799d
* **ap-northeast-1** : ami-005131d3037da9d2a
* **sa-east-1** : ami-0f7fdc4e19c8444a3
* **ca-central-1** : ami-07c234dad3ece2d78
* **ap-southeast-1** : ami-0d8e0475d7d4897e4
* **ap-southeast-2** : ami-053e3f25dee0424b9
* **eu-central-1** : ami-00d25558c5242708e
* **us-east-2** : ami-0bd45f800dfbde456
* **us-west-1** : ami-05e79bf1704721148
* **us-west-2** : ami-037c328649048409b
* **us-east-1** : ami-0a3cafe46bf085200
* **eu-north-1** : ami-0cd314e267426d1b7
* **ap-south-1** : ami-086182cbe29151f96
* **eu-west-3** : ami-0062366012182815b
* **eu-west-2** : ami-022b8f2e32a9d18d0
* **eu-west-1** : ami-0d8cf60446e09aa3d
* **ap-northeast-2** : ami-0d4c168a815b56889
* **ap-northeast-1** : ami-0daf7887db1053ae4
* **sa-east-1** : ami-020a759a3ba4ff22b
* **ca-central-1** : ami-0c10b5e04b707f3e3
* **ap-southeast-1** : ami-0f61bb3529a165fcd
* **ap-southeast-2** : ami-032dcdc82749c66c5
* **eu-central-1** : ami-08f364f32d2eb3bae
* **us-east-2** : ami-0b7efc3591803eba4
* **us-west-1** : ami-08b2df27b0ada6faf
* **us-west-2** : ami-0693029c4bad28816
* **us-east-1** : ami-0200954fa9c2819ff
### v0.15.0 (static update)

View File

@@ -3,13 +3,16 @@
To easily deploy Trains Server on GCP, use one of our pre-built GCP Custom Images.
We provide Custom Images for each released version of Trains Server, see [Released versions](#released-versions) below.
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
Once your GCP instance is up and running using our Custom Image, [configure the Trains client](https://github.com/allegroai/trains/blob/master/README.md#configuration) to use your **trains-server**.
#### Default Trains Server Service ports
The service port numbers on our Trains Server GCP Custom Image are:
- Web application: `8080`
- API Server: `8008`
- File Server: `8081`
#### Default Trains Server Storage paths
The persistent storage configuration:
- MongoDB: `/opt/trains/data/mongo/`
@@ -49,6 +52,15 @@ The minimum recommended requirements for Trains Server are:
To upgrade **trains-server** on an existing GCP instance based on one of these Custom Images, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
## Network and Security
Please make sure your instance is properly secured.
If not specifically set, a GCP instance will use default firewall rules that allow public access to various ports.
If your instance is open for public access, we recommend you follow best practices for access management, including:
- Allow access only to the specific ports used by Trains Server (see [Default Trains Server Service ports](#default-trains-server-service-ports)). Remember to allow access to port `443` if `https` access is configured for your instance.
- Configure Trains Server to use fixed user names and passwords (see [Can I add web login authentication to trains-server?](./faq.md#web-auth))
## Released versions
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
@@ -59,5 +71,6 @@ The following sections contain lists of Custom Image URLs (exported in different
### All released images
- v0.15.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-1.tar.gz
- v0.15.0 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-0.tar.gz
- v0.14.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-14-1.tar.gz

View File

@@ -1,6 +1,6 @@
# Launching the **trains-server** Docker in Linux or macOS
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
For Linux or macOS, use our pre-built Docker image for easy deployment. The latest Docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
For Linux users:
@@ -16,20 +16,20 @@ To launch **trains-server** on Linux or macOS:
1. Verify the Docker CE installation. Execute the command:
sudo docker run hello-world
docker run hello-world
The expected is output is:
Hello from Docker!
This message shows that your installation appears to be working correctly.
To generate this message, Docker took the following steps:
1. The Docker client contacted the Docker daemon.
2. The Docker daemon pulled the "hello-world" image from the Docker Hub. (amd64)
3. The Docker daemon created a new container from that image which runs the executable that produces the output you are currently reading.
4. The Docker daemon streamed that output to the Docker client, which sent it to your terminal.
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
1. For Linux only, install `docker-compose`. Execute the following commands (for more information, see [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
sudo chmod +x /usr/local/bin/docker-compose
@@ -42,12 +42,12 @@ To launch **trains-server** on Linux or macOS:
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
sudo service docker restart
macOS:
screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
sysctl -w vm.max_map_count=262144
1. Remove any previous installation of **trains-server**.
@@ -64,28 +64,28 @@ To launch **trains-server** on Linux or macOS:
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/config
sudo mkdir -p /opt/trains/data/fileserver
1. For macOS only, open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
1. Grant access to the Dockers.
Linux:
sudo chown -R 1000:1000 /opt/trains
macOS:
sudo chown -R $(whoami):staff /opt/trains
1. Download the **trains-server** docker-compose YAML file.
cd /opt/trains
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
1. Run `docker-compose` with the downloaded configuration file.
sudo docker-compose -f docker-compose.yml up
docker-compose -f docker-compose.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* Web server on port `8080`
@@ -94,4 +94,4 @@ To launch **trains-server** on Linux or macOS:
## Next Step
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).
Configure the [Trains client for trains-server](https://github.com/allegroai/trains/blob/master/README.md#configuration).

View File

@@ -1 +1 @@
__version__ = "2.8.0"
__version__ = "2.9.0"

View File

@@ -1,7 +1,8 @@
from jsonmodels import models, fields
from jsonmodels.validators import Length
from mongoengine.base import BaseDocument
from apimodels import DictField
from apimodels import DictField, ListField
class MongoengineFieldsDict(DictField):
@@ -12,14 +13,14 @@ class MongoengineFieldsDict(DictField):
"""
mongoengine_update_operators = (
'inc',
'dec',
'push',
'push_all',
'pop',
'pull',
'pull_all',
'add_to_set',
"inc",
"dec",
"push",
"push_all",
"pop",
"pull",
"pull_all",
"add_to_set",
)
@staticmethod
@@ -30,16 +31,16 @@ class MongoengineFieldsDict(DictField):
@classmethod
def _normalize_mongo_field_path(cls, path, value):
parts = path.split('__')
parts = path.split("__")
if len(parts) > 1:
if parts[0] == 'set':
if parts[0] == "set":
parts = parts[1:]
elif parts[0] == 'unset':
elif parts[0] == "unset":
parts = parts[1:]
value = None
elif parts[0] in cls.mongoengine_update_operators:
return None, None
return '.'.join(parts), cls._normalize_mongo_value(value)
return ".".join(parts), cls._normalize_mongo_value(value)
def parse_value(self, value):
value = super(MongoengineFieldsDict, self).parse_value(value)
@@ -62,3 +63,7 @@ class PagedRequest(models.Base):
class IdResponse(models.Base):
id = fields.StringField(required=True)
class MakePublicRequest(models.Base):
ids = ListField(items_types=str, validators=[Length(minimum_value=1)])

View File

@@ -1,17 +1,20 @@
from typing import Sequence
from enum import auto
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.models import Base
from jsonmodels.validators import Length
from jsonmodels.validators import Length, Min, Max
from apimodels import ListField, IntField, ActualEnumField
from bll.event.event_metrics import EventType
from bll.event.scalar_key import ScalarKeyEnum
from config import config
from utilities.stringenum import StringEnum
class HistogramRequestBase(Base):
samples: int = IntField(default=10000)
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
@@ -21,7 +24,15 @@ class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(
items_types=str, validators=[Length(minimum_value=1)]
items_types=str,
validators=[
Length(
minimum_value=1,
maximum_value=config.get(
"services.tasks.multi_task_histogram_limit", 10
),
)
],
)
@@ -40,12 +51,17 @@ class DebugImagesRequest(Base):
scroll_id: str = StringField()
class LogOrderEnum(StringEnum):
asc = auto()
desc = auto()
class LogEventsRequest(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
navigate_earlier: bool = BoolField(default=True)
refresh: bool = BoolField(default=False)
scroll_id: str = StringField()
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
class IterationEvents(Base):

View File

@@ -6,6 +6,10 @@ from apimodels.base import UpdateResponse
from apimodels.tasks import PublishResponse as TaskPublishResponse
class GetFrameworksRequest(models.Base):
projects = fields.ListField(items_types=[str])
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)

View File

@@ -13,11 +13,5 @@ class GetHyperParamReq(ProjectReq):
page_size = fields.IntField(default=500)
class GetHyperParamResp(models.Base):
parameters = fields.ListField(str)
remaining = fields.IntField()
total = fields.IntField()
class ProjectTagsRequest(TagsRequest):
projects = ListField(str)

View File

@@ -1,7 +1,9 @@
from typing import Sequence
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField, EmbeddedField
from jsonmodels.validators import Enum
from jsonmodels.validators import Enum, Length
from apimodels import DictField, ListField
from apimodels.base import UpdateResponse
@@ -103,6 +105,8 @@ class CloneRequest(TaskRequest):
new_task_system_tags = ListField([str])
new_task_parent = StringField()
new_task_project = StringField()
new_hyperparams = DictField()
new_configuration = DictField()
execution_overrides = DictField()
validate_references = BoolField(default=False)
@@ -118,3 +122,72 @@ class AddOrUpdateArtifactsResponse(models.Base):
class ResetRequest(UpdateRequest):
clear_all = BoolField(default=False)
class MultiTaskRequest(models.Base):
tasks = ListField([str], validators=Length(minimum_value=1))
class GetHyperParamsRequest(MultiTaskRequest):
pass
class HyperParamItem(models.Base):
section = StringField(required=True, validators=Length(minimum_value=1))
name = StringField(required=True, validators=Length(minimum_value=1))
value = StringField(required=True)
type = StringField()
description = StringField()
class ReplaceHyperparams(object):
none = "none"
section = "section"
all = "all"
class EditHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamItem] = ListField(
[HyperParamItem], validators=Length(minimum_value=1)
)
replace_hyperparams = StringField(
validators=Enum(*get_options(ReplaceHyperparams)),
default=ReplaceHyperparams.none,
)
class HyperParamKey(models.Base):
section = StringField(required=True, validators=Length(minimum_value=1))
name = StringField(nullable=True)
class DeleteHyperParamsRequest(TaskRequest):
hyperparams: Sequence[HyperParamKey] = ListField(
[HyperParamKey], validators=Length(minimum_value=1)
)
class GetConfigurationsRequest(MultiTaskRequest):
names = ListField([str])
class GetConfigurationNamesRequest(MultiTaskRequest):
pass
class Configuration(models.Base):
name = StringField(required=True, validators=Length(minimum_value=1))
value = StringField(required=True)
type = StringField()
description = StringField()
class EditConfigurationRequest(TaskRequest):
configuration: Sequence[Configuration] = ListField(
[Configuration], validators=Length(minimum_value=1)
)
replace_configuration = BoolField(default=False)
class DeleteConfigurationRequest(TaskRequest):
configuration: Sequence[str] = ListField([str], validators=Length(minimum_value=1))

View File

@@ -19,6 +19,7 @@ DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base):
worker = StringField(required=True)
tags = ListField(str)
class RegisterRequest(WorkerRequest):
@@ -67,12 +68,14 @@ class WorkerEntry(Base, JsonSerializableMixin):
company = EmbeddedField(IdNameEntry)
ip = StringField()
task = EmbeddedField(IdNameEntry)
project = EmbeddedField(IdNameEntry)
queue = StringField() # queue from which current task was taken
queues = ListField(str) # list of queues this worker listens to
register_time = DateTimeField(required=True)
register_timeout = IntField(required=True)
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
tags = ListField(str)
class CurrentTaskEntry(IdNameEntry):

View File

@@ -208,7 +208,11 @@ class DebugImagesIterator:
"size": 0,
"query": {
"bool": {
"must": [{"term": {"task": task}}, {"terms": {"metric": metrics}}]
"must": [
{"term": {"task": task}},
{"terms": {"metric": metrics}},
{"exists": {"field": "url"}},
]
}
},
"aggs": {
@@ -251,7 +255,7 @@ class DebugImagesIterator:
}
with translate_errors_context(), TimingContext("es", "_init_metric_states"):
es_res = self.es.search(index=es_index, body=es_req, routing=task)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return []
@@ -298,6 +302,7 @@ class DebugImagesIterator:
must_conditions = [
{"term": {"task": metric.task}},
{"term": {"metric": metric.name}},
{"exists": {"field": "url"}},
]
must_not_conditions = []
@@ -368,7 +373,7 @@ class DebugImagesIterator:
"terms": {
"field": "iter",
"size": iter_count,
"order": {"_term": "desc" if navigate_earlier else "asc"},
"order": {"_key": "desc" if navigate_earlier else "asc"},
},
"aggs": {
"variants": {
@@ -387,7 +392,7 @@ class DebugImagesIterator:
},
}
with translate_errors_context(), TimingContext("es", "get_debug_image_events"):
es_res = self.es.search(index=es_index, body=es_req, routing=metric.task)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return metric.task, metric.name, []

View File

@@ -3,7 +3,7 @@ from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple
from typing import Sequence, Set, Tuple, Optional
import six
from elasticsearch import helpers
@@ -22,6 +22,7 @@ from database.errors import translate_errors_context
from database.model.task.task import Task, TaskStatus
from redis_manager import redman
from timing_context import TimingContext
from tools import safe_get
from utilities.dicts import flatten_nested_items
# noinspection PyTypeChecker
@@ -31,6 +32,7 @@ LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
class EventBLL(object):
id_fields = ("task", "iter", "metric", "variant", "key")
empty_scroll = "FFFF"
def __init__(self, events_es=None, redis=None):
self.es = events_es or es_factory.connect("events")
@@ -40,7 +42,7 @@ class EventBLL(object):
)
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
@property
def metrics(self) -> EventMetrics:
@@ -134,7 +136,6 @@ class EventBLL(object):
es_action = {
"_op_type": "index", # overwrite if exists with same ID
"_index": index_name,
"_type": "event",
"_source": event,
}
@@ -144,7 +145,6 @@ class EventBLL(object):
else:
es_action["_id"] = dbutils.id()
es_action["_routing"] = task_id
task_ids.add(task_id)
if (
iter is not None
@@ -322,6 +322,9 @@ class EventBLL(object):
batch_size=10000,
scroll_id=None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "task_log_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
@@ -342,14 +345,9 @@ class EventBLL(object):
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search(
index=es_index, body=es_req, scroll="1h", routing=task_id
)
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
next_scroll_id = es_res["_scroll_id"]
total_events = es_res["hits"]["total"]
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
@@ -377,7 +375,7 @@ class EventBLL(object):
"terms": {
"field": "iter",
"size": num_last_iterations,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
}
}
},
@@ -393,7 +391,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return []
@@ -413,6 +411,9 @@ class EventBLL(object):
size: int = 500,
scroll_id: str = None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
@@ -422,13 +423,11 @@ class EventBLL(object):
if not self.es.indices.exists(es_index):
return TaskEventsResult()
query = {"bool": defaultdict(list)}
must = []
if last_iterations_per_plot is None:
must = query["bool"]["must"]
must.append({"terms": {"task": tasks}})
else:
should = query["bool"]["should"]
should = []
for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant(
es_index, task_id, last_iterations_per_plot, event_type
@@ -451,32 +450,41 @@ class EventBLL(object):
)
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
routing = ",".join(tasks)
es_req = {
"sort": sort,
"size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_plots"):
es_res = self.es.search(
index=es_index,
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
index=es_index, body=es_req, ignore=404, scroll="1h",
)
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
# scroll id may be missing when queering a totally empty DB
next_scroll_id = es_res.get("_scroll_id")
total_events = es_res["hits"]["total"]
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
def _get_events_from_es_res(self, es_res: dict) -> Tuple[list, int, Optional[str]]:
"""
Return events and next scroll id from the scrolled query
Release the scroll once it is exhausted
"""
total_events = safe_get(es_res, "hits/total/value", default=0)
events = [doc["_source"] for doc in safe_get(es_res, "hits/hits", default=[])]
next_scroll_id = es_res.get("_scroll_id")
if next_scroll_id and not events:
self.es.clear_scroll(scroll_id=next_scroll_id)
next_scroll_id = self.empty_scroll
return events, total_events, next_scroll_id
def get_task_events(
self,
company_id,
@@ -489,6 +497,8 @@ class EventBLL(object):
size=500,
scroll_id=None,
):
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
@@ -502,20 +512,16 @@ class EventBLL(object):
if not self.es.indices.exists(es_index):
return TaskEventsResult()
query = {"bool": defaultdict(list)}
if metric or variant:
must = query["bool"]["must"]
if metric:
must.append({"term": {"metric": metric}})
if variant:
must.append({"term": {"variant": variant}})
must = []
if metric:
must.append({"term": {"metric": metric}})
if variant:
must.append({"term": {"variant": variant}})
if last_iter_count is None:
must = query["bool"]["must"]
must.append({"terms": {"task": task_ids}})
else:
should = query["bool"]["should"]
should = []
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
es_index, task_id, event_type, last_iter_count
@@ -534,27 +540,23 @@ class EventBLL(object):
)
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
routing = ",".join(task_ids)
es_req = {
"sort": sort,
"size": min(size, 10000),
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.search(
index=es_index,
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
index=es_index, body=es_req, ignore=404, scroll="1h",
)
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
next_scroll_id = es_res["_scroll_id"]
total_events = es_res["hits"]["total"]
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
@@ -590,7 +592,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
@@ -622,14 +624,14 @@ class EventBLL(object):
"terms": {
"field": "metric",
"size": EventMetrics.MAX_METRICS_COUNT,
"order": {"_term": "asc"},
"order": {"_key": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": EventMetrics.MAX_VARIANTS_COUNT,
"order": {"_term": "asc"},
"order": {"_key": "asc"},
},
"aggs": {
"last_value": {
@@ -659,7 +661,7 @@ class EventBLL(object):
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
metrics = []
max_timestamp = 0
@@ -706,7 +708,7 @@ class EventBLL(object):
"sort": ["iter"],
}
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
vectors = []
iterations = []
@@ -727,7 +729,7 @@ class EventBLL(object):
"terms": {
"field": "iter",
"size": iters,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
}
}
},
@@ -737,7 +739,7 @@ class EventBLL(object):
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return []
@@ -759,8 +761,6 @@ class EventBLL(object):
es_index = EventMetrics.get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(
index=es_index, body=es_req, routing=task_id, refresh=True
)
es_res = self.es.delete_by_query(index=es_index, body=es_req, refresh=True)
return es_res.get("deleted", 0)

View File

@@ -1,12 +1,12 @@
import itertools
import math
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures.thread import ThreadPoolExecutor
from enum import Enum
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Callable, Iterable
from typing import Sequence, Tuple
from boltons.iterutils import bucketize
from elasticsearch import Elasticsearch
from mongoengine import Q
@@ -16,7 +16,7 @@ from config import config
from database.errors import translate_errors_context
from database.model.task.task import Task
from timing_context import TimingContext
from utilities import safe_get
from tools import safe_get
log = config.logger(__file__)
@@ -30,14 +30,18 @@ class EventType(Enum):
class EventMetrics:
MAX_TASKS_COUNT = 50
MAX_METRICS_COUNT = 200
MAX_VARIANTS_COUNT = 500
MAX_METRICS_COUNT = 100
MAX_VARIANTS_COUNT = 100
MAX_AGGS_ELEMENTS_COUNT = 50
MAX_SAMPLE_BUCKETS = 6000
def __init__(self, es: Elasticsearch):
self.es = es
@property
def _max_concurrency(self):
return config.get("services.events.max_metrics_concurrency", 4)
@staticmethod
def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_")
@@ -51,15 +55,48 @@ class EventMetrics:
The amount of points in each histogram should not exceed
the requested samples
"""
es_index = self.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
return self._run_get_scalar_metrics_as_parallel(
company_id,
task_ids=[task_id],
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average,
return self._get_scalar_average_per_iter_core(
task_id, es_index, samples, ScalarKey.resolve(key)
)
def _get_scalar_average_per_iter_core(
self,
task_id: str,
es_index: str,
samples: int,
key: ScalarKey,
run_parallel: bool = True,
) -> dict:
intervals = self._get_task_metric_intervals(
es_index=es_index, task_id=task_id, samples=samples, field=key.field
)
if not intervals:
return {}
interval_groups = self._group_task_metric_intervals(intervals)
get_scalar_average = partial(
self._get_scalar_average, task_id=task_id, es_index=es_index, key=key
)
if run_parallel:
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(get_scalar_average, interval_groups)
)
else:
metrics = itertools.chain.from_iterable(
get_scalar_average(group) for group in interval_groups
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
return ret
def compare_scalar_metrics_average_per_iter(
self,
company_id,
@@ -72,159 +109,115 @@ class EventMetrics:
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
if len(task_ids) > self.MAX_TASKS_COUNT:
raise errors.BadRequest(
f"Up to {self.MAX_TASKS_COUNT} tasks supported for comparison",
len(task_ids),
)
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name"),
override_projection=("id", "name", "company"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
ret = self._run_get_scalar_metrics_as_parallel(
company_id,
task_ids=task_ids,
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average_per_task,
)
companies = {t.company for t in task_objs}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
for metric_data in ret.values():
for variant_data in metric_data.values():
for task_id, task_data in variant_data.items():
task_data["name"] = task_name_by_id[task_id]
return ret
TaskMetric = Tuple[str, str, str]
MetricInterval = Tuple[int, Sequence[TaskMetric]]
MetricData = Tuple[str, dict]
def _split_metrics_by_max_aggs_count(
self, task_metrics: Sequence[TaskMetric]
) -> Iterable[Sequence[TaskMetric]]:
"""
Return task metrics in groups where amount of task metrics in each group
is roughly limited by MAX_AGGS_ELEMENTS_COUNT. The split is done on metrics and
variants while always preserving all their tasks in the same group
"""
if len(task_metrics) < self.MAX_AGGS_ELEMENTS_COUNT:
yield task_metrics
return
tm_grouped = bucketize(task_metrics, key=itemgetter(1, 2))
groups = []
for group in tm_grouped.values():
groups.append(group)
if sum(map(len, groups)) >= self.MAX_AGGS_ELEMENTS_COUNT:
yield list(itertools.chain(*groups))
groups = []
if groups:
yield list(itertools.chain(*groups))
return
def _run_get_scalar_metrics_as_parallel(
self,
company_id: str,
task_ids: Sequence[str],
samples: int,
key: ScalarKey,
get_func: Callable[
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
],
) -> dict:
"""
Group metrics per interval length and execute get_func for each group in parallel
:param company_id: id of the company
:params task_ids: ids of the tasks to collect data for
:param samples: maximum number of samples per metric
:param get_func: callable that given metric names for the same interval
performs histogram aggregation for the metrics and return the aggregated data
"""
es_index = self.get_index_name(company_id, "training_stats_scalar")
es_index = self.get_index_name(next(iter(companies)), "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
intervals = self._get_metric_intervals(
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
get_scalar_average_per_iter = partial(
self._get_scalar_average_per_iter_core,
es_index=es_index,
samples=samples,
key=ScalarKey.resolve(key),
run_parallel=False,
)
if not intervals:
return {}
intervals = list(
itertools.chain.from_iterable(
zip(itertools.repeat(i), self._split_metrics_by_max_aggs_count(tms))
for i, tms in intervals
)
)
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_workers=max_concurrency) as pool:
metrics = itertools.chain.from_iterable(
pool.map(
partial(get_func, task_ids=task_ids, es_index=es_index, key=key),
intervals,
)
with ThreadPoolExecutor(max_workers=self._max_concurrency) as pool:
task_metrics = zip(
task_ids, pool.map(get_scalar_average_per_iter, task_ids)
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
res = defaultdict(lambda: defaultdict(dict))
for task_id, task_data in task_metrics:
task_name = task_name_by_id[task_id]
for metric_key, metric_data in task_data.items():
for variant_key, variant_data in metric_data.items():
variant_data["name"] = task_name
res[metric_key][variant_key][task_id] = variant_data
return ret
return res
def _get_metric_intervals(
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
MetricInterval = Tuple[str, str, int, int]
MetricIntervalGroup = Tuple[int, Sequence[Tuple[str, str]]]
@classmethod
def _group_task_metric_intervals(
cls, intervals: Sequence[MetricInterval]
) -> Sequence[MetricIntervalGroup]:
"""
Group task metric intervals so that the following conditions are meat:
- All the metrics in the same group have the same interval (with 10% rounding)
- The amount of metrics in the group does not exceed MAX_AGGS_ELEMENTS_COUNT
- The total count of samples in the group does not exceed MAX_SAMPLE_BUCKETS
"""
metric_interval_groups = []
interval_group = []
group_interval_upper_bound = 0
group_max_interval = 0
group_samples = 0
for metric, variant, interval, size in sorted(intervals, key=itemgetter(2)):
if (
interval > group_interval_upper_bound
or (group_samples + size) > cls.MAX_SAMPLE_BUCKETS
or len(interval_group) >= cls.MAX_AGGS_ELEMENTS_COUNT
):
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
interval_group = []
group_max_interval = interval
group_interval_upper_bound = interval + int(interval * 0.1)
group_samples = 0
interval_group.append((metric, variant))
group_samples += size
group_max_interval = max(group_max_interval, interval)
if interval_group:
metric_interval_groups.append((group_max_interval, interval_group))
return metric_interval_groups
def _get_task_metric_intervals(
self, es_index, task_id: str, samples: int, field: str = "iter"
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
amount of points does not exceed sample.
Return metric variants grouped by interval value with 10% rounding
For samples==0 return empty list
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
default_intervals = [(1, [])]
if not samples:
return default_intervals
es_req = {
"size": 0,
"query": {"terms": {"task": task_ids}},
"query": {"term": {"task": task_id}},
"aggs": {
"tasks": {
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
"aggs": {
"metrics": {
"variants": {
"terms": {
"field": "metric",
"size": self.MAX_METRICS_COUNT,
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
},
"aggs": {
"count": {"value_count": {"field": field}},
"min_index": {"min": {"field": field}},
"max_index": {"max": {"field": field}},
},
}
"count": {"value_count": {"field": field}},
"min_index": {"min": {"field": field}},
"max_index": {"max": {"field": field}},
},
}
},
@@ -233,88 +226,78 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = self.es.search(
index=es_index, body=es_req, routing=",".join(task_ids)
)
es_res = self.es.search(index=es_index, body=es_req)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return default_intervals
return []
intervals = [
(
task["key"],
metric["key"],
variant["key"],
self._calculate_metric_interval(variant, samples),
)
for task in aggs_result["tasks"]["buckets"]
for metric in task["metrics"]["buckets"]
return [
self._build_metric_interval(metric["key"], variant["key"], variant, samples)
for metric in aggs_result["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
]
metric_intervals = []
upper_border = 0
interval_metrics = None
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
if not interval_metrics or interval > upper_border:
interval_metrics = []
metric_intervals.append((interval, interval_metrics))
upper_border = interval + int(interval * 0.1)
interval_metrics.append((task, metric, variant))
return metric_intervals
@staticmethod
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
def _build_metric_interval(
metric: str, variant: str, data: dict, samples: int
) -> Tuple[str, str, int, int]:
"""
Calculate index interval per metric_variant variant so that the
total amount of intervals does not exceeds the samples
Return the interval and resulting amount of intervals
"""
count = safe_get(metric_variant, "count/value")
if not count or count < samples:
return 1
count = safe_get(data, "count/value", default=0)
if count < samples:
return metric, variant, 1, count
min_index = safe_get(metric_variant, "min_index/value", default=0)
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
return max(1, int(max_index - min_index + 1) // samples)
min_index = safe_get(data, "min_index/value", default=0)
max_index = safe_get(data, "max_index/value", default=min_index)
index_range = max_index - min_index + 1
interval = max(1, math.ceil(float(index_range) / samples))
max_samples = math.ceil(float(index_range) / interval)
return (
metric,
variant,
interval,
max_samples,
)
MetricData = Tuple[str, dict]
def _get_scalar_average(
self,
metrics_interval: MetricInterval,
task_ids: Sequence[str],
metrics_interval: MetricIntervalGroup,
task_id: str,
es_index: str,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
Note: the function works with a single task only
"""
assert len(task_ids) == 1
interval, task_metrics = metrics_interval
interval, metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {
"field": "metric",
"size": self.MAX_METRICS_COUNT,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
"order": {"_term": "desc"},
"order": {"_key": "desc"},
},
"aggs": aggregation,
}
},
}
}
aggs_result = self._query_aggregation_for_metrics_and_tasks(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
aggs_result = self._query_aggregation_for_task_metrics(
es_index, aggs=aggs, task_id=task_id, metrics=metrics
)
if not aggs_result:
@@ -335,61 +318,6 @@ class EventMetrics:
]
return metrics
def _get_scalar_average_per_task(
self,
metrics_interval: MetricInterval,
task_ids: Sequence[str],
es_index: str,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
"""
interval, task_metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
"aggs": {
"variants": {
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
"aggs": {
"tasks": {
"terms": {
"field": "task",
"size": self.MAX_TASKS_COUNT,
},
"aggs": aggregation,
}
},
}
},
}
}
aggs_result = self._query_aggregation_for_metrics_and_tasks(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
)
if not aggs_result:
return {}
metrics = [
(
metric["key"],
{
variant["key"]: {
task["key"]: key.get_iterations_data(task)
for task in variant["tasks"]["buckets"]
}
for variant in metric["variants"]["buckets"]
},
)
for metric in aggs_result["metrics"]["buckets"]
]
return metrics
@staticmethod
def _add_aggregation_average(aggregation):
average_agg = {"avg_val": {"avg": {"field": "value"}}}
@@ -398,69 +326,55 @@ class EventMetrics:
for key, value in aggregation.items()
}
def _query_aggregation_for_metrics_and_tasks(
def _query_aggregation_for_task_metrics(
self,
es_index: str,
aggs: dict,
task_ids: Sequence[str],
task_metrics: Sequence[TaskMetric],
task_id: str,
metrics: Sequence[Tuple[str, str]],
) -> dict:
"""
Return the result of elastic search query for the given aggregation filtered
by the given task_ids and metrics
"""
if task_metrics:
condition = {
"should": [
self._build_metric_terms(task, metric, variant)
for task, metric, variant in task_metrics
]
}
else:
condition = {"must": [{"terms": {"task": task_ids}}]}
must = [{"term": {"task": task_id}}]
if metrics:
should = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
}
for metric, variant in metrics
]
must.append({"bool": {"should": should}})
es_req = {
"size": 0,
"_source": {"excludes": []},
"query": {"bool": condition},
"query": {"bool": {"must": must}},
"aggs": aggs,
"version": True,
}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search(
index=es_index, body=es_req, routing=",".join(task_ids)
)
es_res = self.es.search(index=es_index, body=es_req)
return es_res.get("aggregations")
@staticmethod
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
"""
Build query term for a metric + variant
"""
return {
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
}
def get_tasks_metrics(
self, company_id, task_ids: Sequence, event_type: EventType
) -> Sequence[Tuple]:
) -> Sequence:
"""
For the requested tasks return all the metrics that
reported events of the requested types
"""
es_index = EventMetrics.get_index_name(company_id, event_type.value)
if not self.es.indices.exists(es_index):
return [(tid, []) for tid in task_ids]
return {}
max_concurrency = config.get("services.events.max_metrics_concurrency", 4)
with ThreadPoolExecutor(max_concurrency) as pool:
with ThreadPoolExecutor(self._max_concurrency) as pool:
res = pool.map(
partial(
self._get_task_metrics, es_index=es_index, event_type=event_type,
@@ -488,7 +402,7 @@ class EventMetrics:
}
with translate_errors_context(), TimingContext("es", "_get_task_metrics"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
es_res = self.es.search(index=es_index, body=es_req)
return [
metric["key"]

View File

@@ -2,30 +2,12 @@ from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from jsonmodels.fields import StringField, IntField
from jsonmodels.models import Base
from redis import StrictRedis
from apierrors import errors
from apimodels import JsonSerializableMixin
from bll.event.event_metrics import EventMetrics
from bll.redis_cache_manager import RedisCacheManager
from config import config
from database.errors import translate_errors_context
from timing_context import TimingContext
class LogEventsScrollState(Base, JsonSerializableMixin):
id: str = StringField(required=True)
task: str = StringField(required=True)
last_min_timestamp: Optional[int] = IntField()
last_max_timestamp: Optional[int] = IntField()
def reset(self):
"""Reset the scrolling state """
self.last_min_timestamp = self.last_max_timestamp = None
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
@@ -36,19 +18,8 @@ class TaskEventsResult:
class LogEventsIterator:
EVENT_TYPE = "log"
@property
def state_expiration_sec(self):
return config.get(
f"services.events.events_retrieval.state_expiration_sec", 3600
)
def __init__(self, redis: StrictRedis, es: Elasticsearch):
def __init__(self, es: Elasticsearch):
self.es = es
self.cache_manager = RedisCacheManager(
state_class=LogEventsScrollState,
redis=redis,
expiration_interval=self.state_expiration_sec,
)
def get_task_events(
self,
@@ -56,48 +27,29 @@ class LogEventsIterator:
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
refresh: bool = False,
state_id: str = None,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
es_index = EventMetrics.get_index_name(company_id, self.EVENT_TYPE)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
def init_state(state_: LogEventsScrollState):
state_.task = task_id
def validate_state(state_: LogEventsScrollState):
"""
Checks that the task id stored in the state
is equal to the one passed with the current call
Refresh the state if requested
"""
if state_.task != task_id:
raise errors.bad_request.InvalidScrollId(
"Task stored in the state does not match the passed one",
scroll_id=state_.id,
)
if refresh:
state_.reset()
with self.cache_manager.get_or_create_state(
state_id=state_id, init_state=init_state, validate_state=validate_state,
) as state:
res = TaskEventsResult(next_scroll_id=state.id)
res.events, res.total_events = self._get_events(
es_index=es_index,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
state=state,
)
return res
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
es_index=es_index,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
es_index,
task_id: str,
batch_size: int,
navigate_earlier: bool,
state: LogEventsScrollState,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
@@ -111,29 +63,21 @@ class LogEventsIterator:
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": state.task}},
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if navigate_earlier and state.last_min_timestamp is not None:
es_req["search_after"] = [state.last_min_timestamp]
elif not navigate_earlier and state.last_max_timestamp is not None:
es_req["search_after"] = [state.last_max_timestamp]
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
es_result = self.es.search(index=es_index, body=es_req)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
if navigate_earlier:
state.last_max_timestamp = events[0]["timestamp"]
state.last_min_timestamp = events[-1]["timestamp"]
else:
state.last_min_timestamp = events[0]["timestamp"]
state.last_max_timestamp = events[-1]["timestamp"]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
@@ -142,28 +86,29 @@ class LogEventsIterator:
"query": {
"bool": {
"must": [
{"term": {"task": state.task}},
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = self.es.search(index=es_index, body=es_req, routing=state.task)
hits = es_result["hits"]["hits"]
if not hits or len(hits) < 2:
es_result = self.es.search(index=es_index, body=es_req)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
last_events = [hit["_source"] for hit in es_result["hits"]["hits"]]
already_present_ids = set(ev["_id"] for ev in events)
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[
*events,
*(ev for ev in last_events if ev["_id"] not in already_present_ids),
],
[*events, *last_second_events],
hits_total,
)

View File

@@ -111,7 +111,7 @@ class TimestampKey(ScalarKey):
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": f"{interval}ms",
"fixed_interval": f"{interval}ms",
"min_doc_count": 1,
}
}
@@ -150,7 +150,7 @@ class ISOTimeKey(ScalarKey):
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": f"{interval}ms",
"fixed_interval": f"{interval}ms",
"min_doc_count": 1,
"format": "strict_date_time",
}

View File

@@ -0,0 +1,18 @@
from typing import Optional, Sequence
from mongoengine import Q
from database.model.model import Model
from database.utils import get_company_or_none_constraint
class ModelBLL:
def get_frameworks(self, company, project_ids: Optional[Sequence]) -> Sequence:
"""
Return the list of unique frameworks used by company and public models
If project ids passed then only models from these projects are considered
"""
query = get_company_or_none_constraint(company)
if project_ids:
query &= Q(project__in=project_ids)
return Model.objects(query).distinct(field="framework")

View File

@@ -18,7 +18,6 @@ log = config.logger(__file__)
class QueueMetrics:
class EsKeys:
DOC_TYPE = "metrics"
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
@@ -66,7 +65,6 @@ class QueueMetrics:
entries = [e for e in queue.entries if e.added]
return dict(
_index=es_index,
_type=self.EsKeys.DOC_TYPE,
_source={
self.EsKeys.TIMESTAMP_FIELD: timestamp,
self.EsKeys.QUEUE_FIELD: queue.id,
@@ -93,7 +91,6 @@ class QueueMetrics:
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
doc_type=self.EsKeys.DOC_TYPE,
body=es_req,
)
@@ -109,7 +106,7 @@ class QueueMetrics:
"dates": {
"date_histogram": {
"field": cls.EsKeys.TIMESTAMP_FIELD,
"interval": f"{interval}s",
"fixed_interval": f"{interval}s",
"min_doc_count": 1,
},
"aggs": {

View File

@@ -19,7 +19,7 @@ from config.info import get_deployment_type
from database.model import Company, User
from database.model.queue import Queue
from database.model.task.task import Task
from utilities import safe_get
from tools import safe_get
from utilities.json import dumps
from utilities.threads_manager import ThreadsManager
from version import __version__ as current_version
@@ -237,7 +237,6 @@ class StatisticsReporter:
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
return worker_bll.es_client.search(
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req,
)

View File

@@ -4,5 +4,4 @@ from .utils import (
update_project_time,
validate_status_change,
split_by,
ParameterKeyEscaper,
)

View File

@@ -0,0 +1,229 @@
from datetime import datetime
from itertools import chain
from operator import attrgetter
from typing import Sequence, Dict
from boltons import iterutils
from apierrors import errors
from apimodels.tasks import (
HyperParamKey,
HyperParamItem,
ReplaceHyperparams,
Configuration,
)
from bll.task import TaskBLL
from config import config
from database.model.task.task import ParamsItem, Task, ConfigurationItem, TaskStatus
from utilities.parameter_key_escaper import ParameterKeyEscaper
log = config.logger(__file__)
task_bll = TaskBLL()
class HyperParams:
_properties_section = "properties"
@classmethod
def get_params(cls, company_id: str, task_ids: Sequence[str]) -> Dict[str, dict]:
only = ("id", "hyperparams")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {"hyperparams": cls._get_params_list(items=task.hyperparams)}
for task in tasks
}
@classmethod
def _get_params_list(
cls, items: Dict[str, Dict[str, ParamsItem]]
) -> Sequence[dict]:
ret = list(chain.from_iterable(v.values() for v in items.values()))
return [
p.to_proper_dict() for p in sorted(ret, key=attrgetter("section", "name"))
]
@classmethod
def _normalize_params(cls, params: Sequence) -> bool:
"""
Lower case properties section and return True if it is the only section
"""
for p in params:
if p.section.lower() == cls._properties_section:
p.section = cls._properties_section
return all(p.section == cls._properties_section for p in params)
@classmethod
def delete_params(
cls, company_id: str, task_id: str, hyperparams=Sequence[HyperParamKey]
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
with_param, without_param = iterutils.partition(
hyperparams, key=lambda p: bool(p.name)
)
sections_to_delete = {p.section for p in without_param}
delete_cmds = {
f"unset__hyperparams__{ParameterKeyEscaper.escape(section)}": 1
for section in sections_to_delete
}
for item in with_param:
section = ParameterKeyEscaper.escape(item.section)
if item.section in sections_to_delete:
raise errors.bad_request.FieldsConflict(
"Cannot delete section field if the whole section was scheduled for deletion"
)
name = ParameterKeyEscaper.escape(item.name)
delete_cmds[f"unset__hyperparams__{section}__{name}"] = 1
return task.update(**delete_cmds, last_update=datetime.utcnow())
@classmethod
def edit_params(
cls,
company_id: str,
task_id: str,
hyperparams: Sequence[HyperParamItem],
replace_hyperparams: str,
) -> int:
properties_only = cls._normalize_params(hyperparams)
task = cls._get_task_for_update(
company=company_id, id=task_id, allow_all_statuses=properties_only
)
update_cmds = dict()
hyperparams = cls._db_dicts_from_list(hyperparams)
if replace_hyperparams == ReplaceHyperparams.all:
update_cmds["set__hyperparams"] = hyperparams
elif replace_hyperparams == ReplaceHyperparams.section:
for section, value in hyperparams.items():
update_cmds[f"set__hyperparams__{section}"] = value
else:
for section, section_params in hyperparams.items():
for name, value in section_params.items():
update_cmds[f"set__hyperparams__{section}__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def _db_dicts_from_list(cls, items: Sequence[HyperParamItem]) -> Dict[str, dict]:
sections = iterutils.bucketize(items, key=attrgetter("section"))
return {
ParameterKeyEscaper.escape(section): {
ParameterKeyEscaper.escape(param.name): ParamsItem(**param.to_struct())
for param in params
}
for section, params in sections.items()
}
@classmethod
def get_configurations(
cls, company_id: str, task_ids: Sequence[str], names: Sequence[str]
) -> Dict[str, dict]:
only = ["id"]
if names:
only.extend(
f"configuration.{ParameterKeyEscaper.escape(name)}" for name in names
)
else:
only.append("configuration")
tasks = task_bll.assert_exists(
company_id=company_id, task_ids=task_ids, only=only, allow_public=True,
)
return {
task.id: {
"configuration": [
c.to_proper_dict()
for c in sorted(task.configuration.values(), key=attrgetter("name"))
]
}
for task in tasks
}
@classmethod
def get_configuration_names(
cls, company_id: str, task_ids: Sequence[str]
) -> Dict[str, list]:
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"_id": {"$in": task_ids},
}
},
{"$project": {"items": {"$objectToArray": "$configuration"}}},
{"$unwind": "$items"},
{"$group": {"_id": "$_id", "names": {"$addToSet": "$items.k"}}},
]
tasks = Task.aggregate(pipeline)
return {
task["_id"]: {
"names": sorted(
ParameterKeyEscaper.unescape(name) for name in task["names"]
)
}
for task in tasks
}
@classmethod
def edit_configuration(
cls,
company_id: str,
task_id: str,
configuration: Sequence[Configuration],
replace_configuration: bool,
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
update_cmds = dict()
configuration = {
ParameterKeyEscaper.escape(c.name): ConfigurationItem(**c.to_struct())
for c in configuration
}
if replace_configuration:
update_cmds["set__configuration"] = configuration
else:
for name, value in configuration.items():
update_cmds[f"set__configuration__{name}"] = value
return task.update(**update_cmds, last_update=datetime.utcnow())
@classmethod
def delete_configuration(
cls, company_id: str, task_id: str, configuration=Sequence[str]
) -> int:
task = cls._get_task_for_update(company=company_id, id=task_id)
delete_cmds = {
f"unset__configuration__{ParameterKeyEscaper.escape(name)}": 1
for name in set(configuration)
}
return task.update(**delete_cmds, last_update=datetime.utcnow())
@staticmethod
def _get_task_for_update(
company: str, id: str, allow_all_statuses: bool = False
) -> Task:
task = Task.get_for_writing(company=company, id=id, _only=("id", "status"))
if not task:
raise errors.bad_request.InvalidTaskId(id=id)
if allow_all_statuses:
return task
if task.status != TaskStatus.created:
raise errors.bad_request.InvalidTaskStatus(
expected=TaskStatus.created, status=task.status
)
return task

View File

@@ -0,0 +1,201 @@
import itertools
from typing import Sequence, Tuple
import dpath
from apierrors import errors
from database.model.task.task import Task
from tools import safe_get
from utilities.parameter_key_escaper import ParameterKeyEscaper
hyperparams_default_section = "Args"
hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
"""
Return parameter section and name. The section is either TF_DEFINE or the default one
"""
if default_section is None:
return None, full_name
section, _, name = full_name.partition("/")
if section != tf_define_section:
return default_section, full_name
if not name:
raise errors.bad_request.ValidationError("Parameter name cannot be empty")
return section, name
def _get_full_param_name(param: dict) -> str:
section = param.get("section")
if section != tf_define_section:
return param["name"]
return "/".join((section, param["name"]))
def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
"""
removed = 0
if not data:
return removed
if with_sections:
for section, section_data in list(data.items()):
removed += _remove_legacy_params(section_data)
if not section_data:
"""If section is empty after removing legacy params then delete it"""
del data[section]
else:
for key, param in list(data.items()):
if param.get("type") == hyperparams_legacy_type:
removed += 1
del data[key]
return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
"""
if not data:
return []
if with_sections:
return itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
)
return [
param for param in data.values() if param.get("type") == hyperparams_legacy_type
]
def params_prepare_for_save(fields: dict, previous_task: Task = None):
"""
If legacy hyper params or configuration is passed then replace the corresponding section in the new structure
Escape all the section and param names for hyper params and configuration to make it mongo sage
"""
for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy_params = safe_get(fields, old_params_field)
if legacy_params is None:
continue
if (
not safe_get(fields, new_params_field)
and previous_task
and previous_task[new_params_field]
):
previous_data = previous_task.to_proper_dict().get(new_params_field)
removed = _remove_legacy_params(
previous_data, with_sections=default_section is not None
)
if not legacy_params and not removed:
# if we only need to delete legacy fields from the db
# but they are not there then there is no point to proceed
continue
fields_update = {new_params_field: previous_data}
params_unprepare_from_saved(fields_update)
fields.update(fields_update)
for full_name, value in legacy_params.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_params_field, section, name)))
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None:
new_param["section"] = section
dpath.new(fields, new_path, new_param)
dpath.delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
ParameterKeyEscaper.escape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
}
dpath.set(fields, param_field, escaped_params)
def params_unprepare_from_saved(fields, copy_to_legacy=False):
"""
Unescape all section and param names for hyper params and configuration
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
"""
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
if params:
unescaped_params = {
ParameterKeyEscaper.unescape(key): {
ParameterKeyEscaper.unescape(k): v for k, v in value.items()
}
if isinstance(value, dict)
else value
for key, value in params.items()
}
dpath.set(fields, param_field, unescaped_params)
if copy_to_legacy:
for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True),
(f"configuration", "execution/model_desc", False),
):
legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections
)
if legacy_params:
dpath.new(
fields,
old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params},
)
def _process_path(path: str):
"""
Frontend does a partial escaping on the path so the all '.' in section and key names are escaped
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
)
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
):
path: str
paths = [path.replace(old_prefix, new_prefix) for path in paths]
for prefix in (
"hyperparams.",
"-hyperparams.",
"configuration.",
"-configuration.",
):
paths = [
_process_path(path) if path.startswith(prefix) else path for path in paths
]
return paths

View File

@@ -5,6 +5,7 @@ from random import random
from time import sleep
from typing import Collection, Sequence, Tuple, Any, Optional, List, Dict
import dpath
import pymongo.results
import six
from mongoengine import Q
@@ -32,10 +33,11 @@ from database.model.task.task import (
)
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from services.utils import validate_tags
from timing_context import TimingContext
from utilities.dicts import deep_merge
from .utils import ChangeStatusRequest, validate_status_change, ParameterKeyEscaper
from utilities.parameter_key_escaper import ParameterKeyEscaper
from .param_utils import params_prepare_for_save
from .utils import ChangeStatusRequest, validate_status_change
log = config.logger(__file__)
org_bll = OrgBLL()
@@ -83,25 +85,24 @@ class TaskBLL(object):
@staticmethod
def get_by_id(
company_id,
task_id,
required_status=None,
required_dataset=None,
only_fields=None,
company_id, task_id, required_status=None, only_fields=None, allow_public=False,
):
if only_fields:
if isinstance(only_fields, string_types):
only_fields = [only_fields]
else:
only_fields = list(only_fields)
only_fields = only_fields + ["status"]
with TimingContext("mongo", "task_by_id_all"):
qs = Task.objects(id=task_id, company=company_id)
if only_fields:
qs = (
qs.only(only_fields)
if isinstance(only_fields, string_types)
else qs.only(*only_fields)
)
qs = qs.only(
"status", "input"
) # make sure all fields we rely on here are also returned
task = qs.first()
tasks = Task.get_many(
company=company_id,
query=Q(id=task_id),
allow_public=allow_public,
override_projection=only_fields,
return_dicts=False,
)
task = None if not tasks else tasks[0]
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
@@ -109,17 +110,12 @@ class TaskBLL(object):
if required_status and not task.status == required_status:
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
if required_dataset and required_dataset not in (
entry.dataset for entry in task.input.view.entries
):
raise errors.bad_request.InvalidId(
"not in input view", dataset=required_dataset
)
return task
@staticmethod
def assert_exists(company_id, task_ids, only=None, allow_public=False):
def assert_exists(
company_id, task_ids, only=None, allow_public=False, return_tasks=True
) -> Optional[Sequence[Task]]:
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
ids = set(task_ids)
@@ -130,14 +126,13 @@ class TaskBLL(object):
return_dicts=False,
)
if only:
res = q.only(*only)
count = len(res)
else:
count = q.count()
res = q.first()
if count != len(ids):
q = q.only(*only)
if q.count() != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids)
return res
if return_tasks:
return list(q)
@staticmethod
def create(call: APICall, fields: dict):
@@ -179,21 +174,31 @@ class TaskBLL(object):
project: Optional[str] = None,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
hyperparams: Optional[dict] = None,
configuration: Optional[dict] = None,
execution_overrides: Optional[dict] = None,
validate_references: bool = False,
) -> Task:
validate_tags(tags, system_tags)
task = cls.get_by_id(company_id=company_id, task_id=task_id)
task = cls.get_by_id(company_id=company_id, task_id=task_id, allow_public=True)
execution_dict = task.execution.to_proper_dict() if task.execution else {}
execution_model_overriden = False
params_dict = {
field: value
for field, value in (
("hyperparams", hyperparams),
("configuration", configuration),
)
if value is not None
}
if execution_overrides:
parameters = execution_overrides.get("parameters")
if parameters is not None:
execution_overrides["parameters"] = {
ParameterKeyEscaper.escape(k): v for k, v in parameters.items()
}
params_dict["execution"] = {}
for legacy_param in ("parameters", "configuration"):
legacy_value = execution_overrides.pop(legacy_param, None)
if legacy_value is not None:
params_dict["execution"] = legacy_value
execution_dict = deep_merge(execution_dict, execution_overrides)
execution_model_overriden = execution_overrides.get("model") is not None
params_prepare_for_save(params_dict, previous_task=task)
artifacts = execution_dict.get("artifacts")
if artifacts:
@@ -221,6 +226,8 @@ class TaskBLL(object):
if task.output
else None,
execution=execution_dict,
configuration=params_dict.get("configuration") or task.configuration,
hyperparams=params_dict.get("hyperparams") or task.hyperparams,
)
cls.validate(
new_task,
@@ -626,28 +633,34 @@ class TaskBLL(object):
return [a.key for a in added], [a.key for a in updated]
@staticmethod
def get_aggregated_project_execution_parameters(
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str] = None,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[str]]:
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": company_id,
"execution.parameters": {"$exists": True, "$gt": {}},
"hyperparams": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
}
},
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
{"$unwind": "$parameters"},
{"$group": {"_id": "$parameters.k"}},
{"$sort": {"_id": 1}},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{
"$group": {
"_id": 1,
@@ -673,7 +686,12 @@ class TaskBLL(object):
if result:
total = int(result.get("total", -1))
results = [
ParameterKeyEscaper.unescape(r["_id"])
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))

View File

@@ -3,7 +3,6 @@ from typing import TypeVar, Callable, Tuple, Sequence
import attr
import six
from boltons.dictutils import OneToOne
from apierrors import errors
from database.errors import translate_errors_context
@@ -172,26 +171,3 @@ def split_by(
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)
class ParameterKeyEscaper:
_mapping = OneToOne({".": "%2E", "$": "%24"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
return "%".join(map(cls._unescape, value.split("%%")))

View File

@@ -35,14 +35,21 @@ class SetFieldsResolver:
SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = set_fields
self.fields = {
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
self.orig_fields = {}
self.fields = {}
self.add_fields(**set_fields)
def add_fields(self, **set_fields: Any):
self.orig_fields.update(set_fields)
self.fields.update(
{
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
)
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None:

View File

@@ -21,6 +21,7 @@ from config import config
from database.errors import translate_errors_context
from database.model.auth import User
from database.model.company import Company
from database.model.project import Project
from database.model.queue import Queue
from database.model.task.task import Task
from redis_manager import redman
@@ -49,6 +50,7 @@ class WorkerBLL:
ip: str = "",
queues: Sequence[str] = None,
timeout: int = 0,
tags: Sequence[str] = None,
) -> WorkerEntry:
"""
Register a worker
@@ -58,6 +60,7 @@ class WorkerBLL:
:param ip: the real ip of the worker
:param queues: queues reported as being monitored by the worker
:param timeout: registration expiration timeout in seconds
:param tags: a list of tags for this worker
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
:return: worker entry instance
"""
@@ -91,6 +94,7 @@ class WorkerBLL:
register_time=now,
register_timeout=timeout,
last_activity_time=now,
tags=tags,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
@@ -113,12 +117,15 @@ class WorkerBLL:
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest, tags: Sequence[str] = None,
) -> None:
"""
Write worker status report
:param company_id: worker's company ID
:param user_id: user_id ID under which this worker is running
:param ip: worker IP
:param report: the report itself
:param tags: tags for this worker
:raise bad_request.InvalidTaskId: the reported task was not found
:return: worker entry instance
"""
@@ -129,6 +136,9 @@ class WorkerBLL:
now = datetime.utcnow()
entry.last_activity_time = now
if tags is not None:
entry.tags = tags
if report.machine_stats:
self._log_stats_to_es(
company_id=company_id,
@@ -146,6 +156,7 @@ class WorkerBLL:
if not report.task:
entry.task = None
entry.project = None
else:
with translate_errors_context():
query = dict(id=report.task, company=company_id)
@@ -160,6 +171,12 @@ class WorkerBLL:
raise bad_request.InvalidTaskId(**query)
entry.task = IdNameEntry(id=task.id, name=task.name)
entry.project = None
if task.project:
project = Project.objects(id=task.project).only("name").first()
if project:
entry.project = IdNameEntry(id=project.id, name=project.name)
entry.last_report_time = now
except APIError:
raise
@@ -369,7 +386,6 @@ class WorkerBLL:
def make_doc(category, metric, variant, value) -> dict:
return dict(
_index=es_index,
_type="stat",
_source=dict(
timestamp=timestamp,
worker=worker,

View File

@@ -25,7 +25,6 @@ class WorkerStats:
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req,
)
@@ -53,7 +52,7 @@ class WorkerStats:
res = self._search_company_stats(company_id, es_req)
if not res["hits"]["total"]:
if not res["hits"]["total"]["value"]:
raise bad_request.WorkerStatsNotFound(
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
)
@@ -87,7 +86,7 @@ class WorkerStats:
"dates": {
"date_histogram": {
"field": "timestamp",
"interval": f"{request.interval}s",
"fixed_interval": f"{request.interval}s",
"min_doc_count": 1,
},
"aggs": {
@@ -216,7 +215,7 @@ class WorkerStats:
"dates": {
"date_histogram": {
"field": "timestamp",
"interval": f"{interval}s",
"fixed_interval": f"{interval}s",
},
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
}

View File

@@ -26,6 +26,17 @@
check_max_version: false
}
pre_populate {
enabled: false
zip_files: ["/path/to/export.zip"]
fail_on_error: false
# artifacts_path: "/mnt/fileserver"
}
# time in seconds to take an exclusive lock to init es and mongodb
# not including the pre_populate
db_init_timout: 120
mongo {
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
@@ -34,11 +45,16 @@
aggregate {
allow_disk_use: true
}
}
pre_populate {
enabled: false
zip_file: "/path/to/export.zip"
fail_on_error: false
elastic {
probing {
# settings for inital probing of elastic connection
max_retries: 4
timeout: 30
}
upgrade_monitoring {
v16_migration_verification: true
}
}

View File

@@ -1,21 +1,21 @@
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
hosts: [{host: "127.0.0.1", port: 9211}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
max_retries: 3
retry_on_timeout: true
}
index_version: "1"
}
workers {
hosts: [{host:"127.0.0.1", port:9200}]
hosts: [{host:"127.0.0.1", port:9211}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
max_retries: 3
retry_on_timeout: true
}
index_version: "1"

View File

@@ -0,0 +1,16 @@
fixed_users {
guest {
enabled: false
default_company: "025315a9321f49f8be07f5ac48fbcf92"
name: "Guest"
username: "guest"
password: "guest"
# Allow access only to the following endpoints when using user/pass credentials
allow_endpoints: [
"auth.login"
]
}
}

View File

@@ -0,0 +1,8 @@
# Order of featured projects, by name or ID
featured_order: [
# {id: "<project-id>"}
# OR
# {name: "<project-name>"}
# OR
# {name_regex: "<python-regex>"}
]

View File

@@ -11,4 +11,6 @@ non_responsive_tasks_watchdog {
artifacts {
update_attempts: 10
update_retry_msec: 500
}
}
multi_task_histogram_limit: 100

View File

@@ -41,3 +41,7 @@ def get_deployment_type() -> str:
def get_default_company():
return config.get("apiserver.default_company")
missed_es_upgrade = False
es_connection_error = False

View File

@@ -79,6 +79,10 @@ def get_entries():
return _entries
def get_hosts():
return [entry.host for entry in get_entries()]
def get_aliases():
return [entry.alias for entry in get_entries()]

View File

@@ -32,6 +32,8 @@ class Role(object):
""" Company user """
annotator = "annotator"
""" Annotator with limited access"""
guest = "guest"
""" Guest user. Read Only."""
@classmethod
def get_system_roles(cls) -> set:

View File

@@ -1,14 +1,15 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional
from typing import Collection, Sequence, Union, Optional, Type, Tuple
from boltons.iterutils import first, bucketize
from boltons.iterutils import first, bucketize, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
from apierrors import errors
from apierrors.base import BaseError
from config import config
from database.errors import MakeGetAllQueryError
from database.projection import project_dict, ProjectionHelper
@@ -347,6 +348,20 @@ class GetMixin(PropsMixin):
return []
return parameters.get(cls._projection_key) or parameters.get("only_fields", [])
@classmethod
def split_projection(
cls, projection: Sequence[str]
) -> Tuple[Collection[str], Collection[str]]:
"""Return include and exclude lists based on passed projection and class definition"""
if projection:
include, exclude = partition(
projection, key=lambda x: x[0] != ProjectionHelper.exclusion_prefix,
)
else:
include, exclude = [], []
exclude = {x.lstrip(ProjectionHelper.exclusion_prefix) for x in exclude}
return include, set(cls.get_exclude_fields()).union(exclude).difference(include)
@classmethod
def set_projection(cls, parameters: dict, value: Sequence[str]) -> Sequence[str]:
parameters.pop("only_fields", None)
@@ -483,10 +498,25 @@ class GetMixin(PropsMixin):
query=_query, parameters=parameters, override_projection=override_projection
)
@classmethod
def get_many_public(
cls, query: Q = None, projection: Collection[str] = None,
):
"""
Fetch all public documents matching a provided query.
:param query: Optional query object (mongoengine.Q).
:param projection: A list of projection fields.
:return: A list of documents matching the query.
"""
q = get_company_or_none_constraint()
_query = (q & query) if query else q
return cls._get_many_no_company(query=_query, override_projection=projection)
@classmethod
def _get_many_no_company(
cls: Union["GetMixin", Document],
query,
query: Q,
parameters=None,
override_projection=None,
):
@@ -509,7 +539,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
qs = cls.objects(query)
if search_text:
@@ -517,13 +549,14 @@ class GetMixin(PropsMixin):
if order_by:
# add ordering
qs = qs.order_by(*order_by)
if only:
if include:
# add projection
qs = qs.only(*only)
else:
exclude = set(cls.get_exclude_fields()).difference(only)
if exclude:
qs = qs.exclude(*exclude)
qs = qs.only(*include)
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
@@ -559,7 +592,9 @@ class GetMixin(PropsMixin):
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
query_sets = [cls.objects(query)]
if order_by:
@@ -596,16 +631,15 @@ class GetMixin(PropsMixin):
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
if only:
if include:
# add projection
query_sets = [qs.only(*only) for qs in query_sets]
else:
exclude = set(cls.get_exclude_fields())
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
query_sets = [qs.only(*include) for qs in query_sets]
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
@@ -616,7 +650,8 @@ class GetMixin(PropsMixin):
start -= qs_size
continue
ret.extend(
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
)
if len(ret) >= page_size:
break
@@ -728,6 +763,31 @@ class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
)
return cls.objects.aggregate(pipeline, **kwargs)
@classmethod
def set_public(
cls: Type[Document],
company_id: str,
ids: Sequence[str],
invalid_cls: Type[BaseError],
enabled: bool = True,
):
if enabled:
items = list(cls.objects(id__in=ids, company=company_id).only("id"))
update = dict(set__company_origin=company_id, unset__company=1)
else:
items = list(
cls.objects(
id__in=ids, company__in=(None, ""), company_origin=company_id
).only("id")
)
update = dict(set__company=company_id, unset__company_origin=1)
if len(items) < len(ids):
missing = tuple(set(ids).difference(i.id for i in items))
raise invalid_cls(ids=missing)
return {"updated": cls.objects(id__in=ids).update(**update)}
def validate_id(cls, company, **kwargs):
"""

View File

@@ -19,6 +19,7 @@ class Model(DbModelMixin, Document):
"parent",
"project",
"task",
("company", "framework"),
("company", "name"),
("company", "user"),
{
@@ -71,3 +72,4 @@ class Model(DbModelMixin, Document):
ui_cache = SafeDictField(
default=dict, user_set_allowed=True, exclude_by_default=True
)
company_origin = StringField(exclude_by_default=True)

View File

@@ -1,4 +1,4 @@
from mongoengine import StringField, DateTimeField
from mongoengine import StringField, DateTimeField, IntField
from database import Database, strict
from database.fields import StrippedStringField, SafeSortedListField
@@ -40,3 +40,7 @@ class Project(AttributedDocument):
system_tags = SafeSortedListField(StringField(required=True))
default_output_destination = StrippedStringField()
last_update = DateTimeField()
featured = IntField(default=9999)
logo_url = StringField()
logo_blob = StringField(exclude_by_default=True)
company_origin = StringField(exclude_by_default=True)

View File

@@ -49,13 +49,13 @@ class TaskSystemTags(object):
development = "development"
class Script(EmbeddedDocument):
class Script(EmbeddedDocument, ProperDictMixin):
binary = StringField(default="python")
repository = StringField(required=True)
repository = StringField(default="")
tag = StringField()
branch = StringField()
version_num = StringField()
entry_point = StringField(required=True)
entry_point = StringField(default="")
working_dir = StringField()
requirements = SafeDictField()
diff = StringField()
@@ -84,7 +84,23 @@ class Artifact(EmbeddedDocument):
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
class ParamsItem(EmbeddedDocument, ProperDictMixin):
section = StringField(required=True)
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class ConfigurationItem(EmbeddedDocument, ProperDictMixin):
name = StringField(required=True)
value = StringField(required=True)
type = StringField()
description = StringField()
class Execution(EmbeddedDocument, ProperDictMixin):
meta = {"strict": strict}
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field="Model")
@@ -115,9 +131,12 @@ external_task_types = set(get_options(TaskType))
class Task(AttributedDocument):
_numeric_locale = {"locale": "en_US", "numericOrdering": True}
_field_collation_overrides = {
"execution.parameters.": {"locale": "en_US", "numericOrdering": True},
"last_metrics.": {"locale": "en_US", "numericOrdering": True}
"execution.parameters.": _numeric_locale,
"last_metrics.": _numeric_locale,
"hyperparams.": _numeric_locale,
"configuration.": _numeric_locale,
}
meta = {
@@ -186,10 +205,15 @@ class Task(AttributedDocument):
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
system_tags = SafeSortedListField(StringField(required=True), user_set_allowed=True)
script: Script = EmbeddedDocumentField(Script)
script: Script = EmbeddedDocumentField(Script, default=Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
metric_stats = SafeMapField(field=EmbeddedDocumentField(MetricEventStats))
company_origin = StringField(exclude_by_default=True)
duration = IntField() # task duration in seconds
hyperparams = SafeMapField(field=SafeMapField(EmbeddedDocumentField(ParamsItem)))
configuration = SafeMapField(field=EmbeddedDocumentField(ConfigurationItem))
runtime = SafeDictField(default=dict)

View File

@@ -45,7 +45,7 @@ def project_dict(data, projection, separator=SEP):
)
dst[path_part] = [
copy_path(path_parts[depth + 1:], s, d)
copy_path(path_parts[depth + 1 :], s, d)
for s, d in zip(src_part, dst[path_part])
]
@@ -96,6 +96,7 @@ class _ProxyManager:
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
exclusion_prefix = "-"
@property
def doc_projection(self):
@@ -128,20 +129,28 @@ class ProjectionHelper(object):
[]
) # Projection information for reference fields (used in join queries)
for field in projection:
field_ = field.lstrip(self.exclusion_prefix)
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
if not field_.startswith(ref_field):
# Doesn't start with a reference field
continue
if field == ref_field:
if field_ == ref_field:
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
subfield = field_[len(ref_field) :]
if not subfield.startswith(SEP):
# Starts with something that looks like a reference field, but isn't
continue
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
ref_projection_info.append(
(
ref_field,
ref_field_cls,
("" if field_[0] == field[0] else self.exclusion_prefix)
+ subfield[1:],
)
)
break
else:
# Not a reference field, just add to the top-level projection
@@ -149,7 +158,7 @@ class ProjectionHelper(object):
orig_field = field
if field.endswith(".*"):
field = field[:-2]
if not field:
if not field.lstrip(self.exclusion_prefix):
raise errors.bad_request.InvalidFields(
field=orig_field, object=doc_cls.__name__
)
@@ -199,7 +208,7 @@ class ProjectionHelper(object):
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(
doc_cls, doc_projection.difference(ref_projection).union({"id"})
doc_cls, doc_projection.difference(ref_projection)
)
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
@@ -218,7 +227,10 @@ class ProjectionHelper(object):
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
f
for f in doc_projection
if f.partition(SEP)[0].lstrip(self.exclusion_prefix)
not in doc_cls.get_fields()
]
if invalid_fields:
raise errors.bad_request.InvalidFields(
@@ -234,6 +246,13 @@ class ProjectionHelper(object):
doc_projection.add(field)
doc_projection = list(doc_projection)
# If there are include fields (not only exclude) then add an id field
if (
not all(p.startswith(self.exclusion_prefix) for p in doc_projection)
and "id" not in doc_projection
):
doc_projection.append("id")
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@@ -314,6 +333,7 @@ class ProjectionHelper(object):
]
if items:
def do_projection(item):
ref_field_name, data, ids = item

View File

@@ -4,53 +4,54 @@ Apply elasticsearch mappings to given hosts.
"""
import argparse
import json
import requests
from pathlib import Path
from typing import Optional, Sequence
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from elasticsearch import Elasticsearch
HERE = Path(__file__).resolve().parent
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(5, backoff_factor=0.5))
session.mount('http://', adapter)
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
def apply_mappings_to_host(host: str):
def _send_mapping(f):
def _send_template(f):
with f.open() as json_data:
data = json.load(json_data)
es_server = host
url = f"{es_server}/_template/{f.stem}"
session.delete(url)
r = session.post(
url,
headers={"Content-Type": "application/json"},
data=json.dumps(data),
)
return {"mapping": f.stem, "result": r.text}
template_name = f.stem
res = es.indices.put_template(template_name, body=data)
return {"mapping": template_name, "result": res}
p = HERE / "mappings"
return [
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
]
if key:
files = (p / key).glob("*.json")
else:
files = p.glob("**/*.json")
es = Elasticsearch(hosts=hosts, **(es_args or {}))
return [_send_template(f) for f in files]
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("hosts", nargs="+")
parser.add_argument("--key", help="host key, e.g. events, datasets etc.")
parser.add_argument(
"--hosts",
nargs="+",
help="list of es hosts from the same cluster, where each host is http[s]://[user:password@]host:port",
)
return parser.parse_args()
def main():
for host in parse_args().hosts:
print(">>>>> Applying mapping to " + host)
res = apply_mappings_to_host(host)
print(res)
args = parse_args()
print(">>>>> Applying mapping to " + str(args.hosts))
res = apply_mappings_to_cluster(args.hosts, args.key)
print(res)
if __name__ == "__main__":

View File

@@ -1,8 +1,13 @@
from furl import furl
import logging
from time import sleep
from typing import Type, Optional, Sequence, Any, Union
import urllib3.exceptions
from elasticsearch import Elasticsearch, exceptions
import es_factory
from config import config
from elastic.apply_mappings import apply_mappings_to_host
from es_factory import get_cluster_config
from elastic.apply_mappings import apply_mappings_to_cluster
log = config.logger(__file__)
@@ -15,13 +20,94 @@ class MissingElasticConfiguration(Exception):
pass
def init_es_data():
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration("for cluster 'events'")
class ElasticConnectionError(Exception):
"""
Exception when could not connect to elastic during init
"""
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
pass
class ConnectionErrorFilter(logging.Filter):
def __init__(
self,
level: Optional[Union[int, str]] = None,
err_type: Optional[Type] = None,
args_prefix: Optional[Sequence[Any]] = None,
):
super(ConnectionErrorFilter, self).__init__()
if level is None:
self.level = None
else:
try:
self.level = int(level)
except ValueError:
self.level = logging.getLevelName(level)
self.err_type = err_type
self.args = args_prefix and tuple(args_prefix)
self.last_blocked = None
def filter(self, record):
try:
allow = (
(self.err_type is None or record.exc_info[0] != self.err_type)
and (self.level is None or record.levelno != self.level)
and (self.args is None or record.args[: len(self.args)] != self.args)
)
if not allow:
self.last_blocked = record
return allow
except Exception:
return True
def check_elastic_empty() -> bool:
"""
Check for elasticsearch connection
Use probing settings and not the default es cluster ones
so that we can handle correctly the connection rejects due to ES not fully started yet
:return:
"""
cluster_conf = es_factory.get_cluster_config("events")
max_retries = config.get("apiserver.elastic.probing.max_retries", 4)
timeout = config.get("apiserver.elastic.probing.timeout", 30)
es_logger = logging.getLogger("elasticsearch")
log_filter = ConnectionErrorFilter(
err_type=urllib3.exceptions.NewConnectionError, args_prefix=("GET",)
)
try:
es_logger.addFilter(log_filter)
for retry in range(max_retries):
try:
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
return True
except exceptions.ConnectionError as ex:
if retry >= max_retries - 1:
raise ElasticConnectionError(
f"Error connecting to Elasticsearch: {str(ex)}"
)
log.warn(
f"Could not connect to ElasticSearch Service. Retry {retry+1} of {max_retries}. Waiting for {timeout}sec"
)
sleep(timeout)
finally:
es_logger.removeFilter(log_filter)
def init_es_data():
for name in es_factory.get_all_cluster_names():
cluster_conf = es_factory.get_cluster_config(name)
hosts_config = cluster_conf.get("hosts")
if not hosts_config:
raise MissingElasticConfiguration(f"for cluster '{name}'")
log.info(f"Applying mappings to ES host: {hosts_config}")
args = cluster_conf.get("args", {})
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
log.info(res)

View File

@@ -1,27 +0,0 @@
{
"template": "events-*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_default_": {
"_source": {
"enabled": true
},
"_routing": {
"required": true
},
"properties": {
"@timestamp": { "type": "date" },
"task": { "type": "keyword" },
"type": { "type": "keyword" },
"worker": { "type": "keyword" },
"timestamp": { "type": "date" },
"iter": { "type": "long" },
"metric": { "type": "keyword" },
"variant": { "type": "keyword" },
"value": { "type": "float" }
}
}
}
}

View File

@@ -0,0 +1,40 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"@timestamp": {
"type": "date"
},
"task": {
"type": "keyword"
},
"type": {
"type": "keyword"
},
"worker": {
"type": "keyword"
},
"timestamp": {
"type": "date"
},
"iter": {
"type": "long"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
}
}
}
}

View File

@@ -0,0 +1,15 @@
{
"index_patterns": "events-log-*",
"order": 1,
"mappings": {
"properties": {
"msg": {
"type": "text",
"index": false
},
"level": {
"type": "keyword"
}
}
}
}

View File

@@ -0,0 +1,12 @@
{
"index_patterns": "events-plot-*",
"order": 1,
"mappings": {
"properties": {
"plot_str": {
"type": "text",
"index": false
}
}
}
}

View File

@@ -0,0 +1,14 @@
{
"index_patterns": "events-training_debug_image-*",
"order": 1,
"mappings": {
"properties": {
"key": {
"type": "keyword"
},
"url": {
"type": "keyword"
}
}
}
}

View File

@@ -1,12 +0,0 @@
{
"template": "events-log-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"msg": { "type":"text", "index": false },
"level": { "type":"keyword" }
}
}
}
}

View File

@@ -1,11 +0,0 @@
{
"template": "events-plot-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"plot_str": { "type":"text", "index": false }
}
}
}
}

View File

@@ -1,12 +0,0 @@
{
"template": "events-training_debug_image-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"key": { "type": "keyword" },
"url": { "type": "keyword" }
}
}
}
}

View File

@@ -1,27 +0,0 @@
{
"template": "queue_metrics_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"metrics": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
}
}
}
}
}

View File

@@ -1,23 +0,0 @@
{
"template": "worker_stats_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"stat": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": { "type": "date" },
"worker": { "type": "keyword" },
"category": { "type": "keyword" },
"metric": { "type": "keyword" },
"variant": { "type": "keyword" },
"value": { "type": "float" },
"unit": { "type": "keyword" },
"task": { "type": "keyword" }
}
}
}
}

View File

@@ -0,0 +1,25 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
}
}
}
}

View File

@@ -0,0 +1,37 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"worker": {
"type": "keyword"
},
"category": {
"type": "keyword"
},
"metric": {
"type": "keyword"
},
"variant": {
"type": "keyword"
},
"value": {
"type": "float"
},
"unit": {
"type": "keyword"
},
"task": {
"type": "keyword"
}
}
}
}

View File

@@ -65,6 +65,10 @@ def connect(cluster_name):
return _instances[cluster_name]
def get_all_cluster_names():
return list(config.get("hosts.elastic"))
def get_cluster_config(cluster_name):
"""
Returns cluster config for the specified cluster path

View File

@@ -1,9 +1,11 @@
from pathlib import Path
from typing import Sequence, Union
from config import config
from config.info import get_default_company
from database.model.auth import Role
from service_repo.auth.fixed_user import FixedUser
from .migration import _apply_migrations
from .migration import _apply_migrations, check_mongo_empty, get_last_server_version
from .pre_populate import PrePopulate
from .user import ensure_fixed_user, _ensure_auth_user, _ensure_backend_user
from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
@@ -11,33 +13,51 @@ from .util import _ensure_company, _ensure_default_queue, _ensure_uuid
log = config.logger(__package__)
def _pre_populate(company_id: str, zip_file: str):
if not zip_file or not Path(zip_file).is_file():
msg = f"Invalid pre-populate zip file: {zip_file}"
if config.get("apiserver.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
log.info(f"Pre-populating using {zip_file}")
PrePopulate.import_from_zip(
zip_file,
artifacts_path=config.get("apiserver.pre_populate.artifacts_path", None),
)
def _resolve_zip_files(zip_files: Union[Sequence[str], str]) -> Sequence[str]:
if isinstance(zip_files, str):
zip_files = [zip_files]
for p in map(Path, zip_files):
if p.is_file():
yield p
if p.is_dir():
yield from p.glob("*.zip")
log.warning(f"Invalid pre-populate entry {str(p)}, skipping")
def pre_populate_data():
for zip_file in _resolve_zip_files(config.get("apiserver.pre_populate.zip_files")):
_pre_populate(company_id=get_default_company(), zip_file=zip_file)
PrePopulate.update_featured_projects_order()
def init_mongo_data():
try:
empty_dbs = _apply_migrations(log)
_apply_migrations(log)
_ensure_uuid()
company_id = _ensure_company(log)
company_id = _ensure_company(get_default_company(), "trains", log)
_ensure_default_queue(company_id)
if empty_dbs and config.get("apiserver.mongo.pre_populate.enabled", False):
zip_file = config.get("apiserver.mongo.pre_populate.zip_file")
if not zip_file or not Path(zip_file).is_file():
msg = f"Failed pre-populating database: invalid zip file {zip_file}"
if config.get("apiserver.mongo.pre_populate.fail_on_error", False):
log.error(msg)
raise ValueError(msg)
else:
log.warning(msg)
else:
user_id = _ensure_backend_user(
"__allegroai__", company_id, "Allegro.ai"
)
PrePopulate.import_from_zip(zip_file, user_id=user_id)
fixed_mode = FixedUser.enabled()
for user, credentials in config.get("secure.credentials", {}).items():
@@ -56,9 +76,13 @@ def init_mongo_data():
if fixed_mode:
log.info("Fixed users mode is enabled")
FixedUser.validate()
if FixedUser.guest_enabled():
_ensure_company(FixedUser.get_guest_user().company, "guests", log)
for user in FixedUser.from_config():
try:
ensure_fixed_user(user, company_id, log=log)
ensure_fixed_user(user, log=log)
except Exception as ex:
log.error(f"Failed creating fixed user {user.name}: {ex}")
except Exception as ex:

View File

@@ -13,7 +13,26 @@ from database.model.version import Version as DatabaseVersion
migration_dir = Path(__file__).resolve().parent.with_name("migrations")
def _apply_migrations(log: Logger) -> bool:
def check_mongo_empty() -> bool:
return not all(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
def get_last_server_version() -> Version:
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
return previous_versions[0] if previous_versions else Version("0.0.0")
def _apply_migrations(log: Logger):
"""
Apply migrations as found in the migration dir.
Returns a boolean indicating whether the database was empty prior to migration.
@@ -25,20 +44,8 @@ def _apply_migrations(log: Logger) -> bool:
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
empty_dbs = not any(
get_db(alias).collection_names()
for alias in database.utils.get_options(Database)
)
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
empty_dbs = check_mongo_empty()
last_version = get_last_server_version()
try:
new_scripts = {
@@ -82,5 +89,3 @@ def _apply_migrations(log: Logger) -> bool:
).save()
log.info("Finished mongodb migrations")
return empty_dbs

View File

@@ -1,31 +1,390 @@
import hashlib
import importlib
import os
import re
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timezone
from functools import partial
from io import BytesIO
from itertools import chain
from operator import attrgetter
from os.path import splitext
from typing import List, Optional, Any, Type, Set, Dict
from pathlib import Path
from typing import (
Optional,
Any,
Type,
Set,
Dict,
Sequence,
Tuple,
BinaryIO,
Union,
Mapping,
)
from urllib.parse import unquote, urlparse
from zipfile import ZipFile, ZIP_BZIP2
import dpath
import mongoengine
from tqdm import tqdm
from boltons.iterutils import chunked_iter
from furl import furl
from mongoengine import Q
from bll.event import EventBLL
from bll.task.param_utils import (
split_param_name,
hyperparams_default_section,
hyperparams_legacy_type,
)
from config import config
from config.info import get_default_company
from database.model import EntityVisibility
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, ArtifactModes, TaskStatus
from database.utils import get_options
from tools import safe_get
from utilities import json
from .user import _ensure_backend_user
class PrePopulate:
@classmethod
def export_to_zip(
cls, filename: str, experiments: List[str] = None, projects: List[str] = None
):
with ZipFile(filename, mode="w", compression=ZIP_BZIP2) as zfile:
cls._export(zfile, experiments, projects)
event_bll = EventBLL()
events_file_suffix = "_events"
export_tag_prefix = "Exported:"
export_tag = f"{export_tag_prefix} %Y-%m-%d %H:%M:%S"
metadata_filename = "metadata.json"
zip_args = dict(mode="w", compression=ZIP_BZIP2)
artifacts_ext = ".artifacts"
img_source_regex = re.compile(
r"['\"]source['\"]:\s?['\"](https?://(?:localhost:8081|files.*?)/.*?)['\"]",
flags=re.IGNORECASE,
)
class JsonLinesWriter:
def __init__(self, file: BinaryIO):
self.file = file
self.empty = True
def __enter__(self):
self._write("[")
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self._write("\n]")
def _write(self, data: str):
self.file.write(data.encode("utf-8"))
def write(self, line: str):
if not self.empty:
self._write(",")
self._write("\n" + line)
self.empty = False
@staticmethod
def _get_last_update_time(entity) -> datetime:
return getattr(entity, "last_update", None) or getattr(entity, "created")
@classmethod
def import_from_zip(cls, filename: str, user_id: str = None):
def _check_for_update(
cls, map_file: Path, entities: dict, metadata_hash: str
) -> Tuple[bool, Sequence[str]]:
if not map_file.is_file():
return True, []
files = []
try:
map_data = json.loads(map_file.read_text())
files = map_data.get("files", [])
for file in files:
if not Path(file).is_file():
return True, files
new_times = {
item.id: cls._get_last_update_time(item).replace(tzinfo=timezone.utc)
for item in chain.from_iterable(entities.values())
}
old_times = map_data.get("entities", {})
if set(new_times.keys()) != set(old_times.keys()):
return True, files
for id_, new_timestamp in new_times.items():
if new_timestamp != old_times[id_]:
return True, files
if metadata_hash != map_data.get("metadata_hash", ""):
return True, files
except Exception as ex:
print("Error reading map file. " + str(ex))
return True, files
return False, files
@classmethod
def _write_update_file(
cls,
map_file: Path,
entities: dict,
created_files: Sequence[str],
metadata_hash: str,
):
map_file.write_text(
json.dumps(
dict(
files=created_files,
entities={
entity.id: cls._get_last_update_time(entity)
for entity in chain.from_iterable(entities.values())
},
metadata_hash=metadata_hash,
)
)
)
@staticmethod
def _filter_artifacts(artifacts: Sequence[str]) -> Sequence[str]:
def is_fileserver_link(a: str) -> bool:
a = a.lower()
if a.startswith("https://files."):
return True
if a.startswith("http"):
parsed = urlparse(a)
if parsed.scheme in {"http", "https"} and parsed.netloc.endswith(
"8081"
):
return True
return False
fileserver_links = [a for a in artifacts if is_fileserver_link(a)]
print(
f"Found {len(fileserver_links)} files on the fileserver from {len(artifacts)} total"
)
return fileserver_links
@classmethod
def export_to_zip(
cls,
filename: str,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
artifacts_path: str = None,
task_statuses: Sequence[str] = None,
tag_exported_entities: bool = False,
metadata: Mapping[str, Any] = None,
) -> Sequence[str]:
if task_statuses and not set(task_statuses).issubset(get_options(TaskStatus)):
raise ValueError("Invalid task statuses")
file = Path(filename)
entities = cls._resolve_entities(
experiments=experiments, projects=projects, task_statuses=task_statuses
)
hash_ = hashlib.md5()
if metadata:
meta_str = json.dumps(metadata)
hash_.update(meta_str.encode())
metadata_hash = hash_.hexdigest()
else:
meta_str, metadata_hash = "", ""
map_file = file.with_suffix(".map")
updated, old_files = cls._check_for_update(
map_file, entities=entities, metadata_hash=metadata_hash
)
if not updated:
print(f"There are no updates from the last export")
return old_files
for old in old_files:
old_path = Path(old)
if old_path.is_file():
old_path.unlink()
with ZipFile(file, **cls.zip_args) as zfile:
if metadata:
zfile.writestr(cls.metadata_filename, meta_str)
artifacts = cls._export(
zfile,
entities=entities,
hash_=hash_,
tag_entities=tag_exported_entities,
)
file_with_hash = file.with_name(f"{file.stem}_{hash_.hexdigest()}{file.suffix}")
file.replace(file_with_hash)
created_files = [str(file_with_hash)]
artifacts = cls._filter_artifacts(artifacts)
if artifacts and artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = file_with_hash.with_suffix(cls.artifacts_ext)
with ZipFile(artifacts_file, **cls.zip_args) as zfile:
cls._export_artifacts(zfile, artifacts, artifacts_path)
created_files.append(str(artifacts_file))
cls._write_update_file(
map_file,
entities=entities,
created_files=created_files,
metadata_hash=metadata_hash,
)
return created_files
@classmethod
def import_from_zip(
cls,
filename: str,
artifacts_path: str,
company_id: Optional[str] = None,
user_id: str = "",
user_name: str = "",
):
metadata = None
with ZipFile(filename) as zfile:
cls._import(zfile, user_id)
try:
with zfile.open(cls.metadata_filename) as f:
metadata = json.loads(f.read())
meta_public = metadata.get("public")
if company_id is None and meta_public is not None:
company_id = "" if meta_public else get_default_company()
if not user_id:
meta_user_id = metadata.get("user_id", "")
meta_user_name = metadata.get("user_name", "")
user_id, user_name = meta_user_id, meta_user_name
except Exception:
pass
if not user_id:
user_id, user_name = "__allegroai__", "Allegro.ai"
# Make sure we won't end up with an invalid company ID
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
user_id = _ensure_backend_user(
user_id=user_id, user_name=user_name, company_id="",
)
cls._import(zfile, company_id, user_id, metadata)
if artifacts_path and os.path.isdir(artifacts_path):
artifacts_file = Path(filename).with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
print(f"Unzipping artifacts into {artifacts_path}")
with ZipFile(artifacts_file) as zfile:
zfile.extractall(artifacts_path)
@classmethod
def upgrade_zip(cls, filename) -> Sequence:
hash_ = hashlib.md5()
task_file = cls._get_base_filename(Task) + ".json"
temp_file = Path("temp.zip")
file = Path(filename)
with ZipFile(file) as reader, ZipFile(temp_file, **cls.zip_args) as writer:
for file_info in reader.filelist:
if file_info.orig_filename == task_file:
with reader.open(file_info) as f:
content = cls._upgrade_tasks(f)
else:
content = reader.read(file_info)
writer.writestr(file_info, content)
hash_.update(content)
base_file_name, _, old_hash = file.stem.rpartition("_")
new_hash = hash_.hexdigest()
if old_hash == new_hash:
print(f"The file {filename} was not updated")
temp_file.unlink()
return []
new_file = file.with_name(f"{base_file_name}_{new_hash}{file.suffix}")
temp_file.replace(new_file)
upadated = [str(new_file)]
artifacts_file = file.with_suffix(cls.artifacts_ext)
if artifacts_file.is_file():
new_artifacts = new_file.with_suffix(cls.artifacts_ext)
artifacts_file.replace(new_artifacts)
upadated.append(str(new_artifacts))
print(f"File {str(file)} replaced with {str(new_file)}")
file.unlink()
return upadated
@staticmethod
def _upgrade_task_data(task_data: dict):
for old_param_field, new_param_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
):
legacy = safe_get(task_data, old_param_field)
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_param_field, section, name)))
if not safe_get(task_data, new_path):
new_param = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_param["section"] = section
dpath.new(task_data, new_path, new_param)
dpath.delete(task_data, old_param_field)
@classmethod
def _upgrade_tasks(cls, f: BinaryIO) -> bytes:
"""
Build content array that contains fixed tasks from the passed file
For each task the old execution.parameters and model.design are
converted to the new structure.
The fix is done on Task objects (not the dictionary) so that
the fields are serialized back in the same order as they were in the original file
"""
with BytesIO() as temp:
with cls.JsonLinesWriter(temp) as w:
for line in cls.json_lines(f):
task_data = Task.from_json(line).to_proper_dict()
cls._upgrade_task_data(task_data)
new_task = Task(**task_data)
w.write(new_task.to_json())
return temp.getvalue()
@classmethod
def update_featured_projects_order(cls):
featured_order = config.get("services.projects.featured_order", [])
if not featured_order:
return
def get_index(p: Project):
for index, entry in enumerate(featured_order):
if (
entry.get("id", None) == p.id
or entry.get("name", None) == p.name
or ("name_regex" in entry and re.match(entry["name_regex"], p.name))
):
return index
return 999
for project in Project.get_many_public(projection=["id", "name"]):
featured_index = get_index(project)
Project.objects(id=project.id).update(featured=featured_index)
@staticmethod
def _resolve_type(
cls: Type[mongoengine.Document], ids: Optional[List[str]]
) -> List[Any]:
cls: Type[mongoengine.Document], ids: Optional[Sequence[str]]
) -> Sequence[Any]:
ids = set(ids)
items = list(cls.objects(id__in=list(ids)))
resolved = {i.id for i in items}
@@ -43,20 +402,24 @@ class PrePopulate:
@classmethod
def _resolve_entities(
cls, experiments: List[str] = None, projects: List[str] = None
cls,
experiments: Sequence[str] = None,
projects: Sequence[str] = None,
task_statuses: Sequence[str] = None,
) -> Dict[Type[mongoengine.Document], Set[mongoengine.Document]]:
from database.model.project import Project
from database.model.task.task import Task
entities = defaultdict(set)
if projects:
print("Reading projects...")
entities[Project].update(cls._resolve_type(Project, projects))
print("--> Reading project experiments...")
objs = Task.objects(
project__in=list(set(filter(None, (p.id for p in entities[Project]))))
query = Q(
project__in=list(set(filter(None, (p.id for p in entities[Project])))),
system_tags__nin=[EntityVisibility.archived.value],
)
if task_statuses:
query &= Q(status__in=list(set(task_statuses)))
objs = Task.objects(query)
entities[Task].update(o for o in objs if o.id not in (experiments or []))
if experiments:
@@ -69,85 +432,297 @@ class PrePopulate:
project_ids = {p.id for p in entities[Project]}
entities[Project].update(o for o in objs if o.id not in project_ids)
model_ids = {
model_id
for task in entities[Task]
for model_id in (task.output.model, task.execution.model)
if model_id
}
if model_ids:
print("Reading models...")
entities[Model] = set(Model.objects(id__in=list(model_ids)))
return entities
@classmethod
def _cleanup_task(cls, task):
from database.model.task.task import TaskStatus
def _filter_out_export_tags(cls, tags: Sequence[str]) -> Sequence[str]:
if not tags:
return tags
return [tag for tag in tags if not tag.startswith(cls.export_tag_prefix)]
task.completed = None
task.started = None
if task.execution:
task.execution.model = None
task.execution.model_desc = None
task.execution.model_labels = None
if task.output:
task.output.model = None
@classmethod
def _cleanup_model(cls, model: Model):
model.company = ""
model.user = ""
model.tags = cls._filter_out_export_tags(model.tags)
task.status = TaskStatus.created
@classmethod
def _cleanup_task(cls, task: Task):
task.comment = "Auto generated by Allegro.ai"
task.created = datetime.utcnow()
task.last_iteration = 0
task.last_update = task.created
task.status_changed = task.created
task.status_message = ""
task.status_reason = ""
task.user = ""
task.company = ""
task.tags = cls._filter_out_export_tags(task.tags)
if task.output:
task.output.destination = None
@classmethod
def _cleanup_project(cls, project: Project):
project.user = ""
project.company = ""
project.tags = cls._filter_out_export_tags(project.tags)
@classmethod
def _cleanup_entity(cls, entity_cls, entity):
from database.model.task.task import Task
if entity_cls == Task:
cls._cleanup_task(entity)
elif entity_cls == Model:
cls._cleanup_model(entity)
elif entity == Project:
cls._cleanup_project(entity)
@classmethod
def _add_tag(cls, items: Sequence[Union[Project, Task, Model]], tag: str):
try:
for item in items:
item.update(upsert=False, tags=sorted(item.tags + [tag]))
except AttributeError:
pass
@classmethod
def _export_task_events(
cls, task: Task, base_filename: str, writer: ZipFile, hash_
) -> Sequence[str]:
artifacts = []
filename = f"{base_filename}_{task.id}{cls.events_file_suffix}.json"
print(f"Writing task events into {writer.filename}:{filename}")
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
scroll_id = None
while True:
res = cls.event_bll.get_task_events(
task.company, task.id, scroll_id=scroll_id
)
if not res.events:
break
scroll_id = res.next_scroll_id
for event in res.events:
event_type = event.get("type")
if event_type == "training_debug_image":
url = cls._get_fixed_url(event.get("url"))
if url:
event["url"] = url
artifacts.append(url)
elif event_type == "plot":
plot_str: str = event.get("plot_str", "")
for match in cls.img_source_regex.findall(plot_str):
url = cls._get_fixed_url(match)
if match != url:
plot_str = plot_str.replace(match, url)
artifacts.append(url)
w.write(json.dumps(event))
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
return artifacts
@staticmethod
def _get_fixed_url(url: Optional[str]) -> Optional[str]:
if not (url and url.lower().startswith("s3://")):
return url
try:
fixed = furl(url)
fixed.scheme = "https"
fixed.host += ".s3.amazonaws.com"
return fixed.url
except Exception as ex:
print(f"Failed processing link {url}. " + str(ex))
return url
@classmethod
def _export_entity_related_data(
cls, entity_cls, entity, base_filename: str, writer: ZipFile, hash_
):
if entity_cls == Task:
return [
*cls._get_task_output_artifacts(entity),
*cls._export_task_events(entity, base_filename, writer, hash_),
]
if entity_cls == Model:
entity.uri = cls._get_fixed_url(entity.uri)
return [entity.uri] if entity.uri else []
return []
@classmethod
def _get_task_output_artifacts(cls, task: Task) -> Sequence[str]:
if not task.execution.artifacts:
return []
for a in task.execution.artifacts:
if a.mode == ArtifactModes.output:
a.uri = cls._get_fixed_url(a.uri)
return [
a.uri
for a in task.execution.artifacts
if a.mode == ArtifactModes.output and a.uri
]
@classmethod
def _export_artifacts(
cls, writer: ZipFile, artifacts: Sequence[str], artifacts_path: str
):
unique_paths = set(unquote(str(furl(artifact).path)) for artifact in artifacts)
print(f"Writing {len(unique_paths)} artifacts into {writer.filename}")
for path in unique_paths:
path = path.lstrip("/")
full_path = os.path.join(artifacts_path, path)
if os.path.isfile(full_path):
writer.write(full_path, path)
else:
print(f"Artifact {full_path} not found")
@staticmethod
def _get_base_filename(cls_: type):
return f"{cls_.__module__}.{cls_.__name__}"
@classmethod
def _export(
cls, writer: ZipFile, experiments: List[str] = None, projects: List[str] = None
):
entities = cls._resolve_entities(experiments, projects)
for cls_, items in entities.items():
cls, writer: ZipFile, entities: dict, hash_, tag_entities: bool = False
) -> Sequence[str]:
"""
Export the requested experiments, projects and models and return the list of artifact files
Always do the export on sorted items since the order of items influence hash
"""
artifacts = []
now = datetime.utcnow()
for cls_ in sorted(entities, key=attrgetter("__name__")):
items = sorted(entities[cls_], key=attrgetter("id"))
if not items:
continue
filename = f"{cls_.__module__}.{cls_.__name__}.json"
base_filename = cls._get_base_filename(cls_)
for item in items:
artifacts.extend(
cls._export_entity_related_data(
cls_, item, base_filename, writer, hash_
)
)
filename = base_filename + ".json"
print(f"Writing {len(items)} items into {writer.filename}:{filename}")
with writer.open(filename, "w") as f:
f.write("[\n".encode("utf-8"))
last = len(items) - 1
for i, item in enumerate(items):
cls._cleanup_entity(cls_, item)
f.write(item.to_json().encode("utf-8"))
if i != last:
f.write(",".encode("utf-8"))
f.write("\n".encode("utf-8"))
f.write("]\n".encode("utf-8"))
with BytesIO() as f:
with cls.JsonLinesWriter(f) as w:
for item in items:
cls._cleanup_entity(cls_, item)
w.write(item.to_json())
data = f.getvalue()
hash_.update(data)
writer.writestr(filename, data)
if tag_entities:
cls._add_tag(items, now.strftime(cls.export_tag))
return artifacts
@staticmethod
def _import(reader: ZipFile, user_id: str = None):
for file_info in reader.filelist:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
def json_lines(file: BinaryIO):
for line in file:
clean = (
line.decode("utf-8")
.rstrip("\r\n")
.strip()
.lstrip("[")
.rstrip(",]")
.strip()
)
if not clean:
continue
yield clean
with reader.open(file_info) as f:
for item in tqdm(
f.readlines(),
desc=f"Writing {cls_.__name__.lower()}s into database",
unit="doc",
):
item = (
item.decode("utf-8")
.strip()
.lstrip("[")
.rstrip("]")
.rstrip(",")
.strip()
)
if not item:
continue
doc = cls_.from_json(item)
if user_id is not None and hasattr(doc, "user"):
doc.user = user_id
doc.save(force_insert=True)
@classmethod
def _import(
cls,
reader: ZipFile,
company_id: str = "",
user_id: str = None,
metadata: Mapping[str, Any] = None,
):
"""
Import entities and events from the zip file
Start from entities since event import will require the tasks already in DB
"""
event_file_ending = cls.events_file_suffix + ".json"
entity_files = (
fi
for fi in reader.filelist
if not fi.orig_filename.endswith(event_file_ending)
and fi.orig_filename != cls.metadata_filename
)
event_files = (
fi for fi in reader.filelist if fi.orig_filename.endswith(event_file_ending)
)
for files, reader_func in (
(entity_files, partial(cls._import_entity, metadata=metadata or {})),
(event_files, cls._import_events),
):
for file_info in files:
with reader.open(file_info) as f:
full_name = splitext(file_info.orig_filename)[0]
print(f"Reading {reader.filename}:{full_name}...")
reader_func(f, full_name, company_id, user_id)
@classmethod
def _import_entity(
cls,
f: BinaryIO,
full_name: str,
company_id: str,
user_id: str,
metadata: Mapping[str, Any],
):
module_name, _, class_name = full_name.rpartition(".")
module = importlib.import_module(module_name)
cls_: Type[mongoengine.Document] = getattr(module, class_name)
print(f"Writing {cls_.__name__.lower()}s into database")
override_project_count = 0
for item in cls.json_lines(f):
doc = cls_.from_json(item, created=True)
if hasattr(doc, "user"):
doc.user = user_id
if hasattr(doc, "company"):
doc.company = company_id
if isinstance(doc, Project):
override_project_name = metadata.get("project_name", None)
if override_project_name:
if override_project_count:
override_project_name = (
f"{override_project_name} {override_project_count + 1}"
)
override_project_count += 1
doc.name = override_project_name
doc.logo_url = metadata.get("logo_url", None)
doc.logo_blob = metadata.get("logo_blob", None)
cls_.objects(company=company_id, name=doc.name, id__ne=doc.id).update(
set__name=f"{doc.name}_{datetime.utcnow().strftime('%Y-%m-%d_%H-%M-%S')}"
)
doc.save()
if isinstance(doc, Task):
cls.event_bll.delete_task_events(company_id, doc.id, allow_locked=True)
@classmethod
def _import_events(cls, f: BinaryIO, full_name: str, company_id: str, _):
_, _, task_id = full_name[0 : -len(cls.events_file_suffix)].rpartition("_")
print(f"Writing events for task {task_id} into database")
for events_chunk in chunked_iter(cls.json_lines(f), 1000):
events = [json.loads(item) for item in events_chunk]
cls.event_bll.add_events(
company_id, events=events, worker="", allow_locked_tasks=True
)

View File

@@ -58,15 +58,23 @@ def _ensure_backend_user(user_id: str, company_id: str, user_name: str):
return user_id
def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
if User.objects(id=user.user_id).first():
def ensure_fixed_user(user: FixedUser, log: Logger):
db_user = User.objects(company=user.company, id=user.user_id).first()
if db_user:
# noinspection PyBroadException
try:
log.info(f"Updating user name: {user.name}")
given_name, _, family_name = user.name.partition(" ")
db_user.update(name=user.name, given_name=given_name, family_name=family_name)
except Exception:
pass
return
data = attr.asdict(user)
data["id"] = user.user_id
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user
data["role"] = Role.guest if user.is_guest else Role.user
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
_ensure_auth_user(user_data=data, company_id=user.company, log=log)
return _ensure_backend_user(user.user_id, company_id, user.name)
return _ensure_backend_user(user.user_id, user.company, user.name)

View File

@@ -3,7 +3,6 @@ from uuid import uuid4
from bll.queue import QueueBLL
from config import config
from config.info import get_default_company
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings, SettingKeys
@@ -11,13 +10,11 @@ from database.model.settings import Settings, SettingKeys
log = config.logger(__file__)
def _ensure_company(log: Logger):
company_id = get_default_company()
def _ensure_company(company_id, company_name, log: Logger):
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()

View File

@@ -0,0 +1,36 @@
from pymongo.database import Database, Collection
from bll.task.param_utils import (
hyperparams_legacy_type,
hyperparams_default_section,
split_param_name,
)
from tools import safe_get
def migrate_backend(db: Database):
hyperparam_fields = ("execution.parameters", "hyperparams")
configuration_fields = ("execution.model_desc", "configuration")
collection: Collection = db["task"]
for doc in collection.find(projection=hyperparam_fields + configuration_fields):
set_commands = {}
for (old_field, new_field), default_section in zip(
(hyperparam_fields, configuration_fields),
(hyperparams_default_section, None),
):
legacy = safe_get(doc, old_field, separator=".")
if not legacy:
continue
for full_name, value in legacy.items():
section, name = split_param_name(full_name, default_section)
new_path = list(filter(None, (new_field, section, name)))
# if safe_get(doc, new_path) is not None:
# continue
new_value = dict(
name=name, type=hyperparams_legacy_type, value=str(value)
)
if section is not None:
new_value["section"] = section
set_commands[".".join(new_path)] = new_value
if set_commands:
collection.update_one({"_id": doc["_id"]}, {"$set": set_commands})

View File

@@ -0,0 +1,11 @@
from pymongo.database import Database, Collection
def migrate_backend(db: Database):
collection: Collection = db["project"]
featured = "featured"
query = {featured: {"$exists": False}}
for doc in collection.find(filter=query, projection=()):
collection.update_one(
{"_id": doc["_id"]}, {"$set": {featured: 9999}},
)

View File

@@ -1,7 +1,8 @@
attrs>=19.1.0
boltons>=19.1.0
boto3==1.14.13
dpath>=1.4.2,<2.0
elasticsearch>=5.0.0,<6.0.0
elasticsearch>=7.0.0,<8.0.0
fastjsonschema>=2.8
Flask-Compress>=1.4.0
Flask-Cors>=3.0.5
@@ -24,7 +25,7 @@ python-rapidjson>=0.6.3
redis>=2.10.5
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.0,<3
semantic_version>=2.8.3,<3
six
tqdm
validators>=0.12.4

View File

@@ -328,6 +328,11 @@ fixed_users_mode {
description: "Fixed users mode enabled"
type: boolean
}
server_errors {
description: "Server initialization errors"
type: object
additionalProperties: True
}
}
}
}

View File

@@ -0,0 +1,16 @@
_description: "debugging utilities"
ping {
authorize: false
"2.9" {
description: "Ping server"
request {
type: object
additionalProperties: true
}
response {
type: object
properties: {
}
}
}
}

View File

@@ -530,59 +530,56 @@
}
}
}
// "2.7" {
// description: "Get 'log' events for this task"
// request {
// type: object
// required: [
// task
// ]
// properties {
// task {
// type: string
// description: "Task ID"
// }
// batch_size {
// type: integer
// description: "The amount of log events to return"
// }
// navigate_earlier {
// type: boolean
// description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order). Otherwise from the earliest to the latest ones (in timestamp ascending order). The default is True"
// }
// refresh {
// type: boolean
// description: "If set then scroll will be moved to the latest logs (if 'navigate_earlier' is set to True) or to the earliest (otherwise)"
// }
// scroll_id {
// type: string
// description: "Scroll ID of previous call (used for getting more results)"
// }
// }
// }
// response {
// type: object
// properties {
// events {
// type: array
// items { type: object }
// description: "Log items list"
// }
// returned {
// type: integer
// description: "Number of log events returned"
// }
// total {
// type: number
// description: "Total number of log events available for this query"
// }
// scroll_id {
// type: string
// description: "Scroll ID for getting more results"
// }
// }
// }
// }
"2.9" {
description: "Get 'log' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
batch_size {
type: integer
description: "The amount of log events to return"
}
navigate_earlier {
type: boolean
description: "If set then log events are retreived from the latest to the earliest ones (in timestamp descending order, unless order='asc'). Otherwise from the earliest to the latest ones (in timestamp ascending order, unless order='desc'). The default is True"
}
from_timestamp {
type: number
description: "Epoch time in UTC ms to use as the navigation start. Optional. If not provided, reference timestamp is determined by the 'navigate_earlier' parameter (if true, reference timestamp is the last timestamp and if false, reference timestamp is the first timestamp)"
}
order {
type: string
description: "If set, changes the order in which log events are returned based on the value of 'navigate_earlier'"
enum: [asc, desc]
}
}
}
response {
type: object
properties {
events {
type: array
items { type: object }
description: "Log items list"
}
returned {
type: integer
description: "Number of log events returned"
}
total {
type: number
description: "Total number of log events available for this query"
}
}
}
}
}
get_task_events {
"2.1" {
@@ -856,7 +853,7 @@
description: "Task ID"
}
samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 6000."
type: integer
}
key {
@@ -894,7 +891,7 @@
]
properties {
tasks {
description: "List of task Task IDs"
description: "List of task Task IDs. Maximum amount of tasks is 10"
type: array
items {
type: string
@@ -902,7 +899,7 @@
}
}
samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
description: "The amount of histogram points to return. Optional, the default value is 6000"
type: integer
}
key {

File diff suppressed because it is too large Load Diff

View File

@@ -405,6 +405,11 @@ get_all_ex {
enum: [ active, archived ]
default: active
}
non_public {
description: "Return only non-public projects"
type: boolean
default: false
}
}
}
}
@@ -527,8 +532,8 @@ get_unique_metric_variants {
}
}
get_hyper_parameters {
"2.2" {
description: """Get a list of all hyper parameter names used in tasks within the given project."""
"2.9" {
description: """Get a list of all hyper parameter sections and names used in tasks within the given project."""
request {
type: object
properties {
@@ -552,9 +557,9 @@ get_hyper_parameters {
type: object
properties {
parameters {
description: "A list of hyper parameter names"
description: "A list of parameter sections and names"
type: array
items {type: string}
items {type: object}
}
remaining {
description: "Remaining results"
@@ -568,6 +573,7 @@ get_hyper_parameters {
}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"
@@ -575,10 +581,61 @@ get_task_tags {
response = ${_definitions.tags_response}
}
}
get_model_tags {
"2.8" {
description: "Get user and system tags used for the models under the specified projects"
request = ${_definitions.tags_request}
response = ${_definitions.tags_response}
}
}
make_public {
"2.9" {
description: """Convert company projects to public"""
request {
type: object
properties {
ids {
description: "Ids of the projects to convert"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated"
type: integer
}
}
}
}
}
make_private {
"2.9" {
description: """Convert public projects to private"""
request {
type: object
properties {
ids {
description: "Ids of the projects to convert. Only the projects originated by the company can be converted"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated"
type: integer
}
}
}
}
}

View File

@@ -297,7 +297,80 @@ _definitions {
"$ref": "#/definitions/last_metrics_event"
}
}
params_item {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. The combination of section and name should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
configuration_item {
type: object
properties {
name {
description: "Name of the parameter. Should be unique"
type: string
}
value {
description: "Value of the parameter"
type: string
}
type {
description: "Type of the parameter. Optional"
type: string
}
description {
description: "The parameter description. Optional"
type: string
}
}
}
param_key {
type: object
properties {
section {
description: "Section that the parameter belongs to"
type: string
}
name {
description: "Name of the parameter. If the name is ommitted then the corresponding operation is performed on the whole section"
type: string
}
}
}
section_params {
description: "Task section params"
type: object
additionalProperties {
"$ref": "#/definitions/params_item"
}
}
replace_hyperparams_enum {
type: string
enum: [
none,
section,
all
]
}
task {
type: object
properties {
@@ -418,9 +491,24 @@ _definitions {
"$ref": "#/definitions/last_metrics_variants"
}
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
}
}
}
get_by_id {
"2.1" {
description: "Gets task information"
@@ -625,6 +713,20 @@ clone {
description: "The project of the cloned task. If not provided then taken from the original task"
type: string
}
new_task_hyperparams {
description: "The hyper params for the new task. If not provided then taken from the original task"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
new_task_configuration {
description: "The configuration for the new task. If not provided then taken from the original task"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
execution_overrides {
description: "The execution params for the cloned task. The params not specified are taken from the original task"
"$ref": "#/definitions/execution"
@@ -698,6 +800,20 @@ create {
description: "Script info"
"$ref": "#/definitions/script"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
}
}
response {
@@ -759,6 +875,20 @@ validate {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
@@ -909,6 +1039,20 @@ edit {
description: "Task execution params"
"$ref": "#/definitions/execution"
}
hyperparams {
description: "Task hyper params per section"
type: object
additionalProperties {
"$ref": "#/definitions/section_params"
}
}
configuration {
description: "Task configuration params"
type: object
additionalProperties {
"$ref": "#/definitions/configuration_item"
}
}
script {
description: "Script info"
"$ref": "#/definitions/script"
@@ -1441,4 +1585,263 @@ add_or_update_artifacts {
}
}
}
}
make_public {
"2.9" {
description: """Convert company tasks to public"""
request {
type: object
properties {
ids {
description: "Ids of the tasks to convert"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated"
type: integer
}
}
}
}
}
make_private {
"2.9" {
description: """Convert public tasks to private"""
request {
type: object
properties {
ids {
description: "Ids of the tasks to convert. Only the tasks originated by the company can be converted"
type: array
items { type: string}
}
}
}
response {
type: object
properties {
updated {
description: "Number of tasks updated"
type: integer
}
}
}
}
}
get_hyper_params {
"2.9": {
description: "Get the list of task hyper parameters"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
params {
type: object
description: "Hyper parameters (keyed by task ID)"
}
}
}
}
}
edit_hyper_params {
"2.9" {
description: "Add or update task hyper parameters"
request {
type: object
required: [ task, hyperparams ]
properties {
task {
description: "Task ID"
type: string
}
hyperparams {
description: "Task hyper parameters. The new ones will be added and the already existing ones will be updated"
type: array
items {"$ref": "#/definitions/params_item"}
}
replace_hyperparams {
description: """Can be set to one of the following:
'all' - all the hyper parameters will be replaced with the provided ones
'section' - the sections that present in the new parameters will be replaced with the provided parameters
'none' (the default value) - only the specific parameters will be updated or added"""
"$ref": "#/definitions/replace_hyperparams_enum"
}
}
}
response {
type: object
properties {
updated {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
delete_hyper_params {
"2.9": {
description: "Delete task hyper parameters"
request {
type: object
required: [ task, hyperparams ]
properties {
task {
description: "Task ID"
type: string
}
hyperparams {
description: "List of hyper parameters to delete. In case a parameter with an empty name is passed all the section will be deleted"
type: array
items { "$ref": "#/definitions/param_key" }
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
get_configurations {
"2.9": {
description: "Get the list of task configurations"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
names {
description: "Names of the configuration items to retreive. If not passed or empty then all the configurations will be retreived."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
configurations {
type: object
description: "Configurations (keyed by task ID)"
}
}
}
}
}
get_configuration_names {
"2.9": {
description: "Get the list of task configuration items names"
request {
type: object
required: [tasks]
properties {
tasks {
description: "Task IDs"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
configurations {
type: object
description: "Names of task configuration items (keyed by task ID)"
}
}
}
}
}
edit_configuration {
"2.9" {
description: "Add or update task configuration"
request {
type: object
required: [ task, configuration ]
properties {
task {
description: "Task ID"
type: string
}
configuration {
description: "Task configuration items. The new ones will be added and the already existing ones will be updated"
type: array
items {"$ref": "#/definitions/configuration_item"}
}
replace_configuration {
description: "If set then the all the configuration items will be replaced with the provided ones. Otherwise only the provided configuration items will be updated or added"
type: boolean
}
}
}
response {
type: object
properties {
updated {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}
delete_configuration {
"2.9": {
description: "Delete task configuration items"
request {
type: object
required: [ task, configuration ]
properties {
task {
description: "Task ID"
type: string
}
configuration {
description: "List of configuration itemss to delete"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates if the task was updated successfully"
type: integer
}
}
}
}
}

View File

@@ -135,6 +135,10 @@
description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry"
}
project {
description: "Project in which currently executing task resides"
"$ref": "#/definitions/id_name_entry"
}
queue {
description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry"
@@ -144,6 +148,11 @@
type: array
items { "$ref": "#/definitions/queue_entry" }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
@@ -151,11 +160,11 @@
type: object
properties {
id {
description: "Worker ID"
description: "ID"
type: string
}
name {
description: "Worker name"
description: "Name"
type: string
}
}
@@ -301,6 +310,11 @@
type: array
items { type: string }
}
tags {
description: "User tags for the worker"
type: array
items: { type: string }
}
}
}
response {
@@ -363,6 +377,11 @@
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
tags {
description: "New user tags for the worker"
type: array
items: { type: string }
}
}
}
response {

View File

@@ -1,20 +1,28 @@
import atexit
from argparse import ArgumentParser
from hashlib import md5
from flask import Flask, request, Response
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from werkzeug.exceptions import BadRequest
import database
from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config
from elastic.initialize import init_es_data
from mongo.initialize import init_mongo_data
from config import config, info
from elastic.initialize import init_es_data, check_elastic_empty, ElasticConnectionError
from mongo.initialize import (
init_mongo_data,
pre_populate_data,
check_mongo_empty,
get_last_server_version,
)
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
from sync import distributed_lock
from timing_context import TimingContext
from updates import check_updates_thread
from utilities import json
@@ -33,8 +41,47 @@ app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
database.initialize()
init_es_data()
init_mongo_data()
# build a key that uniquely identifies specific mongo instance
hosts_string = ";".join(sorted(database.get_hosts()))
key = "db_init_" + md5(hosts_string.encode()).hexdigest()
with distributed_lock(key, timeout=config.get("apiserver.db_init_timout", 120)):
upgrade_monitoring = config.get(
"apiserver.elastic.upgrade_monitoring.v16_migration_verification", True
)
try:
empty_es = check_elastic_empty()
except ElasticConnectionError as err:
if not upgrade_monitoring:
raise
log.error(err)
info.es_connection_error = True
empty_db = check_mongo_empty()
if (
upgrade_monitoring
and not empty_db
and (info.es_connection_error or empty_es)
and get_last_server_version() < Version("0.16.0")
):
log.info(f"ES database seems not migrated")
info.missed_es_upgrade = True
if info.es_connection_error and not info.missed_es_upgrade:
raise Exception(
"Error starting server: failed connecting to ElasticSearch service"
)
if not info.missed_es_upgrade:
init_es_data()
init_mongo_data()
if (
not info.missed_es_upgrade
and empty_db
and config.get("apiserver.pre_populate.enabled", False)
):
pre_populate_data()
ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")

View File

@@ -8,7 +8,7 @@ from .endpoint import EndpointFunc, Endpoint
from .service_repo import ServiceRepo
__all__ = ["endpoint"]
__all__ = ["APICall", "endpoint"]
LegacyEndpointFunc = Callable[[APICall], None]

View File

@@ -69,6 +69,10 @@ def authorize_credentials(auth_data, service, action, call_data_items):
if fixed_user:
if secret_key != fixed_user.password:
raise errors.unauthorized.InvalidCredentials('bad username or password')
if fixed_user.is_guest and not FixedUser.is_guest_endpoint(service, action):
raise errors.unauthorized.InvalidCredentials('endpoint not allowed for guest')
query = Q(id=fixed_user.user_id)
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):

View File

@@ -1,14 +1,12 @@
import hashlib
from functools import lru_cache
from typing import Sequence, TypeVar
from typing import Sequence, Optional
import attr
from config import config
from config.info import get_default_company
T = TypeVar("T", bound="FixedUser")
class FixedUsersError(Exception):
pass
@@ -21,6 +19,8 @@ class FixedUser:
name: str
company: str = get_default_company()
is_guest: bool = False
def __attrs_post_init__(self):
self.user_id = hashlib.md5(f"{self.company}:{self.username}".encode()).hexdigest()
@@ -28,6 +28,10 @@ class FixedUser:
def enabled(cls):
return config.get("apiserver.auth.fixed_users.enabled", False)
@classmethod
def guest_enabled(cls):
return cls.enabled() and config.get("services.auth.fixed_users.guest.enabled", False)
@classmethod
def validate(cls):
if not cls.enabled():
@@ -39,18 +43,50 @@ class FixedUser:
)
@classmethod
@lru_cache()
def from_config(cls) -> Sequence[T]:
return [
# @lru_cache()
def from_config(cls) -> Sequence["FixedUser"]:
users = [
cls(**user) for user in config.get("apiserver.auth.fixed_users.users", [])
]
if cls.guest_enabled():
users.insert(
0,
cls.get_guest_user()
)
return users
@classmethod
@lru_cache()
def get_by_username(cls, username) -> T:
def get_by_username(cls, username) -> "FixedUser":
return next(
(user for user in cls.from_config() if user.username == username), None
)
@classmethod
@lru_cache()
def is_guest_endpoint(cls, service, action):
"""
Validate a potential guest user,
This method will verify the user is indeed the guest user,
and that the guest user may access the service/action using its username/password
"""
return any(
ep == ".".join((service, action))
for ep in config.get("services.auth.fixed_users.guest.allow_endpoints", [])
)
@classmethod
def get_guest_user(cls) -> Optional["FixedUser"]:
if cls.guest_enabled():
return cls(
is_guest=True,
username=config.get("services.auth.fixed_users.guest.username"),
password=config.get("services.auth.fixed_users.guest.password"),
name=config.get("services.auth.fixed_users.guest.name"),
company=config.get("services.auth.fixed_users.guest.default_company"),
)
def __hash__(self):
return hash(self.user_id)

View File

@@ -16,7 +16,7 @@ from apimodels.auth import (
)
from apimodels.base import UpdateResponse
from bll.auth import AuthBLL
from config import config
from config import config, info
from database.errors import translate_errors_context
from database.model.auth import User
from service_repo import APICall, endpoint
@@ -176,4 +176,24 @@ def update(call, company_id, _):
@endpoint("auth.fixed_users_mode")
def fixed_users_mode(call: APICall, *_, **__):
call.result.data = dict(enabled=FixedUser.enabled())
server_errors = {
name: error
for name, error in zip(
("missed_es_upgrade", "es_connection_error"),
(info.missed_es_upgrade, info.es_connection_error),
)
if error
}
data = {
"enabled": FixedUser.enabled(),
"guest": {"enabled": FixedUser.guest_enabled()},
"server_errors": server_errors,
}
guest_user = FixedUser.get_guest_user()
if guest_user:
data["guest"]["name"] = guest_user.name
data["guest"]["username"] = guest_user.username
data["guest"]["password"] = guest_user.password
call.result.data = data

6
server/services/debug.py Normal file
View File

@@ -0,0 +1,6 @@
from service_repo import APICall, endpoint
@endpoint("debug.ping")
def ping(call: APICall, _, __):
call.result.data = {"msg": "Because it trains cats and dogs"}

View File

@@ -12,6 +12,7 @@ from apimodels.events import (
IterationEvents,
TaskMetricsRequest,
LogEventsRequest,
LogOrderEnum,
)
from bll.event import EventBLL
from bll.event.event_metrics import EventMetrics
@@ -24,7 +25,7 @@ event_bll = EventBLL()
@endpoint("events.add")
def add(call: APICall, company_id, req_model):
def add(call: APICall, company_id, _):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, err_count, err_info = event_bll.add_events(
@@ -35,7 +36,7 @@ def add(call: APICall, company_id, req_model):
@endpoint("events.add_batch")
def add_batch(call: APICall, company_id, req_model):
def add_batch(call: APICall, company_id, _):
events = call.batched_data
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
@@ -46,14 +47,16 @@ def add_batch(call: APICall, company_id, req_model):
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log_v1_5(call, company_id, req_model):
def get_task_log_v1_5(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id,
task.company,
task_id,
order,
event_type="log",
@@ -66,9 +69,11 @@ def get_task_log_v1_5(call, company_id, req_model):
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, req_model):
def get_task_log_v1_7(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
order = call.data.get("order") or "desc"
from_ = call.data.get("from") or "head"
@@ -78,7 +83,7 @@ def get_task_log_v1_7(call, company_id, req_model):
scroll_order = "asc" if (from_ == "head") else "desc"
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=company_id,
company_id=task.company,
task_id=task_id,
order=scroll_order,
event_type="log",
@@ -94,33 +99,40 @@ def get_task_log_v1_7(call, company_id, req_model):
)
# uncomment this once the front end is ready
# @endpoint("events.get_task_log", min_version="2.7", request_data_model=LogEventsRequest)
# def get_task_log(call, company_id, req_model: LogEventsRequest):
# task_id = req_model.task
# task_bll.assert_exists(company_id, task_id, allow_public=True)
#
# res = event_bll.log_events_iterator.get_task_events(
# company_id=company_id,
# task_id=task_id,
# batch_size=req_model.batch_size,
# navigate_earlier=req_model.navigate_earlier,
# refresh=req_model.refresh,
# state_id=req_model.scroll_id,
# )
#
# call.result.data = dict(
# events=res.events,
# returned=len(res.events),
# total=res.total_events,
# scroll_id=res.next_scroll_id,
# )
@endpoint("events.get_task_log", min_version="2.9", request_data_model=LogEventsRequest)
def get_task_log(call, company_id, request: LogEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
res = event_bll.log_events_iterator.get_task_events(
company_id=task.company,
task_id=task_id,
batch_size=request.batch_size,
navigate_earlier=request.navigate_earlier,
from_timestamp=request.from_timestamp,
)
if (
request.order and (
(request.navigate_earlier and request.order == LogOrderEnum.asc)
or (not request.navigate_earlier and request.order == LogOrderEnum.desc)
)
):
res.events.reverse()
call.result.data = dict(
events=res.events, returned=len(res.events), total=res.total_events
)
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, req_model):
def download_task_log(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
line_type = call.data.get("line_type", "json").lower()
line_format = str(call.data.get("line_format", "{asctime} {worker} {level} {msg}"))
@@ -163,7 +175,7 @@ def download_task_log(call, company_id, req_model):
batch_size = 1000
while True:
log_events, scroll_id, _ = event_bll.scroll_task_events(
company_id,
task.company,
task_id,
order="asc",
event_type="log",
@@ -173,7 +185,7 @@ def download_task_log(call, company_id, req_model):
if not log_events:
break
for ev in log_events:
ev["asctime"] = ev.pop("@timestamp")
ev["asctime"] = ev.pop("timestamp")
if is_json:
ev.pop("type")
ev.pop("task")
@@ -196,23 +208,27 @@ def download_task_log(call, company_id, req_model):
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, req_model):
def get_vector_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
company_id, task_id, "training_stats_vector"
task.company, task_id, "training_stats_vector"
)
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, req_model):
def get_scalar_metrics_and_variants(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(
company_id, task_id, "training_stats_scalar"
task.company, task_id, "training_stats_scalar"
)
)
@@ -222,13 +238,15 @@ def get_scalar_metrics_and_variants(call, company_id, req_model):
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, req_model):
def vector_metrics_iter_histogram(call, company_id, _):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(
company_id, task_id, metric, variant
task.company, task_id, metric, variant
)
call.result.data = dict(
metric=metric, variant=variant, vectors=vectors, iterations=iterations
@@ -243,9 +261,11 @@ def get_task_events(call, company_id, _):
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
@@ -262,14 +282,16 @@ def get_task_events(call, company_id, _):
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, req_model):
def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}],
@@ -286,13 +308,15 @@ def get_scalar_metric_data(call, company_id, req_model):
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, req_model):
def get_task_latest_scalar_values(call, company_id, _):
task_id = call.data["task"]
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
company_id, task_id
task.company, task_id
)
es_index = EventMetrics.get_index_name(company_id, "*")
es_index = EventMetrics.get_index_name(task.company, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,
@@ -309,11 +333,13 @@ def get_task_latest_scalar_values(call, company_id, req_model):
request_data_model=ScalarMetricsIterHistogramRequest,
)
def scalar_metrics_iter_histogram(
call, company_id, req_model: ScalarMetricsIterHistogramRequest
call, company_id, request: ScalarMetricsIterHistogramRequest
):
task_bll.assert_exists(call.identity.company, req_model.task, allow_public=True)
task = task_bll.assert_exists(
company_id, request.task, allow_public=True, only=("company",)
)[0]
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
company_id, task_id=req_model.task, samples=req_model.samples, key=req_model.key
task.company, task_id=request.task, samples=request.samples, key=request.key
)
call.result.data = metrics
@@ -341,21 +367,27 @@ def multi_task_scalar_metrics_iter_histogram(
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, req_model):
def get_multi_task_plots_v1_7(call, company_id, _):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name"),
company_id=company_id,
only=("id", "name", "company"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.company for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id,
next(iter(companies)),
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
@@ -385,13 +417,19 @@ def get_multi_task_plots(call, company_id, req_model):
tasks = task_bll.assert_exists(
company_id=call.identity.company,
only=("id", "name"),
only=("id", "name", "company"),
task_ids=task_ids,
allow_public=True,
)
companies = {t.company for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.get_task_events(
company_id,
next(iter(companies)),
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
@@ -414,12 +452,14 @@ def get_multi_task_plots(call, company_id, req_model):
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, req_model):
def get_task_plots_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
# event_type="plot",
@@ -429,7 +469,7 @@ def get_task_plots_v1_7(call, company_id, req_model):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
@@ -448,14 +488,16 @@ def get_task_plots_v1_7(call, company_id, req_model):
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, req_model):
def get_task_plots(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_plots(
company_id,
task.company,
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
@@ -473,12 +515,14 @@ def get_task_plots(call, company_id, req_model):
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, req_model):
def get_debug_images_v1_7(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company, task_id,
# event_type="training_debug_image",
@@ -488,7 +532,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
@@ -508,14 +552,16 @@ def get_debug_images_v1_7(call, company_id, req_model):
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images_v1_8(call, company_id, req_model):
def get_debug_images_v1_8(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",)
)[0]
result = event_bll.get_task_events(
company_id,
task.company,
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
@@ -540,16 +586,25 @@ def get_debug_images_v1_8(call, company_id, req_model):
request_data_model=DebugImagesRequest,
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, req_model: DebugImagesRequest):
tasks = set(m.task for m in req_model.metrics)
task_bll.assert_exists(call.identity.company, task_ids=tasks, allow_public=True)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_ids = {m.task for m in request.metrics}
tasks = task_bll.assert_exists(
company_id, task_ids=task_ids, allow_public=True, only=("company",)
)
companies = {t.company for t in tasks}
if len(companies) > 1:
raise errors.bad_request.InvalidTaskId(
"only tasks from the same company are supported"
)
result = event_bll.debug_images_iterator.get_task_events(
company_id=company_id,
metrics=[(m.task, m.metric) for m in req_model.metrics],
iter_count=req_model.iters,
navigate_earlier=req_model.navigate_earlier,
refresh=req_model.refresh,
state_id=req_model.scroll_id,
company_id=next(iter(companies)),
metrics=[(m.task, m.metric) for m in request.metrics],
iter_count=request.iters,
navigate_earlier=request.navigate_earlier,
refresh=request.refresh,
state_id=request.scroll_id,
)
call.result.data_model = DebugImageResponse(
@@ -569,12 +624,12 @@ def get_debug_images(call, company_id, req_model: DebugImagesRequest):
@endpoint("events.get_task_metrics", request_data_model=TaskMetricsRequest)
def get_tasks_metrics(call: APICall, company_id, req_model: TaskMetricsRequest):
task_bll.assert_exists(
call.identity.company, task_ids=req_model.tasks, allow_public=True
)
def get_tasks_metrics(call: APICall, company_id, request: TaskMetricsRequest):
task = task_bll.assert_exists(
company_id, task_ids=request.tasks, allow_public=True, only=("company",)
)[0]
res = event_bll.metrics.get_tasks_metrics(
company_id, task_ids=req_model.tasks, event_type=req_model.event_type
task.company, task_ids=request.tasks, event_type=request.event_type
)
call.result.data = {
"metrics": [{"task": task, "metrics": metrics} for (task, metrics) in res]
@@ -586,7 +641,7 @@ def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id)
task_bll.assert_exists(company_id, task_id, return_tasks=False)
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked

View File

@@ -5,14 +5,17 @@ from mongoengine import Q, EmbeddedDocument
import database
from apierrors import errors
from apimodels.base import UpdateResponse
from apierrors.errors.bad_request import InvalidModelId
from apimodels.base import UpdateResponse, MakePublicRequest
from apimodels.models import (
CreateModelRequest,
CreateModelResponse,
PublishModelRequest,
PublishModelResponse,
ModelTaskPublishResponse,
GetFrameworksRequest,
)
from bll.model import ModelBLL
from bll.organization import OrgBLL, Tags
from bll.task import TaskBLL
from config import config
@@ -32,6 +35,7 @@ from timing_context import TimingContext
log = config.logger(__file__)
org_bll = OrgBLL()
model_bll = ModelBLL()
@endpoint("models.get_by_id", required_fields=["model"])
@@ -107,6 +111,15 @@ def get_all(call: APICall, company_id, _):
call.result.data = {"models": models}
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
def get_frameworks(call: APICall, company_id, request: GetFrameworksRequest):
call.result.data = {
"frameworks": sorted(
model_bll.get_frameworks(company_id, project_ids=request.projects)
)
}
create_fields = {
"name": None,
"tags": list,
@@ -455,3 +468,21 @@ def update(call: APICall, company_id, _):
if del_count:
_reset_cached_tags(company_id, projects=[model.project])
call.result.data = dict(deleted=del_count > 0)
@endpoint("models.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public(
company_id, ids=request.ids, invalid_cls=InvalidModelId, enabled=True
)
@endpoint(
"models.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Model.set_public(
company_id, request.ids, invalid_cls=InvalidModelId, enabled=False
)

View File

@@ -8,10 +8,10 @@ from mongoengine import Q
import database
from apierrors import errors
from apimodels.base import UpdateResponse
from apierrors.errors.bad_request import InvalidProjectId
from apimodels.base import UpdateResponse, MakePublicRequest
from apimodels.projects import (
GetHyperParamReq,
GetHyperParamResp,
ProjectReq,
ProjectTagsRequest,
)
@@ -185,6 +185,7 @@ def make_projects_get_all_pipelines(company_id, project_ids, specific_state=None
def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
allow_public = not call.data.get("non_public", False)
if stats_for_state:
try:
@@ -200,7 +201,7 @@ def get_all_ex(call: APICall):
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True,
allow_public=allow_public,
)
conform_output_tags(call, projects)
@@ -375,13 +376,12 @@ def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectR
@endpoint(
"projects.get_hyper_parameters",
min_version="2.2",
min_version="2.9",
request_data_model=GetHyperParamReq,
response_data_model=GetHyperParamResp,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
page=request.page,
@@ -421,3 +421,23 @@ def get_tags(call: APICall, company, request: ProjectTagsRequest):
projects=request.projects,
)
call.result.data = get_tags_response(ret)
@endpoint(
"projects.make_public", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=True
)
@endpoint(
"projects.make_private", min_version="2.9", request_data_model=MakePublicRequest
)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Project.set_public(
company_id, ids=request.ids, invalid_cls=InvalidProjectId, enabled=False
)

View File

@@ -11,7 +11,8 @@ from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
from apierrors import errors, APIError
from apimodels.base import UpdateResponse, IdResponse
from apierrors.errors.bad_request import InvalidTaskId
from apimodels.base import UpdateResponse, IdResponse, MakePublicRequest
from apimodels.tasks import (
StartedResponse,
ResetResponse,
@@ -31,6 +32,13 @@ from apimodels.tasks import (
AddOrUpdateArtifactsResponse,
GetTypesRequest,
ResetRequest,
GetHyperParamsRequest,
EditHyperParamsRequest,
DeleteHyperParamsRequest,
GetConfigurationsRequest,
EditConfigurationRequest,
DeleteConfigurationRequest,
GetConfigurationNamesRequest,
)
from bll.event import EventBLL
from bll.organization import OrgBLL, Tags
@@ -40,9 +48,14 @@ from bll.task import (
ChangeStatusRequest,
update_project_time,
split_by,
ParameterKeyEscaper,
)
from bll.task.hyperparams import HyperParams
from bll.task.non_responsive_tasks_watchdog import NonResponsiveTasksWatchdog
from bll.task.param_utils import (
params_prepare_for_save,
params_unprepare_from_saved,
escape_paths,
)
from bll.util import SetFieldsResolver
from database.errors import translate_errors_context
from database.model.model import Model
@@ -56,9 +69,9 @@ from database.model.task.task import (
)
from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from service_repo.base import PartialVersion
from services.utils import conform_tag_fields, conform_output_tags, validate_tags
from timing_context import TimingContext
from utilities import safe_get
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
@@ -78,10 +91,24 @@ def set_task_status_from_call(
task = TaskBLL.get_task_with_access(
request.task,
company_id=company_id,
only=tuple({"status", "project"} | fields_resolver.get_names()),
only=tuple(
{"status", "project", "started", "duration"} | fields_resolver.get_names()
),
requires_write_access=True,
)
if "duration" not in fields_resolver.get_names():
if new_status == Task.started:
fields_resolver.add_fields(min__duration=max(0, task.duration or 0))
elif new_status in (
TaskStatus.completed,
TaskStatus.failed,
TaskStatus.stopped,
):
fields_resolver.add_fields(
duration=int((task.started - datetime.utcnow()).total_seconds())
)
status_reason = request.status_reason
status_message = request.status_message
force = request.force
@@ -105,30 +132,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
def escape_execution_parameters(call: APICall):
default_prefix = "execution.parameters."
def escape_paths(paths, prefix=default_prefix):
escaped_paths = []
for path in paths:
if path == prefix:
raise errors.bad_request.ValidationError(
"invalid task field", path=path
)
escaped_paths.append(
prefix + ParameterKeyEscaper.escape(path[len(prefix) :])
if path.startswith(prefix)
else path
)
return escaped_paths
projection = Task.get_projection(call.data)
if projection:
Task.set_projection(call.data, escape_paths(projection))
ordering = Task.get_ordering(call.data)
if ordering:
ordering = Task.set_ordering(call.data, escape_paths(ordering, default_prefix))
Task.set_ordering(call.data, escape_paths(ordering, "-" + default_prefix))
Task.set_ordering(call.data, escape_paths(ordering))
@endpoint("tasks.get_all_ex", required_fields=[])
@@ -260,12 +270,15 @@ create_fields = {
"input": None,
"output_dest": None,
"execution": None,
"hyperparams": None,
"configuration": None,
"script": None,
}
def prepare_for_save(call: APICall, fields: dict):
def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
conform_tag_fields(call, fields, validate=True)
params_prepare_for_save(fields, previous_task=previous_task)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
@@ -278,12 +291,6 @@ def prepare_for_save(call: APICall, fields: dict):
except KeyError:
pass
parameters = safe_get(fields, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {ParameterKeyEscaper.escape(k): v for k, v in parameters.items()}
dpath.set(fields, "execution/parameters", parameters)
return fields
@@ -293,18 +300,15 @@ def unprepare_from_saved(call: APICall, tasks_data: Union[Sequence[dict], dict])
conform_output_tags(call, tasks_data)
for task_data in tasks_data:
parameters = safe_get(task_data, "execution/parameters")
if parameters is not None:
# Escape keys to make them mongo-safe
parameters = {
ParameterKeyEscaper.unescape(k): v for k, v in parameters.items()
}
dpath.set(task_data, "execution/parameters", parameters)
for data in tasks_data:
params_unprepare_from_saved(
fields=data,
copy_to_legacy=call.requested_endpoint_version < PartialVersion("2.9"),
)
def prepare_create_fields(
call: APICall, valid_fields=None, output=None, previous_task: Task = None
call: APICall, valid_fields=None, output=None, previous_task: Task = None,
):
valid_fields = valid_fields if valid_fields is not None else create_fields
t_fields = task_fields
@@ -322,7 +326,7 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
return prepare_for_save(call, fields)
return prepare_for_save(call, fields, previous_task=previous_task)
def _validate_and_get_task_from_call(call: APICall, **kwargs) -> Tuple[Task, dict]:
@@ -354,9 +358,7 @@ def _update_cached_tags(company: str, project: str, fields: dict):
def _reset_cached_tags(company: str, projects: Sequence[str]):
org_bll.reset_tags(
company, Tags.Task, projects=projects
)
org_bll.reset_tags(company, Tags.Task, projects=projects)
@endpoint(
@@ -377,6 +379,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
"tasks.clone", request_data_model=CloneRequest, response_data_model=IdResponse
)
def clone_task(call: APICall, company_id, request: CloneRequest):
validate_tags(request.new_task_tags, request.new_task_system_tags)
task = task_bll.clone_task(
company_id=company_id,
user_id=call.identity.user,
@@ -387,6 +390,8 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
project=request.new_task_project,
tags=request.new_task_tags,
system_tags=request.new_task_system_tags,
hyperparams=request.new_hyperparams,
configuration=request.new_configuration,
execution_overrides=request.execution_overrides,
validate_references=request.validate_references,
)
@@ -572,9 +577,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
if updated:
new_project = fixed_fields.get("project", task.project)
if new_project != task.project:
_reset_cached_tags(
company_id, projects=[new_project, task.project]
)
_reset_cached_tags(company_id, projects=[new_project, task.project])
else:
_update_cached_tags(
company_id, project=task.project, fields=fixed_fields
@@ -586,6 +589,100 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(updated=0)
@endpoint(
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
call.result.data = {
"params": [{"task": task, **data} for task, data in tasks_params.items()]
}
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
)
}
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_params(
company_id, task_id=request.task, hyperparams=request.hyperparams
)
}
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
call.result.data = {
"configurations": [
{"task": task, **data} for task, data in tasks_params.items()
]
}
@endpoint(
"tasks.get_configuration_names", request_data_model=GetConfigurationNamesRequest,
)
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks
)
call.result.data = {
"configurations": [
{"task": task, **data} for task, data in tasks_params.items()
]
}
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
)
}
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
def delete_configuration(
call: APICall, company_id, request: DeleteConfigurationRequest
):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id, task_id=request.task, configuration=request.configuration
)
}
@endpoint(
"tasks.enqueue",
request_data_model=EnqueueRequest,
@@ -1004,3 +1101,19 @@ def add_or_update_artifacts(
task_id=request.task, company_id=company_id, artifacts=request.artifacts
)
call.result.data_model = AddOrUpdateArtifactsResponse(added=added, updated=updated)
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)

View File

@@ -46,10 +46,10 @@ def get_all(call: APICall, company_id: str, request: GetAllRequest):
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
def register(call: APICall, company_id, req_model: RegisterRequest):
worker = req_model.worker
timeout = req_model.timeout
queues = req_model.queues
def register(call: APICall, company_id, request: RegisterRequest):
worker = request.worker
timeout = request.timeout
queues = request.queues
if not timeout or timeout <= 0:
raise bad_request.WorkerRegistrationFailed(
@@ -63,6 +63,7 @@ def register(call: APICall, company_id, req_model: RegisterRequest):
ip=call.real_ip,
queues=queues,
timeout=timeout,
tags=request.tags,
)
@@ -78,6 +79,7 @@ def status_report(call: APICall, company_id, request: StatusReportRequest):
user_id=call.identity.user,
ip=call.real_ip,
report=request,
tags=request.tags,
)

28
server/sync.py Normal file
View File

@@ -0,0 +1,28 @@
import time
from contextlib import contextmanager
from time import sleep
from redis_manager import redman
_redis = redman.connection("apiserver")
@contextmanager
def distributed_lock(name: str, timeout: int, max_wait: int = 0):
"""
Context manager that acquires a distributed lock on enter
and releases it on exit. The has a ttl equal to timeout seconds
If the lock can not be acquired for wait seconds (defaults to timeout * 2)
then the exception is thrown
"""
lock_name = f"dist_lock_{name}"
start = time.time()
max_wait = max_wait or timeout * 2
while not _redis.set(lock_name, value="", ex=timeout, nx=True):
sleep(1)
if time.time() - start > max_wait:
raise Exception(f"Could not acquire {name} lock for {max_wait} seconds")
try:
yield
finally:
_redis.delete(lock_name)

View File

@@ -1,3 +1,4 @@
from apierrors.errors.bad_request import InvalidModelId
from tests.automated import TestService
MODEL_CANNOT_BE_UPDATED_CODES = (400, 203)
@@ -7,6 +8,9 @@ IN_PROGRESS = "in_progress"
class TestModelsService(TestService):
def setUp(self, version="2.9"):
super().setUp(version=version)
def test_publish_output_model_running_task(self):
task_id, model_id = self._create_task_and_model()
self._assert_model_ready(model_id, False)
@@ -164,6 +168,58 @@ class TestModelsService(TestService):
1000
)
def test_get_frameworks(self):
framework_1 = "Test framework 1"
framework_2 = "Test framework 2"
# create model on top level
self._create_model(name="framework model test", framework=framework_1)
# create model under a project as make it inherit its framework from the task
project = self.create_temp("projects", name="Frameworks test", description="")
task = self._create_task(project=project, execution=dict(framework=framework_2))
self.api.models.update_for_task(
task=task,
name="framework output model test",
uri="file:///b",
iteration=999,
)
# get all frameworks
res = self.api.models.get_frameworks()
self.assertTrue({framework_1, framework_2}.issubset(set(res.frameworks)))
# get frameworks under the project
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual([framework_2], res.frameworks)
# empty result
self.api.tasks.delete(task=task, force=True)
res = self.api.models.get_frameworks(projects=[project])
self.assertEqual([], res.frameworks)
def test_make_public(self):
m1 = self._create_model(name="public model test")
# model with company_origin not set to the current company cannot be converted to private
with self.api.raises(InvalidModelId):
self.api.models.make_private(ids=[m1])
# public model can be retrieved but not updated
res = self.api.models.make_public(ids=[m1])
self.assertEqual(res.updated, 1)
res = self.api.models.get_all(id=[m1])
self.assertEqual([m.id for m in res.models], [m1])
with self.api.raises(InvalidModelId):
self.api.models.update(model=m1, name="public model test change 1")
# task made private again and can be both retrieved and updated
res = self.api.models.make_private(ids=[m1])
self.assertEqual(res.updated, 1)
res = self.api.models.get_all(id=[m1])
self.assertEqual([m.id for m in res.models], [m1])
self.api.models.update(model=m1, name="public model test change 2")
def _assert_task_status(self, task_id, status):
task = self.api.tasks.get_by_id(task=task_id).task
assert task.status == status
@@ -178,24 +234,23 @@ class TestModelsService(TestService):
def _assert_update_task_failure(self):
return self.api.raises(TASK_CANNOT_BE_UPDATED_CODES)
def _create_model(self):
model_id = self.create_temp(
def _create_model(self, **kwargs):
return self.create_temp(
service="models",
name='test',
uri='file:///a',
labels={}
delete_params=dict(can_fail=True, force=True),
name=kwargs.pop("name", 'test'),
uri=kwargs.pop("name", 'file:///a'),
labels=kwargs.pop("labels", {}),
**kwargs,
)
self.defer(self.api.models.delete, can_fail=True, model=model_id, force=True)
return model_id
def _create_task(self):
def _create_task(self, **kwargs):
task_id = self.create_temp(
service="tasks",
type='testing',
name='server-test',
input=dict(view={}),
type=kwargs.pop("type", 'testing'),
name=kwargs.pop("name", 'server-test'),
input=kwargs.pop("input", dict(view={})),
**kwargs,
)
return task_id

View File

@@ -6,14 +6,86 @@ log = config.logger(__file__)
class TestProjection(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.6")
def _temp_task(self, **kwargs):
self.update_missing(
kwargs,
type="testing",
name="test projection",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs)
def _temp_project(self):
return self.create_temp(
"projects",
name="Test projection",
description="test",
delete_params=dict(force=True),
)
def test_overlapping_fields(self):
message = "task started"
task_id = self.create_temp(
"tasks", name="test", type="testing", input=dict(view=dict())
)
task_id = self._temp_task()
self.api.tasks.started(task=task_id, status_message=message)
task = self.api.tasks.get_all_ex(
id=[task_id], only_fields=["status", "status_message"]
).tasks[0]
assert task["status"] == TaskStatus.in_progress
assert task["status_message"] == message
def test_task_projection(self):
project = self._temp_project()
task1 = self._temp_task(project=project)
task2 = self._temp_task(project=project)
self.api.tasks.started(task=task2, status_message="Started")
res = self.api.tasks.get_all_ex(
project=[project],
only_fields=[
"system_tags",
"company",
"type",
"name",
"tags",
"status",
"project.name",
"user.name",
"started",
"last_update",
"last_iteration",
"comment",
],
order_by=["-started"],
page=0,
page_size=15,
system_tags=["-archived"],
type=[
"__$not",
"annotation_manual",
"__$not",
"annotation",
"__$not",
"dataset_import",
],
).tasks
self.assertEqual([task2, task1], [t.id for t in res])
self.assertEqual("Test projection", res[0].project.name)
def test_exclude_projection(self):
task_id = self._temp_task()
res = self.api.tasks.get_all_ex(
id=[task_id]
).tasks[0]
self.assertEqual("test projection", res.name)
task = self.api.tasks.get_all_ex(
id=[task_id],
only_fields=["-name"]
).tasks[0]
self.assertFalse("name" in task)
self.assertEqual("testing", res.type)

View File

@@ -0,0 +1,34 @@
from apierrors.errors.bad_request import InvalidProjectId
from apierrors.errors.forbidden import NoWritePermission
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestProjectsEdit(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.9")
def test_make_public(self):
p1 = self.create_temp("projects", name="Test public", description="test")
# project with company_origin not set to the current company cannot be converted to private
with self.api.raises(InvalidProjectId):
self.api.projects.make_private(ids=[p1])
# public project can be retrieved but not updated
res = self.api.projects.make_public(ids=[p1])
self.assertEqual(res.updated, 1)
res = self.api.projects.get_all(id=[p1])
self.assertEqual([p.id for p in res.projects], [p1])
with self.api.raises(NoWritePermission):
self.api.projects.update(project=p1, name="Test public change 1")
# task made private again and can be both retrieved and updated
res = self.api.projects.make_private(ids=[p1])
self.assertEqual(res.updated, 1)
res = self.api.projects.get_all(id=[p1])
self.assertEqual([p.id for p in res.projects], [p1])
self.api.projects.update(project=p1, name="Test public change 2")

View File

@@ -6,7 +6,7 @@ import operator
import unittest
from functools import partial
from statistics import mean
from typing import Sequence
from typing import Sequence, Optional, Tuple
from boltons.iterutils import first
@@ -16,7 +16,7 @@ from tests.automated import TestService
class TestTaskEvents(TestService):
def setUp(self, version="2.7"):
def setUp(self, version="2.9"):
super().setUp(version=version)
def _temp_task(self, name="test task events"):
@@ -213,7 +213,6 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res.events), 1)
def test_task_logs(self):
# this test will fail until the new api is uncommented
task = self._temp_task()
timestamp = es_factory.get_timestamp_millis()
events = [
@@ -229,32 +228,29 @@ class TestTaskEvents(TestService):
self.send_batch(events)
# test forward navigation
scroll_id = None
for page in range(3):
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, expected_page=page
ftime, ltime = None, None
for page in range(2):
ftime, ltime = self._assert_log_events(
task=task, timestamp=ltime, expected_page=page
)
# test backwards navigation
scroll_id = self._assert_log_events(
task=task, scroll_id=scroll_id, navigate_earlier=False
)
self._assert_log_events(task=task, timestamp=ftime, navigate_earlier=False)
# refresh
self._assert_log_events(task=task, scroll_id=scroll_id)
self._assert_log_events(task=task, scroll_id=scroll_id, refresh=True)
# test order
self._assert_log_events(task=task, order="asc")
def _assert_log_events(
self,
task,
scroll_id,
batch_size: int = 5,
timestamp: Optional[int] = None,
expected_total: int = 10,
expected_page: int = 0,
**extra_params,
):
) -> Tuple[int, int]:
res = self.api.events.get_task_log(
task=task, batch_size=batch_size, scroll_id=scroll_id, **extra_params,
task=task, batch_size=batch_size, from_timestamp=timestamp, **extra_params,
)
self.assertEqual(res.total, expected_total)
expected_events = max(
@@ -266,7 +262,10 @@ class TestTaskEvents(TestService):
self.assertEqual(len(res.events), unique_events)
if res.events:
cmp_operator = operator.ge
if not extra_params.get("navigate_earlier", True):
if (
not extra_params.get("navigate_earlier", True)
or extra_params.get("order", None) == "asc"
):
cmp_operator = operator.le
self.assertTrue(
all(
@@ -274,7 +273,12 @@ class TestTaskEvents(TestService):
for first, second in zip(res.events, res.events[1:])
)
)
return res.scroll_id
return (
(res.events[0].timestamp, res.events[-1].timestamp)
if res.events
else (None, None)
)
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"

View File

@@ -0,0 +1,281 @@
from operator import itemgetter
from typing import Sequence, List, Tuple
from boltons import iterutils
from apierrors.errors.bad_request import InvalidTaskStatus
from tests.api_client import APIClient
from tests.automated import TestService
class TestTasksHyperparams(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.9")
def new_task(self, **kwargs) -> Tuple[str, str]:
if "project" not in kwargs:
kwargs["project"] = self.create_temp(
"projects",
name="Test hyperparams",
description="test",
delete_params=dict(force=True),
)
self.update_missing(
kwargs,
type="testing",
name="test hyperparams",
input=dict(view=dict()),
delete_params=dict(force=True),
)
return self.create_temp("tasks", **kwargs), kwargs["project"]
def test_hyperparams(self):
legacy_params = {"legacy$1": "val1", "legacy2/name": "val2"}
new_params = [
dict(section="1/1", name="param1/1", type="type1", value="10"),
dict(section="1/1", name="param2", type="type1", value="20"),
dict(section="2", name="param2", type="type2", value="xxx"),
]
new_params_dict = self._param_dict_from_list(new_params)
task, project = self.new_task(
execution={"parameters": legacy_params}, hyperparams=new_params_dict,
)
# both params and hyper params are set correctly
old_params = self._new_params_from_legacy(legacy_params)
params_dict = new_params_dict.copy()
params_dict["Args"] = {p["name"]: p for p in old_params}
res = self.api.tasks.get_by_id(task=task).task
self.assertEqual(params_dict, res.hyperparams)
# returned as one list with params in the _legacy section
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(new_params + old_params, res.hyperparams)
# replace section
replace_params = [
dict(section="1/1", name="param1", type="type1", value="40"),
dict(section="2", name="param5", type="type1", value="11"),
]
self.api.tasks.edit_hyper_params(
task=task, hyperparams=replace_params, replace_hyperparams="section"
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(replace_params + old_params, res.hyperparams)
# replace all
replace_params = [
dict(section="1/1", name="param1/1", type="type1", value="30"),
dict(section="Args", name="legacy$1", value="123", type="legacy"),
]
self.api.tasks.edit_hyper_params(
task=task, hyperparams=replace_params, replace_hyperparams="all"
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(replace_params, res.hyperparams)
# add and update
self.api.tasks.edit_hyper_params(task=task, hyperparams=new_params + old_params)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(new_params + old_params, res.hyperparams)
# delete
new_to_delete = self._get_param_keys(new_params[1:])
old_to_delete = self._get_param_keys(old_params[:1])
self.api.tasks.delete_hyper_params(
task=task, hyperparams=new_to_delete + old_to_delete
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(new_params[:1] + old_params[1:], res.hyperparams)
# delete section
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[{"section": "1/1"}, {"section": "2"}]
)
res = self.api.tasks.get_hyper_params(tasks=[task]).params[0]
self.assertEqual(old_params[1:], res.hyperparams)
# project hyperparams
res = self.api.projects.get_hyper_parameters(project=project)
self.assertEqual(
[
{k: v for k, v in p.items() if k in ("section", "name")}
for p in old_params[1:]
],
res.parameters,
)
# clone task
new_task = self.api.tasks.clone(task=task, new_hyperparams=new_params_dict).id
try:
res = self.api.tasks.get_hyper_params(tasks=[new_task]).params[0]
self.assertEqual(new_params, res.hyperparams)
finally:
self.api.tasks.delete(task=new_task, force=True)
# editing of started task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="test", name="x", value="123")]
)
self.api.tasks.edit_hyper_params(
task=task, hyperparams=[dict(section="properties", name="x", value="123")]
)
self.api.tasks.delete_hyper_params(
task=task, hyperparams=[dict(section="Properties")]
)
@staticmethod
def _get_param_keys(params: Sequence[dict]) -> List[dict]:
return [{k: p[k] for k in ("name", "section")} for p in params]
@staticmethod
def _new_params_from_legacy(legacy: dict) -> List[dict]:
return [
dict(section="Args", name=k, value=str(v), type="legacy")
if not k.startswith("TF_DEFINE/")
else dict(section="TF_DEFINE", name=k[len("TF_DEFINE/"):], value=str(v), type="legacy")
for k, v in legacy.items()
]
@staticmethod
def _param_dict_from_list(params: Sequence[dict]) -> dict:
return {
k: {v["name"]: v for v in values}
for k, values in iterutils.bucketize(
params, key=itemgetter("section")
).items()
}
@staticmethod
def _config_dict_from_list(config: Sequence[dict]) -> dict:
return {c["name"]: c for c in config}
def test_configuration(self):
legacy_config = {"design": "hello"}
new_config = [
dict(name="param$1", type="type1", value="10"),
dict(name="param/2", type="type1", value="20"),
]
new_config_dict = self._config_dict_from_list(new_config)
task, _ = self.new_task(
execution={"model_desc": legacy_config}, configuration=new_config_dict
)
# both params and hyper params are set correctly
old_config = self._new_config_from_legacy(legacy_config)
config_dict = new_config_dict.copy()
config_dict["design"] = old_config[0]
res = self.api.tasks.get_by_id(task=task).task
self.assertEqual(config_dict, res.configuration)
# returned as one list
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config, res.configuration)
# names
res = self.api.tasks.get_configuration_names(tasks=[task]).configurations[0]
self.assertEqual(task, res.task)
self.assertEqual(["design", "param$1", "param/2"], res.names)
# returned as one list with names filtering
res = self.api.tasks.get_configurations(
tasks=[task], names=[new_config[1]["name"]]
).configurations[0]
self.assertEqual([new_config[1]], res.configuration)
# replace all
replace_configs = [
dict(name="design", value="123", type="legacy"),
dict(name="param/2", type="type1", value="30"),
]
self.api.tasks.edit_configuration(
task=task, configuration=replace_configs, replace_configuration=True
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(replace_configs, res.configuration)
# add and update
self.api.tasks.edit_configuration(
task=task, configuration=new_config + old_config
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config, res.configuration)
# delete
new_to_delete = self._get_config_keys(new_config[1:])
res = self.api.tasks.delete_configuration(
task=task, configuration=new_to_delete
)
res = self.api.tasks.get_configurations(tasks=[task]).configurations[0]
self.assertEqual(old_config + new_config[:1], res.configuration)
# clone task
new_task = self.api.tasks.clone(task=task, new_configuration=new_config_dict).id
try:
res = self.api.tasks.get_configurations(tasks=[new_task]).configurations[0]
self.assertEqual(new_config, res.configuration)
finally:
self.api.tasks.delete(task=new_task, force=True)
@staticmethod
def _get_config_keys(config: Sequence[dict]) -> List[dict]:
return [c["name"] for c in config]
@staticmethod
def _new_config_from_legacy(legacy: dict) -> List[dict]:
return [dict(name=k, value=str(v), type="legacy") for k, v in legacy.items()]
def test_hyperparams_projection(self):
legacy_param = {"legacy.1": "val1"}
new_params1 = [
dict(section="sec.tion1", name="param1", type="type1", value="10")
]
new_params_dict1 = self._param_dict_from_list(new_params1)
task1, project = self.new_task(
execution={"parameters": legacy_param}, hyperparams=new_params_dict1,
)
new_params2 = [
dict(section="sec.tion1", name="param1", type="type1", value="20")
]
new_params_dict2 = self._param_dict_from_list(new_params2)
task2, _ = self.new_task(hyperparams=new_params_dict2, project=project)
old_params = self._new_params_from_legacy(legacy_param)
params_dict = new_params_dict1.copy()
params_dict["Args"] = {p["name"]: p for p in old_params}
res = self.api.tasks.get_all_ex(id=[task1], only_fields=["hyperparams"]).tasks[
0
]
self.assertEqual(params_dict, res.hyperparams)
res = self.api.tasks.get_all_ex(
project=[project],
only_fields=["hyperparams.sec%2Etion1"],
order_by=["-hyperparams.sec%2Etion1"],
).tasks[0]
self.assertEqual(new_params_dict2, res.hyperparams)
def test_old_api(self):
legacy_params = {"legacy.1": "val1", "TF_DEFINE/param2": "val2"}
legacy_config = {"design": "hello"}
task_id, _ = self.new_task(
execution={"parameters": legacy_params, "model_desc": legacy_config}
)
config = self._config_dict_from_list(self._new_config_from_legacy(legacy_config))
params = self._param_dict_from_list(self._new_params_from_legacy(legacy_params))
old_api = APIClient(base_url="http://localhost:8008/v2.8")
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(legacy_params, task.execution.parameters)
self.assertEqual(legacy_config, task.execution.model_desc)
self.assertEqual(params, task.hyperparams)
self.assertEqual(config, task.configuration)
modified_params = {"legacy.2": "val2"}
modified_config = {"design": "by"}
old_api.tasks.edit(task=task_id, execution=dict(parameters=modified_params, model_desc=modified_config))
task = old_api.tasks.get_all_ex(id=[task_id]).tasks[0]
self.assertEqual(modified_params, task.execution.parameters)
self.assertEqual(modified_config, task.execution.model_desc)

View File

@@ -5,7 +5,6 @@ log = config.logger(__file__)
class TestTasksDiff(TestService):
def setUp(self, version="2.0"):
super(TestTasksDiff, self).setUp(version=version)
@@ -17,7 +16,14 @@ class TestTasksDiff(TestService):
def _compare_script(self, task_id, script):
task = self.api.tasks.get_by_id(task=task_id).task
if not script:
self.assertFalse(task.get("script", None))
self.assertTrue(
task.get(
"script",
dict(
binary="python", repository="", entry_point="", requirements={}
),
)
)
else:
for key, value in script.items():
self.assertEqual(task.script[key], value)

View File

@@ -1,4 +1,5 @@
from apierrors.errors.bad_request import InvalidModelId, ValidationError
from apierrors.errors.bad_request import InvalidModelId, ValidationError, InvalidTaskId
from apierrors.errors.forbidden import NoWritePermission
from config import config
from tests.automated import TestService
@@ -8,7 +9,7 @@ log = config.logger(__file__)
class TestTasksEdit(TestService):
def setUp(self, **kwargs):
super().setUp(version=2.5)
super().setUp(version="2.9")
def new_task(self, **kwargs):
self.update_missing(
@@ -113,7 +114,7 @@ class TestTasksEdit(TestService):
self.assertEqual(new_task.status, "created")
self.assertEqual(new_task.script, script)
self.assertEqual(new_task.parent, task)
self.assertEqual(new_task.execution.parameters, execution["parameters"])
# self.assertEqual(new_task.execution.parameters, execution["parameters"])
self.assertEqual(new_task.execution.framework, execution_overrides["framework"])
self.assertEqual(new_task.system_tags, [])
@@ -145,3 +146,28 @@ class TestTasksEdit(TestService):
self.api.tasks.delete, task=new_task, move_to_trash=False, force=True
)
return new_task
def test_make_public(self):
task = self.new_task()
# task is created as private and can be updated
self.api.tasks.started(task=task)
# task with company_origin not set to the current company cannot be converted to private
with self.api.raises(InvalidTaskId):
self.api.tasks.make_private(ids=[task])
# public task can be retrieved but not updated
res = self.api.tasks.make_public(ids=[task])
self.assertEqual(res.updated, 1)
res = self.api.tasks.get_all_ex(id=[task])
self.assertEqual([t.id for t in res.tasks], [task])
with self.api.raises(NoWritePermission):
self.api.tasks.stopped(task=task)
# task made private again and can be both retrieved and updated
res = self.api.tasks.make_private(ids=[task])
self.assertEqual(res.updated, 1)
res = self.api.tasks.get_all_ex(id=[task])
self.assertEqual([t.id for t in res.tasks], [task])
self.api.tasks.stopped(task=task)

View File

@@ -1,7 +1,6 @@
from typing import Sequence
from uuid import uuid4
from apierrors import errors
from config import config
from tests.automated import TestService

View File

@@ -1,12 +1,2 @@
import dpath
def strict_map(*args, **kwargs):
return list(map(*args, **kwargs))
def safe_get(obj, glob, default=None):
try:
return dpath.get(obj, glob)
except KeyError:
return default

View File

@@ -0,0 +1,46 @@
from boltons.dictutils import OneToOne
from apierrors import errors
class ParameterKeyEscaper:
"""
Makes the fields name ready for use with MongoDB and Mongoengine
. and $ are replaced with their codes
__ and leading _ are escaped
Since % is used as an escape character the % is also escaped
"""
_mapping = OneToOne({".": "%2E", "$": "%24", "__": "%_%_"})
@classmethod
def escape(cls, value):
""" Quote a parameter key """
if value is None:
raise errors.bad_request.ValidationError("Key cannot be empty")
value = value.strip().replace("%", "%%")
for c, r in cls._mapping.items():
value = value.replace(c, r)
if value.startswith("_"):
value = "%_" + value[1:]
return value
@classmethod
def _unescape(cls, value):
for c, r in cls._mapping.inv.items():
value = value.replace(c, r)
return value
@classmethod
def unescape(cls, value):
""" Unquote a quoted parameter key """
value = "%".join(map(cls._unescape, value.split("%%")))
if value.startswith("%_"):
value = "_" + value[2:]
return value

View File

@@ -1 +1 @@
__version__ = "0.15.1"
__version__ = "0.16.1"