Compare commits

39 Commits

Author SHA1 Message Date
allegroai
96ffc89c64 Update AMI's to new version 2019-09-24 21:46:07 +03:00
allegroai
4f2564d33a Add Artifacts support, changed tags to system_tags and added user tags
Add hyper parameter sorting
Add min/max value for all time series metrics
2019-09-24 21:35:41 +03:00
allegroai
70ae090cc0 Documentation 2019-09-19 23:04:45 +03:00
Allegro AI
4f01778961 Merge pull request #26 from jayanthkoushik/patch-1
Bind conf for apiserver in docker-compose.yml
2019-09-12 21:53:17 +03:00
Jayanth Koushik
596bdd06ec Bind conf for apiserver in docker-compose.yml 2019-09-12 14:20:30 -04:00
allegroai
6c56d0fc33 Documentation 2019-09-02 01:00:45 +03:00
allegroai
5f0213d2de Update AWS images 2019-08-22 13:48:07 +03:00
Allegro AI
15eb00a931 Update LICENSE 2019-08-21 00:19:59 +03:00
allegroai
becc4fb6a2 Documentation 2019-08-15 23:56:19 +03:00
allegroai
32476a216a Documentation 2019-08-14 04:01:41 +03:00
allegroai
a9ba1580dc Documentation 2019-08-09 03:44:17 +03:00
allegroai
cfcd0b22a0 Documentation 2019-08-09 03:40:28 +03:00
allegroai
780355250c Documentation 2019-08-09 03:33:45 +03:00
allegroai
fd65ad38bc Documentation 2019-08-09 03:24:47 +03:00
allegroai
e29973a0b2 Typo 2019-08-09 00:30:16 +03:00
allegroai
c259d0883e Documentation 2019-08-08 12:02:30 +03:00
allegroai
9eab017a31 Documentation 2019-08-08 12:01:11 +03:00
allegroai
68c7f307a2 Documentation 2019-08-08 11:58:19 +03:00
allegroai
0aa5694b58 Documentation 2019-08-08 02:22:36 +03:00
allegroai
639d72c5d6 Documentation 2019-08-08 02:08:39 +03:00
allegroai
70708ecdcc Documentation 2019-08-08 02:01:59 +03:00
allegroai
dacdd5e965 Documentation 2019-08-08 02:00:15 +03:00
allegroai
c199976f70 Improved docker-compose installation process 2019-08-08 01:51:40 +03:00
allegroai
c3e2bc5ad7 Add FAQ 2019-08-01 19:36:58 +03:00
allegroai
f0c900c174 Documentation 2019-07-29 23:47:52 +03:00
allegroai
1bdbc44720 Fix, always restart trains-server container 2019-07-25 02:29:39 +03:00
allegroai
c6e765bd07 renamed 2019-07-25 02:26:07 +03:00
allegroai
c037ddd044 Add unified docker compose (all three trains-server services running on the same docker). Used for easier installation, such as on OS X. 2019-07-25 02:15:34 +03:00
allegroai
ffe4764f20 Add automatically updating AMIs 2019-07-22 11:48:42 +03:00
allegroai
1681fd6bf4 Fix AMI image ids 2019-07-21 19:18:59 +03:00
allegroai
e55ce5536a Add fixed users mode documentation 2019-07-17 18:46:12 +03:00
allegroai
b714952ab1 Add v0.10.1 pre-built AMI 2019-07-17 18:18:56 +03:00
allegroai
07fd8b9f2f Changed, web serving through NGINX 2019-07-17 18:18:33 +03:00
allegroai
d24f633a8e Add easier sub-domains configuration 2019-07-17 18:17:27 +03:00
allegroai
bed714890d Add File server CORS support 2019-07-17 18:16:43 +03:00
allegroai
02671910b2 Add support for fix user list credentials 2019-07-17 18:16:27 +03:00
allegroai
1a00f29415 Add support for fix user list credentials 2019-07-17 18:15:58 +03:00
allegroai
b7614622fc Changed webserver is deprecated, Web UI served through NGINX 2019-07-17 18:15:19 +03:00
allegroai
bc2cbe9a91 Documentation 2019-07-12 01:05:07 +03:00
95 changed files with 4173 additions and 2989 deletions

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2018 MongoDB, Inc.
Copyright © 2019 allegro.ai, Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

385
README.md
View File

@@ -22,170 +22,143 @@ In order to host your own server, you will need to install **trains-server** and
* Querying experiments history, logs and results
* Locally-hosted file server for storing images and models making them easily accessible using the Web-App
You can quickly setup your **trains-server** using a pre-built Docker image (see [Installation](#installation)).
You can quickly setup your **trains-server** using:
- [Docker Installation](#installation)
- Pre-built Amazon [AWS image](#aws)
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#trains-server-for-kubernetes-clusters-using-helm)
or manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#trains-server-for-kubernetes-clusters)
When new releases are available, you can upgrade your pre-built Docker image (see [Upgrade](#upgrade)).
## System diagram
## System design
![Alt Text](https://github.com/allegroai/trains/blob/master/docs/system_diagram.png?raw=true)
**trains-server** has two supported configurations:
- Single IP (domain) with the following open ports
- Web application on port 8080
- API service on port 8008
- File storage service on port 8081
- Sub-Domain configuration with default http/s ports (80 or 443)
- Web application on sub-domain: app.\*.\*
- API service on sub-domain: api.\*.\*
- File storage service on sub-domain: files.\*.\*
## Install / Upgrade - AWS <a name="aws"></a>
## Install / Upgrade - AWS
Use one of our pre-installed Amazon Machine Images for easy deployment in AWS.
Use our pre-installed Amazon Machine Image for easy deployment in AWS.
For details and instructions, see [TRAINS-server: AWS pre-installed images](docs/install_aws.md).
Details and instructions can be found [here](docs/install_aws.md).
## Docker Installation - Linux, Mac OS X <a name="installation"></a>
## Installation - Docker
Use our pre-built Docker image for easy deployment in Linux and Mac OS X.
For Windows, we recommend installing our pre-built Docker image on a Linux virtual machine.
Latest docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
This section contains the instructions to setup and launch a pre-built Docker image for the **trains-server**.
This is the quickest way to get started with your own server.
Alternatively, you can build the entire trains-server architecture using the code available in our repositories.
**Please Note**:
* This Docker image was tested with Linux, only. For Windows users, we recommend running the server
on a Linux virtual machine.
* All command-line instructions below assume you're using `bash`.
### Prerequisites
Make sure you are logged in as a user with sudo privileges.
### Setup
#### Step 1: Install Docker CE
In order to run the pre-packaged **trains-server**, install Docker.
* See [Supported platforms](https://docs.docker.com/install//#support) in the Docker documentation for instructions
* For example, to install in [Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/) / Mint (x86_64/amd64):
1. Setup Docker ([docker-compose Ubuntu](docs/faq.md#ubuntu), [docker-compose OS X](docs/faq.md#mac-osx), [Setup Docker Service Manually](docs/docker_setup.md#setup-docker))
Make sure port 8080/8081/8008 are available for the `trains-server` services
Increase vm.max_map_count for `ElasticSearch` docker
```bash
sudo apt-get install -y apt-transport-https ca-certificates curl software-properties-common
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
. /etc/os-release
sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $UBUNTU_CODENAME stable"
sudo apt-get update
sudo apt-get install -y docker-ce
```
#### Step 2: Setup the Docker daemon
To run the ElasticSearch Docker container, setup the Docker daemon by modifying the default
values required by Elastic in your Docker configuration file (see [Notes for production use and defaults](https://www.elastic.co/guide/en/elasticsearch/reference/master/docker.html#_notes_for_production_use_and_defaults)). We provide instructions for the most common Docker configuration files.
Edit or create the Docker configuration file:
* If your system contains a `/etc/sysconfig/docker` Docker configuration file, edit it.
Add the options in quotes to the available arguments in the `OPTIONS` section:
```bash
OPTIONS="--default-ulimit nofile=1024:65536 --default-ulimit memlock=-1:-1"
```
* Otherwise, edit `/etc/docker/daemon.json` (if it exists) or create it (if it does not exist).
Add or modify the `defaults-ulimits` section as shown below. Be sure the `defaults-ulimits` section contains the `nofile` and `memlock` sub-sections and values shown.
**Note**: Your configuration file may contain other sections. If so, confirm that the sections are separated by commas (valid JSON format). For more information about Docker configuration files, see [Daemon configuration file](https://docs.docker.com/engine/reference/commandline/dockerd/#daemon-configuration-file) in the Docker documentation.
The **trains-server** required defaults values are:
```json
{
"default-ulimits": {
"nofile": {
"name": "nofile",
"hard": 65536,
"soft": 1024
},
"memlock":
{
"name": "memlock",
"soft": -1,
"hard": -1
}
}
}
```
#### Step 3: Restart the Docker daemon
After modifying the configuration file, restart the Docker daemon:
```bash
sudo service docker stop
sudo service docker start
```
#### Step 4: Set the Maximum Number of Memory Map Areas
The maximum number of memory map areas a process can use is defined
using the `vm.max_map_count` kernel setting.
Elastic requires that `vm.max_map_count` is at least 262144 (see [Production mode](https://www.elastic.co/guide/en/elasticsearch/reference/master/docker.html#docker-cli-run-prod-mode)).
* For CentOS 7, Ubuntu 16.04, Mint 18.3, Ubuntu 18.04 and Mint 19 users, we tested the following commands to set
`vm.max_map_count`:
```bash
sudo echo "vm.max_map_count=262144" > /tmp/99-trains.conf
echo "vm.max_map_count=262144" > /tmp/99-trains.conf
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
sudo service docker restart
```
1. Create local directories for the databases and storage.
```bash
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/data/fileserver
```
Linux
```bash
$ sudo chown -R 1000:1000 /opt/trains
```
Mac OS X
```bash
$ sudo chown -R $(whoami):staff /opt/trains
```
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
```bash
$ git clone https://github.com/allegroai/trains-server.git
$ cd trains-server
```
1. Launch the Docker containers <a name="launch-docker"></a>
* For information about setting this parameter on other systems, see the [elastic](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode) documentation.
* Automatically with docker-compose (details: [Linux/Ubuntu](docs/faq.md#ubuntu), [OS X](docs/faq.md#mac-osx))
```bash
$ docker-compose up
```
* Manually, see [Launching Docker Containers Manually](docs/docker_setup.md#launch) for instructions.
1. Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* Web server on port `8080`
* API server on port `8008`
* File server on port `8081`
#### Step 5: Choose a Data Directory
Choose a directory on your system in which all data maintained by the **trains-server** is stored.
Create this directory, and set its owner and group to `uid` 1000. The data stored in this directory will include the database, uploaded files and logs.
For example, if your data directory is `/opt/trains`, then use the following command:
```bash
sudo mkdir -p /opt/trains/data/elastic && sudo chown -R 1000:1000 /opt/trains
```
### Configuration
## Optional Configuration
The **trains-server** default configuration can be easily overridden using external configuration files. By default, the server will look for these files in `/opt/trains/config`.
If the configuration is changed while the server is running, the server should be restarted for changes to take effect.
In order to apply the new configuration, you must restart the server (see [Restarting trains-server](#restart-server)).
<!---
#### Fixed users mode (basic users management)
### Adding Web Login Authentication
In this mode, the server authenticates users based on a pre-configured users list.
By default anyone can login to the **trains-server** Web-App.
You can configure the **trains-server** to allow only a specific set of users to access the system.
Enable this feature by placing an `apiserver.conf` file under `/opt/trains/config`, containing for example:
Enable this feature by placing `apiserver.conf` file under `/opt/trains/config`.
fixed_users {
enabled: true
users: [
{
username: "jane"
password: "123456"
name: "Jane Doe"
},
{
username: "john"
password: "abcdef"
name: "John Doe"
}
]
Sample fixed user configuration file `/opt/trains/config/apiserver.conf`:
auth {
# Fixed users login credetials
# No other user will be able to login
fixed_users {
enabled: true
users: [
{
username: "jane"
password: "12345678"
name: "Jane Doe"
},
{
username: "john"
password: "12345678"
name: "John Doe"
},
]
}
}
-->
#### Non-responsive experiments watchdog
To apply the `apiserver.conf` changes, you must restart the *trains-apiserver* (docker) (see [Restarting trains-server](#restart-server)).
This watchdog monitors experiments that were not updated for a given period of time, and marks them as `stopped`. The watchdog is always active.
### Configuring the Non-Responsive Experiments Watchdog
To change the watchdog's timeouts, place a `services.conf` file under `/opt/trains/config`, containing for example:
The non-responsive experiment watchdog, monitors experiments that were not updated for a given period of time,
and marks them as `aborted`. The watchdog is always active with a default of 7200 seconds (2 hours) of inactivity threshold.
To change the watchdog's timeouts, place a `services.conf` file under `/opt/trains/config`.
Sample watchdog configuration file `/opt/trains/config/services.conf`:
tasks {
non_responsive_tasks_watchdog {
@@ -197,55 +170,43 @@ To change the watchdog's timeouts, place a `services.conf` file under `/opt/trai
}
}
### Launching Docker Containers
To apply the `services.conf` changes, you must restart the *trains-apiserver* (docker) (see [Restarting trains-server](#restart-server)).
**Note**:
* If your data directory is not `/opt/trains`, please find and replace `/opt/trains` in the following commands with your data directory path
### Restarting trains-server <a name="restart-server"></a>
* Make sure ports `8008`, `8080` and `8081` are not in use before starting the docker containers, as the containers will fail to initialize if these ports are already taken. If the following commands shows no output, the ports are available:
```bash
sudo netstat -tplna | egrep "8008|8080|8081"
```
To restart the **trains-server**, you must first stop and remove the containers, and then restart.
To launch the Docker containers, use the following commands:
1. Restarting docker-compose containers.
```bash
sudo docker run -d --restart="always" --name="trains-elastic" -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
```
$ docker-compose down
$ docker-compose up
1. Manually restarting dockers [instructions](docs/docker_setup.md#launch).
```bash
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
```
## Configuring **TRAINS** client
```bash
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
```
Once you have installed the **trains-server**, make sure to configure **TRAINS** [client](https://github.com/allegroai/trains)
to use your locally installed server (and not the demo server).
```bash
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/config:/opt/trains/config allegroai/trains:latest apiserver
```
- Run the `trains-init` command for an interactive setup
```bash
sudo docker run -d --restart="always" --name="trains-webserver" --network="host" -v /opt/trains/logs:/var/log/trains allegroai/trains:latest webserver
```
- Or manually edit `~/trains.conf` file, making sure the `api_server` value is configured correctly, for example:
After the **trains-server** Dockers are up, the following are available:
api {
# API server on port 8008
api_server: "http://localhost:8008"
# web_server on port 8080
web_server: "http://localhost:8080"
# file server on port 8081
files_server: "http://localhost:8081"
}
* API server on port `8008`
* Web server on port `8080`
* File server on port `8081`
* Notice that if you setup **trains-server** in a sub-domain configuration, there is no need to specify a port number,
it will be inferred from the http/s scheme.
### Configuring **trains**
Once you have installed the **trains-server**, make sure to configure **trains** to use your locally installed server (and not the demo server).
If you have already installed **trains**, run the `trains-init` command for an interactive setup or edit your `trains.conf` file and make sure the `api.host` value is configured as follows:
api {
host: "http://localhost:8008"
}
See [Installing and Configuring TRAINS](https://github.com/allegroai/trains#installing-and-configuring-trains) for more details.
See [Installing and Configuring TRAINS](https://github.com/allegroai/trains#configuration) for more details.
## What next?
@@ -253,7 +214,7 @@ Now that the **trains-server** is installed, and TRAINS is configured to use it,
you can [use](https://github.com/allegroai/trains#using-trains) TRAINS in your experiments and view them in the web server,
for example http://localhost:8080
## Upgrade
## Upgrading <a name="upgrade"></a>
We are constantly updating, improving and adding to the **trains-server**.
New releases will include new pre-built Docker images.
@@ -261,39 +222,63 @@ When we release a new version and include a new pre-built Docker image for it, u
1. Shut down and remove each of your Docker instances using the following commands:
sudo docker stop <docker-name>
sudo docker rm -v <docker-name>
The Docker names are (see [Launching Docker Containers](#launching-docker-containers)):
* `trains-elastic`
* `trains-mongo`
* `trains-fileserver`
* `trains-apiserver`
* `trains-webserver`
2. Pull the new **trains-server** docker image using the following command:
sudo docker pull allegroai/trains:latest
* Using Docker-Compose
If you wish to pull a different version, replace `latest` with the required version number, for example:
```bash
$ docker-compose down
```
sudo docker pull allegroai/trains:0.10.0
3. We highly recommend backing up your data directory!. A simple way to do that is using `tar`:
* Manual Docker launching
```bash
$ sudo docker stop <docker-name>
$ sudo docker rm -v <docker-name>
```
The Docker names are (see [Launching Docker Containers](#launch-docker)):
* `trains-elastic`
* `trains-mongo`
* `trains-fileserver`
* `trains-apiserver`
* `trains-webserver`
2. We highly recommend backing up your data directory!. A simple way to do that is using `tar`:
For example, if your data directory is `/opt/trains`, use the following command:
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
This back ups all data to an archive in your home directory.
```bash
$ sudo tar czvf ~/trains_backup.tgz /opt/trains/data
```
This backups all data to an archive in your home directory.
To restore this example backup, use the following command:
```bash
$ sudo rm -R /opt/trains/data
$ sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```
3. Pull the new **trains-server** docker image using the following command:
sudo rm -R /opt/trains/data
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```bash
$ sudo docker pull allegroai/trains:latest
```
If you wish to pull a different version, replace `latest` with the required version number, for example:
```bash
$ sudo docker pull allegroai/trains:0.10.1
```
4. Launch the newly released Docker image (see [Launching Docker Containers](#launch-docker)).
4. Launch the newly released Docker image (see [Launching Docker Containers](#launching-docker-containers)).
## Community & Support
If you have any questions, look to the TRAINS-server [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/trains) with '**trains**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/trains-server/issues).
Additionally, you can always find us at *trains@allegro.ai*
## License

View File

@@ -0,0 +1,81 @@
version: "3.6"
services:
trainsserver:
command:
- -c
- "echo \"#!/bin/bash\" > /opt/trains/all.sh && echo \"/opt/trains/wrapper.sh webserver&\" >> /opt/trains/all.sh && echo \"/opt/trains/wrapper.sh fileserver&\" >> /opt/trains/all.sh && echo \"/opt/trains/wrapper.sh apiserver\" >> /opt/trains/all.sh && cat /opt/trains/all.sh && chmod +x /opt/trains/all.sh && /opt/trains/all.sh"
entrypoint: /bin/bash
container_name: trains-server
image: allegroai/trains:latest
ports:
- 8008:8008
- 8080:80
- 8081:8081
restart: always
volumes:
- type: bind
source: /opt/trains/logs
target: /var/log/trains
- type: bind
source: /opt/trains/data/fileserver
target: /mnt/fileserver
links:
- mongo:mongo
- elasticsearch:elasticsearch
environment:
ELASTIC_SERVICE_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_SERVICE_HOST: mongo
networks:
- backend
elasticsearch:
networks:
- backend
container_name: trains-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
bootstrap.memory_lock: "true"
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
restart: always
volumes:
- type: bind
source: /opt/trains/data/elastic
target: /usr/share/elasticsearch/data
ports:
- "9200:9200"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: always
volumes:
- type: bind
source: /opt/trains/data/mongo/db
target: /data/db
- type: bind
source: /opt/trains/data/mongo/configdb
target: /data/configdb
ports:
- "27017:27017"
networks:
backend:
driver: bridge

View File

@@ -5,13 +5,28 @@ services:
- apiserver
container_name: trains-apiserver
image: allegroai/trains:latest
network_mode: host
restart: always
volumes:
- type: bind
source: /opt/trains/logs
target: /var/log/trains
- type: bind
source: /opt/trains/config
target: /opt/trains/config
links:
- mongo:mongo
- elasticsearch:elasticsearch
- fileserver:fileserver
environment:
ELASTIC_SERVICE_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_SERVICE_HOST: mongo
ports:
- "8008:8008"
networks:
- backend
elasticsearch:
networks:
- backend
container_name: trains-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
@@ -30,19 +45,25 @@ services:
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
network_mode: host
restart: always
volumes:
- type: bind
source: /opt/trains/data/elastic
target: /usr/share/elasticsearch/data
ports:
- "9200:9200"
fileserver:
networks:
- backend
command:
- fileserver
container_name: trains-fileserver
image: allegroai/trains:latest
network_mode: host
restart: always
volumes:
- type: bind
@@ -51,10 +72,13 @@ services:
- type: bind
source: /opt/trains/data/fileserver
target: /mnt/fileserver
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
network_mode: host
restart: always
volumes:
- type: bind
@@ -63,14 +87,25 @@ services:
- type: bind
source: /opt/trains/data/mongo/configdb
target: /data/configdb
ports:
- "27017:27017"
webserver:
networks:
- backend
command:
- webserver
container_name: trains-webserver
image: allegroai/trains:latest
network_mode: host
restart: always
volumes:
- type: bind
source: /opt/trains/logs
target: /var/log/trains
links:
- apiserver
ports:
- "8080:80"
networks:
backend:
driver: bridge

100
docs/docker_setup.md Normal file
View File

@@ -0,0 +1,100 @@
# TRAINS-server: Using Docker Pre-Built Images
The pre-built Docker image for the **trains-server** is the quickest way to get started with your own **TRAINS** server.
You can also build the entire **trains-server** architecture using the code available in the [trains-server](https://github.com/allegroai/trains-server) repository.
**Note**: We tested this pre-built Docker image with Linux, only. For Windows users, we recommend installing the pre-built image on a Linux virtual machine.
## Prerequisites
* You must be logged in as a user with sudo privileges
* Use `bash` for all command-line instructions in this installation
## Setup Docker
### Step 1: Install Docker CE
You must first install Docker. For instructions about installing Docker, see [Supported platforms](https://docs.docker.com/install//#support) in the Docker documentation.
For example, to [install in Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/) / Mint (x86_64/amd64):
```bash
sudo apt-get install -y apt-transport-https ca-certificates curl software-properties-common
curl -fsSL https://download.docker.com/linux/ubuntu/gpg | sudo apt-key add -
. /etc/os-release
sudo add-apt-repository "deb [arch=amd64] https://download.docker.com/linux/ubuntu $UBUNTU_CODENAME stable"
sudo apt-get update
sudo apt-get install -y docker-ce
```
### Step 2: Set the Maximum Number of Memory Map Areas
Elastic requires that the `vm.max_map_count` kernel setting, which is the maximum number of memory map areas a process can use, is set to at least 262144.
For CentOS 7, Ubuntu 16.04, Mint 18.3, Ubuntu 18.04 and Mint 19.x, we tested the following commands to set `vm.max_map_count`:
```bash
echo "vm.max_map_count=262144" > /tmp/99-trains.conf
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
```
For information about setting this parameter on other systems, see the [elastic](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode) documentation.
### Step 3: Restart the Docker daemon
Restart the Docker daemon.
```bash
sudo service docker restart
```
### Step 4: Choose a Data Directory
Choose a directory on your system in which all data maintained by the **trains-server** is stored.
Create this directory, and set its owner and group to `uid` 1000. The data stored in this directory includes the database, uploaded files and logs.
For example, if your data directory is `/opt/trains`, then use the following command:
```bash
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/data/fileserver
sudo chown -R 1000:1000 /opt/trains
```
## TRAINS-server: Manually Launching Docker Containers <a name="launch"></a>
You can manually launch the Docker containers using the following commands.
If your data directory is not `/opt/trains`, then in the five `docker run` commands below, you must replace all occurrences of `/opt/trains` with your data directory path.
1. Launch the **trains-elastic** Docker container.
sudo docker run -d --restart="always" --name="trains-elastic" -e "bootstrap.memory_lock=true" --ulimit memlock=-1:-1 -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
1. Launch the **trains-mongo** Docker container.
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
1. Launch the **trains-fileserver** Docker container.
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
1. Launch the **trains-apiserver** Docker container.
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/config:/opt/trains/config allegroai/trains:latest apiserver
1. Launch the **trains-webserver** Docker container.
sudo docker run -d --restart="always" --name="trains-webserver" -p 8080:80 allegroai/trains:latest webserver
1. Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* API server on port `8008`
* Web server on port `8080`
* File server on port `8081`

181
docs/faq.md Normal file
View File

@@ -0,0 +1,181 @@
# TRAINS-server FAQ
* [Deploying trains-server on Kubernetes clusters](#kubernetes)
* [Creating a Helm Chart for trains-server Kubernetes deployment](#helm)
* [Running trains-server on Mac OS X](#mac-osx)
* [Installing trains-server on stand alone Linux Ubuntu systems ](#ubuntu)
* [Resolving port conflicts preventing fixed users mode authentication and login](#port-conflict)
* [Configuring trains-server for sub-domains and load balancers](#sub-domains)
### Deploying trains-server on Kubernetes clusters <a name="kubernetes"></a>
**trains-server** supports Kubernetes. See [trains-server-k8s](https://github.com/allegroai/trains-server-k8s)
which contains the YAML files describing the required services and detailed instructions for deploying
**trains-server** to a Kubernetes clusters.
### Creating a Helm Chart for trains-server Kubernetes deployment <a name="helm"></a>
**trains-server** supports creating a Helm chart for Kubernetes deployment. See [trains-server-helm](https://github.com/allegroai/trains-server-helm)
which you can use to create a Helm chart for **trains-server** and contains detailed instructions for deploying
**trains-server** to a Kubernetes clusters using Helm.
### Running trains-server on Mac OS X <a name="mac-osx"></a>
To install and configure **trains-server** on Mac OS X, follow the steps below.
1. Install [docker for OS X](https://docs.docker.com/docker-for-mac/install/).
1. Configure [Docker](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode).
$ screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
sysctl -w vm.max_map_count=262144
1. Create local directories for the databases and storage.
$ sudo mkdir -p /opt/trains/data/elastic
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo chown -R $(whoami):staff /opt/trains
1. Open the Docker app, select **Preferences**, and then on the **File Sharing** tab, add `/opt/trains`.
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
$ git clone https://github.com/allegroai/trains-server.git
$ cd trains-server
1. Run `docker-compose` with the unified docker image.
$ docker-compose -f docker-compose-unified.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080)
### Installing trains-server on stand alone Linux Ubuntu systems <a name="ubuntu"></a>
To install **trains-server** on a stand alone Linux Ubuntu, follow the steps belows.
1. Install [docker for Ubuntu](https://docs.docker.com/install/linux/docker-ce/ubuntu/).
1. Install `docker-compose` using the following commands (for more detailed information, see the [Install Docker Compose](https://docs.docker.com/compose/install/) in the Docker documentation):
sudo curl -L "https://github.com/docker/compose/releases/download/1.24.1/docker-compose-$(uname -s)-$(uname -m)" -o /usr/local/bin/docker-compose
sudo chmod +x /usr/local/bin/docker-compose
1. Remove the previous installation of **trains-server**.
**WARNING**: This clears all existing **TRAINS** databases.
$ sudo rm -R /opt/trains/
1. Create local directories for the databases and storage.
$ sudo mkdir -p /opt/trains/data/elastic
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo chown -R 1000:1000 /opt/trains
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
$ git clone https://github.com/allegroai/trains-server.git
$ cd trains-server
1. Run `docker-compose`
$ /usr/local/bin/docker-compose -f docker-compose.yml up
Your server is now running on [http://localhost:8080](http://localhost:8080)
### Resolving port conflicts preventing fixed users mode authentication and login <a name="port-conflict"></a>
A port conflict may occur between the **trains-server** MongoDB and Elastic instances and other
instances running on your system. **trains-server** uses the following default ports which may be in conflict with other instances:
* MongoDB port `27017`
* Elastic port `9200`
You can check for port conflicts in the logs in `/opt/trains/log`.
If a port conflict occurs, first change the port in your **trains-server** `/opt/trains/server/config/default/hosts.conf` file to the new port and then
run the `docker run` command with the `port` option specifying the new port to restart the **trains-server** instance.
For example, to resolve a MongoDB port conflict change port `27017` to `27018`:
1. Modify `/opt/trains/server/config/default/hosts.conf` changing the ports in the `mongo` section:
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
retry_on_timeout: true
}
index_version: "1"
}
}
mongo {
backend {
host: "mongodb://127.0.0.1:27018/backend"
}
auth {
host: "mongodb://127.0.0.1:27018/auth"
}
}
2. Start the **trains-server** MongoDB container using `--port 27018`.
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5 mongod --port 27018
In a future version of **trains-server**, to start the API server, environment variables will be available to use instead of modifying the configuration file (instead of Step 1 above).
The environment variables will be available to set different ports for both MongoDB and Elastic instances:
* `MONGODB_SERVICE_PORT` (e.g., `MONGODB_SERVICE_PORT=27018`)
* `ELASTIC_SERVICE_POST` (e.g., `ELASTIC_SERVICE_POST=9201`)
### Configuring trains-server for sub-domains and load balancers <a name="sub-domains"></a>
You can configure **trains-server** for sub-domains and a load balancer.
For example, if your domain is `trains.mydomain.com` and your sub-domains are `app` and `api`, then do the following:
1. If you are not using the current **trains-server** version, [upgrade](https://github.com/allegroai/trains-server#upgrade) **trains-server**.
1. Add the following to `/opt/trains/config/apiserver.conf`:
auth {
cookies {
httponly: true
secure: true
domain: ".trains.mydomain.com"
max_age: 99999999999
}
}
1. Use the following load balancer configuration:
* Listeners:
* Optional: HTTP listener, that redirects all traffic to HTTPS.
* HTTPS listener for `app.` forwarded to `AppTargetGroup`
* HTTPS listener for `api.` forwarded to `ApiTargetGroup`
* HTTPS listener for `files.` forwarded to `FilesTargetGroup`
* Target groups:
* `AppTargetGroup`: HTTP based target group, port `8080`
* `ApiTargetGroup`: HTTP based target group, port `8008`
* `FilesTargetGroup`: HTTP based target group, port `8081`
* Security and routing:
* Load balancer: make sure the load balancers are able to receive traffic from the relevant IP addresses (Security groups and Subnets definitions).
* Instances: make sure the load balancers are able to access the instances, using the relevant ports (Security groups definitions).
1. Run the Docker containers with our updated `docker run` commands (see [Launching Docker Containers](#https://github.com/allegroai/trains-server#launching-docker-containers)).

View File

@@ -25,6 +25,62 @@ In order to upgrade **trains-server** on an existing EC2 instance based on one o
The following sections provide a list containing AMI Image ID per region for each released **trains-server** version.
### Latest Version AMI <a name="autoupdate"></a>
**For easier upgrades: The following AMI automatically update to the latest release every reboot**
* **eu-north-1** : ami-047eb12cf0b47b2d1
* **ap-south-1** : ami-0a2facc5f027ab528
* **eu-west-3** : ami-08ef18e0e4ca1e6c6
* **eu-west-2** : ami-0a7133d9a3c800bbd
* **eu-west-1** : ami-0f1cce84bb2187729
* **ap-northeast-2** : ami-0825c4e06cc194272
* **ap-northeast-1** : ami-024db084d549289f3
* **sa-east-1** : ami-04eca8d7ab944a48c
* **ca-central-1** : ami-03b7bfbb8607c9bc4
* **ap-southeast-1** : ami-0a8667b8ba3564202
* **ap-southeast-2** : ami-0866de3db64f63e15
* **eu-central-1** : ami-04898b0923493de1b
* **us-east-2** : ami-06afbbc84f5d829da
* **us-west-1** : ami-045fe6664792a00d7
* **us-west-2** : ami-0132184364da97720
* **us-east-1** : ami-08747037c11256d44
### v0.11.0
* **eu-north-1** : ami-0303acd0967b3df38
* **ap-south-1** : ami-0e14dc1e886344a3e
* **eu-west-3** : ami-00de3fa500c2e7ea9
* **eu-west-2** : ami-0bd68bec0c2631535
* **eu-west-1** : ami-094b8dcc9b6f9a04c
* **ap-northeast-2** : ami-0091bb348c218d4c5
* **ap-northeast-1** : ami-0e06fbc71a9e7a74d
* **sa-east-1** : ami-0e99a346d8e585f76
* **ca-central-1** : ami-09874b823457e5874
* **ap-southeast-1** : ami-0823fd4963b3d4ff4
* **ap-southeast-2** : ami-0463d77897f1c0569
* **eu-central-1** : ami-0bb5cb2f5d444f905
* **us-east-2** : ami-0b364bf4c7dc12f67
* **us-west-1** : ami-0a97c0548d53d9f1d
* **us-west-2** : ami-06588b5bde813c28c
* **us-east-1** : ami-0a43a4b03215b0144
### v0.10.1
* **eu-north-1** : ami-09937ec4d18350c32
* **ap-south-1** : ami-089d6ba7541ec4c7f
* **eu-west-3** : ami-0accb1a94bdd5c5c1
* **eu-west-2** : ami-0dd2c97bc678b8570
* **eu-west-1** : ami-07a38865cbe7ca3cb
* **ap-northeast-2** : ami-09aa0b7fe1cf3dd55
* **ap-northeast-1** : ami-0905e7d1543e5ed36
* **sa-east-1** : ami-08c0627daa67d7372
* **ca-central-1** : ami-034add081712ff648
* **ap-southeast-1** : ami-0c6caee3689b6e066
* **ap-southeast-2** : ami-04994afd8dae5b417
* **eu-central-1** : ami-06b10f8c30e1434f1
* **us-east-2** : ami-0d3abe7a1fec535cc
* **us-west-1** : ami-02bb610b70c55018b
* **us-west-2** : ami-0d1cb8ba7de246ff0
* **us-east-1** : ami-049ccba6abdb40cba
### v0.10.0
* **eu-north-1** : ami-05ba33c763877e54e
* **ap-south-1** : ami-0529eec569161cae5

View File

@@ -0,0 +1,8 @@
download {
# Add response headers requesting no caching for served files
disable_browser_caching: false
}
cors {
origins: "*"
}

View File

@@ -1,16 +1,18 @@
""" A Simple file server for uploading and downloading files """
import json
import logging.config
import os
from argparse import ArgumentParser
from pathlib import Path
from flask import Flask, request, send_from_directory, safe_join
from pyhocon import ConfigFactory
from flask_compress import Compress
from flask_cors import CORS
logging.config.dictConfig(ConfigFactory.parse_file("logging.conf"))
from config import config
app = Flask(__name__)
CORS(app, **config.get("fileserver.cors"))
Compress(app)
@app.route("/", methods=["POST"])
@@ -29,7 +31,15 @@ def upload():
@app.route("/<path:path>", methods=["GET"])
def download(path):
return send_from_directory(app.config["UPLOAD_FOLDER"], path)
response = send_from_directory(app.config["UPLOAD_FOLDER"], path)
if config.get("fileserver.download.disable_browser_caching", False):
headers = response.headers
headers["Pragma-directive"] = "no-cache"
headers["Cache-directive"] = "no-cache"
headers["Cache-control"] = "no-cache"
headers["Pragma"] = "no-cache"
headers["Expires"] = "0"
return response
def main():

View File

@@ -1,2 +1,4 @@
Flask
Flask-Cors>=3.0.5
Flask-Compress>=1.4.0
pyhocon>=0.3.35

View File

@@ -0,0 +1,18 @@
from pymongo.database import Database, Collection
from database.utils import partition_tags
def migrate_backend(db: Database):
for name in ("project", "task", "model"):
collection: Collection = db[name]
for doc in collection.find(projection=["tags", "system_tags"]):
tags = doc.get("tags")
if tags is not None:
user_tags, system_tags = partition_tags(
name, tags, doc.get("system_tags", [])
)
collection.update_one(
{"_id": doc["_id"]},
{"$set": {"system_tags": system_tags, "tags": user_tags}}
)

View File

@@ -83,7 +83,8 @@ _error_codes = {
21: ('bad_credentials', 'unauthorized (malformed credentials)'),
22: ('invalid_credentials', 'unauthorized (invalid credentials)'),
30: ('invalid_token', 'invalid token'),
31: ('blocked_token', 'token is blocked')
31: ('blocked_token', 'token is blocked'),
40: ('invalid_fixed_user', 'fixed user ID was not found')
},
(403, 'forbidden'): {

View File

@@ -4,11 +4,10 @@ from enum import Enum
from typing import Union, Type, Iterable
import jsonmodels.errors
import jsonmodels.validators
import six
import validators
from jsonmodels import fields
from jsonmodels.fields import _LazyType
from jsonmodels.fields import _LazyType, NotSet
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from luqum.parser import parser, ParseError
@@ -25,6 +24,12 @@ def make_default(field_cls, default_value):
class ListField(fields.ListField):
def __init__(self, items_types=None, *args, default=NotSet, **kwargs):
if default is not NotSet and callable(default):
default = default()
super(ListField, self).__init__(items_types, *args, default=default, **kwargs)
def _cast_value(self, value):
try:
return super(ListField, self)._cast_value(value)
@@ -144,6 +149,46 @@ class EnumField(fields.StringField):
return super().parse_value(value)
class ActualEnumField(fields.StringField):
@property
def types(self):
return (self.__enum,)
def __init__(
self,
enum_class: Type[Enum],
*args,
validators=None,
required=False,
default=None,
**kwargs
):
self.__enum = enum_class
# noinspection PyTypeChecker
choices = list(enum_class)
validator_cls = EnumValidator if required else NullableEnumValidator
validators = [*(validators or []), validator_cls(*choices)]
super().__init__(
default=default and self.parse_value(default),
*args,
required=required,
validators=validators,
**kwargs
)
def parse_value(self, value):
if value is None and not self.required:
return self.get_default_value()
try:
# noinspection PyArgumentList
return self.__enum(value)
except ValueError:
return value
def to_struct(self, value):
return super().to_struct(value.value)
class EmailField(fields.StringField):
def validate(self, value):
super().validate(value)
@@ -160,3 +205,12 @@ class DomainField(fields.StringField):
return
if validators.domain(value) is not True:
raise errors.bad_request.InvalidDomainName()
class StringEnum(Enum):
def __str__(self):
return self.value
# noinspection PyMethodParameters
def _generate_next_value_(name, start, count, last_values):
return name

View File

@@ -1,4 +1,4 @@
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField, DateTimeField
from jsonmodels.models import Base
from jsonmodels.validators import Max, Enum
@@ -79,6 +79,7 @@ class Credentials(Base):
class CredentialsResponse(Credentials):
secret_key = StringField()
last_used = DateTimeField(default=None)
class CreateCredentialsResponse(Base):

View File

@@ -0,0 +1,20 @@
from typing import Sequence
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from apimodels import ListField, IntField, ActualEnumField
from bll.event.scalar_key import ScalarKeyEnum
class HistogramRequestBase(Base):
samples: int = IntField(default=10000)
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
tasks: Sequence[str] = ListField(items_types=str)

View File

@@ -11,6 +11,7 @@ class CreateModelRequest(models.Base):
uri = fields.StringField(required=True)
labels = DictField(value_types=string_types+(int,), required=True)
tags = ListField(items_types=string_types)
system_tags = ListField(items_types=string_types)
comment = fields.StringField()
public = fields.BoolField(default=False)
project = fields.StringField()

View File

@@ -0,0 +1,16 @@
from jsonmodels import models, fields
class ProjectReq(models.Base):
project = fields.StringField()
class GetHyperParamReq(ProjectReq):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
class GetHyperParamResp(models.Base):
parameters = fields.ListField(str)
remaining = fields.IntField()
total = fields.IntField()

View File

@@ -57,5 +57,5 @@ class CreateRequest(TaskData):
type = StringField(required=True, validators=Enum(*get_options(TaskType)))
class PingRequest(models.Base):
class PingRequest(TaskRequest):
task = StringField(required=True)

View File

@@ -1,24 +1,24 @@
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from enum import Enum
from operator import attrgetter
from typing import Sequence
import attr
import six
from elasticsearch import helpers
from enum import Enum
from mongoengine import Q
from nested_dict import nested_dict
import database.utils as dbutils
import es_factory
from apierrors import errors
from bll.event.event_metrics import EventMetrics
from bll.task import TaskBLL
from database.errors import translate_errors_context
from database.model.task.task import Task
from database.model.task.metrics import MetricEvent
from timing_context import TimingContext
from utilities.dicts import flatten_nested_items
class EventType(Enum):
@@ -44,7 +44,12 @@ class EventBLL(object):
id_fields = ["task", "iter", "metric", "variant", "key"]
def __init__(self, events_es=None):
self.es = events_es if events_es is not None else es_factory.connect("events")
self.es = events_es or es_factory.connect("events")
self._metrics = EventMetrics(self.es)
@property
def metrics(self) -> EventMetrics:
return self._metrics
def add_events(self, company_id, events, worker):
actions = []
@@ -94,7 +99,7 @@ class EventBLL(object):
event["value"] = event["values"]
del event["values"]
index_name = EventBLL.get_index_name(company_id, event_type)
index_name = EventMetrics.get_index_name(company_id, event_type)
es_action = {
"_op_type": "index", # overwrite if exists with same ID
"_index": index_name,
@@ -154,13 +159,6 @@ class EventBLL(object):
else:
errors_in_bulk.append(info)
last_metrics = {
t.id: t.to_proper_dict().get("last_metrics", {})
for t in Task.objects(id__in=task_ids, company=company_id).only(
"last_metrics"
)
}
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
@@ -173,7 +171,6 @@ class EventBLL(object):
now=now,
iter=task_iteration.get(task_id),
last_events=task_last_events.get(task_id),
last_metrics=last_metrics.get(task_id),
)
if not updated:
@@ -210,9 +207,7 @@ class EventBLL(object):
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
def _update_task(
self, company_id, task_id, now, iter=None, last_events=None, last_metrics=None
):
def _update_task(self, company_id, task_id, now, iter=None, last_events=None):
"""
Update task information in DB with aggregated results after handling event(s) related to this task.
@@ -226,23 +221,13 @@ class EventBLL(object):
fields["last_iteration"] = iter
if last_events:
def get_metric_event(ev):
me = MetricEvent.from_dict(**ev)
if "timestamp" in ev:
me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] / 1000)
return me
new_last_metrics = nested_dict(2, MetricEvent)
new_last_metrics.update(last_metrics)
for metric_hash, variants in last_events.items():
for variant_hash, event in variants.items():
new_last_metrics[metric_hash][variant_hash] = get_metric_event(
event
)
fields["last_metrics"] = new_last_metrics.to_dict()
fields["last_values"] = list(
flatten_nested_items(
last_events,
nesting=2,
include_leaves=["value", "metric", "variant"],
)
)
if not fields:
return False
@@ -270,7 +255,7 @@ class EventBLL(object):
if event_type is None:
event_type = "*"
es_index = EventBLL.get_index_name(company_id, event_type)
es_index = EventMetrics.get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return [], None, 0
@@ -290,6 +275,125 @@ class EventBLL(object):
return events, next_scroll_id, total_events
def get_last_iterations_per_event_metric_variant(
self, es_index: str, task_id: str, num_last_iterations: int, event_type: str
):
if not self.es.indices.exists(es_index):
return []
es_req: dict = {
"size": 0,
"aggs": {
"metrics": {
"terms": {"field": "metric"},
"aggs": {
"variants": {
"terms": {"field": "variant"},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": num_last_iterations,
"order": {"_term": "desc"},
}
}
},
}
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
if event_type:
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext(
"es", "task_last_iter_metric_variant"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
if "aggregations" not in es_res:
return []
return [
(metric["key"], variant["key"], iter["key"])
for metric in es_res["aggregations"]["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
for iter in variant["iters"]["buckets"]
]
def get_task_plots(
self,
company_id: str,
tasks: Sequence[str],
last_iterations_per_plot: int = None,
sort=None,
size: int = 500,
scroll_id: str = None,
):
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
event_type = "plot"
es_index = EventMetrics.get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
query = {"bool": defaultdict(list)}
if last_iterations_per_plot is None:
must = query["bool"]["must"]
must.append({"terms": {"task": tasks}})
else:
should = query["bool"]["should"]
for i, task_id in enumerate(tasks):
last_iters = self.get_last_iterations_per_event_metric_variant(
es_index, task_id, last_iterations_per_plot, event_type
)
if not last_iters:
continue
for metric, variant, iter in last_iters:
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
{"term": {"iter": iter}},
]
}
}
)
if not should:
return TaskEventsResult()
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
routing = ",".join(tasks)
with translate_errors_context(), TimingContext("es", "get_task_plots"):
es_res = self.es.search(
index=es_index,
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
)
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
# scroll id may be missing when queering a totally empty DB
next_scroll_id = es_res.get("_scroll_id")
total_events = es_res["hits"]["total"]
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
def get_task_events(
self,
company_id,
@@ -311,7 +415,7 @@ class EventBLL(object):
if event_type is None:
event_type = "*"
es_index = EventBLL.get_index_name(company_id, event_type)
es_index = EventMetrics.get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
@@ -374,7 +478,7 @@ class EventBLL(object):
def get_metrics_and_variants(self, company_id, task_id, event_type):
es_index = EventBLL.get_index_name(company_id, event_type)
es_index = EventMetrics.get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return {}
@@ -405,7 +509,7 @@ class EventBLL(object):
return metrics
def get_task_latest_scalar_values(self, company_id, task_id):
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
es_index = EventMetrics.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
@@ -488,147 +592,9 @@ class EventBLL(object):
metrics.append(metric_summary)
return metrics, max_timestamp
def compare_scalar_metrics_average_per_iter(
self, company_id, task_ids, allow_public=True
):
assert isinstance(task_ids, list)
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
es_req = {
"size": 0,
"_source": {"excludes": []},
"query": {"terms": {"task": task_ids}},
"aggs": {
"iters": {
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
"aggs": {
"metric_and_variant": {
"terms": {
"script": "doc['metric'].value +'/'+ doc['variant'].value",
"size": 10000,
},
"aggs": {
"tasks": {
"terms": {"field": "task"},
"aggs": {"avg_val": {"avg": {"field": "value"}}},
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "task_stats_comparison"):
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return
metrics = {}
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
iteration = int(iter_bucket["key"])
for metric_bucket in iter_bucket["metric_and_variant"]["buckets"]:
metric_name = metric_bucket["key"]
if metrics.get(metric_name) is None:
metrics[metric_name] = {}
metric_data = metrics[metric_name]
for task_bucket in metric_bucket["tasks"]["buckets"]:
task_id = task_bucket["key"]
value = task_bucket["avg_val"]["value"]
if metric_data.get(task_id) is None:
metric_data[task_id] = {
"x": [],
"y": [],
"name": task_name_by_id[
task_id
], # todo: lookup task name from id
}
metric_data[task_id]["x"].append(iteration)
metric_data[task_id]["y"].append(value)
return metrics
def get_scalar_metrics_average_per_iter(self, company_id, task_id):
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
es_req = {
"size": 0,
"_source": {"excludes": []},
"query": {"term": {"task": task_id}},
"aggs": {
"iters": {
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 200,
"order": {"_term": "desc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": 500,
"order": {"_term": "desc"},
},
"aggs": {"avg_val": {"avg": {"field": "value"}}},
}
},
}
},
}
},
"version": True,
}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
metrics = {}
if "aggregations" in es_res:
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
iteration = int(iter_bucket["key"])
for metric_bucket in iter_bucket["metrics"]["buckets"]:
metric_name = metric_bucket["key"]
if metrics.get(metric_name) is None:
metrics[metric_name] = {}
metric_data = metrics[metric_name]
for variant_bucket in metric_bucket["variants"]["buckets"]:
variant = variant_bucket["key"]
value = variant_bucket["avg_val"]["value"]
if metric_data.get(variant) is None:
metric_data[variant] = {"x": [], "y": [], "name": variant}
metric_data[variant]["x"].append(iteration)
metric_data[variant]["y"].append(value)
return metrics
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
es_index = EventBLL.get_index_name(company_id, "training_stats_vector")
es_index = EventMetrics.get_index_name(company_id, "training_stats_vector")
if not self.es.indices.exists(es_index):
return [], []
@@ -685,7 +651,7 @@ class EventBLL(object):
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
def delete_task_events(self, company_id, task_id):
es_index = EventBLL.get_index_name(company_id, "*")
es_index = EventMetrics.get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(
@@ -693,8 +659,3 @@ class EventBLL(object):
)
return es_res.get("deleted", 0)
@staticmethod
def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_")
return "events-%s-%s" % (event_type, company_id)

View File

@@ -0,0 +1,398 @@
import itertools
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from operator import itemgetter
from elasticsearch import Elasticsearch
from typing import Sequence, Tuple, Callable
from mongoengine import Q
from apierrors import errors
from bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from config import config
from database.errors import translate_errors_context
from database.model.task.task import Task
from timing_context import TimingContext
from utilities import safe_get
log = config.logger(__file__)
class EventMetrics:
MAX_TASKS_COUNT = 100
MAX_METRICS_COUNT = 200
MAX_VARIANTS_COUNT = 500
def __init__(self, es: Elasticsearch):
self.es = es
@staticmethod
def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_")
return f"events-{event_type}-{company_id}"
def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
) -> dict:
"""
Get scalar metric histogram per metric and variant
The amount of points in each histogram should not exceed
the requested samples
"""
return self._run_get_scalar_metrics_as_parallel(
company_id,
task_ids=[task_id],
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average,
)
def compare_scalar_metrics_average_per_iter(
self,
company_id,
task_ids: Sequence[str],
samples,
key: ScalarKeyEnum,
allow_public=True,
):
"""
Compare scalar metrics for different tasks per metric and variant
The amount of points in each histogram should not exceed the requested samples
"""
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
ret = self._run_get_scalar_metrics_as_parallel(
company_id,
task_ids=task_ids,
samples=samples,
key=ScalarKey.resolve(key),
get_func=self._get_scalar_average_per_task,
)
for metric_data in ret.values():
for variant_data in metric_data.values():
for task_id, task_data in variant_data.items():
task_data["name"] = task_name_by_id[task_id]
return ret
TaskMetric = Tuple[str, str, str]
MetricInterval = Tuple[int, Sequence[TaskMetric]]
MetricData = Tuple[str, dict]
def _run_get_scalar_metrics_as_parallel(
self,
company_id: str,
task_ids: Sequence[str],
samples: int,
key: ScalarKey,
get_func: Callable[
[MetricInterval, Sequence[str], str, ScalarKey], Sequence[MetricData]
],
) -> dict:
"""
Group metrics per interval length and execute get_func for each group in parallel
:param company_id: id of the company
:params task_ids: ids of the tasks to collect data for
:param samples: maximum number of samples per metric
:param get_func: callable that given metric names for the same interval
performs histogram aggregation for the metrics and return the aggregated data
"""
es_index = self.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
intervals = self._get_metric_intervals(
es_index=es_index, task_ids=task_ids, samples=samples, field=key.field
)
if not intervals:
return {}
with ThreadPoolExecutor(len(intervals)) as pool:
metrics = list(
itertools.chain.from_iterable(
pool.map(
partial(
get_func, task_ids=task_ids, es_index=es_index, key=key
),
intervals,
)
)
)
ret = defaultdict(dict)
for metric_key, metric_values in metrics:
ret[metric_key].update(metric_values)
return ret
def _get_metric_intervals(
self, es_index, task_ids: Sequence[str], samples: int, field: str = "iter"
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
amount of points does not exceed sample.
Return metric variants grouped by interval value with 10% rounding
For samples==0 return empty list
"""
default_intervals = [(1, [])]
if not samples:
return default_intervals
es_req = {
"size": 0,
"query": {"terms": {"task": task_ids}},
"aggs": {
"tasks": {
"terms": {"field": "task", "size": self.MAX_TASKS_COUNT},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": self.MAX_METRICS_COUNT,
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
},
"aggs": {
"count": {"value_count": {"field": field}},
"min_index": {"min": {"field": field}},
"max_index": {"max": {"field": field}},
},
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "task_stats_get_interval"):
es_res = self.es.search(
index=es_index, body=es_req, routing=",".join(task_ids)
)
aggs_result = es_res.get("aggregations")
if not aggs_result:
return default_intervals
intervals = [
(
task["key"],
metric["key"],
variant["key"],
self._calculate_metric_interval(variant, samples),
)
for task in aggs_result["tasks"]["buckets"]
for metric in task["metrics"]["buckets"]
for variant in metric["variants"]["buckets"]
]
metric_intervals = []
upper_border = 0
interval_metrics = None
for task, metric, variant, interval in sorted(intervals, key=itemgetter(3)):
if not interval_metrics or interval > upper_border:
interval_metrics = []
metric_intervals.append((interval, interval_metrics))
upper_border = interval + int(interval * 0.1)
interval_metrics.append((task, metric, variant))
return metric_intervals
@staticmethod
def _calculate_metric_interval(metric_variant: dict, samples: int) -> int:
"""
Calculate index interval per metric_variant variant so that the
total amount of intervals does not exceeds the samples
"""
count = safe_get(metric_variant, "count/value")
if not count or count < samples:
return 1
min_index = safe_get(metric_variant, "min_index/value", default=0)
max_index = safe_get(metric_variant, "max_index/value", default=min_index)
return max(1, int(max_index - min_index + 1) // samples)
def _get_scalar_average(
self,
metrics_interval: MetricInterval,
task_ids: Sequence[str],
es_index: str,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
Note: the function works with a single task only
"""
assert len(task_ids) == 1
interval, task_metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {
"field": "metric",
"size": self.MAX_METRICS_COUNT,
"order": {"_term": "desc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": self.MAX_VARIANTS_COUNT,
"order": {"_term": "desc"},
},
"aggs": aggregation,
}
},
}
}
aggs_result = self._query_aggregation_for_metrics_and_tasks(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
)
if not aggs_result:
return {}
metrics = [
(
metric["key"],
{
variant["key"]: {
"name": variant["key"],
**key.get_iterations_data(variant),
}
for variant in metric["variants"]["buckets"]
},
)
for metric in aggs_result["metrics"]["buckets"]
]
return metrics
def _get_scalar_average_per_task(
self,
metrics_interval: MetricInterval,
task_ids: Sequence[str],
es_index: str,
key: ScalarKey,
) -> Sequence[MetricData]:
"""
Retrieve scalar histograms per several metric variants that share the same interval
"""
interval, task_metrics = metrics_interval
aggregation = self._add_aggregation_average(key.get_aggregation(interval))
aggs = {
"metrics": {
"terms": {"field": "metric", "size": self.MAX_METRICS_COUNT},
"aggs": {
"variants": {
"terms": {"field": "variant", "size": self.MAX_VARIANTS_COUNT},
"aggs": {
"tasks": {"terms": {"field": "task"}, "aggs": aggregation}
},
}
},
}
}
aggs_result = self._query_aggregation_for_metrics_and_tasks(
es_index, aggs=aggs, task_ids=task_ids, task_metrics=task_metrics
)
if not aggs_result:
return {}
metrics = [
(
metric["key"],
{
variant["key"]: {
task["key"]: key.get_iterations_data(task)
for task in variant["tasks"]["buckets"]
}
for variant in metric["variants"]["buckets"]
},
)
for metric in aggs_result["metrics"]["buckets"]
]
return metrics
@staticmethod
def _add_aggregation_average(aggregation):
average_agg = {"avg_val": {"avg": {"field": "value"}}}
return {
key: {**value, "aggs": {**value.get("aggs", {}), **average_agg}}
for key, value in aggregation.items()
}
def _query_aggregation_for_metrics_and_tasks(
self,
es_index: str,
aggs: dict,
task_ids: Sequence[str],
task_metrics: Sequence[TaskMetric],
) -> dict:
"""
Return the result of elastic search query for the given aggregation filtered
by the given task_ids and metrics
"""
if task_metrics:
condition = {
"should": [
self._build_metric_terms(task, metric, variant)
for task, metric, variant in task_metrics
]
}
else:
condition = {"must": [{"terms": {"task": task_ids}}]}
es_req = {
"size": 0,
"_source": {"excludes": []},
"query": {"bool": condition},
"aggs": aggs,
"version": True,
}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search(
index=es_index, body=es_req, routing=",".join(task_ids)
)
return es_res.get("aggregations")
@staticmethod
def _build_metric_terms(task: str, metric: str, variant: str) -> dict:
"""
Build query term for a metric + variant
"""
return {
"bool": {
"must": [
{"term": {"task": task}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
}

View File

@@ -0,0 +1,161 @@
"""
Module for polymorphism over different types of X axes in scalar aggregations
"""
from abc import ABC, abstractmethod
from enum import auto
from apimodels import StringEnum
from bll.util import extract_properties_to_lists
from config import config
log = config.logger(__file__)
class ScalarKeyEnum(StringEnum):
"""
String enum representing X axes key
"""
iter = auto()
timestamp = auto()
iso_time = auto()
class ScalarKey(ABC):
"""
Abstract scalar key
"""
_enum_to_key = {}
bucket_key_key = "key"
@property
@abstractmethod
def enum_value(self) -> ScalarKeyEnum:
"""
Enum value accepted in API requests
"""
pass
@property
@abstractmethod
def name(self) -> str:
"""
Key name. Used as arbitrary internal key in elasticsearch queries
"""
pass
@property
@abstractmethod
def field(self) -> str:
"""
Event key to aggregate by
"""
pass
@abstractmethod
def get_aggregation(self, interval: int) -> dict:
"""
Get aggregation for this type of key
:param interval: elasticsearch aggregation interval
"""
pass
def __init_subclass__(cls, **kwargs):
"""
Save a mapping from enum values to key class
"""
if cls.enum_value not in ScalarKeyEnum:
raise ValueError(f"{cls.enum_value!r} not in {ScalarKeyEnum.__name__}")
if cls.enum_value in cls._enum_to_key:
log.warning(
f"'{cls.enum_value.value}' is already registered to {ScalarKey.__name__}"
)
cls._enum_to_key[cls.enum_value] = cls
@classmethod
def resolve(cls, key: ScalarKeyEnum):
"""
Create a key instance from enum instance
"""
return cls._enum_to_key[key]()
def get_iterations_data(self, iter_buckets: dict) -> dict:
"""
Convert a list of bucket entries to `x`s array and `y`s array
"""
return extract_properties_to_lists(
("x", "y"),
iter_buckets[self.name]["buckets"],
self._get_iterations_data_single,
)
def _get_iterations_data_single(self, iter_data):
"""
Extract x value and y value from a single bucket item
"""
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
class TimestampKey(ScalarKey):
"""
Aggregate by timestamp in milliseconds since epoch
"""
name = "timestamp"
field = "timestamp"
enum_value = ScalarKeyEnum.timestamp
def get_aggregation(self, interval: int) -> dict:
return {
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": interval,
"min_doc_count": 1,
}
}
}
class IterKey(ScalarKey):
"""
Aggregate by iteration number
"""
name = "iters"
field = "iter"
enum_value = ScalarKeyEnum.iter
def get_aggregation(self, interval: int) -> dict:
return {
self.name: {
"histogram": {"field": "iter", "interval": interval, "min_doc_count": 1}
}
}
class ISOTimeKey(ScalarKey):
"""
Aggregate by time formatted as ISO strings
"""
name = "iso_time"
field = "timestamp"
enum_value = ScalarKeyEnum.iso_time
bucket_key_key = "key_as_string"
def get_aggregation(self, interval: int) -> dict:
return {
self.name: {
"date_histogram": {
"field": "timestamp",
"interval": interval,
"min_doc_count": 1,
"format": "strict_date_time",
}
}
}
def _get_iterations_data_single(self, iter_data):
return iter_data[self.bucket_key_key], iter_data["avg_val"]["value"]

View File

@@ -2,8 +2,7 @@ import re
from collections import OrderedDict
from datetime import datetime, timedelta
from time import sleep
from typing import Mapping, Collection
from urllib.parse import urlparse
from typing import Collection, Sequence, Tuple, Any
import six
from mongoengine import Q
@@ -13,12 +12,15 @@ import es_factory
from apierrors import errors
from config import config
from database.errors import translate_errors_context
from database.fields import OutputDestinationField
from database.model.model import Model
from database.model.project import Project
from database.model.task.metrics import MetricEvent
from database.model.task.output import Output
from database.model.task.task import Task, TaskStatus, TaskStatusMessage, TaskTags
from database.model.task.task import (
Task,
TaskStatus,
TaskStatusMessage,
TaskSystemTags,
)
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from timing_context import TimingContext
@@ -143,7 +145,7 @@ class TaskBLL(object):
return model
@classmethod
def validate(cls, task: Task, force=False):
def validate(cls, task: Task):
assert isinstance(task, Task)
if task.parent and not Task.get(
@@ -154,24 +156,12 @@ class TaskBLL(object):
if task.project:
Project.get_for_writing(company=task.company, id=task.project)
model = cls.validate_execution_model(task)
if model and not force and not model.ready:
raise errors.bad_request.ModelNotReady(
"can't be used in a task", model=model.id
)
cls.validate_execution_model(task)
if task.execution:
if task.execution.parameters:
cls._validate_execution_parameters(task.execution.parameters)
if task.output and task.output.destination:
parsed_url = urlparse(task.output.destination)
if parsed_url.scheme not in OutputDestinationField.schemes:
raise errors.bad_request.FieldsValueError(
"unsupported scheme for output destination",
dest=task.output.destination,
)
@staticmethod
def _validate_execution_parameters(parameters):
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
@@ -236,7 +226,7 @@ class TaskBLL(object):
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_metrics: Mapping[str, Mapping[str, MetricEvent]] = None,
last_values: Sequence[Tuple[Tuple[str, ...], Any]] = None,
**extra_updates,
):
"""
@@ -248,7 +238,7 @@ class TaskBLL(object):
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_metrics: Last reported metrics summary.
:param last_values: Last reported metrics summary (value, metric, variant).
:param extra_updates: Extra task updates to include in this update call.
:return:
"""
@@ -259,10 +249,18 @@ class TaskBLL(object):
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_metrics is not None:
extra_updates.update(last_metrics=last_metrics)
if last_values is not None:
return Task.objects(id=task_id, company=company_id).update(
def op_path(op, *path):
return "__".join((op, "last_metrics") + path)
for path, value in last_values:
extra_updates[op_path("set", *path)] = value
if path[-1] == "value":
extra_updates[op_path("min", *path[:-1], "min_value")] = value
extra_updates[op_path("max", *path[:-1], "max_value")] = value
Task.objects(id=task_id, company=company_id).update(
upsert=False, last_update=last_update, **extra_updates
)
@@ -378,11 +376,11 @@ class TaskBLL(object):
task = TaskBLL.get_task_with_access(
task_id,
company_id=company_id,
only=("status", "project", "tags", "last_update"),
only=("status", "project", "tags", "system_tags", "last_update"),
requires_write_access=True,
)
if TaskTags.development in task.tags:
if TaskSystemTags.development in task.system_tags:
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
@@ -448,3 +446,55 @@ class TaskBLL(object):
except Exception as ex:
log.exception(f"Failed stopping tasks: {str(ex)}")
@staticmethod
def get_aggregated_project_execution_parameters(
company_id,
project_ids: Sequence[str] = None,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[str]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": company_id,
"execution.parameters": {"$exists": True, "$gt": {}},
**({"project": {"$in": project_ids}} if project_ids else {}),
}
},
{"$project": {"parameters": {"$objectToArray": "$execution.parameters"}}},
{"$unwind": "$parameters"},
{"$group": {"_id": "$parameters.k"}},
{"$sort": {"_id": 1}},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
{
"$project": {
"total": 1,
"results": {"$slice": ["$results", page * page_size, page_size]},
}
},
]
with translate_errors_context():
result = next(Task.objects.aggregate(*pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [r["_id"] for r in result.get("results", [])]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results

View File

@@ -66,6 +66,10 @@ class ChangeStatusRequest(object):
)
update_project_time(project_id)
# make sure that _raw_ queries are not returned back to the client
fields.pop("__raw__", None)
return dict(updated=updated, fields=fields)
def validate_transition(self, current_status):
@@ -135,9 +139,11 @@ def get_possible_status_changes(current_status):
:return possible states from current state
"""
possible = state_machine.get(current_status)
assert (
possible is not None
), f"Current status {current_status} not supported by state machine"
if possible is None:
raise errors.server_error.InternalError(
f"Current status {current_status} not supported by state machine"
)
return possible

20
server/bll/util.py Normal file
View File

@@ -0,0 +1,20 @@
from operator import itemgetter
from typing import Sequence, Optional, Callable, Tuple
def extract_properties_to_lists(
key_names: Sequence[str],
data: Sequence[dict],
extract_func: Optional[Callable[[dict], Tuple]] = None,
) -> dict:
"""
Given a list of dictionaries and names of dictionary keys
builds a dictionary with the requested keys and values lists
:param key_names: names of the keys in the resulting dictionary
:param data: sequence of dictionaries to extract values from
:param extract_func: the optional callable that extracts properties
from a dictionary and put them in a tuple in the order corresponding to
key_names. If not specified then properties are extracted according to key_names
"""
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
return dict(zip(key_names, map(list, value_sequences)))

View File

@@ -1,4 +1,5 @@
import logging
import os
from functools import reduce
from os import getenv
from os.path import expandvars
@@ -16,6 +17,9 @@ DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
EXTRA_CONFIG_PATH_SEP = ":"
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
class BasicConfig:
NotSet = object()
@@ -46,6 +50,20 @@ class BasicConfig:
path = ".".join((self.prefix, Path(name).stem))
return logging.getLogger(path)
def _read_extra_env_config_values(self):
""" Loads extra configuration from environment-injected values """
result = ConfigTree()
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
keys = sorted(k for k in os.environ if k.startswith(prefix))
for key in keys:
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".")
result = ConfigTree.merge_configs(
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
)
return result
def _read_env_paths(self, key):
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
if value is None:
@@ -64,12 +82,17 @@ class BasicConfig:
def _load(self, verbose=True):
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
extra_config_values = self._read_extra_env_config_values()
configs = [
self._read_recursive(path, verbose=verbose)
for path in [self.folder] + extra_config_paths
]
self._config = reduce(
lambda config, path: ConfigTree.merge_configs(
config, self._read_recursive(path, verbose=verbose), copy_trees=True
lambda last, config: ConfigTree.merge_configs(
last, config, copy_trees=True
),
[self.folder] + extra_config_paths,
configs + [extra_config_values],
ConfigTree(),
)

View File

@@ -21,6 +21,9 @@
version {
required: false
default: 1.0
# if set then calls to endpoints with the version
# greater that the current max version will be rejected
check_max_version: false
}
mongo {
@@ -66,6 +69,9 @@
cors {
origins: "*"
# Not supported when origins is "*"
supports_credentials: true
}
default_company: "d1bd92a3b039400cbafc60a7a5b1e52b"

View File

@@ -18,3 +18,11 @@ def get_version():
return (root / "VERSION").read_text().strip()
except FileNotFoundError:
return ""
@lru_cache()
def get_commit_number():
try:
return (root / "COMMIT").read_text().strip()
except FileNotFoundError:
return ""

View File

@@ -1,3 +1,6 @@
from os import getenv
from furl import furl
from jsonmodels import models
from jsonmodels.errors import ValidationError
from jsonmodels.fields import StringField
@@ -8,9 +11,14 @@ from config import config
from .defs import Database
from .utils import get_items
log = config.logger(__file__)
from boltons.iterutils import first
strict = config.get('apiserver.mongo.strict', True)
log = config.logger("database")
strict = config.get("apiserver.mongo.strict", True)
OVERRIDE_HOST_ENV_KEY = ("MONGODB_SERVICE_HOST", "MONGODB_SERVICE_SERVICE_HOST")
OVERRIDE_PORT_ENV_KEY = "MONGODB_SERVICE_PORT"
_entries = []
@@ -21,28 +29,47 @@ class DatabaseEntry(models.Base):
@property
def health_alias(self):
return '__health__' + self.alias
return "__health__" + self.alias
def initialize():
db_entries = config.get('hosts.mongo', {})
db_entries = config.get("hosts.mongo", {})
missing = []
log.info('Initializing database connections')
log.info("Initializing database connections")
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
override_port = getenv(OVERRIDE_PORT_ENV_KEY)
if override_port:
log.info(f"Using override mongodb port {override_port}")
for key, alias in get_items(Database).items():
if key not in db_entries:
missing.append(key)
continue
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
try:
entry.validate()
log.info('Registering connection to %(alias)s (%(host)s)' % entry.to_struct())
log.info(
"Registering connection to %(alias)s (%(host)s)" % entry.to_struct()
)
register_connection(alias=alias, host=entry.host)
_entries.append(entry)
except ValidationError as ex:
raise Exception('Invalid database entry `%s`: %s' % (key, ex.args[0]))
raise Exception("Invalid database entry `%s`: %s" % (key, ex.args[0]))
if missing:
raise ValueError('Missing database configuration for %s' % ', '.join(missing))
raise ValueError("Missing database configuration for %s" % ", ".join(missing))
def get_entries():

View File

@@ -1,5 +1,6 @@
import re
from operator import itemgetter
from sys import maxsize
from typing import Type, Tuple
import six
from mongoengine import (
@@ -11,6 +12,7 @@ from mongoengine import (
SortedListField,
MapField,
DictField,
DynamicField,
)
@@ -88,104 +90,6 @@ class CustomFloatField(FloatField):
self.error("Float value must be greater than %s" % str(self.greater_than))
# TODO: bucket name should be at most 63 characters....
aws_s3_bucket_only_regex = (
r"^s3://"
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
)
aws_s3_url_with_bucket_regex = (
r"^s3://"
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
)
non_aws_s3_regex = (
r"^s3://"
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
r"(?::\d+)?" # optional port
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
)
google_gs_bucket_only_regex = (
r"^gs://"
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
)
file_regex = r"^file://"
generic_url_regex = (
r"^%s://" # scheme placeholder
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
r"(?::\d+)?" # optional port
)
path_suffix = r"(?:/?|[/?]\S+)$"
file_path_suffix = r"(?:/\S*[^/]+)$"
class _RegexURLField(StringField):
_regex = []
def __init__(self, regex, **kwargs):
super(_RegexURLField, self).__init__(**kwargs)
regex = regex if isinstance(regex, (tuple, list)) else [regex]
self._regex = [
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
for e in regex
]
def validate(self, value):
# Check first if the scheme is valid
if not any(regex for regex in self._regex if regex.match(value)):
self.error("Invalid URL: {}".format(value))
return
class OutputDestinationField(_RegexURLField):
""" A field representing task output URL """
schemes = ["s3", "gs", "file"]
_expressions = (
aws_s3_bucket_only_regex + path_suffix,
aws_s3_url_with_bucket_regex + path_suffix,
non_aws_s3_regex + path_suffix,
google_gs_bucket_only_regex + path_suffix,
file_regex + path_suffix,
)
def __init__(self, **kwargs):
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
class SupportedURLField(_RegexURLField):
""" A field representing a model URL """
schemes = ["s3", "gs", "file", "http", "https"]
_expressions = tuple(
pattern + file_path_suffix
for pattern in (
aws_s3_bucket_only_regex,
aws_s3_url_with_bucket_regex,
non_aws_s3_regex,
google_gs_bucket_only_regex,
file_regex,
(generic_url_regex % "http"),
(generic_url_regex % "https"),
)
)
def __init__(self, **kwargs):
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
class StrippedStringField(StringField):
def __init__(
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
@@ -235,3 +139,42 @@ class SafeDictField(DictField):
if contains_empty_key(value):
self.error("Empty keys are not allowed in a DictField")
class SafeSortedListField(SortedListField):
"""
SortedListField that does not raise an error in case items are not comparable
(in which case they will be sorted by their string representation)
"""
def to_mongo(self, *args, **kwargs):
try:
return super(SafeSortedListField, self).to_mongo(*args, **kwargs)
except TypeError:
return self._safe_to_mongo(*args, **kwargs)
def _safe_to_mongo(self, value, use_db_field=True, fields=None):
value = super(SortedListField, self).to_mongo(value, use_db_field, fields)
if self._ordering is not None:
def key(v): return str(itemgetter(self._ordering)(v))
else:
key = str
return sorted(value, key=key, reverse=self._order_reverse)
class UnionField(DynamicField):
def __init__(self, types, *args, **kwargs):
super(UnionField, self).__init__(*args, **kwargs)
self.types: Tuple[Type] = tuple(types)
def validate(self, value, clean=True):
if not isinstance(value, self.types):
type_names = [t.__name__ for t in self.types]
expected = " or ".join(
filter(
None,
(", ".join(type_names[:-1]), type_names[-1]))
)
self.error(
f"Expected {expected}, got {type(value).__name__}: {value}"
)
super(UnionField, self).validate(value, clean)

View File

@@ -1,3 +1,5 @@
from enum import Enum
from mongoengine import Document, StringField
from apierrors import errors
@@ -54,3 +56,7 @@ def validate_id(cls, company, **kwargs):
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
)
class EntityVisibility(Enum):
active = "active"
archived = "archived"

View File

@@ -45,6 +45,7 @@ class Role(object):
class Credentials(EmbeddedDocument):
key = StringField(required=True)
secret = StringField(required=True)
last_used = DateTimeField()
class User(DbModelMixin, AuthDocument):

View File

@@ -1,8 +1,9 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection
from typing import Collection, Sequence
from boltons.iterutils import first
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document
from six import string_types
@@ -13,7 +14,12 @@ from database.errors import MakeGetAllQueryError
from database.projection import project_dict, ProjectionHelper
from database.props import PropsMixin
from database.query import RegexQ, RegexWrapper
from database.utils import get_company_or_none_constraint, get_fields_with_attr
from database.utils import (
get_company_or_none_constraint,
get_fields_with_attr,
field_exists,
field_does_not_exist,
)
log = config.logger("dbmodel")
@@ -68,7 +74,7 @@ class GetMixin(PropsMixin):
def __init__(
self,
pattern_fields=("name",),
list_fields=("tags", "id"),
list_fields=("tags", "system_tags", "id"),
datetime_fields=None,
fields=None,
):
@@ -261,6 +267,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
override_none_ordering=False,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
@@ -296,6 +303,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
override_none_ordering=override_none_ordering,
)
def projection_func(doc_type, projection, ids):
@@ -320,6 +328,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
override_none_ordering=False,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
@@ -343,6 +352,8 @@ class GetMixin(PropsMixin):
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:param override_none_ordering: If True, then items with the None values in the first ordered field
are always sorted in the end
:return: A list of objects matching the query.
"""
if query_dict is not None:
@@ -356,6 +367,15 @@ class GetMixin(PropsMixin):
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
if override_none_ordering:
return cls._get_many_override_none_ordering(
query=_query,
parameters=parameters,
query_dict=query_dict,
query_options=query_options,
override_projection=override_projection,
)
return cls._get_many_no_company(
query=_query,
parameters=parameters,
@@ -428,6 +448,105 @@ class GetMixin(PropsMixin):
return [obj.to_proper_dict(only=only) for obj in qs]
return qs
@classmethod
def _get_many_override_none_ordering(
cls,
query: Q = None,
parameters: dict = None,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
override_projection: Collection[str] = None,
) -> Sequence[dict]:
"""
Fetch all documents matching a provided query. For the first order by field
the None values are sorted in the end regardless of the sorting order.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
:param query: Query object (mongoengine.Q)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
a query. The resulting query is AND'ed with the `query` parameter (if provided).
:param query_options: query parameters options (see ParametersOptions)
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
"""
parameters = parameters or {}
search_text = parameters.get("search_text")
page, page_size = cls.validate_paging(parameters=parameters)
query_sets = []
order_by = parameters.get(cls._ordering_key)
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
if not search_text and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
order_field = first(
field for field in order_by if not field.startswith("$")
)
if (
order_field
and not order_field.startswith("-")
and (not query_dict or order_field not in query_dict)
):
empty_value = None
if order_field in query_options.list_fields:
empty_value = []
elif order_field in query_options.pattern_fields:
empty_value = ""
mongo_field = order_field.replace(".", "__")
non_empty = query & field_exists(mongo_field, empty_value=empty_value)
empty = query & field_does_not_exist(
mongo_field, empty_value=empty_value
)
query_sets = [cls.objects(non_empty), cls.objects(empty)]
if not query_sets:
query_sets = [cls.objects(query)]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
if order_by:
# add ordering
query_sets = [qs.order_by(*order_by) for qs in query_sets]
only = cls.get_projection(parameters, override_projection)
if only:
# add projection
query_sets = [qs.only(*only) for qs in query_sets]
else:
exclude = set(cls.get_exclude_fields())
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
return [obj.to_proper_dict(only=only) for qs in query_sets for obj in qs]
# add paging
ret = []
start = page * page_size
for qs in query_sets:
qs_size = qs.count()
if qs_size < start:
start -= qs_size
continue
ret.extend(
obj.to_proper_dict(only=only) for obj in qs.skip(start).limit(page_size)
)
if len(ret) >= page_size:
break
start = 0
page_size -= len(ret)
return ret
@classmethod
def get_for_writing(
cls, *args, _only: Collection[str] = None, **kwargs

View File

@@ -1,7 +1,7 @@
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
from database import Database, strict
from database.fields import SupportedURLField, StrippedStringField, SafeDictField
from database.fields import StrippedStringField, SafeDictField
from database.model import DbModelMixin
from database.model.model_labels import ModelLabels
from database.model.company import Company
@@ -48,7 +48,8 @@ class Model(DbModelMixin, Document):
task = StringField(reference_field=Task)
comment = StringField(user_set_allowed=True)
tags = ListField(StringField(required=True), user_set_allowed=True)
uri = SupportedURLField(default='', user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
uri = StrippedStringField(default='', user_set_allowed=True)
framework = StringField()
design = SafeDictField()
labels = ModelLabels()

View File

@@ -1,7 +1,7 @@
from mongoengine import StringField, DateTimeField, ListField
from database import Database, strict
from database.fields import OutputDestinationField, StrippedStringField
from database.fields import StrippedStringField
from database.model import AttributedDocument
from database.model.base import GetMixin
@@ -9,7 +9,8 @@ from database.model.base import GetMixin
class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"), list_fields=("tags", "id")
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id"),
)
meta = {
@@ -34,6 +35,7 @@ class Project(AttributedDocument):
)
description = StringField(required=True)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list)
default_output_destination = OutputDestinationField()
tags = ListField(StringField(required=True))
system_tags = ListField(StringField(required=True))
default_output_destination = StrippedStringField()
last_update = DateTimeField()

View File

@@ -1,14 +1,14 @@
from mongoengine import EmbeddedDocument, StringField, DateTimeField, LongField, DynamicField
from mongoengine import EmbeddedDocument, StringField, DynamicField
class MetricEvent(EmbeddedDocument):
metric = StringField(required=True, )
variant = StringField(required=True)
type = StringField(required=True)
timestamp = DateTimeField(default=0, required=True)
iter = LongField()
value = DynamicField(required=True)
meta = {
# For backwards compatibility reasons
'strict': False,
}
@classmethod
def from_dict(cls, **kwargs):
return cls(**{k: v for k, v in kwargs.items() if k in cls._fields})
metric = StringField(required=True)
variant = StringField(required=True)
value = DynamicField(required=True)
min_value = DynamicField() # for backwards compatibility reasons
max_value = DynamicField() # for backwards compatibility reasons

View File

@@ -1,7 +1,7 @@
from mongoengine import EmbeddedDocument, StringField
from database.utils import get_options
from database.fields import OutputDestinationField
from database.fields import StrippedStringField
from database.utils import get_options
class Result(object):
@@ -10,7 +10,7 @@ class Result(object):
class Output(EmbeddedDocument):
destination = OutputDestinationField()
destination = StrippedStringField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@@ -1,5 +1,3 @@
from enum import Enum
from mongoengine import (
StringField,
EmbeddedDocumentField,
@@ -7,10 +5,18 @@ from mongoengine import (
DateTimeField,
IntField,
ListField,
LongField,
)
from database import Database, strict
from database.fields import StrippedStringField, SafeMapField, SafeDictField
from database.fields import (
StrippedStringField,
SafeMapField,
SafeDictField,
UnionField,
EmbeddedDocumentSortedListField,
SafeSortedListField,
)
from database.model import AttributedDocument
from database.model.model_labels import ModelLabels
from database.model.project import Project
@@ -22,27 +28,27 @@ DEFAULT_LAST_ITERATION = 0
class TaskStatus(object):
created = 'created'
in_progress = 'in_progress'
stopped = 'stopped'
publishing = 'publishing'
published = 'published'
closed = 'closed'
failed = 'failed'
completed = 'completed'
unknown = 'unknown'
created = "created"
in_progress = "in_progress"
stopped = "stopped"
publishing = "publishing"
published = "published"
closed = "closed"
failed = "failed"
completed = "completed"
unknown = "unknown"
class TaskStatusMessage(object):
stopping = 'stopping'
stopping = "stopping"
class TaskTags(object):
development = 'development'
class TaskSystemTags(object):
development = "development"
class Script(EmbeddedDocument):
binary = StringField(default='python')
binary = StringField(default="python")
repository = StringField(required=True)
tag = StringField()
branch = StringField()
@@ -53,51 +59,70 @@ class Script(EmbeddedDocument):
diff = StringField()
class ArtifactTypeData(EmbeddedDocument):
preview = StringField()
content_type = StringField()
data_hash = StringField()
class Artifact(EmbeddedDocument):
key = StringField(required=True)
type = StringField(required=True)
mode = StringField(choices=("input", "output"), default="output")
uri = StringField()
hash = StringField()
content_size = LongField()
timestamp = LongField()
type_data = EmbeddedDocumentField(ArtifactTypeData)
display_data = SafeSortedListField(ListField(UnionField((int, float, str))))
class Execution(EmbeddedDocument):
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field='Model')
model_desc = SafeMapField(StringField(default=''))
model = StringField(reference_field="Model")
model_desc = SafeMapField(StringField(default=""))
model_labels = ModelLabels()
framework = StringField()
artifacts = EmbeddedDocumentSortedListField(Artifact)
queue = StringField()
''' Queue ID where task was queued '''
""" Queue ID where task was queued """
class TaskType(object):
training = 'training'
testing = 'testing'
training = "training"
testing = "testing"
class Task(AttributedDocument):
meta = {
'db_alias': Database.backend,
'strict': strict,
'indexes': [
'created',
'started',
'completed',
"db_alias": Database.backend,
"strict": strict,
"indexes": [
"created",
"started",
"completed",
{
'name': '%s.task.main_text_index' % Database.backend,
'fields': [
'$name',
'$id',
'$comment',
'$execution.model',
'$output.model',
'$script.repository',
'$script.entry_point',
"name": "%s.task.main_text_index" % Database.backend,
"fields": [
"$name",
"$id",
"$comment",
"$execution.model",
"$output.model",
"$script.repository",
"$script.entry_point",
],
'default_language': 'english',
'weights': {
'name': 10,
'id': 10,
'comment': 10,
'execution.model': 2,
'output.model': 2,
'script.repository': 1,
'script.entry_point': 1,
"default_language": "english",
"weights": {
"name": 10,
"id": 10,
"comment": 10,
"execution.model": 2,
"output.model": 2,
"script.repository": 1,
"script.entry_point": 1,
},
},
],
@@ -123,12 +148,8 @@ class Task(AttributedDocument):
output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
class TaskVisibility(Enum):
active = 'active'
archived = 'archived'

View File

@@ -0,0 +1,18 @@
from mongoengine import Document, DateTimeField, StringField
from database import Database, strict
from database.model import DbModelMixin
class Version(DbModelMixin, Document):
meta = {
"collection": "versions", # custom collection name ('version' is not a proper collection name...)
"db_alias": Database.backend, # although we'll use this model for all databases, a default must be defined
"strict": strict,
"indexes": [("-created", "-num")],
}
id = StringField(primary_key=True)
num = StringField(required=True)
created = DateTimeField(required=True)
desc = StringField()

View File

@@ -1,13 +1,17 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
from typing import Sequence, Dict, Callable, Tuple, Any, Type
import dpath
import dpath.path
from apierrors import errors
from database.props import PropsMixin
SEP = "."
def project_dict(data, projection, separator='.'):
def project_dict(data, projection, separator=SEP):
"""
Project partial data from a dictionary into a new dictionary
:param data: Input dictionary
@@ -30,19 +34,27 @@ def project_dict(data, projection, separator='.'):
if path_part not in dst:
dst[path_part] = [{} for _ in range(len(src_part))]
elif not isinstance(dst[path_part], (list, tuple)):
raise TypeError('Incompatible destination type %s for %s (list expected)'
% (type(dst), separator.join(path_parts[:depth + 1])))
raise TypeError(
"Incompatible destination type %s for %s (list expected)"
% (type(dst), separator.join(path_parts[: depth + 1]))
)
elif not len(dst[path_part]) == len(src_part):
raise ValueError('Destination list length differs from source length for %s'
% separator.join(path_parts[:depth + 1]))
raise ValueError(
"Destination list length differs from source length for %s"
% separator.join(path_parts[: depth + 1])
)
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
for s, d in zip(src_part, dst[path_part])]
dst[path_part] = [
copy_path(path_parts[depth + 1:], s, d)
for s, d in zip(src_part, dst[path_part])
]
return destination
else:
raise TypeError('Unsupported projection type %s for %s'
% (type(src), separator.join(path_parts[:depth + 1])))
raise TypeError(
"Unsupported projection type %s for %s"
% (type(src), separator.join(path_parts[: depth + 1]))
)
last_part = path_parts[-1]
dst[last_part] = src[last_part]
@@ -53,12 +65,35 @@ def project_dict(data, projection, separator='.'):
for projection_path in sorted(projection):
copy_path(
path_parts=projection_path.split(separator),
source=data,
destination=result)
path_parts=projection_path.split(separator), source=data, destination=result
)
return result
class _ReferenceProxy(dict):
def __init__(self, id):
super(_ReferenceProxy, self).__init__(**({"id": id} if id else {}))
class _ProxyManager:
lock = threading.Lock()
def __init__(self):
self._proxies: Dict[str, _ReferenceProxy] = {}
def add(self, id):
with self.lock:
proxy = self._proxies.get(id)
if proxy is None:
proxy = self._proxies[id] = _ReferenceProxy(id)
return proxy
def update(self, result):
proxy = self._proxies.get(result.get("id"))
if proxy is not None:
proxy.update(result)
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
@@ -72,6 +107,11 @@ class ProjectionHelper(object):
self._doc_cls = doc_cls
self._doc_projection = None
self._ref_projection = None
self._proxy_manager = _ProxyManager()
# Cached dpath paths for each of the result documents
self._cached_results_paths: Dict[int, Sequence[Tuple[Any, Type]]] = {}
self._parse_projection(projection)
def _collect_projection_fields(self, doc_cls, projection):
@@ -81,8 +121,12 @@ class ProjectionHelper(object):
:param projection: List of projection fields
:return: A tuple of document projection and reference fields information
"""
doc_projection = set() # Projection fields for this class (used in the main query)
ref_projection_info = [] # Projection information for reference fields (used in join queries)
doc_projection = (
set()
) # Projection fields for this class (used in the main query)
ref_projection_info = (
[]
) # Projection information for reference fields (used in join queries)
for field in projection:
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
@@ -93,7 +137,7 @@ class ProjectionHelper(object):
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
if not subfield.startswith('.'):
if not subfield.startswith(SEP):
# Starts with something that looks like a reference field, but isn't
continue
@@ -103,10 +147,12 @@ class ProjectionHelper(object):
# Not a reference field, just add to the top-level projection
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
orig_field = field
if field.endswith('.*'):
if field.endswith(".*"):
field = field[:-2]
if not field:
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
raise errors.bad_request.InvalidFields(
field=orig_field, object=doc_cls.__name__
)
doc_projection.add(field)
return doc_projection, ref_projection_info
@@ -124,12 +170,14 @@ class ProjectionHelper(object):
if not projection:
return [], {}
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
doc_projection, ref_projection_info = self._collect_projection_fields(
doc_cls, projection
)
def normalize_cls_projection(cls_, fields):
""" Normalize projection for this class and group (expand *, for once) """
if '*' in fields:
return list(fields.difference('*').union(cls_.get_fields()))
if "*" in fields:
return list(fields.difference("*").union(cls_.get_fields()))
return list(fields)
def compute_ref_cls_projection(cls_, group):
@@ -143,12 +191,16 @@ class ProjectionHelper(object):
# Aggregate by reference field. We'll leave out '*' from the projected items since
ref_projection = {
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
for (ref_field, ref_cls), g in groupby(
sorted(ref_projection_info, key=sort_key), sort_key
)
}
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
doc_projection = normalize_cls_projection(
doc_cls, doc_projection.difference(ref_projection).union({"id"})
)
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
@@ -158,13 +210,20 @@ class ProjectionHelper(object):
doc_projection = [
field
for field in doc_projection
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
if not any(
field.startswith(f"{other_field}.")
for other_field in projection_set - {field}
)
]
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
invalid_fields = [
f for f in doc_projection if f.split(SEP)[0] not in doc_cls.get_fields()
]
if invalid_fields:
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
raise errors.bad_request.InvalidFields(
fields=invalid_fields, object=doc_cls.__name__
)
if ref_projection:
# Join mode - use both normal projection fields and top-level reference fields
@@ -178,11 +237,44 @@ class ProjectionHelper(object):
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@staticmethod
def _search(doc_cls, obj, path, only_values=True):
""" Call dpath.search with yielded=True, collect result values """
def _search(
self,
doc_cls: PropsMixin,
obj: dict,
path: str,
factory: Callable[[str], dict] = None,
) -> Sequence[str]:
"""
Search for a path in the given object, return the list of values found for the
given path (multiple values may exist if the path is a glob expression)
:param doc_cls: The document class represented by the object
:param obj: Data object
:param path: Path to a leaf in the data object ("." separated, may contain "*")
(in case the path contains "*", there may be multiple values)
:param factory: If provided, replace each value found with an instance provided by the factory.
"""
norm_path = doc_cls.get_dpath_translated_path(path)
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
globlist = norm_path.strip(SEP).split(SEP)
obj_paths = self._cached_results_paths.get(id(obj))
if obj_paths is None:
obj_paths = self._cached_results_paths[id(obj)] = list(
dpath.path.paths(obj, dirs=True, skip=True)
)
paths = [p for p in obj_paths if dpath.path.match(p, globlist)]
def search_and_replace(p: Sequence[Tuple[str, Type]]) -> Any:
parent = None
target = obj
for part in p:
parent = target
target = target[part[0]]
if parent and factory:
parent[p[-1][0]] = factory(target)
return target
return [search_and_replace(p) for p in paths]
def project(self, results, projection_func):
"""
@@ -197,28 +289,50 @@ class ProjectionHelper(object):
if ref_projection:
# Join mode - get results for each reference fields projection required (this is the join step)
# Note: this is a recursive step, so we support nested reference fields
# Note: this is a recursive step, so nested reference fields are supported
def do_projection(item):
ref_field_name, data = item
res = {}
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
for res in results))))
if ids:
doc_type = data['cls']
doc_only = list(filter(None, data['only']))
doc_only = list({'id'} | set(doc_only)) if doc_only else None
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
data['res'] = res
def collect_ids(ref_field_name):
"""
Collect unique IDs for the given reference path from all result documents.
All collected IDs are replaced in the result dictionaries with a reference proxy generated by the
proxies manager to allow rapid update later on when projection results are obtained.
"""
all_ids = (
self._search(
cls, res, ref_field_name, factory=self._proxy_manager.add
)
for res in results
)
return list(filter(None, set(chain.from_iterable(all_ids))))
items = list(ref_projection.items())
if len(ref_projection) == 1:
do_projection(items[0])
else:
for _ in self.pool.map(do_projection, items):
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
# will be raised when its value is retrieved from the map() iterator
pass
items = [
tup
for tup in (
(*item, collect_ids(item[0])) for item in ref_projection.items()
)
if tup[2]
]
if items:
def do_projection(item):
ref_field_name, data, ids = item
doc_type = data["cls"]
doc_only = list(filter(None, data["only"]))
doc_only = list({"id"} | set(doc_only)) if doc_only else None
for res in projection_func(
doc_type=doc_type, projection=doc_only, ids=ids
):
self._proxy_manager.update(res)
if len(ref_projection) == 1:
do_projection(items[0])
else:
for _ in self.pool.map(do_projection, items):
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
# will be raised when its value is retrieved from the map() iterator
pass
def do_expand_reference_ids(result, skip_fields=None):
ref_fields = cls.get_reference_fields()
@@ -226,44 +340,18 @@ class ProjectionHelper(object):
ref_fields = set(ref_fields) - set(skip_fields)
self._expand_reference_fields(cls, result, ref_fields)
def merge_projection_result(result):
for ref_field_name, data in ref_projection.items():
res = data.get('res')
if not res:
self._expand_reference_fields(cls, result, [ref_field_name])
continue
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
obj = res.get(value) or {'id': value}
dpath.new(result, path, obj, separator='.')
# any reference field not projected should be expanded
do_expand_reference_ids(result, skip_fields=list(ref_projection))
update_func = merge_projection_result if ref_projection else \
do_expand_reference_ids if self._should_expand_reference_ids else None
if update_func:
# any reference field not projected should be expanded
if self._should_expand_reference_ids:
for result in results:
update_func(result)
do_expand_reference_ids(
result, skip_fields=list(ref_projection) if ref_projection else None
)
return results
@classmethod
def _expand_reference_fields(cls, doc_cls, result, fields):
def _expand_reference_fields(self, doc_cls, result, fields):
for ref_field_name in fields:
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
dpath.set(
result,
path,
{'id': value} if value else {},
separator='.')
self._search(doc_cls, result, ref_field_name, factory=_ReferenceProxy)
@classmethod
def expand_reference_ids(cls, doc_cls, result):
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
def expand_reference_ids(self, doc_cls, result):
self._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())

View File

@@ -1,5 +1,6 @@
import hashlib
from inspect import ismethod, getmembers
from typing import Sequence, Tuple, Set, Optional
from uuid import uuid4
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
@@ -12,9 +13,13 @@ def get_fields(cls, of_type=BaseField, return_instance=False):
""" get field names from a class containing mongoengine fields """
res = []
for cls_ in reversed(cls.mro()):
res.extend([k if not return_instance else (k, v)
for k, v in vars(cls_).items()
if isinstance(v, of_type)])
res.extend(
[
k if not return_instance else (k, v)
for k, v in vars(cls_).items()
if isinstance(v, of_type)
]
)
return res
@@ -22,9 +27,13 @@ def get_fields_and_attr(cls, attr):
""" get field names from a class containing mongoengine fields """
res = {}
for cls_ in reversed(cls.mro()):
res.update({k: getattr(v, attr)
for k, v in vars(cls_).items()
if isinstance(v, BaseField) and hasattr(v, attr)})
res.update(
{
k: getattr(v, attr)
for k, v in vars(cls_).items()
if isinstance(v, BaseField) and hasattr(v, attr)
}
)
return res
@@ -33,7 +42,7 @@ def _get_field_choices(name, field):
if issubclass(field_t, EmbeddedDocumentField):
obj = field.document_type_obj
n, choices = _get_field_choices(field.name, obj.field)
return '%s__%s' % (name, n), choices
return "%s__%s" % (name, n), choices
elif issubclass(type(field), ListField):
return name, field.field.choices
return name, field.choices
@@ -46,8 +55,14 @@ def get_fields_with_attr(cls, attr, default=False):
continue
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
fields.extend((('%s__%s' % (field_name, name), choices)
for name, choices in get_fields_with_attr(field.document_type, attr, default)))
fields.extend(
(
("%s__%s" % (field_name, name), choices)
for name, choices in get_fields_with_attr(
field.document_type, attr, default
)
)
)
elif issubclass(type(field), ListField):
fields.append((field_name, field.field.choices))
else:
@@ -58,11 +73,7 @@ def get_fields_with_attr(cls, attr, default=False):
def get_items(cls):
""" get key/value items from an enum-like class (members represent enumeration key/value) """
res = {
k: v
for k, v in getmembers(cls)
if not (k.startswith("_") or ismethod(v))
}
res = {k: v for k, v in getmembers(cls) if not (k.startswith("_") or ismethod(v))}
return res
@@ -81,7 +92,7 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
fields = {k: None for k in fields}
fields = {k: v for k, v in fields.items() if k in cls_fields}
res = {}
with translate_errors_context('parsing call data'):
with translate_errors_context("parsing call data"):
for field, desc in fields.items():
value = call_data.get(field)
if value is None:
@@ -93,20 +104,34 @@ def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
if callable(desc):
desc(value)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(value, desc):
raise ParseCallError('expecting %s' % desc.__name__, field=field)
if issubclass(desc, Document) and not desc.objects(id=value).only('id'):
raise ParseCallError('expecting %s id' % desc.__name__, id=value, field=field)
if issubclass(desc, (list, tuple, dict)) and not isinstance(
value, desc
):
raise ParseCallError(
"expecting %s" % desc.__name__, field=field
)
if issubclass(desc, Document) and not desc.objects(id=value).only(
"id"
):
raise ParseCallError(
"expecting %s id" % desc.__name__, id=value, field=field
)
res[field] = value
return res
def init_cls_from_base(cls, instance):
return cls(**{k: v for k, v in instance.to_mongo(use_db_field=False).to_dict().items() if k[0] != '_'})
return cls(
**{
k: v
for k, v in instance.to_mongo(use_db_field=False).to_dict().items()
if k[0] != "_"
}
)
def get_company_or_none_constraint(company=None):
return Q(company__in=(company, None, '')) | Q(company__exists=False)
return Q(company__in=(company, None, "")) | Q(company__exists=False)
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
@@ -118,23 +143,40 @@ def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
the length of the array will be used (len==0 means empty)
:return:
"""
query = (Q(**{f"{field}__exists": False}) |
Q(**{f"{field}__in": {empty_value, None}}))
query = Q(**{f"{field}__exists": False}) | Q(
**{f"{field}__in": {empty_value, None}}
)
if is_list:
query |= Q(**{f"{field}__size": 0})
return query
def field_exists(field: str, empty_value=None) -> Q:
"""
Creates a query object used for finding a field that exists and is not None or empty.
:param field: Field name
:param empty_value: The empty value to test for (None means no specific empty value will be used).
For lists pass [] for empty_value
:return:
"""
query = Q(**{f"{field}__exists": True}) & Q(
**{f"{field}__nin": {empty_value, None}}
)
return query
def get_subkey(d, key_path, default=None):
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
the nested dictionary.
"""
keys = key_path.split('.')
keys = key_path.split(".")
for i, key in enumerate(keys):
if not isinstance(d, dict):
raise KeyError('Expecting a dict (%s)' % ('.'.join(keys[:i]) if i else 'bad input'))
raise KeyError(
"Expecting a dict (%s)" % (".".join(keys[:i]) if i else "bad input")
)
d = d.get(key)
if key is None:
if d is None:
return default
return d
@@ -158,3 +200,41 @@ def merge_dicts(*dicts):
def filter_fields(cls, fields):
"""From the fields dictionary return only the fields that match cls fields"""
return {key: fields[key] for key in fields if key in get_fields(cls)}
def _names_set(*names: str) -> Set[str]:
"""
Given a list of names return set with names and '-names'
"""
return set(names) | set(f"-{name}" for name in names)
system_tag_names = {
"model": _names_set("active", "archived"),
"project": _names_set("archived", "public", "default"),
"task": _names_set("active", "archived", "development"),
}
system_tag_prefixes = {"task": _names_set("annotat")}
def partition_tags(
entity: str, tags: Sequence[str], system_tags: Optional[Sequence[str]] = ()
) -> Tuple[Sequence[str], Sequence[str]]:
"""
Partition the given tags sequence into system and user-defined tags
:param entity: The name of the entity that defines the list of the system tags
:param tags: The tags to partition
:param system_tags: Optional. If passed then these tags are considered system together
with those defined for the entity.
:return: a tuple where the first element is the sequence of user-defined tags and
the second element is the sequence of system tags
"""
tags = set(tags)
system_tags = set(system_tags)
system_tags |= tags & system_tag_names[entity]
prefixes = system_tag_prefixes.get(entity, [])
system_tags |= {t for t in tags for p in prefixes if t.lower().startswith(p)}
return list(tags - system_tags), list(system_tags)

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from os import getenv
from elasticsearch import Elasticsearch, Transport
@@ -6,6 +7,17 @@ from config import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = ("ELASTIC_SERVICE_HOST", "ELASTIC_SERVICE_SERVICE_HOST")
OVERRIDE_PORT_ENV_KEY = "ELASTIC_SERVICE_PORT"
OVERRIDE_HOST = next(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)), None)
if OVERRIDE_HOST:
log.info(f"Using override elastic host {OVERRIDE_HOST}")
OVERRIDE_PORT = getenv(OVERRIDE_PORT_ENV_KEY)
if OVERRIDE_PORT:
log.info(f"Using override elastic port {OVERRIDE_PORT}")
_instances = {}
@@ -33,17 +45,18 @@ def connect(cluster_name):
:raises InvalidClusterConfiguration: in case cluster config section misses needed properties
"""
if cluster_name not in _instances:
cluster_config = _get_cluster_config(cluster_name)
cluster_config = get_cluster_config(cluster_name)
hosts = cluster_config.get('hosts', None)
if not hosts:
raise InvalidClusterConfiguration(cluster_name)
args = cluster_config.get('args', {})
_instances[cluster_name] = Elasticsearch(hosts=hosts, transport_class=Transport, **args)
return _instances[cluster_name]
def _get_cluster_config(cluster_name):
def get_cluster_config(cluster_name):
"""
Returns cluster config for the specified cluster path
:param cluster_name: Dot separated cluster path in the configuration file
@@ -55,6 +68,16 @@ def _get_cluster_config(cluster_name):
if not cluster_config:
raise MissingClusterConfiguration(cluster_name)
def set_host_prop(key, value):
for host in cluster_config.get('hosts', []):
host[key] = value
if OVERRIDE_HOST:
set_host_prop("host", OVERRIDE_HOST)
if OVERRIDE_PORT:
set_host_prop("port", OVERRIDE_PORT)
return cluster_config

View File

@@ -1,17 +1,28 @@
import importlib.util
from datetime import datetime
from pathlib import Path
import attr
from furl import furl
from mongoengine.connection import get_db
from semantic_version import Version
from database.model.user import User
from database.model.auth import User as AuthUser, Credentials
import database.utils
from config import config
from database import Database
from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials
from database.model.company import Company
from database.model.user import User
from database.model.version import Version as DatabaseVersion
from elastic.apply_mappings import apply_mappings_to_host
from es_factory import get_cluster_config
from service_repo.auth.fixed_user import FixedUser
log = config.logger(__file__)
migration_dir = (Path(__file__) / "../../migration/mongodb").resolve()
class MissingElasticConfiguration(Exception):
"""
@@ -22,10 +33,9 @@ class MissingElasticConfiguration(Exception):
def init_es_data():
hosts_key = "hosts.elastic.events.hosts"
hosts_config = config.get(hosts_key, None)
hosts_config = get_cluster_config("events").get("hosts")
if not hosts_config:
raise MissingElasticConfiguration(hosts_key)
raise MissingElasticConfiguration("for cluster 'events'")
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
@@ -101,8 +111,64 @@ def _ensure_user(user: FixedUser, company_id: str):
).save()
def _apply_migrations():
if not migration_dir.is_dir():
raise ValueError(f"Invalid migration dir {migration_dir}")
try:
previous_versions = sorted(
(Version(ver.num) for ver in DatabaseVersion.objects().only("num")),
reverse=True,
)
except ValueError as ex:
raise ValueError(f"Invalid database version number encountered: {ex}")
last_version = previous_versions[0] if previous_versions else Version("0.0.0")
try:
new_scripts = {
ver: path
for ver, path in (
(Version(f.stem), f) for f in migration_dir.glob("*.py")
)
if ver > last_version
}
except ValueError as ex:
raise ValueError(f"Failed parsing migration version from file: {ex}")
dbs = {Database.auth: "migrate_auth", Database.backend: "migrate_backend"}
migration_log = log.getChild("mongodb_migration")
for script_version in sorted(new_scripts.keys()):
script = new_scripts[script_version]
spec = importlib.util.spec_from_file_location(script.stem, str(script))
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
for alias, func_name in dbs.items():
func = getattr(module, func_name, None)
if not func:
continue
try:
migration_log.info(f"Applying {script.stem}/{func_name}()")
func(get_db(alias))
except Exception:
migration_log.exception(f"Failed applying {script}:{func_name}()")
raise ValueError("Migration failed, aborting. Please restore backup.")
DatabaseVersion(
id=database.utils.id(),
num=script.stem,
created=datetime.utcnow(),
desc="Applied on server startup",
).save()
def init_mongo_data():
try:
_apply_migrations()
company_id = _ensure_company()
users = [
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"},
@@ -117,7 +183,11 @@ def init_mongo_data():
_ensure_auth_user(user, company_id)
if FixedUser.enabled():
log.info("Fixed users mode is enabled")
for user in FixedUser.from_config():
_ensure_user(user, company_id)
try:
_ensure_user(user, company_id)
except Exception as ex:
log.error(f"Failed creating fixed user {user['name']}: {ex}")
except Exception as ex:
pass
log.exception("Failed initializing mongodb")

View File

@@ -15,6 +15,11 @@ _definitions {
type: string
description: ""
}
last_used {
type: string
description: ""
format: "date-time"
}
}
}
role {
@@ -52,6 +57,22 @@ login {
}
}
logout {
internal: false
allow_roles = [ "*" ]
"2.2" {
description: """Removes the authentication cookie from the current session"""
request {
type: object
additionalProperties: false
}
response {
type: object
additionalProperties: false
}
}
}
get_token_for_user {
"2.1" {
description: """Get a token for the specified user. Intended for internal use."""

View File

@@ -149,6 +149,14 @@
}
}
}
scalar_key_enum {
type: string
enum: [
iter
timestamp
iso_time
]
}
log_level_enum {
type: string
enum: [
@@ -682,6 +690,19 @@
type: string
description: "Task ID"
}
samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
type: integer
}
key {
description: """
Histogram x axis to use:
iter - iteration number
iso_time - event time as ISO formatted string
timestamp - event timestamp as milliseconds since epoch
"""
"$ref": "#/definitions/scalar_key_enum"
}
}
}
response {
@@ -715,7 +736,19 @@
description: "List of task Task IDs"
}
}
samples {
description: "The amount of histogram points to return (0 to return all the points). Optional, the default value is 10000."
type: integer
}
key {
description: """
Histogram x axis to use:
iter - iteration number
iso_time - event time as ISO formatted string
timestamp - event timestamp as milliseconds since epoch
"""
"$ref": "#/definitions/scalar_key_enum"
}
}
}
response {

View File

@@ -57,9 +57,14 @@
}
tags {
type: array
description: "Tags"
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
framework {
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
type: string
@@ -159,7 +164,12 @@
type: boolean
}
tags {
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
type: array
items { type: string }
}
system_tags {
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
type: array
items { type: string }
}
@@ -263,10 +273,15 @@
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
override_model_id {
description: "Override model ID. If provided, this model is updated in the task."
type: string
@@ -325,10 +340,15 @@
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
type: string
@@ -344,7 +364,7 @@
additionalProperties { type: integer }
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
description: "Indication if the model is final and can be used by other tasks. Default is false."
type: boolean
default: false
}
@@ -408,10 +428,15 @@
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
type: string
@@ -485,10 +510,15 @@
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
type: boolean

View File

@@ -1,462 +1,523 @@
{
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
multi_field_pattern_data {
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
project {
type: object
properties {
id {
description: "Project id"
type: string
}
name {
description: "Project name"
type: string
}
description {
description: "Project description"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Creation time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
last_update {
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
type: string
format: "date-time"
}
}
}
stats_status_count {
type: object
properties {
total_runtime {
description: "Total run time of all tasks in project (in seconds)"
type: integer
}
status_count {
description: "Status counts"
type: object
properties {
created {
description: "Number of 'created' tasks in project"
type: integer
}
queued {
description: "Number of 'queued' tasks in project"
type: integer
}
in_progress {
description: "Number of 'in_progress' tasks in project"
type: integer
}
stopped {
description: "Number of 'stopped' tasks in project"
type: integer
}
published {
description: "Number of 'published' tasks in project"
type: integer
}
closed {
description: "Number of 'closed' tasks in project"
type: integer
}
failed {
description: "Number of 'failed' tasks in project"
type: integer
}
unknown {
description: "Number of 'unknown' tasks in project"
type: integer
}
}
}
}
}
stats {
type: object
properties {
active {
description: "Stats for active tasks"
"$ref": "#/definitions/stats_status_count"
}
archived {
description: "Stats for archived tasks"
"$ref": "#/definitions/stats_status_count"
}
}
}
projects_get_all_response_single {
// copy-paste from project definition
type: object
properties {
id {
description: "Project id"
type: string
}
name {
description: "Project name"
type: string
}
description {
description: "Project description"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Creation time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
// extra properties
stats: {
description: "Additional project stats"
"$ref": "#/definitions/stats"
}
}
}
metric_variant_result {
type: object
properties {
metric {
description: "Metric name"
type: string
}
metric_hash {
description: """Metric name hash. Used instead of the metric name when categorizing
last metrics events in task objects."""
type: string
}
variant {
description: "Variant name"
type: string
}
variant_hash {
description: """Variant name hash. Used instead of the variant name when categorizing
last metrics events in task objects."""
type: string
}
}
}
}
create {
"2.1" {
description: "Create a new project"
request {
type: object
required :[
name
description
]
properties {
pattern {
description: "Pattern string (regex)"
name {
description: "Project name Unique within the company."
type: string
}
fields {
description: "List of field names"
description {
description: "Project description. "
type: string
}
tags {
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
}
}
project_tags_enum {
type: string
enum: [ archived, public, default ]
}
project {
response {
type: object
properties {
id {
description: "Project id"
type: string
}
name {
description: "Project name"
type: string
}
description {
description: "Project description"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Creation time"
type: string
format: "date-time"
}
tags {
description: "Tags"
type: array
items { "$ref": "#/definitions/project_tags_enum" }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
last_update {
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
type: string
format: "date-time"
}
}
}
stats_status_count {
type: object
properties {
total_runtime {
description: "Total run time of all tasks in project (in seconds)"
type: integer
}
status_count {
description: "Status counts"
type: object
properties {
created {
description: "Number of 'created' tasks in project"
type: integer
}
queued {
description: "Number of 'queued' tasks in project"
type: integer
}
in_progress {
description: "Number of 'in_progress' tasks in project"
type: integer
}
stopped {
description: "Number of 'stopped' tasks in project"
type: integer
}
published {
description: "Number of 'published' tasks in project"
type: integer
}
closed {
description: "Number of 'closed' tasks in project"
type: integer
}
failed {
description: "Number of 'failed' tasks in project"
type: integer
}
unknown {
description: "Number of 'unknown' tasks in project"
type: integer
}
}
}
}
}
stats {
type: object
properties {
active {
description: "Stats for active tasks"
"$ref": "#/definitions/stats_status_count"
}
archived {
description: "Stats for archived tasks"
"$ref": "#/definitions/stats_status_count"
}
}
}
projects_get_all_response_single {
// copy-paste from project definition
type: object
properties {
id {
description: "Project id"
type: string
}
name {
description: "Project name"
type: string
}
description {
description: "Project description"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Creation time"
type: string
format: "date-time"
}
tags {
description: "Tags"
type: array
items { "$ref": "#/definitions/project_tags_enum" }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
// extra properties
stats: {
description: "Additional project stats"
"$ref": "#/definitions/stats"
}
}
}
metric_variant_result {
type: object
properties {
metric {
description: "Metric name"
type: string
}
metric_hash {
description: """Metric name hash. Used instead of the metric name when categorizing
last metrics events in task objects."""
type: string
}
variant {
description: "Variant name"
type: string
}
variant_hash {
description: """Variant name hash. Used instead of the variant name when categorizing
last metrics events in task objects."""
type: string
}
}
}
}
create {
"2.1" {
description: "Create a new project"
request {
type: object
required :[
name
description
]
properties {
name {
description: "Project name Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
tags {
description: "Tags"
type: array
items { "$ref": "#/definitions/project_tags_enum" }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
}
}
response {
type: object
properties {
id {
description: "Project id"
type: string
}
}
}
}
}
get_by_id {
"2.1" {
description: ""
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
}
}
response {
type: object
properties {
project {
description: "Project info"
"$ref": "#/definitions/project"
}
}
}
}
}
get_all {
"2.1" {
description: "Get all the company's projects and all public projects"
request {
type: object
properties {
id {
description: "List of IDs to filter by"
type: array
items { type: string }
}
name {
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
type: string
}
description {
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
type: string
}
tags {
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
type: array
items { type: string }
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of dataviews"
type: integer
minimum: 0
}
page_size {
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
type: integer
minimum: 1
}
search_text {
description: "Free text search query"
type: string
}
only_fields {
description: "List of document's field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
_any_ {
description: "Multi-field pattern condition (any field matches pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
}
}
response {
type: object
properties {
projects {
description: "Projects list"
type: array
items { "$ref": "#/definitions/projects_get_all_response_single" }
}
}
}
}
}
get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
request {
properties {
include_stats {
description: "If true, include project statistic in response."
type: boolean
default: false
}
stats_for_state {
description: "Report stats include only statistics for tasks in the specified state. If Null is provided, stats for all task states will be returned."
type: string
enum: [ active, archived ]
default: active
}
}
}
}
}
update {
"2.1" {
description: "Update project information"
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
name {
description: "Project name. Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
description {
description: "Project description"
type: string
}
tags {
description: "Tags list"
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
delete {
"2.1" {
description: "Deletes a project"
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
force {
description: """If not true, fails if project has tasks.
If true, and project has tasks, they will be unassigned"""
type: boolean
default: false
}
}
}
response {
type: object
properties {
deleted {
description: "Number of projects deleted (0 or 1)"
type: integer
}
disassociated_tasks {
description: "Number of tasks disassociated from the deleted project"
type: integer
}
}
}
}
}
get_unique_metric_variants {
"2.1" {
description: """Get all metric/variant pairs reported for tasks in a specific project.
If no project is specified, metrics/variant paris reported for all tasks will be returned.
If the project does not exist, an empty list will be returned."""
request {
type: object
properties {
project {
description: "Project ID"
type: string
}
}
}
response {
type: object
properties {
metrics {
description: "A list of metric variants reported for tasks in this project"
type: array
items { "$ref": "#/definitions/metric_variant_result" }
}
}
}
}
}
}
get_by_id {
"2.1" {
description: ""
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
}
}
response {
type: object
properties {
project {
description: "Project info"
"$ref": "#/definitions/project"
}
}
}
}
}
get_all {
"2.1" {
description: "Get all the company's projects and all public projects"
request {
type: object
properties {
id {
description: "List of IDs to filter by"
type: array
items { type: string }
}
name {
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
type: string
}
description {
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
type: string
}
tags {
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
type: array
items { type: string }
}
system_tags {
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
type: array
items { type: string }
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of dataviews"
type: integer
minimum: 0
}
page_size {
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
type: integer
minimum: 1
}
search_text {
description: "Free text search query"
type: string
}
only_fields {
description: "List of document's field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
_any_ {
description: "Multi-field pattern condition (any field matches pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
}
}
response {
type: object
properties {
projects {
description: "Projects list"
type: array
items { "$ref": "#/definitions/projects_get_all_response_single" }
}
}
}
}
}
get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
request {
properties {
include_stats {
description: "If true, include project statistic in response."
type: boolean
default: false
}
stats_for_state {
description: "Report stats include only statistics for tasks in the specified state. If Null is provided, stats for all task states will be returned."
type: string
enum: [ active, archived ]
default: active
}
}
}
}
}
update {
"2.1" {
description: "Update project information"
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
name {
description: "Project name. Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
description {
description: "Project description"
type: string
}
tags {
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
delete {
"2.1" {
description: "Deletes a project"
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
force {
description: """If not true, fails if project has tasks.
If true, and project has tasks, they will be unassigned"""
type: boolean
default: false
}
}
}
response {
type: object
properties {
deleted {
description: "Number of projects deleted (0 or 1)"
type: integer
}
disassociated_tasks {
description: "Number of tasks disassociated from the deleted project"
type: integer
}
}
}
}
}
get_unique_metric_variants {
"2.1" {
description: """Get all metric/variant pairs reported for tasks in a specific project.
If no project is specified, metrics/variant paris reported for all tasks will be returned.
If the project does not exist, an empty list will be returned."""
request {
type: object
properties {
project {
description: "Project ID"
type: string
}
}
}
response {
type: object
properties {
metrics {
description: "A list of metric variants reported for tasks in this project"
type: array
items { "$ref": "#/definitions/metric_variant_result" }
}
}
}
}
}
get_hyper_parameters {
"2.2" {
description: """Get a list of all hyper parameter names used in tasks within the given project."""
request {
type: object
properties {
project {
description: "Project ID"
type: string
}
page {
description: "Page number"
default: 0
type: integer
}
page_size {
description: "Page size"
default: 500
type: integer
}
}
}
response {
type: object
properties {
parameters {
description: "A list of hyper parameter names"
type: array
items { type: string }
}
remaining {
description: "Remaining results"
type: integer
}
total {
description: "Total number of results"
type: integer
}
}
}
}

View File

@@ -0,0 +1,68 @@
_description: "server utilities"
_default {
internal: true
allow_roles: ["root", "system"]
}
config {
"2.1" {
description: "Get server configuration. Secure section is not returned."
request {
type: object
properties {
path {
description: "Path of config value. Defaults to root"
type: string
}
}
}
response {
type: object
properties {
}
}
}
}
info {
authorize = false
allow_roles = [ "*" ]
"2.1" {
description: "Get server information, including version and build number"
request {
type: object
properties {
}
}
response {
type: object
properties {
version {
description: "Version string"
type: string
}
build {
description: "Build number"
type: string
}
commit {
description: "VCS commit number"
type: string
}
}
}
}
}
endpoints {
"2.1" {
description: "Show available endpoints"
request {
type: object
properties {
}
}
response {
type: object
properties {
}
}
}
}

View File

@@ -120,6 +120,76 @@ _definitions {
frame_per_roi
]
}
artifact_type_data {
type: object
properties {
preview {
description: "Description or textual data"
type: string
}
content_type {
description: "System defined raw data content type"
type: string
}
data_hash {
description: "Hash of raw data, without any headers or descriptive parts"
type: string
}
}
}
artifact {
type: object
required: [key, type]
properties {
key {
description: "Entry key"
type: string
}
type {
description: "System defined type"
type: string
}
mode {
description: "System defined input/output indication"
type: string
enum: [
input
output
]
default: output
}
uri {
description: "Raw data location"
type: string
}
content_size {
description: "Raw data length in bytes"
type: integer
}
hash {
description: "Hash of entire raw data"
type: string
}
timestamp {
description: "Epoch time when artifact was created"
type: integer
}
type_data {
description: "Additional fields defined by the system"
"$ref": "#/definitions/artifact_type_data"
}
display_data {
description: "User-defined list of key/value pairs, sorted"
type: array
items {
type: array
items {
type: string # can also be a number... TODO: upgrade the generator
}
}
}
}
}
execution {
type: object
properties {
@@ -149,6 +219,11 @@ _definitions {
description: """Framework related to the task. Case insensitive. Mandatory for Training tasks. """
type: string
}
artifacts {
description: "Task artifacts"
type: array
items { "$ref": "#/definitions/artifact" }
}
}
}
task_status_enum {
@@ -183,21 +258,16 @@ _definitions {
description: "Variant name"
type: string
}
type {
description: "Event type"
type: string
}
timestamp {
description: "Event report time (UTC)"
type: string
format: "date-time"
}
iter {
description: "Iteration number"
type: integer
}
value {
description: "Value"
description: "Last value reported"
type: number
}
min_value {
description: "Minimum value reported"
type: number
}
max_value {
description: "Maximum value reported"
type: number
}
}
@@ -278,10 +348,15 @@ _definitions {
"$ref": "#/definitions/script"
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
status_changed {
description: "Last status change time"
type: string
@@ -392,7 +467,12 @@ get_all {
items { type: string }
}
tags {
description: "List of task tags. Use '-' prefix to exclude tags"
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
type: array
items { type: string }
}
system_tags {
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
type: array
items { type: string }
}
@@ -467,10 +547,15 @@ create {
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
type {
description: "Type of task"
"$ref": "#/definitions/task_type_enum"
@@ -527,10 +612,15 @@ validate {
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
type {
description: "Type of task"
"$ref": "#/definitions/task_type_enum"
@@ -585,10 +675,15 @@ update {
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
comment {
description: "Free text comment "
type: string
@@ -667,10 +762,15 @@ edit {
type: string
}
tags {
description: "Tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
}
type {
description: "Type of task"
"$ref": "#/definitions/task_type_enum"
@@ -1062,7 +1162,7 @@ completed {
task
]
properties.force = ${_references.force_arg} {
description: "If not true, call fails if the task status is not created/in_progress/published"
description: "If not true, call fails if the task status is not in_progress/stopped"
}
} ${_references.status_change_request}
response {

View File

@@ -16,7 +16,7 @@ from utilities import json
from init_data import init_es_data, init_mongo_data
app = Flask(__name__, static_url_path="/static")
CORS(app, supports_credentials=True, **config.get("apiserver.cors"))
CORS(app, **config.get("apiserver.cors"))
Compress(app)
log = config.logger(__file__)
@@ -63,7 +63,10 @@ def before_request():
if call.result.cookies:
for key, value in call.result.cookies.items():
response.set_cookie(key, value, **config.get("apiserver.auth.cookies"))
if value is None:
response.set_cookie(key, "", expires=0)
else:
response.set_cookie(key, value, **config.get("apiserver.auth.cookies"))
return response
except Exception as ex:
@@ -104,7 +107,6 @@ def update_call_data(call, req):
form[key] = True
elif form[key].lower() == "false":
form[key] = False
# NOTE: dict() form data to make sure we won't pass along a MultiDict or some other nasty crap
call.data = json_body or form or {}

View File

@@ -104,7 +104,7 @@ class DataContainer(object):
if self._batched_data:
try:
data_model = [cls(**item) for item in self._batched_data]
except TypeError as ex:
except (ValueError, TypeError) as ex:
raise CallParsingError(str(ex))
for m in data_model:
@@ -112,7 +112,7 @@ class DataContainer(object):
else:
try:
data_model = cls(**self.data)
except TypeError as ex:
except (ValueError, TypeError) as ex:
raise CallParsingError(str(ex))
if not self.schema_validator.enabled:
@@ -182,8 +182,6 @@ class APICallResult(DataContainer):
traceback=self._traceback,
extra=self._extra,
)
if self.log_data:
res["data"] = self.data
return res
def copy_from(self, result):

View File

@@ -1,19 +1,19 @@
import base64
from datetime import datetime
import jwt
from mongoengine import Q
from database.errors import translate_errors_context
from database.model.company import Company
from database.utils import get_options
from database.model.auth import User, Entities, Credentials
from apierrors import errors
from config import config
from database.errors import translate_errors_context
from database.model.auth import User, Entities, Credentials
from database.model.company import Company
from database.utils import get_options
from timing_context import TimingContext
from .payload import Payload, Token, Basic, AuthType
from .identity import Identity
from .fixed_user import FixedUser
from .identity import Identity
from .payload import Payload, Token, Basic, AuthType
log = config.logger(__file__)
@@ -38,12 +38,16 @@ def authorize_token(jwt_token, *_, **__):
return Token.from_encoded_token(jwt_token)
except jwt.exceptions.InvalidKeyError as ex:
raise errors.unauthorized.InvalidToken('jwt invalid key error', reason=ex.args[0])
raise errors.unauthorized.InvalidToken(
"jwt invalid key error", reason=ex.args[0]
)
except jwt.InvalidTokenError as ex:
raise errors.unauthorized.InvalidToken('invalid jwt token', reason=ex.args[0])
raise errors.unauthorized.InvalidToken("invalid jwt token", reason=ex.args[0])
except ValueError as ex:
log.exception('Failed while processing token: %s' % ex.args[0])
raise errors.unauthorized.InvalidToken('failed processing token', reason=ex.args[0])
log.exception("Failed while processing token: %s" % ex.args[0])
raise errors.unauthorized.InvalidToken(
"failed processing token", reason=ex.args[0]
)
def authorize_credentials(auth_data, service, action, call_data_items):
@@ -67,9 +71,14 @@ def authorize_credentials(auth_data, service, action, call_data_items):
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
user = User.objects(query).first()
if not user:
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
if not user:
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
if not FixedUser.enabled():
# In case these are proper credentials, update last used time
User.objects(id=user.id, credentials__key=access_key).update(
**{"set__credentials__$__last_used": datetime.utcnow()}
)
with TimingContext("mongo", "company_by_id"):
company = Company.objects(id=user.company).only('id', 'name').first()
@@ -85,13 +94,13 @@ def authorize_credentials(auth_data, service, action, call_data_items):
return basic
def authorize_impersonation(user, identity, service, action, call_data_items):
def authorize_impersonation(user, identity, service, action, call):
""" Returns a new basic object (auth payload)"""
if not user:
raise ValueError('missing user')
raise ValueError("missing user")
company = Company.objects(id=user.company).only('id', 'name').first()
company = Company.objects(id=user.company).only("id", "name").first()
if not company:
raise errors.unauthorized.InvalidCredentials('invalid user company')
raise errors.unauthorized.InvalidCredentials("invalid user company")
return Payload(auth_type=None, identity=identity)

View File

@@ -30,6 +30,8 @@ def get_secret_key(length=50):
Create a random secret key.
Taken from the Django project.
NOTE: asterisk is not supported due to issues with environment variables containing
asterisks (in case the secret key is stored in an environment variable)
"""
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*(-_=+)'
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&(-_=+)'
return get_random_string(length, chars)

View File

@@ -76,7 +76,7 @@ class Endpoint(object):
Provided endpoints and their schemas on a best-effort basis.
"""
d = {
"min_version": self.min_version,
"min_version": str(self.min_version),
"required_fields": self.required_fields,
"request_data_model": None,
"response_data_model": None,

View File

@@ -12,7 +12,7 @@ from config import config
log = config.logger(__file__)
@attr.s(auto_attribs=True, auto_exc=True)
@attr.s(auto_attribs=True, cmp=False)
class FastValidationError(Exception):
error: fastjsonschema.JsonSchemaException
data: dict

View File

@@ -1,10 +1,10 @@
import re
from importlib import import_module
from itertools import chain
from typing import cast, Iterable, List, MutableMapping
from pathlib import Path
from typing import cast, Iterable, List, MutableMapping, Optional, Tuple
import jsonmodels.models
from pathlib import Path
import timing_context
from apierrors import APIError
@@ -30,7 +30,11 @@ class ServiceRepo(object):
_version_required = config.get("apiserver.version.required")
""" If version is required, parsing will fail for endpoint paths that do not contain a valid version """
_max_version = PartialVersion("2.1")
_check_max_version = config.get("apiserver.version.check_max_version")
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.3")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -133,7 +137,7 @@ class ServiceRepo(object):
return cls._max_version
@classmethod
def _get_endpoint(cls, name, version):
def _get_endpoint(cls, name, version) -> Optional[Endpoint]:
versions = cls._endpoints.get(name)
if not versions:
return None
@@ -144,7 +148,7 @@ class ServiceRepo(object):
return None
@classmethod
def _resolve_endpoint_from_call(cls, call):
def _resolve_endpoint_from_call(cls, call: APICall) -> Optional[Endpoint]:
assert isinstance(call, APICall)
endpoint = cls._get_endpoint(
call.endpoint_name, call.requested_endpoint_version
@@ -167,7 +171,7 @@ class ServiceRepo(object):
return endpoint
@classmethod
def parse_endpoint_path(cls, path):
def parse_endpoint_path(cls, path: str) -> Tuple[PartialVersion, str]:
""" Parse endpoint version, service and action from request path. """
m = cls._endpoint_exp.match(path)
if not m:
@@ -182,14 +186,14 @@ class ServiceRepo(object):
version = PartialVersion(version)
except ValueError as e:
raise RequestPathHasInvalidVersion(version=version, reason=e)
if version > cls._max_version:
if cls._check_max_version and version > cls._max_version:
raise InvalidVersionError(
f"Invalid API version (max. supported version is {cls._max_version})"
)
return version, endpoint_name
@classmethod
def _should_return_stack(cls, code, subcode):
def _should_return_stack(cls, code: int, subcode: int) -> bool:
if not cls._return_stack or code not in cls._return_stack_on_code:
return False
if subcode is None:
@@ -202,7 +206,7 @@ class ServiceRepo(object):
return subcode in subcode_list
@classmethod
def _validate_call(cls, call):
def _validate_call(cls, call: APICall) -> Optional[Endpoint]:
endpoint = cls._resolve_endpoint_from_call(call)
if call.failed:
return
@@ -210,11 +214,13 @@ class ServiceRepo(object):
return endpoint
@classmethod
def validate_call(cls, call):
def validate_call(cls, call: APICall):
cls._validate_call(call)
@classmethod
def _get_company(cls, call, endpoint=None, ignore_error=False):
def _get_company(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]:
authorize = endpoint and endpoint.authorize
if ignore_error or not authorize:
try:
@@ -224,7 +230,7 @@ class ServiceRepo(object):
return call.identity.company
@classmethod
def handle_call(cls, call):
def handle_call(cls, call: APICall):
try:
assert isinstance(call, APICall)

View File

@@ -150,7 +150,7 @@ def validate_impersonation(endpoint, call):
),
service=service,
action=action,
call_data_items=call.batched_data,
call=call,
)
else:
return False

View File

@@ -31,10 +31,8 @@ log = config.logger(__file__)
request_data_model=GetTokenRequest,
response_data_model=GetTokenResponse,
)
def login(call):
def login(call: APICall, *_, **__):
""" Generates a token based on the authenticated user (intended for use with credentials) """
assert isinstance(call, APICall)
call.result.data_model = AuthBLL.get_token_for_user(
user_id=call.identity.user,
company_id=call.identity.company,
@@ -47,6 +45,11 @@ def login(call):
] = call.result.data_model.token
@endpoint("auth.logout", min_version="2.2")
def logout(call: APICall, *_, **__):
call.result.cookies[config.get("apiserver.auth.session_auth_cookie_name")] = None
@endpoint(
"auth.get_token_for_user",
request_data_model=GetTokenForUserRequest,
@@ -140,7 +143,8 @@ def get_credentials(call):
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
call.result.data_model = GetCredentialsResponse(
credentials=[
CredentialsResponse(access_key=c.key) for c in user.credentials
CredentialsResponse(access_key=c.key, last_used=c.last_used)
for c in user.credentials
]
)

View File

@@ -5,7 +5,12 @@ from operator import itemgetter
import six
from apierrors import errors
from apimodels.events import (
MultiTaskScalarMetricsIterHistogramRequest,
ScalarMetricsIterHistogramRequest,
)
from bll.event import EventBLL
from bll.event.event_metrics import EventMetrics
from bll.task import TaskBLL
from service_repo import APICall, endpoint
from utilities import json
@@ -17,11 +22,10 @@ event_bll = EventBLL()
@endpoint("events.add")
def add(call, company_id, req_model):
assert isinstance(call, APICall)
added, batch_errors = event_bll.add_events(company_id, [call.data.copy()], call.worker)
call.result.data = dict(
added=added,
errors=len(batch_errors)
added, batch_errors = event_bll.add_events(
company_id, [call.data.copy()], call.worker
)
call.result.data = dict(added=added, errors=len(batch_errors))
call.kpis["events"] = 1
@@ -33,10 +37,7 @@ def add_batch(call, company_id, req_model):
raise errors.bad_request.BatchContainsNoItems()
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(
added=added,
errors=len(batch_errors)
)
call.result.data = dict(added=added, errors=len(batch_errors))
call.kpis["events"] = len(events)
@@ -48,16 +49,16 @@ def get_task_log(call, company_id, req_model):
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id, task_id, order,
company_id,
task_id,
order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id)
call.result.data = dict(
events=events,
returned=len(events),
total=total_events,
scroll_id=scroll_id,
)
call.result.data = dict(
events=events, returned=len(events), total=total_events, scroll_id=scroll_id
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
@@ -70,7 +71,7 @@ def get_task_log_v1_7(call, company_id, req_model):
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_order = 'asc' if (from_ == 'head') else 'desc'
scroll_order = "asc" if (from_ == "head") else "desc"
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=company_id,
@@ -78,54 +79,57 @@ def get_task_log_v1_7(call, company_id, req_model):
order=scroll_order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id
scroll_id=scroll_id,
)
if scroll_order != order:
events = events[::-1]
call.result.data = dict(
events=events,
returned=len(events),
total=total_events,
scroll_id=scroll_id,
events=events, returned=len(events), total=total_events, scroll_id=scroll_id
)
@endpoint('events.download_task_log', required_fields=['task'])
@endpoint("events.download_task_log", required_fields=["task"])
def download_task_log(call, company_id, req_model):
task_id = call.data['task']
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
line_type = call.data.get('line_type', 'json').lower()
line_format = str(call.data.get('line_format', '{asctime} {worker} {level} {msg}'))
line_type = call.data.get("line_type", "json").lower()
line_format = str(call.data.get("line_format", "{asctime} {worker} {level} {msg}"))
is_json = (line_type == 'json')
is_json = line_type == "json"
if not is_json:
if not line_format:
raise errors.bad_request.MissingRequiredFields('line_format is required for plain text lines')
raise errors.bad_request.MissingRequiredFields(
"line_format is required for plain text lines"
)
# validate line format placeholders
valid_task_log_fields = {'asctime', 'timestamp', 'level', 'worker', 'msg'}
valid_task_log_fields = {"asctime", "timestamp", "level", "worker", "msg"}
invalid_placeholders = set()
while True:
try:
line_format.format(**dict.fromkeys(valid_task_log_fields | invalid_placeholders))
line_format.format(
**dict.fromkeys(valid_task_log_fields | invalid_placeholders)
)
break
except KeyError as e:
invalid_placeholders.add(e.args[0])
except Exception as e:
raise errors.bad_request.FieldsValueError('invalid line format', error=e.args[0])
raise errors.bad_request.FieldsValueError(
"invalid line format", error=e.args[0]
)
if invalid_placeholders:
raise errors.bad_request.FieldsValueError(
'undefined placeholders in line format',
placeholders=invalid_placeholders
"undefined placeholders in line format",
placeholders=invalid_placeholders,
)
# make sure line_format has a trailing newline
line_format = line_format.rstrip('\n') + '\n'
line_format = line_format.rstrip("\n") + "\n"
def generate():
scroll_id = None
@@ -137,30 +141,30 @@ def download_task_log(call, company_id, req_model):
order="asc",
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id
scroll_id=scroll_id,
)
if not log_events:
break
for ev in log_events:
ev['asctime'] = ev.pop('@timestamp')
ev["asctime"] = ev.pop("@timestamp")
if is_json:
ev.pop('type')
ev.pop('task')
yield json.dumps(ev) + '\n'
ev.pop("type")
ev.pop("task")
yield json.dumps(ev) + "\n"
else:
try:
yield line_format.format(**ev)
except KeyError as ex:
raise errors.bad_request.FieldsValueError(
'undefined placeholders in line format',
placeholders=[str(ex)]
"undefined placeholders in line format",
placeholders=[str(ex)],
)
if len(log_events) < batch_size:
break
call.result.filename = 'task_%s.log' % task_id
call.result.content_type = 'text/plain'
call.result.filename = "task_%s.log" % task_id
call.result.content_type = "text/plain"
call.result.raw_data = generate()
@@ -169,7 +173,9 @@ def get_vector_metrics_and_variants(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_vector")
metrics=event_bll.get_metrics_and_variants(
company_id, task_id, "training_stats_vector"
)
)
@@ -178,23 +184,27 @@ def get_scalar_metrics_and_variants(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_scalar")
metrics=event_bll.get_metrics_and_variants(
company_id, task_id, "training_stats_scalar"
)
)
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint("events.vector_metrics_iter_histogram", required_fields=["task", "metric", "variant"])
@endpoint(
"events.vector_metrics_iter_histogram",
required_fields=["task", "metric", "variant"],
)
def vector_metrics_iter_histogram(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(company_id, task_id, metric, variant)
iterations, vectors = event_bll.get_vector_metrics_per_iter(
company_id, task_id, metric, variant
)
call.result.data = dict(
metric=metric,
variant=variant,
vectors=vectors,
iterations=iterations
metric=metric, variant=variant, vectors=vectors, iterations=iterations
)
@@ -207,10 +217,11 @@ def get_task_events(call, company_id, req_model):
task_bll.assert_exists(company_id, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
company_id,
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
scroll_id=scroll_id
scroll_id=scroll_id,
)
call.result.data = dict(
@@ -229,11 +240,12 @@ def get_scalar_metric_data(call, company_id, req_model):
task_bll.assert_exists(company_id, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
company_id,
task_id,
event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id
scroll_id=scroll_id,
)
call.result.data = dict(
@@ -248,35 +260,50 @@ def get_scalar_metric_data(call, company_id, req_model):
def get_task_latest_scalar_values(call, company_id, req_model):
task_id = call.data["task"]
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(company_id, task_id)
es_index = EventBLL.get_index_name(company_id, "*")
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(
company_id, task_id
)
es_index = EventMetrics.get_index_name(company_id, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
name=task.name,
status=task.status,
last_timestamp=last_timestamp
last_timestamp=last_timestamp,
)
# todo: should not repeat iter (x-axis) for each metric/variant, JS client should get raw data and fill gaps if needed
@endpoint("events.scalar_metrics_iter_histogram", required_fields=["task"])
def scalar_metrics_iter_histogram(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
metrics = event_bll.get_scalar_metrics_average_per_iter(company_id, task_id)
@endpoint(
"events.scalar_metrics_iter_histogram",
request_data_model=ScalarMetricsIterHistogramRequest,
)
def scalar_metrics_iter_histogram(
call, company_id, req_model: ScalarMetricsIterHistogramRequest
):
task_bll.assert_exists(call.identity.company, req_model.task, allow_public=True)
metrics = event_bll.metrics.get_scalar_metrics_average_per_iter(
company_id, task_id=req_model.task, samples=req_model.samples, key=req_model.key
)
call.result.data = metrics
@endpoint("events.multi_task_scalar_metrics_iter_histogram", required_fields=["tasks"])
def multi_task_scalar_metrics_iter_histogram(call, company_id, req_model):
task_ids = call.data["tasks"]
@endpoint(
"events.multi_task_scalar_metrics_iter_histogram",
request_data_model=MultiTaskScalarMetricsIterHistogramRequest,
)
def multi_task_scalar_metrics_iter_histogram(
call, company_id, req_model: MultiTaskScalarMetricsIterHistogramRequest
):
task_ids = req_model.tasks
if isinstance(task_ids, six.string_types):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
metrics=event_bll.compare_scalar_metrics_average_per_iter(company_id, task_ids, allow_public=True)
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id, task_ids=task_ids, samples=req_model.samples, allow_public=True, key=req_model.key
)
)
@@ -287,21 +314,27 @@ def get_multi_task_plots_v1_7(call, company_id, req_model):
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
company_id=call.identity.company,
only=("id", "name"),
task_ids=task_ids,
allow_public=True,
)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_ids,
company_id,
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
scroll_id=scroll_id,
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
)
call.result.data = dict(
plots=return_events,
@@ -318,20 +351,26 @@ def get_multi_task_plots(call, company_id, req_model):
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
company_id=call.identity.company,
only=("id", "name"),
task_ids=task_ids,
allow_public=True,
)
result = event_bll.get_task_events(
company_id, task_ids,
company_id,
task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
scroll_id=scroll_id,
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
return_events = _get_top_iter_unique_events_per_task(
result.events, max_iters=iters, tasks=tasks
)
call.result.data = dict(
plots=return_events,
@@ -357,11 +396,12 @@ def get_task_plots_v1_7(call, company_id, req_model):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_id,
company_id,
task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
scroll_id=scroll_id,
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
@@ -381,12 +421,13 @@ def get_task_plots(call, company_id, req_model):
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
event_type="plot",
result = event_bll.get_task_plots(
company_id,
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
last_iterations_per_plot=iters,
scroll_id=scroll_id,
)
return_events = result.events
@@ -415,11 +456,12 @@ def get_debug_images_v1_7(call, company_id, req_model):
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_id,
company_id,
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
scroll_id=scroll_id,
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
@@ -441,11 +483,12 @@ def get_debug_images(call, company_id, req_model):
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
company_id,
task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
scroll_id=scroll_id,
)
return_events = result.events
@@ -464,34 +507,32 @@ def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id)
call.result.data = dict(
deleted=event_bll.delete_task_events(company_id, task_id)
)
call.result.data = dict(deleted=event_bll.delete_task_events(company_id, task_id))
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
key = itemgetter('metric', 'variant', 'task', 'iter')
key = itemgetter("metric", "variant", "task", "iter")
unique_events = itertools.chain.from_iterable(
itertools.islice(group, max_iters)
for _, group in itertools.groupby(sorted(events, key=key, reverse=True), key=key))
for _, group in itertools.groupby(
sorted(events, key=key, reverse=True), key=key
)
)
def collect(evs, fields):
if not fields:
evs = list(evs)
return {
'name': tasks.get(evs[0].get('task')),
'plots': evs
}
return {"name": tasks.get(evs[0].get("task")), "plots": evs}
return {
str(k): collect(group, fields[1:])
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
}
collect_fields = ('metric', 'variant', 'task', 'iter')
collect_fields = ("metric", "variant", "task", "iter")
return collect(
sorted(unique_events, key=itemgetter(*collect_fields), reverse=True),
collect_fields
collect_fields,
)
@@ -502,6 +543,8 @@ def _get_top_iter_unique_events(events, max_iters):
evs = top_unique_events[key]
if len(evs) < max_iters:
evs.append(e)
unique_events = list(itertools.chain.from_iterable(list(top_unique_events.values())))
unique_events = list(
itertools.chain.from_iterable(list(top_unique_events.values()))
)
unique_events.sort(key=lambda e: e["iter"], reverse=True)
return unique_events

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from urllib.parse import urlparse
from mongoengine import Q, EmbeddedDocument
@@ -16,7 +15,6 @@ from apimodels.models import (
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
from database.fields import SupportedURLField
from database.model import validate_id
from database.model.model import Model
from database.model.project import Project
@@ -27,13 +25,23 @@ from database.utils import (
filter_fields,
)
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
log = config.logger(__file__)
get_all_query_options = Model.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=("tags", "framework", "uri", "id", "project", "task", "parent"),
list_fields=(
"tags",
"system_tags",
"framework",
"uri",
"id",
"project",
"task",
"parent",
),
)
@@ -43,20 +51,20 @@ def get_by_id(call):
model_id = call.data["model"]
with translate_errors_context():
res = Model.get_many(
models = Model.get_many(
company=call.identity.company,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not res:
if not models:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
)
call.result.data = {"model": res[0]}
conform_output_tags(call, models[0])
call.result.data = {"model": models[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
@@ -66,31 +74,32 @@ def get_by_task_id(call):
with translate_errors_context():
query = dict(id=task_id, company=call.identity.company)
res = Task.get(_only=["output"], **query)
if not res:
task = Task.get(_only=["output"], **query)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
if not res.output:
if not task.output:
raise errors.bad_request.MissingTaskFields(field="output")
if not res.output.model:
if not task.output.model:
raise errors.bad_request.MissingTaskFields(field="output.model")
model_id = res.output.model
res = Model.objects(
model_id = task.output.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
).first()
if not res:
if not model:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
)
call.result.data = {"model": res.to_proper_dict()}
model_dict = model.to_proper_dict()
conform_output_tags(call, model_dict)
call.result.data = {"model": model_dict}
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call):
assert isinstance(call, APICall)
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
@@ -99,14 +108,13 @@ def get_all_ex(call):
allow_public=True,
query_options=get_all_query_options,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call):
assert isinstance(call, APICall)
def get_all(call: APICall):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
models = Model.get_many(
@@ -116,13 +124,14 @@ def get_all(call):
allow_public=True,
query_options=get_all_query_options,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
create_fields = {
"name": None,
"tags": list,
"system_tags": list,
"task": Task,
"comment": None,
"uri": None,
@@ -134,22 +143,10 @@ create_fields = {
"ready": None,
}
schemes = list(SupportedURLField.schemes)
def _validate_uri(uri):
parsed_uri = urlparse(uri)
if parsed_uri.scheme not in schemes:
raise errors.bad_request.InvalidModelUri("unsupported scheme", uri=uri)
elif not parsed_uri.path:
raise errors.bad_request.InvalidModelUri("missing path", uri=uri)
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
tags = fields.get("tags")
if tags:
fields["tags"] = list(set(tags))
conform_tag_fields(call, fields)
return fields
@@ -251,15 +248,14 @@ def create(call, company, req_model):
if project:
validate_id(Project, company=company, project=project)
uri = req_model.uri
if uri:
_validate_uri(uri)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(call, req_data)
fields = filter_fields(Model, req_data)
conform_tag_fields(call, fields)
# create and save model
model = Model(
id=database.utils.id(),
@@ -276,12 +272,12 @@ def create(call, company, req_model):
def prepare_update_fields(call, fields):
fields = fields.copy()
if "uri" in fields:
_validate_uri(fields["uri"])
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(call, fields)
conform_tag_fields(call, fields)
return fields
@@ -290,8 +286,7 @@ def validate_task(call, fields):
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call):
assert isinstance(call, APICall)
def edit(call: APICall):
identity = call.identity
model_id = call.data["model"]
@@ -327,13 +322,13 @@ def edit(call):
if fields:
updated = model.update(upsert=False, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call, model_id=None):
assert isinstance(call, APICall)
def _update_model(call: APICall, model_id=None):
identity = call.identity
model_id = model_id or call.data["model"]
@@ -358,6 +353,7 @@ def _update_model(call, model_id=None):
updated_count, updated_fields = Model.safe_update(
call.identity.company, model.id, data
)
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)

View File

@@ -9,27 +9,32 @@ from mongoengine import Q
import database
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.projects import GetHyperParamReq, GetHyperParamResp, ProjectReq
from bll.task import TaskBLL
from database.errors import translate_errors_context
from database.model import EntityVisibility
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, TaskStatus, TaskVisibility
from database.model.task.task import Task, TaskStatus
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
task_bll = TaskBLL()
archived_tasks_cond = {"$in": [TaskVisibility.archived.value, "$tags"]}
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
create_fields = {
"name": None,
"description": None,
"tags": list,
"system_tags": list,
"default_output_destination": None,
}
get_all_query_options = Project.QueryParameterOptions(
pattern_fields=("name", "description"), list_fields=("tags", "id")
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id"),
)
@@ -43,32 +48,39 @@ def get_by_id(call):
query = Q(id=project_id) & get_company_or_none_constraint(
call.identity.company
)
res = Project.objects(query).first()
if not res:
project = Project.objects(query).first()
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
res = res.to_proper_dict()
project_dict = project.to_proper_dict()
conform_output_tags(call, project_dict)
call.result.data = {"project": res}
call.result.data = {"project": project_dict}
def make_projects_get_all_pipelines(project_ids, specific_state=None):
archived = TaskVisibility.archived.value
status_count_pipeline = [
# count tasks per project per status
{"$match": {"project": {"$in": project_ids}}},
# make sure tags is always an array (required by subsequent $in in archived_tasks_cond)
{
archived = EntityVisibility.archived.value
def ensure_system_tags():
"""
Make sure system tags is always an array (required by subsequent $in in archived_tasks_cond
"""
return {
"$addFields": {
"tags": {
"system_tags": {
"$cond": {
"if": {"$ne": [{"$type": "$tags"}, "array"]},
"if": {"$ne": [{"$type": "$system_tags"}, "array"]},
"then": [],
"else": "$tags",
"else": "$system_tags",
}
}
}
},
}
status_count_pipeline = [
# count tasks per project per status
{"$match": {"project": {"$in": project_ids}}},
ensure_system_tags(),
{
"$group": {
"_id": {
@@ -125,12 +137,12 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
group_step = {"_id": "$project"}
for state in TaskVisibility:
for state in EntityVisibility:
if specific_state and state != specific_state:
continue
if state == TaskVisibility.active:
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
elif state == TaskVisibility.archived:
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(archived_tasks_cond)
runtime_pipeline = [
@@ -141,6 +153,7 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
"project": {"$in": project_ids},
}
},
ensure_system_tags(),
{
# for each project
"$group": group_step
@@ -151,32 +164,33 @@ def make_projects_get_all_pipelines(project_ids, specific_state=None):
@endpoint("projects.get_all_ex")
def get_all_ex(call):
assert isinstance(call, APICall)
def get_all_ex(call: APICall):
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", TaskVisibility.active.value)
stats_for_state = call.data.get("stats_for_state", EntityVisibility.active.value)
if stats_for_state:
try:
specific_state = TaskVisibility(stats_for_state)
specific_state = EntityVisibility(stats_for_state)
except ValueError:
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
else:
specific_state = None
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
res = Project.get_many_with_join(
projects = Project.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True,
)
conform_output_tags(call, projects)
if not include_stats:
call.result.data = {"projects": res}
call.result.data = {"projects": projects}
return
ids = [project["id"] for project in res]
ids = [project["id"] for project in projects]
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
ids, specific_state=specific_state
)
@@ -187,11 +201,11 @@ def get_all_ex(call):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(TaskVisibility.archived.value)
key = itemgetter(EntityVisibility.archived.value)
for result in Task.objects.aggregate(*status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
TaskVisibility.archived if k else TaskVisibility.active
EntityVisibility.archived if k else EntityVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
@@ -219,32 +233,32 @@ def get_all_ex(call):
}
report_for_states = [
s for s in TaskVisibility if not specific_state or specific_state == s
s for s in EntityVisibility if not specific_state or specific_state == s
]
for project in res:
for project in projects:
project["stats"] = {
task_state.value: get_status_counts(project["id"], task_state.value)
for task_state in report_for_states
}
call.result.data = {"projects": res}
call.result.data = {"projects": projects}
@endpoint("projects.get_all")
def get_all(call):
assert isinstance(call, APICall)
def get_all(call: APICall):
conform_tag_fields(call, call.data)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
res = Project.get_many(
projects = Project.get_many(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
parameters=call.data,
allow_public=True,
)
conform_output_tags(call, projects)
call.result.data = {"projects": res}
call.result.data = {"projects": projects}
@endpoint("projects.create", required_fields=["name", "description"])
@@ -254,6 +268,7 @@ def create(call):
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
conform_tag_fields(call, fields)
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
@@ -271,7 +286,7 @@ def create(call):
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call):
def update(call: APICall):
"""
update
@@ -280,7 +295,6 @@ def update(call):
:return: updated - `int` - number of projects updated
fields - `[string]` - updated fields
"""
assert isinstance(call, APICall)
project_id = call.data["project"]
with translate_errors_context():
@@ -291,9 +305,11 @@ def update(call):
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
conform_tag_fields(call, fields)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@@ -317,7 +333,7 @@ def delete(call):
(Model, errors.bad_request.ProjectHasModels),
):
res = cls.objects(
project=project_id, tags__nin=[TaskVisibility.archived.value]
project=project_id, system_tags__nin=[EntityVisibility.archived.value]
).only("id")
if res and not force:
raise error("use force=true to delete", id=project_id)
@@ -329,12 +345,33 @@ def delete(call):
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
@endpoint("projects.get_unique_metric_variants")
def get_unique_metric_variants(call, company_id, req_model):
project_id = call.data.get("project")
@endpoint("projects.get_unique_metric_variants", request_data_model=ProjectReq)
def get_unique_metric_variants(call: APICall, company_id: str, request: ProjectReq):
metrics = task_bll.get_unique_metric_variants(
company_id, [project_id] if project_id else None
company_id, [request.project] if request.project else None
)
call.result.data = {"metrics": metrics}
@endpoint(
"projects.get_hyper_parameters",
min_version="2.2",
request_data_model=GetHyperParamReq,
response_data_model=GetHyperParamResp,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamReq):
total, remaining, parameters = TaskBLL.get_aggregated_project_execution_parameters(
company_id,
project_ids=[request.project] if request.project else None,
page=request.page,
page_size=request.page_size,
)
call.result.data = {
"total": total,
"remaining": remaining,
"parameters": parameters,
}

View File

@@ -0,0 +1,45 @@
from pyhocon.config_tree import NoneValue
from config import config
from config.info import get_version, get_build_number, get_commit_number
from service_repo import ServiceRepo, APICall, endpoint
@endpoint("server.config")
def get_config(call: APICall):
path = call.data.get("path")
if path:
c = dict(config.get(path))
else:
c = config.to_dict()
def remove_none_value(x):
"""
Pyhocon bug in Python 3: leaves dummy "NoneValue"s in tree,
see: https://github.com/chimpler/pyhocon/issues/111
"""
if isinstance(x, dict):
return {key: remove_none_value(value) for key, value in x.items()}
if isinstance(x, list):
return list(map(remove_none_value, x))
if isinstance(x, NoneValue):
return None
return x
c.pop("secure", None)
call.result.data = remove_none_value(c)
@endpoint("server.endpoints")
def get_endpoints(call: APICall):
call.result.data = ServiceRepo.endpoints_summary()
@endpoint("server.info")
def info(call: APICall):
call.result.data = {
"version": get_version(),
"build": get_build_number(),
"commit": get_commit_number(),
}

View File

@@ -33,13 +33,14 @@ from database.model.task.output import Output
from database.model.task.task import Task, TaskStatus, Script, DEFAULT_LAST_ITERATION
from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
from timing_context import TimingContext
from utilities import safe_get
task_fields = set(Task.get_fields())
task_script_fields = set(get_fields(Script))
get_all_query_options = Task.QueryParameterOptions(
list_fields=("id", "user", "tags", "type", "status", "project"),
list_fields=("id", "user", "tags", "system_tags", "type", "status", "project"),
datetime_fields=("status_changed",),
pattern_fields=("name", "comment"),
fields=("parent",),
@@ -79,11 +80,13 @@ def get_by_id(call: APICall, company_id, req_model: TaskRequest):
req_model.task, company_id=company_id, allow_public=True
)
task_dict = task.to_proper_dict()
conform_output_tags(call, task_dict)
call.result.data = {"task": task_dict}
@endpoint("tasks.get_all_ex", required_fields=[])
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
tasks = Task.get_many_with_join(
@@ -91,13 +94,15 @@ def get_all_ex(call: APICall):
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
override_none_ordering=True,
)
conform_output_tags(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
def get_all(call: APICall):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
@@ -106,7 +111,9 @@ def get_all(call: APICall):
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
override_none_ordering=True,
)
conform_output_tags(call, tasks)
call.result.data = {"tasks": tasks}
@@ -188,6 +195,7 @@ def close(call: APICall, company_id, req_model: UpdateRequest):
create_fields = {
"name": None,
"tags": list,
"system_tags": list,
"type": None,
"error": None,
"comment": None,
@@ -219,10 +227,7 @@ def prepare_create_fields(
output = Output(destination=output_dest)
fields["output"] = output
# Make sure there are no duplicate tags
tags = fields.get("tags")
if tags:
fields["tags"] = list(set(tags))
conform_tag_fields(call, fields)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_fields:
@@ -251,7 +256,7 @@ def _validate_and_get_task_from_call(call: APICall, **kwargs):
task = task_bll.create(call, fields)
with TimingContext("code", "validate"):
task_bll.validate(task, force=call.data.get("force", False))
task_bll.validate(task)
return task
@@ -272,16 +277,14 @@ def create(call: APICall, company_id, req_model: CreateRequest):
call.result.data = {"id": task.id}
def prepare_update_fields(task, call_data):
def prepare_update_fields(call: APICall, task, call_data):
valid_fields = deepcopy(task.__class__.user_set_allowed())
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
update_fields["output__error"] = None
t_fields = task_fields
t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields)
tags = fields.get("tags")
if tags:
fields["tags"] = list(set(tags))
conform_tag_fields(call, fields)
return fields, valid_fields
@@ -296,7 +299,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(task, call.data)
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
if not partial_update_dict:
return UpdateResponse(updated=0)
@@ -309,7 +312,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
)
update_project_time(updated_fields.get("project"))
conform_output_tags(call, updated_fields)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@@ -364,7 +367,7 @@ def update_batch(call: APICall):
bulk_ops = []
for id, data in items.items():
fields, valid_fields = prepare_update_fields(tasks[id], data)
fields, valid_fields = prepare_update_fields(call, tasks[id], data)
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
@@ -421,7 +424,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
d.update(value)
fields[key] = d
task_bll.validate(task_bll.create(call, fields), force=force)
task_bll.validate(task_bll.create(call, fields))
# make sure field names do not end in mongoengine comparison operators
fixed_fields = {
@@ -434,6 +437,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
fixed_fields.update(last_update=now)
updated = task.update(upsert=False, **fixed_fields)
update_project_time(fields.get("project"))
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
@@ -463,6 +467,7 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
set__last_metrics={},
unset__output__result=1,
unset__output__model=1,
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
)
res = ResetResponse(
@@ -670,7 +675,10 @@ def publish(call: APICall, company_id, req_model: PublishRequest):
@endpoint(
"tasks.completed", min_version="2.2", request_data_model=UpdateRequest, response_data_model=UpdateResponse
"tasks.completed",
min_version="2.2",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
)
def completed(call: APICall, company_id, request: PublishRequest):
call.result.data_model = UpdateResponse(
@@ -688,4 +696,3 @@ def ping(_, company_id, request: PingRequest):
TaskBLL.set_last_update(
task_ids=[request.task], company_id=company_id, last_update=datetime.utcnow()
)

52
server/services/utils.py Normal file
View File

@@ -0,0 +1,52 @@
from typing import Union, Sequence
from database.utils import partition_tags
from service_repo import APICall
from service_repo.base import PartialVersion
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
if call.requested_endpoint_version >= PartialVersion("2.3"):
return
if isinstance(documents, dict):
documents = [documents]
for doc in documents:
system_tags = doc.get("system_tags")
if system_tags:
doc["tags"] = list(set(doc.get("tags", [])) | set(system_tags))
def conform_tag_fields(call: APICall, document: dict):
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if call.requested_endpoint_version < PartialVersion("2.3"):
service_name = call.endpoint_name.partition(".")[0]
upgrade_tags(
service_name[:-1] if service_name.endswith("s") else service_name, document
)
remove_duplicate_tags(document)
def upgrade_tags(entity: str, document: dict):
"""
If only 'tags' is present in the fields then extract
the system tags from it to a separate field 'system_tags'
"""
tags = document.get("tags")
if tags is not None and not document.get("system_tags"):
user_tags, system_tags = partition_tags(entity, tags)
document["tags"] = user_tags
document["system_tags"] = system_tags
def remove_duplicate_tags(document: dict):
"""
Remove duplicates from 'tags' and 'system_tags' fields
"""
for name in ("tags", "system_tags"):
values = document.get(name)
if values:
document[name] = list(set(values))

View File

@@ -1,159 +0,0 @@
"""
Comprehensive test of all(?) use cases of datasets and frames
"""
import json
import unittest
import es_factory
from tests.api_client import APIClient
from config import config
log = config.logger(__file__)
class TestDatasetsService(unittest.TestCase):
def setUp(self):
self.api = APIClient(base_url="http://localhost:5100/v1.0")
self.created_tasks = []
self.task = dict(
name="test task events",
type="training",
)
res, self.task_id = self.api.send('tasks.create', self.task, extract="id")
assert (res.meta.result_code == 200)
self.created_tasks.append(self.task_id)
def tearDown(self):
log.info("Cleanup...")
for task_id in self.created_tasks:
try:
self.api.send('tasks.delete', dict(task=task_id, force=True))
except Exception as ex:
log.exception(ex)
def create_task_event(self, type, iteration):
return {
"worker": "test",
"type": type,
"task": self.task_id,
"iter": iteration,
"timestamp": es_factory.get_timestamp_millis()
}
def copy_and_update(self, src_obj, new_data):
obj = src_obj.copy()
obj.update(new_data)
return obj
def test_task_logs(self):
events = []
for iter in range(10):
log_event = self.create_task_event("log", iteration=iter)
events.append(self.copy_and_update(log_event, {
"msg": "This is a log message from test task iter " + str(iter)
}))
# sleep so timestamp is not the same
import time
time.sleep(0.01)
self.send_batch(events)
data = self.api.events.get_task_log(task=self.task_id)
assert len(data["events"]) == 10
self.api.tasks.reset(task=self.task_id)
data = self.api.events.get_task_log(task=self.task_id)
assert len(data["events"]) == 0
def test_task_plots(self):
event = self.create_task_event("plot", 0)
event["metric"] = "roc"
event.update({
"plot_str": json.dumps({
"data": [
{
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"y": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"text": ["Th=0.1", "Th=0.2", "Th=0.3", "Th=0.4", "Th=0.5", "Th=0.6", "Th=0.7", "Th=0.8"],
"name": 'class1'
},
{
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"y": [2.0, 3.0, 5.0, 8.2, 6.4, 7.5, 9.2, 8.1, 10.0],
"text": ["Th=0.1", "Th=0.2", "Th=0.3", "Th=0.4", "Th=0.5", "Th=0.6", "Th=0.7", "Th=0.8"],
"name": 'class2',
}
],
"layout": {
"title": "ROC for iter 0",
"xaxis": {
"title": 'my x axis'
},
"yaxis": {
"title": 'my y axis'
}
}
})
})
self.send(event)
event = self.create_task_event("plot", 100)
event["metric"] = "confusion"
event.update({
"plot_str": json.dumps({
"data": [
{
"y": [
"lying",
"sitting",
"standing",
"people",
"backgroun"
],
"x": [
"lying",
"sitting",
"standing",
"people",
"backgroun"
],
"z": [
[758, 163, 0, 0, 23],
[63, 858, 3, 0, 0],
[0, 50, 188, 21, 35],
[0, 22, 8, 40, 4, ],
[12, 91, 26, 29, 368]
],
"type": "heatmap"
}
],
"layout": {
"title": "Confusion Matrix for iter 100",
"xaxis": {
"title": "Predicted value"
},
"yaxis": {
"title": "Real value"
}
}
})
})
self.send(event)
data = self.api.events.get_task_plots(task=self.task_id)
assert len(data["plots"]) == 2
self.api.tasks.reset(task=self.task_id)
data = self.api.events.get_task_plots(task=self.task_id)
assert len(data["plots"]) == 0
def send_batch(self, events):
self.api.send_batch('events.add_batch', events)
def send(self, event):
self.api.send('events.add', event)
if __name__ == '__main__':
unittest.main()

View File

@@ -0,0 +1,218 @@
from time import sleep
from typing import Sequence
from apierrors.errors import bad_request
from database.utils import partition_tags
from tests.api_client import APIClient, AttrDict
from tests.automated import TestService
from config import config
log = config.logger(__file__)
class TestTags(TestService):
def setUp(self, version="2.3"):
super().setUp(version)
def testPartition(self):
tags, system_tags = partition_tags("project", ["test"])
self.assertTagsEqual(tags, ["test"])
self.assertTagsEqual(system_tags, [])
tags, system_tags = partition_tags("project", ["test", "archived"])
self.assertTagsEqual(tags, ["test"])
self.assertTagsEqual(system_tags, ["archived"])
tags, system_tags = partition_tags("project", ["test", "archived"], ["custom"])
self.assertTagsEqual(tags, ["test"])
self.assertTagsEqual(system_tags, ["archived", "custom"])
tags, system_tags = partition_tags(
"task", ["test", "development", "annotator20", "Annotation"]
)
self.assertTagsEqual(tags, ["test"])
self.assertTagsEqual(system_tags, ["development", "annotator20", "Annotation"])
def testBackwardsCompatibility(self):
new_api = self.api
self.api = APIClient(base_url="http://localhost:8008/v2.2")
entity_tags = {
"model": "archived",
"project": "public",
"task": "development",
}
for name, system_tag in entity_tags.items():
create_func = getattr(self, f"_temp_{name}")
_id = create_func(tags=[system_tag, "test"])
names = f"{name}s"
# when accessed through the old api all the tags are in the tags field
self.assertGetById(
service=names, entity=name, _id=_id, tags=[system_tag, "test"]
)
entities = self._send(
names, "get_all", name="Test tags", tags=[f"-{system_tag}"]
)[names]
self.assertNotFound(_id, entities)
# when accessed through the new api the tags are in tags and system_tags fields
self.assertGetById(
service=names,
entity=name,
_id=_id,
tags=["test"],
system_tags=[system_tag],
api=new_api,
)
# update operation, remove system tag through the old api
self._send(names, "update", tags=["test"], **{name: _id})
self.assertGetById(service=names, entity=name, _id=_id, tags=["test"])
def testProjectTags(self):
pr_id = self._temp_project(system_tags=["default"])
# Test getting project with system tags
projects = self.api.projects.get_all(name="Test tags").projects
self.assertFound(pr_id, ["default"], projects)
projects = self.api.projects.get_all(
name="Test tags", system_tags=["default"]
).projects
self.assertFound(pr_id, ["default"], projects)
projects = self.api.projects.get_all(
name="Test tags", system_tags=["-default"]
).projects
self.assertNotFound(pr_id, projects)
self.api.projects.update(project=pr_id, system_tags=[])
projects = self.api.projects.get_all(
name="Test tags", system_tags=["-default"]
).projects
self.assertFound(pr_id, [], projects)
# Test task statistics and delete
task1_id = self._temp_task(
name="Tags test1", project=pr_id, system_tags=["active"]
)
self._run_task(task1_id)
task2_id = self._temp_task(
name="Tags test2", project=pr_id, system_tags=["archived"]
)
projects = self.api.projects.get_all_ex(name="Test tags").projects
self.assertFound(pr_id, [], projects)
projects = self.api.projects.get_all_ex(
name="Test tags", include_stats=True
).projects
project = next(p for p in projects if p.id == pr_id)
self.assertProjectStats(project)
with self.api.raises(bad_request.ProjectHasTasks):
self.api.projects.delete(project=pr_id)
self.api.projects.delete(project=pr_id, force=True)
def testModelTags(self):
model_id = self._temp_model(system_tags=["default"])
models = self.api.models.get_all_ex(
name="Test tags", system_tags=["default"]
).models
self.assertFound(model_id, ["default"], models)
models = self.api.models.get_all_ex(
name="Test tags", system_tags=["-default"]
).models
self.assertNotFound(model_id, models)
self.api.models.update(model=model_id, system_tags=[])
models = self.api.models.get_all_ex(
name="Test tags", system_tags=["-default"]
).models
self.assertFound(model_id, [], models)
def testTaskTags(self):
task_id = self._temp_task(
name="Test tags", system_tags=["active"]
)
tasks = self.api.tasks.get_all_ex(
name="Test tags", system_tags=["active"]
).tasks
self.assertFound(task_id, ["active"], tasks)
tasks = self.api.tasks.get_all_ex(
name="Test tags", system_tags=["-active"]
).tasks
self.assertNotFound(task_id, tasks)
self.api.tasks.update(task=task_id, system_tags=[])
tasks = self.api.tasks.get_all_ex(
name="Test tags", system_tags=["-active"]
).tasks
self.assertFound(task_id, [], tasks)
# test development system tag
self.api.tasks.started(task=task_id)
self.api.tasks.stop(task=task_id)
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "in_progress")
self.api.tasks.update(task=task_id, system_tags=["development"])
self.api.tasks.stop(task=task_id)
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "stopped")
def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0)
def _run_task(self, task_id):
"""Imitate 1 second of running"""
self.api.tasks.started(task=task_id)
sleep(1)
self.api.tasks.stopped(task=task_id)
def _temp_project(self, **kwargs):
self._update_missing(kwargs, name="Test tags", description="test")
return self.create_temp("projects", **kwargs)
def _temp_model(self, **kwargs):
self._update_missing(kwargs, name="Test tags", uri="file:///a/b", labels={})
return self.create_temp("models", **kwargs)
def _temp_task(self, **kwargs):
self._update_missing(kwargs, name="Test tags", type="testing", input=dict(view=dict()))
return self.create_temp("tasks", **kwargs)
@staticmethod
def _update_missing(target: dict, **update):
target.update({k: v for k, v in update.items() if k not in target})
def _send(self, service, action, **kwargs):
api = kwargs.pop("api", self.api)
return AttrDict(
api.send(f"{service}.{action}", kwargs)[1]
)
def assertGetById(self, service, entity, _id, tags, system_tags=None, **kwargs):
entity = self._send(service, "get_by_id", **{entity: _id}, **kwargs)[entity]
self.assertEqual(set(entity.tags), set(tags))
if system_tags is not None:
self.assertEqual(set(entity.system_tags), set(system_tags))
def assertFound(
self, _id: str, system_tags: Sequence[str], res: Sequence[AttrDict]
):
found = next((r for r in res if _id == r.id), None)
assert found
self.assertTagsEqual(found.system_tags, system_tags)
def assertNotFound(
self, _id: str, res: Sequence[AttrDict]
):
self.assertFalse(any(r for r in res if r.id == _id))
def assertTagsEqual(self, tags: Sequence[str], expected_tags: Sequence[str]):
self.assertEqual(set(tags), set(expected_tags))

View File

@@ -0,0 +1,240 @@
"""
Comprehensive test of all(?) use cases of datasets and frames
"""
import json
import unittest
from statistics import mean
import es_factory
from config import config
from tests.automated import TestService
log = config.logger(__file__)
class TestTaskEvents(TestService):
def setUp(self, version="1.7"):
super().setUp(version=version)
self.created_tasks = []
self.task = dict(
name="test task events",
type="training",
input=dict(mapping={}, view=dict(entries=[])),
)
res, self.task_id = self.api.send("tasks.create", self.task, extract="id")
assert res.meta.result_code == 200
self.created_tasks.append(self.task_id)
def tearDown(self):
log.info("Cleanup...")
for task_id in self.created_tasks:
try:
self.api.send("tasks.delete", dict(task=task_id, force=True))
except Exception as ex:
log.exception(ex)
def create_task_event(self, type, iteration):
return {
"worker": "test",
"type": type,
"task": self.task_id,
"iter": iteration,
"timestamp": es_factory.get_timestamp_millis()
}
def copy_and_update(self, src_obj, new_data):
obj = src_obj.copy()
obj.update(new_data)
return obj
def test_task_logs(self):
events = []
for iter in range(10):
log_event = self.create_task_event("log", iteration=iter)
events.append(
self.copy_and_update(
log_event,
{"msg": "This is a log message from test task iter " + str(iter)},
)
)
# sleep so timestamp is not the same
import time
time.sleep(0.01)
self.send_batch(events)
data = self.api.events.get_task_log(task=self.task_id)
assert len(data["events"]) == 10
self.api.tasks.reset(task=self.task_id)
data = self.api.events.get_task_log(task=self.task_id)
assert len(data["events"]) == 0
def test_task_metric_value_intervals_keys(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
events = [
{
**self.create_task_event("training_stats_scalar", iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
for key in None, "iter", "timestamp", "iso_time":
with self.subTest(key=key):
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, key=key)
self.assertIn(metric, data)
self.assertIn(variant, data[metric])
self.assertIn("x", data[metric][variant])
self.assertIn("y", data[metric][variant])
def test_task_metric_value_intervals(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
events = [
{
**self.create_task_event("training_stats_scalar", iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id)
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=100)
self._assert_metrics_histogram(data[metric][variant], iter_count, 100)
data = self.api.events.scalar_metrics_iter_histogram(task=self.task_id, samples=10)
self._assert_metrics_histogram(data[metric][variant], iter_count, 10)
def _assert_metrics_histogram(self, data, iters, samples):
interval = iters // samples
self.assertEqual(len(data["x"]), samples)
self.assertEqual(len(data["y"]), samples)
for curr in range(samples):
self.assertEqual(data["x"][curr], curr * interval)
self.assertEqual(
data["y"][curr],
mean(v for v in range(curr * interval, (curr + 1) * interval)),
)
def test_task_plots(self):
event = self.create_task_event("plot", 0)
event["metric"] = "roc"
event.update(
{
"plot_str": json.dumps(
{
"data": [
{
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"y": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"text": [
"Th=0.1",
"Th=0.2",
"Th=0.3",
"Th=0.4",
"Th=0.5",
"Th=0.6",
"Th=0.7",
"Th=0.8",
],
"name": "class1",
},
{
"x": [0, 1, 2, 3, 4, 5, 6, 7, 8],
"y": [2.0, 3.0, 5.0, 8.2, 6.4, 7.5, 9.2, 8.1, 10.0],
"text": [
"Th=0.1",
"Th=0.2",
"Th=0.3",
"Th=0.4",
"Th=0.5",
"Th=0.6",
"Th=0.7",
"Th=0.8",
],
"name": "class2",
},
],
"layout": {
"title": "ROC for iter 0",
"xaxis": {"title": "my x axis"},
"yaxis": {"title": "my y axis"},
},
}
)
}
)
self.send(event)
event = self.create_task_event("plot", 100)
event["metric"] = "confusion"
event.update(
{
"plot_str": json.dumps(
{
"data": [
{
"y": [
"lying",
"sitting",
"standing",
"people",
"backgroun",
],
"x": [
"lying",
"sitting",
"standing",
"people",
"backgroun",
],
"z": [
[758, 163, 0, 0, 23],
[63, 858, 3, 0, 0],
[0, 50, 188, 21, 35],
[0, 22, 8, 40, 4],
[12, 91, 26, 29, 368],
],
"type": "heatmap",
}
],
"layout": {
"title": "Confusion Matrix for iter 100",
"xaxis": {"title": "Predicted value"},
"yaxis": {"title": "Real value"},
},
}
)
}
)
self.send(event)
data = self.api.events.get_task_plots(task=self.task_id)
assert len(data["plots"]) == 2
self.api.tasks.reset(task=self.task_id)
data = self.api.events.get_task_plots(task=self.task_id)
assert len(data["plots"]) == 0
def send_batch(self, events):
self.api.send_batch("events.add_batch", events)
def send(self, event):
self.api.send("events.add", event)
if __name__ == "__main__":
unittest.main()

View File

@@ -36,8 +36,8 @@ class TestTasksResetDelete(TestService):
TASK_CANNOT_BE_DELETED_CODES = (400, 123)
def setUp(self):
super(TestTasksResetDelete, self).setUp()
def setUp(self, version="1.7"):
super(TestTasksResetDelete, self).setUp(version=version)
self.tasks = self.api.tasks
self.models = self.api.models

View File

@@ -0,0 +1,108 @@
import operator
from time import sleep
from typing import Sequence
from tests.automated import TestService
class TestTasksOrdering(TestService):
test_comment = "Task ordering test"
only_fields = ["id", "started", "comment"]
def setUp(self, **kwargs):
super().setUp(**kwargs)
self.task_ids = self._create_tasks()
def test_order(self):
# test no ordering
self._assertGetTasksWithOrdering()
# sort ascending
self._assertGetTasksWithOrdering(order_by="started")
# sort descending
self._assertGetTasksWithOrdering(order_by="-started")
# sort by the same field that we use for the search
self._assertGetTasksWithOrdering(order_by="comment")
def test_order_with_paging(self):
order_field = "started"
# all results in one page
self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20)
field_vals = []
page_size = 2
num_pages = 5
for page in range(num_pages):
paged_tasks = self._get_page_tasks(
order_by=order_field, page=page, page_size=page_size
)
self.assertEqual(len(paged_tasks), page_size)
field_vals.extend(t.get(order_field) for t in paged_tasks)
paged_tasks = self._get_page_tasks(
order_by=order_field, page=num_pages, page_size=page_size
)
self.assertTrue(not paged_tasks)
self._assertSorted(field_vals)
def _get_page_tasks(self, order_by, page: int, page_size: int) -> Sequence:
return self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=order_by,
comment=self.test_comment,
page=page,
page_size=page_size,
).tasks
def _assertSorted(self, vals: Sequence, ascending=True):
"""
Assert that vals are sorted in the ascending or descending order
with None values are always coming from the end
"""
if None in vals:
first_null_idx = vals.index(None)
none_tail = vals[first_null_idx:]
vals = vals[:first_null_idx]
self.assertTrue(all(val is None for val in none_tail))
self.assertTrue(all(val is not None for val in vals))
if ascending:
cmp = operator.le
else:
cmp = operator.ge
self.assertTrue(all(cmp(i, j) for i, j in zip(vals, vals[1:])))
def _assertGetTasksWithOrdering(self, order_by: str = None, **kwargs):
tasks = self.api.tasks.get_all_ex(
only_fields=self.only_fields,
order_by=order_by,
comment=self.test_comment,
**kwargs,
).tasks
self.assertLessEqual(set(self.task_ids), set(t.id for t in tasks))
if order_by:
# test that the output is correctly ordered
field_name = order_by if not order_by.startswith("-") else order_by[1:]
field_vals = [t.get(field_name) for t in tasks]
self._assertSorted(field_vals, ascending=not order_by.startswith("-"))
def _create_tasks(self):
tasks = [self._temp_task() for _ in range(10)]
for _, task in zip(range(5), tasks):
self.api.tasks.started(task=task)
sleep(0.1)
return tasks
def _temp_task(self, **kwargs):
return self.create_temp(
"tasks",
name="test",
comment=self.test_comment,
type="testing",
input=dict(view=dict()),
**kwargs,
)

17
server/utilities/dicts.py Normal file
View File

@@ -0,0 +1,17 @@
from typing import Sequence, Tuple, Any
def flatten_nested_items(
dictionary: dict, nesting: int = None, include_leaves=None, prefix=None
) -> Sequence[Tuple[Tuple[str, ...], Any]]:
"""
iterate through dictionary and return with nested keys flattened into a tuple
"""
next_nesting = None if nesting is None else (nesting - 1)
prefix = prefix or ()
for key, value in dictionary.items():
path = prefix + (key,)
if isinstance(value, dict) and nesting != 0:
yield from flatten_nested_items(value, next_nesting, include_leaves, prefix=path)
elif include_leaves is None or key in include_leaves:
yield path, value

View File

@@ -1,557 +0,0 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2018 MongoDB, Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
TERMS AND CONDITIONS
0. Definitions.
“This License” refers to Server Side Public License.
“Copyright” also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
“The Program” refers to any copyrightable work licensed under this
License. Each licensee is addressed as “you”. “Licensees” and
“recipients” may be individuals or organizations.
To “modify” a work means to copy from or adapt all or part of the work in
a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a “modified version” of the
earlier work or a work “based on” the earlier work.
A “covered work” means either the unmodified Program or a work based on
the Program.
To “propagate” a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To “convey” a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through a
computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays “Appropriate Legal Notices” to the
extent that it includes a convenient and prominently visible feature that
(1) displays an appropriate copyright notice, and (2) tells the user that
there is no warranty for the work (except to the extent that warranties
are provided), that licensees may convey the work under this License, and
how to view a copy of this License. If the interface presents a list of
user commands or options, such as a menu, a prominent item in the list
meets this criterion.
1. Source Code.
The “source code” for a work means the preferred form of the work for
making modifications to it. “Object code” means any non-source form of a
work.
A “Standard Interface” means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that is
widely used among developers working in that language. The “System
Libraries” of an executable work include anything, other than the work as
a whole, that (a) is included in the normal form of packaging a Major
Component, but which is not part of that Major Component, and (b) serves
only to enable use of the work with that Major Component, or to implement
a Standard Interface for which an implementation is available to the
public in source code form. A “Major Component”, in this context, means a
major essential component (kernel, window system, and so on) of the
specific operating system (if any) on which the executable work runs, or
a compiler used to produce the work, or an object code interpreter used
to run it.
The “Corresponding Source” for a work in object code form means all the
source code needed to generate, install, and (for an executable work) run
the object code and to modify the work, including scripts to control
those activities. However, it does not include the work's System
Libraries, or general-purpose tools or generally available free programs
which are used unmodified in performing those activities but which are
not part of the work. For example, Corresponding Source includes
interface definition files associated with source files for the work, and
the source code for shared libraries and dynamically linked subprograms
that the work is specifically designed to require, such as by intimate
data communication or control flow between those subprograms and other
parts of the work.
The Corresponding Source need not include anything that users can
regenerate automatically from other parts of the Corresponding Source.
The Corresponding Source for a work in source code form is that same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program, subject to section 13. The
output from running a covered work is covered by this License only if the
output, given its content, constitutes a covered work. This License
acknowledges your rights of fair use or other equivalent, as provided by
copyright law. Subject to section 13, you may make, run and propagate
covered works that you do not convey, without conditions so long as your
license otherwise remains in force. You may convey covered works to
others for the sole purpose of having them make modifications exclusively
for you, or provide you with facilities for running those works, provided
that you comply with the terms of this License in conveying all
material for which you do not control copyright. Those thus making or
running the covered works for you must do so exclusively on your
behalf, under your direction and control, on terms that prohibit them
from making any copies of your copyrighted material outside their
relationship with you.
Conveying under any other circumstances is permitted solely under the
conditions stated below. Sublicensing is not allowed; section 10 makes it
unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article 11
of the WIPO copyright treaty adopted on 20 December 1996, or similar laws
prohibiting or restricting circumvention of such measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention is
effected by exercising rights under this License with respect to the
covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's users,
your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice; keep
intact all notices stating that this License and any non-permissive terms
added in accord with section 7 apply to the code; keep intact all notices
of the absence of any warranty; and give all recipients a copy of this
License along with the Program. You may charge any price or no price for
each copy that you convey, and you may offer support or warranty
protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the terms
of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified it,
and giving a relevant date.
b) The work must carry prominent notices stating that it is released
under this License and any conditions added under section 7. This
requirement modifies the requirement in section 4 to “keep intact all
notices”.
c) You must license the entire work, as a whole, under this License to
anyone who comes into possession of a copy. This License will therefore
apply, along with any applicable section 7 additional terms, to the
whole of the work, and all its parts, regardless of how they are
packaged. This License gives no permission to license the work in any
other way, but it does not invalidate such permission if you have
separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your work
need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work, and
which are not combined with it such as to form a larger program, in or on
a volume of a storage or distribution medium, is called an “aggregate” if
the compilation and its resulting copyright are not used to limit the
access or legal rights of the compilation's users beyond what the
individual works permit. Inclusion of a covered work in an aggregate does
not cause this License to apply to the other parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms of
sections 4 and 5, provided that you also convey the machine-readable
Corresponding Source under the terms of this License, in one of these
ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium customarily
used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a written
offer, valid for at least three years and valid for as long as you
offer spare parts or customer support for that product model, to give
anyone who possesses the object code either (1) a copy of the
Corresponding Source for all the software in the product that is
covered by this License, on a durable physical medium customarily used
for software interchange, for a price no more than your reasonable cost
of physically performing this conveying of source, or (2) access to
copy the Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This alternative is
allowed only occasionally and noncommercially, and only if you received
the object code with such an offer, in accord with subsection 6b.
d) Convey the object code by offering access from a designated place
(gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to copy
the object code is a network server, the Corresponding Source may be on
a different server (operated by you or a third party) that supports
equivalent copying facilities, provided you maintain clear directions
next to the object code saying where to find the Corresponding Source.
Regardless of what server hosts the Corresponding Source, you remain
obligated to ensure that it is available for as long as needed to
satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided you
inform other peers where the object code and Corresponding Source of
the work are being offered to the general public at no charge under
subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be included
in conveying the object code work.
A “User Product” is either (1) a “consumer product”, which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, “normally used” refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
“Installation Information” for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as part
of a transaction in which the right of possession and use of the User
Product is transferred to the recipient in perpetuity or for a fixed term
(regardless of how the transaction is characterized), the Corresponding
Source conveyed under this section must be accompanied by the
Installation Information. But this requirement does not apply if neither
you nor any third party retains the ability to install modified object
code on the User Product (for example, the work has been installed in
ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access
to a network may be denied when the modification itself materially
and adversely affects the operation of the network or violates the
rules and protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided, in
accord with this section must be in a format that is publicly documented
(and with an implementation available to the public in source code form),
and must require no special password or key for unpacking, reading or
copying.
7. Additional Terms.
“Additional permissions” are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall be
treated as though they were included in this License, to the extent that
they are valid under applicable law. If additional permissions apply only
to part of the Program, that part may be used separately under those
permissions, but the entire Program remains governed by this License
without regard to the additional permissions. When you convey a copy of
a covered work, you may at your option remove any additional permissions
from that copy, or from any part of it. (Additional permissions may be
written to require their own removal in certain cases when you modify the
work.) You may place additional permissions on material, added by you to
a covered work, for which you have or can give appropriate copyright
permission.
Notwithstanding any other provision of this License, for material you add
to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some trade
names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that material
by anyone who conveys the material (or modified versions of it) with
contractual assumptions of liability to the recipient, for any
liability that these contractual assumptions directly impose on those
licensors and authors.
All other non-permissive additional terms are considered “further
restrictions” within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further restriction,
you may remove that term. If a license document contains a further
restriction but permits relicensing or conveying under this License, you
may add to a covered work material governed by the terms of that license
document, provided that the further restriction does not survive such
relicensing or conveying.
If you add terms to a covered work in accord with this section, you must
place, in the relevant source files, a statement of the additional terms
that apply to those files, or a notice indicating where to find the
applicable terms. Additional terms, permissive or non-permissive, may be
stated in the form of a separately written license, or stated as
exceptions; the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or modify
it is void, and will automatically terminate your rights under this
License (including any patent licenses granted under the third paragraph
of section 11).
However, if you cease all violation of this License, then your license
from a particular copyright holder is reinstated (a) provisionally,
unless and until the copyright holder explicitly and finally terminates
your license, and (b) permanently, if the copyright holder fails to
notify you of the violation by some reasonable means prior to 60 days
after the cessation.
Moreover, your license from a particular copyright holder is reinstated
permanently if the copyright holder notifies you of the violation by some
reasonable means, this is the first time you have received notice of
violation of this License (for any work) from that copyright holder, and
you cure the violation prior to 30 days after your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or run a
copy of the Program. Ancillary propagation of a covered work occurring
solely as a consequence of using peer-to-peer transmission to receive a
copy likewise does not require acceptance. However, nothing other than
this License grants you permission to propagate or modify any covered
work. These actions infringe copyright if you do not accept this License.
Therefore, by modifying or propagating a covered work, you indicate your
acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically receives
a license from the original licensors, to run, modify and propagate that
work, subject to this License. You are not responsible for enforcing
compliance by third parties with this License.
An “entity transaction” is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered work
results from an entity transaction, each party to that transaction who
receives a copy of the work also receives whatever licenses to the work
the party's predecessor in interest had or could give under the previous
paragraph, plus a right to possession of the Corresponding Source of the
work from the predecessor in interest, if the predecessor has it or can
get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the rights
granted or affirmed under this License. For example, you may not impose a
license fee, royalty, or other charge for exercise of rights granted
under this License, and you may not initiate litigation (including a
cross-claim or counterclaim in a lawsuit) alleging that any patent claim
is infringed by making, using, selling, offering for sale, or importing
the Program or any portion of it.
11. Patents.
A “contributor” is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The work
thus licensed is called the contributor's “contributor version”.
A contributor's “essential patent claims” are all patent claims owned or
controlled by the contributor, whether already acquired or hereafter
acquired, that would be infringed by some manner, permitted by this
License, of making, using, or selling its contributor version, but do not
include claims that would be infringed only as a consequence of further
modification of the contributor version. For purposes of this definition,
“control” includes the right to grant patent sublicenses in a manner
consistent with the requirements of this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to make,
use, sell, offer for sale, import and otherwise run, modify and propagate
the contents of its contributor version.
In the following three paragraphs, a “patent license” is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To “grant” such a patent license to a party
means to make such an agreement or commitment not to enforce a patent
against the party.
If you convey a covered work, knowingly relying on a patent license, and
the Corresponding Source of the work is not available for anyone to copy,
free of charge and under the terms of this License, through a publicly
available network server or other readily accessible means, then you must
either (1) cause the Corresponding Source to be so available, or (2)
arrange to deprive yourself of the benefit of the patent license for this
particular work, or (3) arrange, in a manner consistent with the
requirements of this License, to extend the patent license to downstream
recipients. “Knowingly relying” means you have actual knowledge that, but
for the patent license, your conveying the covered work in a country, or
your recipient's use of the covered work in a country, would infringe
one or more identifiable patents in that country that you have reason
to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties receiving
the covered work authorizing them to use, propagate, modify or convey a
specific copy of the covered work, then the patent license you grant is
automatically extended to all recipients of the covered work and works
based on it.
A patent license is “discriminatory” if it does not include within the
scope of its coverage, prohibits the exercise of, or is conditioned on
the non-exercise of one or more of the rights that are specifically
granted under this License. You may not convey a covered work if you are
a party to an arrangement with a third party that is in the business of
distributing software, under which you make payment to the third party
based on the extent of your activity of conveying the work, and under
which the third party grants, to any of the parties who would receive the
covered work from you, a discriminatory patent license (a) in connection
with copies of the covered work conveyed by you (or copies made from
those copies), or (b) primarily for and in connection with specific
products or compilations that contain the covered work, unless you
entered into that arrangement, or that patent license was granted, prior
to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting any
implied license or other defenses to infringement that may otherwise be
available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot use,
propagate or convey a covered work so as to satisfy simultaneously your
obligations under this License and any other pertinent obligations, then
as a consequence you may not use, propagate or convey it at all. For
example, if you agree to terms that obligate you to collect a royalty for
further conveying from those to whom you convey the Program, the only way
you could satisfy both those terms and this License would be to refrain
entirely from conveying the Program.
13. Offering the Program as a Service.
If you make the functionality of the Program or a modified version
available to third parties as a service, you must make the Service Source
Code available via network download to everyone at no charge, under the
terms of this License. Making the functionality of the Program or
modified version available to third parties as a service includes,
without limitation, enabling third parties to interact with the
functionality of the Program or modified version remotely through a
computer network, offering a service the value of which entirely or
primarily derives from the value of the Program or modified version, or
offering a service that accomplishes for users the primary purpose of the
Program or modified version.
“Service Source Code” means the Corresponding Source for the Program or
the modified version, and the Corresponding Source for all programs that
you use to make the Program or modified version available as a service,
including, without limitation, management software, user interfaces,
application program interfaces, automation software, monitoring software,
backup software, storage software and hosting software, all such that a
user could run an instance of the service using the Service Source Code
you make available.
14. Revised Versions of this License.
MongoDB, Inc. may publish revised and/or new versions of the Server Side
Public License from time to time. Such new versions will be similar in
spirit to the present version, but may differ in detail to address new
problems or concerns.
Each version is given a distinguishing version number. If the Program
specifies that a certain numbered version of the Server Side Public
License “or any later version” applies to it, you have the option of
following the terms and conditions either of that numbered version or of
any later version published by MongoDB, Inc. If the Program does not
specify a version number of the Server Side Public License, you may
choose any version ever published by MongoDB, Inc.
If the Program specifies that a proxy can decide which future versions of
the Server Side Public License can be used, that proxy's public statement
of acceptance of a version permanently authorizes you to choose that
version for the Program.
Later license versions may give you additional or different permissions.
However, no additional obligations are imposed on any author or copyright
holder as a result of your choosing to follow a later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM “AS IS” WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING
ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF
THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO
LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU
OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER
PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE
POSSIBILITY OF SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided above
cannot be given local legal effect according to their terms, reviewing
courts shall apply local law that most closely approximates an absolute
waiver of all civil liability in connection with the Program, unless a
warranty or assumption of liability accompanies a copy of the Program in
return for a fee.
END OF TERMS AND CONDITIONS

24
webserver/README.md Normal file
View File

@@ -0,0 +1,24 @@
# Webserver (NGINX)
## Introduction
The webserver is the **trains-server**'s component responsible for serving the TRAINS webapp.
For this purpose, we use an [NGINX](https://www.nginx.com/) server.
## Configuration
In order to serve the TRAINS webapp, the following is required:
* The pre-built TRAINS webapp should be copied to the NGINX html directory (usually `/usr/share/nginx/html`)
* The default NGINX port (usually `80`) should be changed to match the **trains-server** configuration (usually `8080`)
NOTE: This configuration may vary in different systems, depending on the NGINX version and distribution used.
#### Example: Centos 7
The following commands can be used to install and run NGINX in the Centos 7 OS:
```bash
yum install nginx
cp -R /path/to/trains-webapp/build/* /var/www/html
systemctl enable nginx
systemctl start nginx
```

View File

@@ -1 +0,0 @@
api_server: "127.0.0.1:8008"

View File

@@ -1,38 +0,0 @@
{
version: 1
disable_existing_loggers: false
formatters: {
standard: {
format: "[%(asctime)s] [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
}
}
handlers {
console {
formatter: standard
class: "logging.StreamHandler"
}
text_file: {
formatter: standard,
backupCount: 3
maxBytes: 10240000,
class: "logging.handlers.RotatingFileHandler",
filename: "/var/log/trains/webserver.log"
}
}
root {
handlers: [console, text_file]
level: INFO
}
loggers {
urllib3 {
handlers: [console, text_file]
level: WARN
propagate: false
}
werkzeug {
handlers: [console, text_file]
level: WARN
propagate: false
}
}
}

View File

@@ -1,15 +0,0 @@
{
http {
session_secret {
webserver: "n(Dtd6!w(QXW^mLmEKTjsWiOODd9gi@SHJMyt4UF*tiYN3Q@!T"
}
}
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
webserver {
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
}
}
}

View File

@@ -1,47 +0,0 @@
# requested token expiration in seconds (one month)
apiserver_token_expiration: 2592000
debug: false
flask {
# Uncomment next line to disable login requirement while testing (or unit-testing)
TESTING: False
# Uncomment to allow reloading of templates if the caches version differs from the latest version
TEMPLATES_AUTO_RELOAD: True
# Flask-Login session protection ('basic', 'strong' or null)
SESSION_PROTECTION: basic
SESSION_COOKIE_HTTPONLY: True
REMEMBER_COOKIE_HTTPONLY: True
SESSION_COOKIE_SECURE: False
REMEMBER_COOKIE_SECURE: False
}
listen {
ip : "0.0.0.0"
port: 8080
}
auth {
cookies {
httponly: true # allow only http to access the cookies (no JS etc)
secure: false # not using HTTPS
domain: null # Limit to localhost is not supported
}
session_auth_cookie_name: "trains_token_basic"
user_token_expiration_sec: 3600
}
docs {
# Default filename used when file not found error is reported when serving docs.
# This usually happens when the path is to a folder and not a file.
default_filename: "index.html"
}
default_company: "d1bd92a3b039400cbafc60a7a5b1e52b"
redirect_to_https: false

View File

@@ -1,20 +0,0 @@
import functools
class Factory:
default_cls = None
registered_cls = None
@classmethod
@functools.lru_cache(maxsize=None)
def get(cls, *args, **kwargs):
return cls.get_class()(*args, **kwargs)
@classmethod
@functools.lru_cache(maxsize=None)
def get_class(cls):
return cls.registered_cls or cls.default_cls
@classmethod
def register(cls, registered_cls):
cls.registered_cls = registered_cls

View File

@@ -1,11 +0,0 @@
Jinja2>=2.9.6
requests>=2.0
Werkzeug==0.12.2
Flask==0.12.2
Flask-Caching==1.4.0
Flask-Login==0.4.1
Flask-Compress==1.4.0
PyJWT==1.3.0
attrs>=18
pyhocon>=0.3.35
furl>=2.0.0

View File

@@ -1,6 +0,0 @@
from factory import Factory
from .simple import SimpleSession
class SessionFactory(Factory):
default_cls = SimpleSession

View File

@@ -1,21 +0,0 @@
import requests
from furl import furl
from requests.auth import HTTPBasicAuth
from config import config
class SimpleSession:
def __init__(self):
self.host = config["hosts.api_server"]
if not self.host.startswith("http"):
self.host = f"http://{self.host}"
self.key = config.get("secure.credentials.webserver.user_key")
self.secret = config.get("secure.credentials.webserver.user_secret")
self.auth = HTTPBasicAuth(self.key, self.secret)
def send_request(self, endpoint, json=None):
url = furl(self.host).set(path=endpoint)
return requests.get(str(url), json=json, auth=self.auth)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

View File

@@ -1,109 +0,0 @@
body {
background-color: #202432;
}
.jumbotron {
background-color: #384161;
color: #c3cdf0;
font-family: Heebo;
margin-bottom: 15px !important;
}
.username-input {
margin-bottom: 20px;
padding: 0 20px;
}
.jumbotron .btn {
background-color: #5a658e;
color: #ffffff;
font-size: 13px;
font-weight: bold;
margin: auto;
padding: 10px 30px;
border:0px solid transparent; /* this was 1px earlier */
}
.jumbotron .btn:hover {
background-color: #c3cdf0;
color: #2c3246;
border:0px solid transparent; /* this was 1px earlier */
}
.login-buttons .btn {
text-align: left !important;
margin-bottom: 10px;
}
.login-buttons .btn i {
margin-right: 10px;
}
.mobile-warn {
font-size: 40px;
display: none;
}
.mobile-warn.show {
display: block;
}
.mobile-warn > strong {
margin-right: 5px;
}
.logo {
width: 60%;
margin-bottom: 30px;
}
.or-container {
position: relative;
margin-top: 50px;
background-color: #384161;
padding-bottom: 20px;
}
.or-container hr{
border-top: 1px solid rgba(255,255,255,0.3);
}
.or {
position: absolute;
left: 0;
right: 0;
top: -10px;
margin: auto;
width: 30px;
font-size: 15px;
font-weight: bold;
text-align: center;
color: #8693be;
background-color: inherit;
}
.github-link {
display: flex;
justify-content: center;
align-items: center;
color: #ffffff;
}
a {
color: #ffffff;
margin-left: 10px;
}
a:hover, a:active, a:focus {
color: #ffffff;
}
.fa.fa-2x {
font-size: 1.8em;
}
.fork-github {
opacity: 0.5;
}
/*# sourceMappingURL=index.css.map */

View File

@@ -1,7 +0,0 @@
{
"version": 3,
"mappings": "AACE,gBAAE;EACA,UAAU,EAAE,eAAe;EAE3B,kBAAE;IACA,YAAY,EAAE,IAAI",
"sources": ["index.scss"],
"names": [],
"file": "index.css"
}

View File

@@ -1,30 +0,0 @@
body {
background-color: #202432;
}
.jumbotron {
background-color: #384161;
color: #c3cdf0;
font-family: Heebo;
}
.jumbotron .btn {
background-color: #2c3246;
color: #c3cdf0;
border:0px solid transparent; /* this was 1px earlier */
}
.jumbotron .btn:hover {
background-color: #c3cdf0;
color: #2c3246;
border:0px solid transparent; /* this was 1px earlier */
}
.login-buttons .btn {
text-align: left !important;
margin-bottom: 10px;
}
.login-buttons .btn i {
margin-right: 10px;
}

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 5.6 KiB

View File

@@ -1,73 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Trains - Magic Version Control & Experiment Manager for AI</title>
<link rel="stylesheet" href="https://maxcdn.bootstrapcdn.com/bootstrap/3.3.7/css/bootstrap.min.css">
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css">
<link rel="stylesheet" href="https://fonts.googleapis.com/css?family=Heebo" type="text/css" media="all"/>
<link rel="stylesheet" href="/static/styles/index.css">
<script charset="utf-8">
function acknowlage() {
var div = document.getElementById('mobile-warn');
div.className = div.className.slice(0, div.className.length - 5);
}
function mobilecheck() {
var check = false;
(function(a){if(/(android|bb\d+|meego).+mobile|avantgo|bada\/|blackberry|blazer|compal|elaine|fennec|hiptop|iemobile|ip(hone|od)|iris|kindle|lge |maemo|midp|mmp|mobile.+firefox|netfront|opera m(ob|in)i|palm( os)?|phone|p(ixi|re)\/|plucker|pocket|psp|series(4|6)0|symbian|treo|up\.(browser|link)|vodafone|wap|windows ce|xda|xiino/i.test(a)||/1207|6310|6590|3gso|4thp|50[1-6]i|770s|802s|a wa|abac|ac(er|oo|s\-)|ai(ko|rn)|al(av|ca|co)|amoi|an(ex|ny|yw)|aptu|ar(ch|go)|as(te|us)|attw|au(di|\-m|r |s )|avan|be(ck|ll|nq)|bi(lb|rd)|bl(ac|az)|br(e|v)w|bumb|bw\-(n|u)|c55\/|capi|ccwa|cdm\-|cell|chtm|cldc|cmd\-|co(mp|nd)|craw|da(it|ll|ng)|dbte|dc\-s|devi|dica|dmob|do(c|p)o|ds(12|\-d)|el(49|ai)|em(l2|ul)|er(ic|k0)|esl8|ez([4-7]0|os|wa|ze)|fetc|fly(\-|_)|g1 u|g560|gene|gf\-5|g\-mo|go(\.w|od)|gr(ad|un)|haie|hcit|hd\-(m|p|t)|hei\-|hi(pt|ta)|hp( i|ip)|hs\-c|ht(c(\-| |_|a|g|p|s|t)|tp)|hu(aw|tc)|i\-(20|go|ma)|i230|iac( |\-|\/)|ibro|idea|ig01|ikom|im1k|inno|ipaq|iris|ja(t|v)a|jbro|jemu|jigs|kddi|keji|kgt( |\/)|klon|kpt |kwc\-|kyo(c|k)|le(no|xi)|lg( g|\/(k|l|u)|50|54|\-[a-w])|libw|lynx|m1\-w|m3ga|m50\/|ma(te|ui|xo)|mc(01|21|ca)|m\-cr|me(rc|ri)|mi(o8|oa|ts)|mmef|mo(01|02|bi|de|do|t(\-| |o|v)|zz)|mt(50|p1|v )|mwbp|mywa|n10[0-2]|n20[2-3]|n30(0|2)|n50(0|2|5)|n7(0(0|1)|10)|ne((c|m)\-|on|tf|wf|wg|wt)|nok(6|i)|nzph|o2im|op(ti|wv)|oran|owg1|p800|pan(a|d|t)|pdxg|pg(13|\-([1-8]|c))|phil|pire|pl(ay|uc)|pn\-2|po(ck|rt|se)|prox|psio|pt\-g|qa\-a|qc(07|12|21|32|60|\-[2-7]|i\-)|qtek|r380|r600|raks|rim9|ro(ve|zo)|s55\/|sa(ge|ma|mm|ms|ny|va)|sc(01|h\-|oo|p\-)|sdk\/|se(c(\-|0|1)|47|mc|nd|ri)|sgh\-|shar|sie(\-|m)|sk\-0|sl(45|id)|sm(al|ar|b3|it|t5)|so(ft|ny)|sp(01|h\-|v\-|v )|sy(01|mb)|t2(18|50)|t6(00|10|18)|ta(gt|lk)|tcl\-|tdg\-|tel(i|m)|tim\-|t\-mo|to(pl|sh)|ts(70|m\-|m3|m5)|tx\-9|up(\.b|g1|si)|utst|v400|v750|veri|vi(rg|te)|vk(40|5[0-3]|\-v)|vm40|voda|vulc|vx(52|53|60|61|70|80|81|83|85|98)|w3c(\-| )|webc|whit|wi(g |nc|nw)|wmlb|wonu|x700|yas\-|your|zeto|zte\-/i.test(a.substr(0,4))) check = true;})(navigator.userAgent||navigator.vendor||window.opera);
return check;
};
document.addEventListener('DOMContentLoaded', function(){
var div = document.getElementById('mobile-warn');
if(mobilecheck()) {
div.className += ' show';
}
}, false);
</script>
{% block head %}
{% endblock %}
</head>
<body>
<div class="container text-center">
<br/><br/><br/>
<div id="mobile-warn" class="alert alert-warning mobile-warn" role="alert">
<button type="button" class="close" data-dismiss="alert" aria-label="Close" onclick="acknowlage()">
<span aria-hidden="true">&times;</span></button>
<strong>Notice!</strong> The Web-App is not yet optimized for mobile devices.
</div>
<div class="row">
{% set pwidth = popup_width or 4 %}
<div class="col col-md-{{ ((12 - pwidth) / 2)|int }}"></div>
<div class="col col-md-{{ pwidth }} jumbotron" style="border-radius: 0px; !important;">
<div class="row login-section">
<section class="mx-auto">
<img src="/static/trains-wht-svg.svg" alt="TRAINS" class="logo"/>
<br class="my-4 nav-divider"/>
{% block popup %}
{% endblock %}
</section>
</div>
</div>
</div>
<a class="fork-github" href="https://github.com/allegroai/trains/issues">Questions or comments?</a>
{% if get_flashed_messages() %}
<div class="row">
<div class="col col-md-2"></div>
<div class="col col-md-8">
{% block messages %}
{% endblock %}
</div>
<div class="col col-md-2"></div>
{% endif %}
</div>
</div>
</body>
</html>

View File

@@ -1,42 +0,0 @@
{% extends "base.html" %}
{% block head %}
{% endblock %}
{% block popup %}
<form role="form" class='form-horizontal' name="login_form" action="/create_user" method="post">
<div class="username-input">
<input id="username-input" name="name" type="text" class="form-control" value="" placeholder="Full name" required/>
</div>
<div >
<button class="btn btn-primary" type="submit">LOGIN<span class="new-user-login"> AS NEW USER</span></button>
</div>
</form>
<div class="btn-group-vertical btn-group-lg login-buttons" style="width: 100%">
{% for user in users %}
<a class="btn btn-default provider" style="border-radius: 0px; !important;" href="{{ url_for('login_by_id', user_id=user.id, **request.args) }}">
Login as <b>{{ user.name|capitalize }}</b>
</a>
{% endfor %}
</div>
<div class="or-container">
<hr/>
<div class="or"> OR </div>
</div>
<div class="github-link">
<i class="fa fa-2x fa-github" aria-hidden="true"></i>
<a href="https://github.com/allegroai/trains">Fork on Github</a>
</div>
{% endblock %}
{% block messages %}
<hr class="my-4 nav-divider"/>
{% for message in get_flashed_messages() %}
<div class="alert alert-danger text-left" role="alert">
<span class="glyphicon glyphicon-exclamation-sign" aria-hidden="true"></span>
<span class="sr-only">Error:</span>
{{ message }}
</div>
{% endfor %}
{% endblock %}

View File

@@ -1,7 +0,0 @@
from factory import Factory
from .simple import SimpleUser, CreateUserError, AUTH_TOKEN_COOKIE_KEY
class UserFactory(Factory):
default_cls = SimpleUser

View File

@@ -1,133 +0,0 @@
import functools
import re
from uuid import uuid4
import attr
from flask import request
from flask_login import UserMixin
from config import config
from session import SessionFactory
log = config.logger(__file__)
AUTH_TOKEN_COOKIE_KEY = config["webserver.auth.session_auth_cookie_name"]
@attr.s(auto_attribs=True)
class UserData:
id: str = None
company: str = None
name: str = None
family_name: str = None
given_name: str = None
@classmethod
def from_dict(cls, d):
return cls(**{k: v for k, v in d.items() if k in attr.fields_dict(cls)})
class CreateUserError(Exception):
pass
class UsersGetAllError(Exception):
def __init__(self, res, *args):
self.res = res
super(UsersGetAllError, self).__init__(res, *args)
class SimpleUser(UserMixin):
_cache = None
@property
def user_data(self) -> UserData:
return self._user_data
@property
def token(self):
return self._get_token()
def __init__(self, user_data: UserData):
super(SimpleUser, self).__init__()
self._user_data = user_data
def get_id(self):
return self._user_data.id
@classmethod
def get(cls, user_id):
res = SessionFactory.get().send_request(
"users.get_by_id", json={"user": user_id}
)
if not res.ok:
return None
return cls(user_data=UserData.from_dict(res.json()["data"]["user"]))
@classmethod
def get_all(cls):
res = SessionFactory.get().send_request("users.get_all")
if not res.ok:
raise UsersGetAllError(res)
return [
cls(user_data=UserData.from_dict(user))
for user in res.json()["data"]["users"]
]
@classmethod
def create_by_name(cls, name: str):
name = re.sub(r"\s+", " ", name.strip())
existing_user = next(
(user for user in cls.get_all() if user.user_data.name.lower() == name.lower()),
None,
)
if existing_user:
return existing_user
company_id = config.get("webserver.default_company")
unique_email = f"{str(uuid4()).replace('-', '')}@example.com"
given_name, _, family_name = name.partition(" ")
res = SessionFactory.get().send_request(
"auth.create_user",
json={
"email": unique_email,
"name": name,
"company": company_id,
"given_name": given_name,
"family_name": family_name,
},
)
if not res.ok:
resp = res.json()
log.error(f"Failed creating user {name} ({resp['meta']})")
raise CreateUserError(
f"Failed creating user: {res.json().get(resp['meta']['result_msg'])}"
)
return cls.get(res.json()["data"]["id"])
@property
def is_authenticated(self) -> bool:
if AUTH_TOKEN_COOKIE_KEY not in request.cookies:
return False
token = request.cookies[AUTH_TOKEN_COOKIE_KEY]
# Assume we're authenticated if we have a token
return bool(token)
@functools.lru_cache(maxsize=None)
def _get_token(self):
res = SessionFactory.get().send_request(
"auth.get_token_for_user",
json={"user": self._user_data.id, "company": self._user_data.company},
)
if not res.ok:
log.error(
f"Failed generating token for user {self._user_data.id} ({res.json()['meta']})"
)
raise ValueError(f"Failed generating token for user {self._user_data.id}")
return res.json()["data"]["token"]

View File

@@ -1,261 +0,0 @@
import json
from argparse import ArgumentParser
from operator import attrgetter
from os.path import join
from flask import (
Flask,
render_template,
request,
redirect,
url_for,
send_from_directory,
flash,
make_response,
send_file,
session,
)
from flask_compress import Compress
from flask_login import (
LoginManager,
login_required,
logout_user,
login_user,
current_user,
)
from config import config
from user import (
UserFactory,
AUTH_TOKEN_COOKIE_KEY,
CreateUserError,
)
from user.simple import UsersGetAllError
log = config.logger(__file__)
log.info("################ Web Server initializing #####################")
app = Flask(__name__)
app.config.from_mapping(config.get("webserver.flask"))
app.config.update(
SECRET_KEY=config["secure.http.session_secret.webserver"],
COMPRESS_MIMETYPES=[
"text/html",
"text/css",
"text/xml",
"application/json",
"application/javascript",
"text/javascript",
],
COMPRESS_LEVEL=9,
)
login_manager = LoginManager(app)
login_manager.login_view = "login"
Compress(app)
def _secure_url_for(endpoint, external=True, **values):
if not config.get("webserver.redirect_to_https", False):
return url_for(endpoint, _external=external, **values)
return url_for(endpoint, _external=external, _scheme="https", **values)
@login_manager.user_loader
def load_user(id):
return UserFactory.get_class().get(id)
@login_manager.unauthorized_handler
def handle_needs_login():
session.pop("_flashes", None)
flash("You have to be logged in to access this page.")
return redirect(
_secure_url_for(".login", next=request.endpoint, **request.view_args)
)
@app.route("/create_user", methods=["GET", "POST"])
def create_user():
data = request.args or request.form
name = data["name"].strip()
try:
user = UserFactory.get_class().create_by_name(name)
except CreateUserError as e:
session.pop("_flashes", None)
message = e.args[0]
if "value combination" in message.lower():
message = f"Failed creating user {name}"
flash(message)
return redirect(_secure_url_for(".login"))
return _complete_user_login(user)
@app.route("/login/<user_id>")
def login_by_id(user_id):
if current_user.get_id() == user_id and current_user.is_authenticated:
return redirect(_secure_url_for(".login"))
try:
user = load_user(user_id)
except Exception as e:
# Some callback issue, try again.
# For example, the user tried to reload the callback page and the provider token was already redeemed.
session.pop("_flashes", None)
if e.args:
flash(e.args[0])
else:
flash(repr(e))
return redirect(_secure_url_for(".login"))
args = dict(request.args)
endpoint = args.pop("next", ".index")
return _complete_user_login(user, endpoint, **request.args)
def _complete_user_login(user, endpoint=".index", **kwargs):
login_user(user, True)
response = redirect(_secure_url_for(endpoint, **kwargs))
set_response_cookie(response, user)
return response
def set_response_cookie(response, user, copy_request=None):
if copy_request and AUTH_TOKEN_COOKIE_KEY in copy_request.cookies:
token = request.cookies[AUTH_TOKEN_COOKIE_KEY]
else:
token = user.token
response.set_cookie(
AUTH_TOKEN_COOKIE_KEY,
token,
**config.get("webserver.auth.cookies"),
)
@app.route("/logout")
def logout():
logout_user()
return redirect(_secure_url_for(".login"))
@app.route("/static/network/")
@login_required
def send_netron():
return send_from_directory("static", join("network", "index.html"))
@app.route("/static/<path:path>")
@login_required
def send_static(path):
return send_from_directory("static", path)
@app.route("/webapp_conf.js")
def webapp_conf():
webapp_conf = _get_webapp_conf()
response = make_response(f"SM_CONFIG={json.dumps(webapp_conf)}")
response.content_type = "text/javascript"
return response
def _get_webapp_conf():
webapp_conf = config.get("webapp")
webapp_build = None
webapp_path = "pages/webapp"
try:
with open(f"{webapp_path}/assets/build.json") as data_file:
webapp_build = json.load(data_file)
webapp_build["branch"] = webapp_build["branch"].replace("@", "")
except Exception:
webapp_build = {}
backend_path = "."
try:
with open(f"{backend_path}/static/build.json") as data_file:
backend_build = json.load(data_file)
except Exception:
backend_build = {}
webapp_conf.put("backend_build", backend_build)
webapp_conf.put("webapp_build", webapp_build)
webapp_conf.put("env", config.env)
return webapp_conf
@app.route("/login")
def login():
try:
users = sorted((user.user_data for user in UserFactory.get_class().get_all()), key=attrgetter("name"))
except UsersGetAllError as ex:
log.warning("error when getting users: %r", ex)
try:
result_msg = ex.res.json()["meta"]["result_msg"]
except Exception as ex:
log.exception("error when getting users: error when parsing error data")
flash(f"Unknown error: {ex!r} (check logs for more info)")
else:
flash(result_msg)
return make_response(render_template("login.html", users=[]))
response = make_response(render_template("login.html", users=users))
# make sure to clear out session token cookie
response.set_cookie(
AUTH_TOKEN_COOKIE_KEY, "", expires=0, **config.get("webserver.auth.cookies")
)
return response
def _serve_webapp(path=None):
if not path:
response = make_response(send_file("pages/webapp/index.html"))
else:
try:
response = make_response(send_from_directory("pages/webapp", path))
except Exception:
response = make_response(send_file("pages/webapp/index.html"))
set_response_cookie(response, current_user, request)
response.headers["X-Trains-Environment"] = "oss"
return response
@app.route("/favicon.ico")
def favicon():
return send_from_directory("static", "favicon.ico")
@app.route("/")
def index():
if not current_user.is_authenticated:
return redirect(_secure_url_for(".login"))
return _serve_webapp()
@app.route("/<path:path>")
@login_required
def webapp(path=None):
return _serve_webapp(path)
def parse_args():
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"--port", "-p", type=int, default=config.get("webserver.listen.port")
)
parser.add_argument("--ip", "-i", type=str, default=config.get("webserver.listen.ip"))
parser.add_argument(
"--debug", action="store_true", default=config.get("webserver.debug")
)
return parser.parse_args()
def main():
args = parse_args()
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)
if __name__ == "__main__":
main()