Compare commits

42 Commits

Author SHA1 Message Date
allegroai
2c3f0e4ba3 Update AWS images 0.12.1 2019-12-14 23:46:21 +02:00
allegroai
c48eb34d8d Add resource monitoring 2019-12-14 23:35:42 +02:00
allegroai
49515e06e1 Optimize thread processing 2019-12-14 23:35:18 +02:00
allegroai
4a1d97c02f typo 2019-12-14 23:34:00 +02:00
allegroai
6c6c1c3f41 Add server resource monitoring 2019-12-14 23:33:36 +02:00
allegroai
0ad687008c Improve server update checks 2019-12-14 23:33:04 +02:00
Allegro AI
fe3dbc92dc Update README.md 2019-11-19 00:14:45 +02:00
Allegro AI
dc53970ff0 Update README.md 2019-11-19 00:01:12 +02:00
Allegro AI
73592b991b Update README.md 2019-11-16 00:10:19 +02:00
Allegro AI
47b981a993 Update README.md 2019-11-16 00:08:36 +02:00
Allegro AI
b500bcab0b Update faq.md 2019-11-16 00:07:30 +02:00
allegroai
59e910db1a Add docker-compose Windows support 2019-11-16 00:04:04 +02:00
allegroai
2ecb430f02 Documentation 2019-11-10 00:23:45 +02:00
Allegro AI
a08722e394 Update README.md 2019-11-10 00:18:16 +02:00
Allegro AI
67c210d9d7 Update README.md 2019-11-10 00:14:30 +02:00
Allegro AI
101ba540f4 Update README.md 2019-11-10 00:08:52 +02:00
Allegro AI
82fc28d477 Update README.md 2019-11-10 00:06:12 +02:00
Allegro AI
7b73f699d2 Update README.md 2019-11-10 00:05:21 +02:00
allegroai
a7e5380f67 Add configuration example, experiments watchdog 2019-11-10 00:03:57 +02:00
allegroai
bcade31786 Add configuration example, limit user login 2019-11-09 23:59:08 +02:00
Allegro AI
6b902f85f4 Update README.md 2019-11-09 23:54:59 +02:00
allegroai
6d4c974045 Documentation 2019-11-09 23:45:12 +02:00
allegroai
2346c6f3f5 Documentation 2019-11-09 23:19:21 +02:00
Allegro AI
82e51b4d36 Update README.md 2019-11-09 23:07:43 +02:00
allegroai
e63599254e Documentation 2019-11-09 21:32:30 +02:00
allegroai
8e7e234161 Add finer control for mongo/elastic/redis host configuration 2019-11-09 21:29:23 +02:00
allegroai
17d94b26c3 Documentation 2019-11-06 12:25:39 +02:00
allegroai
1e701becd3 Upgrade to v0.12 2019-10-29 20:43:46 +02:00
allegroai
18c8dd449d Fix jupyter support 2019-10-29 20:43:40 +02:00
allegroai
50031c4d6d Upgrade to v0.12 2019-10-29 20:37:29 +02:00
allegroai
6101dc4f11 Add check for server updates 2019-10-28 21:49:16 +02:00
allegroai
5d17059cbe Improve docker compose support 2019-10-27 00:10:08 +03:00
allegroai
b93e843143 Add schema files 2019-10-26 01:14:47 +03:00
allegroai
1a732ccd8e Add API version 2.4 with new trains-server capabilities including DevOps and scheduling 2019-10-25 15:36:58 +03:00
allegroai
2ea25e498f Removed redundant license file 2019-10-22 22:54:20 +03:00
allegroai
1b1cdb34ad Fix docker compose file, replaced deprecated 'links' with 'depends_on' 2019-10-22 18:43:27 +03:00
allegroai
e171a8b523 Update auto update AMI images 2019-10-16 23:28:59 +03:00
allegroai
539b76d362 Fix increase mongodb memory limit for large queries 2019-10-12 21:37:24 +03:00
Allegro AI
64b5e1f1f0 Update faq.md 2019-10-07 14:04:08 +03:00
Allegro AI
6a1eb9cea0 Update docker_setup.md 2019-10-07 14:03:08 +03:00
Allegro AI
24907b4eaa Update README.md 2019-10-07 14:00:51 +03:00
allegroai
efc540b837 Documentation 2019-09-25 17:52:41 +03:00
79 changed files with 5906 additions and 640 deletions

2
.gitignore vendored
View File

@@ -4,6 +4,8 @@ static/build.json
static/dashboard/node_modules
static/webapp/node_modules
static/webapp/.git
scripts/
generators/
*.pyc
__pycache__
.ropeproject

260
README.md
View File

@@ -11,7 +11,7 @@
The **trains-server** is the backend service infrastructure for [TRAINS](https://github.com/allegroai/trains).
It allows multiple users to collaborate and manage their experiments.
By default, TRAINS is set up to work with the TRAINS demo server, which is open to anyone and resets periodically.
By default, TRAINS is set up to work with the TRAINS demo server, which is open to anyone and resets periodically.
In order to host your own server, you will need to install **trains-server** and point TRAINS to it.
**trains-server** contains the following components:
@@ -23,9 +23,9 @@ In order to host your own server, you will need to install **trains-server** and
* Locally-hosted file server for storing images and models making them easily accessible using the Web-App
You can quickly setup your **trains-server** using:
- [Docker Installation](#installation)
- [Docker Installation](#installation)
- Pre-built Amazon [AWS image](#aws)
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#trains-server-for-kubernetes-clusters-using-helm)
- [Kubernetes Helm](https://github.com/allegroai/trains-server-helm#trains-server-for-kubernetes-clusters-using-helm)
or manual [Kubernetes installation](https://github.com/allegroai/trains-server-k8s#trains-server-for-kubernetes-clusters)
@@ -36,83 +36,99 @@ You can quickly setup your **trains-server** using:
**trains-server** has two supported configurations:
- Single IP (domain) with the following open ports
- Web application on port 8080
- Web application on port 8080
- API service on port 8008
- File storage service on port 8081
- Sub-Domain configuration with default http/s ports (80 or 443)
- Web application on sub-domain: app.\*.\*
- API service on sub-domain: api.\*.\*
- File storage service on sub-domain: files.\*.\*
## Install / Upgrade - AWS <a name="aws"></a>
Use one of our pre-installed Amazon Machine Images for easy deployment in AWS.
Use one of our pre-installed Amazon Machine Images for easy deployment in AWS.
For details and instructions, see [TRAINS-server: AWS pre-installed images](docs/install_aws.md).
## Docker Installation - Linux, Mac OS X <a name="installation"></a>
## Docker Installation - Linux, macOS, and Windows <a name="installation"></a>
Use our pre-built Docker image for easy deployment in Linux and Mac OS X.
For Windows, we recommend installing our pre-built Docker image on a Linux virtual machine.
Use our pre-built Docker image for easy deployment in Linux and macOS. <br>
For [Windows](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#docker_compose_win10), please see detailed docker-compose installation instructions on our [FAQ](https://github.com/allegroai/trains-server/blob/master/docs/faq.md#docker_compose_win10).<br>
Latest docker images can be found [here](https://hub.docker.com/r/allegroai/trains).
1. Setup Docker ([docker-compose Ubuntu](docs/faq.md#ubuntu), [docker-compose OS X](docs/faq.md#mac-osx), [Setup Docker Service Manually](docs/docker_setup.md#setup-docker))
1. Setup Docker (docker-compose installation details: [Ubuntu](docs/faq.md#ubuntu) / [macOS](docs/faq.md#mac-osx))
Make sure port 8080/8081/8008 are available for the `trains-server` services
<details>
<summary>Make sure ports 8080/8081/8008 are available for the TRAINS-server services:</summary>
For example, to see if port `8080` is in use:
```bash
$ sudo lsof -Pn -i4 | grep :8080 | grep LISTEN
```
</details>
Increase vm.max_map_count for `ElasticSearch` docker
```bash
echo "vm.max_map_count=262144" > /tmp/99-trains.conf
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
sudo service docker restart
```
- Linux
```bash
$ echo "vm.max_map_count=262144" > /tmp/99-trains.conf
$ sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
$ sudo sysctl -w vm.max_map_count=262144
$ sudo service docker restart
```
- macOS
```bash
$ screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
$ sysctl -w vm.max_map_count=262144
```
1. Create local directories for the databases and storage.
```bash
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/data/fileserver
```
Linux
```bash
$ sudo chown -R 1000:1000 /opt/trains
$ sudo mkdir -p /opt/trains/data/elastic
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/data/redis
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo mkdir -p /opt/trains/config
```
Mac OS X
```bash
$ sudo chown -R $(whoami):staff /opt/trains
Set folder permissions
- Linux
```bash
$ sudo chown -R 1000:1000 /opt/trains
```
- macOS
```bash
$ sudo chown -R $(whoami):staff /opt/trains
```
1. Download the `docker-compose.yml` file, either download [manually](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml) or execute:
```bash
$ curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
```
1. Clone the [trains-server](https://github.com/allegroai/trains-server) repository and change directories to the new **trains-server** directory.
```bash
$ git clone https://github.com/allegroai/trains-server.git
$ cd trains-server
```
1. Launch the Docker containers <a name="launch-docker"></a>
* Automatically with docker-compose (details: [Linux/Ubuntu](docs/faq.md#ubuntu), [OS X](docs/faq.md#mac-osx))
```bash
$ docker-compose up
```bash
$ docker-compose -f docker-compose.yml up
```
* Manually, see [Launching Docker Containers Manually](docs/docker_setup.md#launch) for instructions.
1. Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* Web server on port `8080`
* API server on port `8008`
* File server on port `8081`
**\* If something went wrong along the way, check our FAQ: [Docker Setup](docs/docker_setup.md#setup-docker), [Ubuntu Support](docs/faq.md#ubuntu), [macOS Support](docs/faq.md#mac-osx)**
## Optional Configuration
The **trains-server** default configuration can be easily overridden using external configuration files. By default, the server will look for these files in `/opt/trains/config`.
@@ -125,93 +141,60 @@ By default anyone can login to the **trains-server** Web-App.
You can configure the **trains-server** to allow only a specific set of users to access the system.
Enable this feature by placing `apiserver.conf` file under `/opt/trains/config`.
Sample fixed user configuration file `/opt/trains/config/apiserver.conf`:
Sample `apiserver.conf` configuration file can be found [here](https://github.com/allegroai/trains-server/blob/master/docs/apiserver.conf)
auth {
# Fixed users login credetials
# No other user will be able to login
fixed_users {
enabled: true
users: [
{
username: "jane"
password: "12345678"
name: "Jane Doe"
},
{
username: "john"
password: "12345678"
name: "John Doe"
},
]
}
}
To apply the `apiserver.conf` changes, you must restart the *trains-apiserver* (docker) (see [Restarting trains-server](#restart-server)).
To apply the changes, you must [restart the *trains-server*](#restart-server).
### Configuring the Non-Responsive Experiments Watchdog
The non-responsive experiment watchdog, monitors experiments that were not updated for a given period of time,
The non-responsive experiment watchdog, monitors experiments that were not updated for a given period of time,
and marks them as `aborted`. The watchdog is always active with a default of 7200 seconds (2 hours) of inactivity threshold.
To change the watchdog's timeouts, place a `services.conf` file under `/opt/trains/config`.
Sample watchdog configuration file `/opt/trains/config/services.conf`:
Sample watchdog `services.conf` configuration file can be found [here](https://github.com/allegroai/trains-server/blob/master/docs/services.conf)
tasks {
non_responsive_tasks_watchdog {
# In-progress tasks that haven't been updated for at least 'value' seconds will be stopped by the watchdog
threshold_sec: 7200
# Watchdog will sleep for this number of seconds after each cycle
watch_interval_sec: 900
}
}
To apply the `services.conf` changes, you must restart the *trains-apiserver* (docker) (see [Restarting trains-server](#restart-server)).
To apply the changes, you must [restart the *trains-server*](#restart-server).
### Restarting trains-server <a name="restart-server"></a>
To restart the **trains-server**, you must first stop and remove the containers, and then restart.
To restart the **trains-server**, you must first stop the containers, and then restart them.
```bash
$ docker-compose down
$ docker-compose -f docker-compose.yml up
```
1. Restarting docker-compose containers.
$ docker-compose down
$ docker-compose up
1. Manually restarting dockers [instructions](docs/docker_setup.md#launch).
## Configuring **TRAINS** client
Once you have installed the **trains-server**, make sure to configure **TRAINS** [client](https://github.com/allegroai/trains)
Once you have installed the **trains-server**, make sure to configure **TRAINS** [client](https://github.com/allegroai/trains)
to use your locally installed server (and not the demo server).
- Run the `trains-init` command for an interactive setup
- Run the `trains-init` command for an interactive setup
- Or manually edit `~/trains.conf` file, making sure the `api_server` value is configured correctly, for example:
api {
# API server on port 8008
api_server: "http://localhost:8008"
# web_server on port 8080
web_server: "http://localhost:8080"
# file server on port 8081
files_server: "http://localhost:8081"
}
* Notice that if you setup **trains-server** in a sub-domain configuration, there is no need to specify a port number,
* Notice that if you setup **trains-server** in a sub-domain configuration, there is no need to specify a port number,
it will be inferred from the http/s scheme.
See [Installing and Configuring TRAINS](https://github.com/allegroai/trains#configuration) for more details.
## What next?
Now that the **trains-server** is installed, and TRAINS is configured to use it,
you can [use](https://github.com/allegroai/trains#using-trains) TRAINS in your experiments and view them in the web server,
Now that the **trains-server** is installed, and TRAINS is configured to use it,
you can [use](https://github.com/allegroai/trains#using-trains) TRAINS in your experiments and view them in the web server,
for example http://localhost:8080
## Upgrading <a name="upgrade"></a>
@@ -220,56 +203,43 @@ We are constantly updating, improving and adding to the **trains-server**.
New releases will include new pre-built Docker images.
When we release a new version and include a new pre-built Docker image for it, upgrade as follows:
1. Shut down and remove each of your Docker instances using the following commands:
1. Shut down the docker containers
```bash
$ docker-compose down
```
* Using Docker-Compose
```bash
$ docker-compose down
```
1. We highly recommend backing up your data directory before upgrading.
* Manual Docker launching
```bash
$ sudo docker stop <docker-name>
$ sudo docker rm -v <docker-name>
```
The Docker names are (see [Launching Docker Containers](#launch-docker)):
* `trains-elastic`
* `trains-mongo`
* `trains-fileserver`
* `trains-apiserver`
* `trains-webserver`
Assuming your data directory is `/opt/trains`, to archive all data into `~/trains_backup.tgz` execute:
2. We highly recommend backing up your data directory!. A simple way to do that is using `tar`:
```bash
$ sudo tar czvf ~/trains_backup.tgz /opt/trains/data
```
For example, if your data directory is `/opt/trains`, use the following command:
```bash
$ sudo tar czvf ~/trains_backup.tgz /opt/trains/data
```
This backups all data to an archive in your home directory.
<details>
<summary>Restore instructions:</summary>
To restore this example backup, use the following command:
```bash
$ sudo rm -R /opt/trains/data
$ sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```
3. Pull the new **trains-server** docker image using the following command:
To restore this example backup, execute:
```bash
$ sudo rm -R /opt/trains/data
$ sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```
</details>
1. Download the latest `docker-compose.yml` file, either [manually](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml) or execute:
```bash
$ curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
```
1. Spin up the docker containers, it will automatically pull the latest trains-server build
```bash
$ docker-compose -f docker-compose.yml pull
$ docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Docker Upgrade](docs/docker_setup.md#common-docker-upgrade-errors)**
```bash
$ sudo docker pull allegroai/trains:latest
```
If you wish to pull a different version, replace `latest` with the required version number, for example:
```bash
$ sudo docker pull allegroai/trains:0.10.1
```
4. Launch the newly released Docker image (see [Launching Docker Containers](#launch-docker)).
## Community & Support
@@ -285,9 +255,9 @@ Additionally, you can always find us at *trains@allegro.ai*
[Server Side Public License v1.0](https://github.com/mongodb/mongo/blob/master/LICENSE-Community.txt)
**trains-server** relies on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our responsibility as a
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our responsibility as a
member of the community to support the projects we love and cherish.
We believe the cause for the license change in both cases is more than just,
We believe the cause for the license change in both cases is more than just,
and chose [SSPL](https://www.mongodb.com/licensing/server-side-public-license) because it is the more general and flexible of the two licenses.
This is our way to say - we support you guys!

View File

@@ -11,20 +11,18 @@ services:
- 8008:8008
- 8080:80
- 8081:8081
restart: always
restart: unless-stopped
volumes:
- type: bind
source: /opt/trains/logs
target: /var/log/trains
- type: bind
source: /opt/trains/data/fileserver
target: /mnt/fileserver
links:
- mongo:mongo
- elasticsearch:elasticsearch
- /opt/trains/logs:/var/log/trains
- /opt/trains/data/fileserver:/mnt/fileserver
depends_on:
- redis
- mongo
- elasticsearch
environment:
ELASTIC_SERVICE_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_SERVICE_HOST: mongo
ELASTIC_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_HOST: mongo
REDIS_SERVICE_HOST: redis
networks:
- backend
elasticsearch:
@@ -52,12 +50,13 @@ services:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
restart: always
restart: unless-stopped
volumes:
- type: bind
source: /opt/trains/data/elastic
target: /usr/share/elasticsearch/data
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
ports:
- "9200:9200"
mongo:
@@ -65,16 +64,23 @@ services:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: always
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- type: bind
source: /opt/trains/data/mongo/db
target: /data/db
- type: bind
source: /opt/trains/data/mongo/configdb
target: /data/configdb
- /opt/trains/data/mongo/db:/data/db
- /opt/trains/data/mongo/configdb:/data/configdb
ports:
- "27017:27017"
redis:
networks:
- backend
container_name: trains-redis
image: redis:5.0
restart: unless-stopped
volumes:
- /opt/trains/data/redis:/data
ports:
- "6379:6379"
networks:
backend:

117
docker-compose-win10.yml Normal file
View File

@@ -0,0 +1,117 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: trains-apiserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/opt/trains/logs:/var/log/trains
- c:/opt/trains/config:/opt/trains/config
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
ELASTIC_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_HOST: mongo
REDIS_SERVICE_HOST: redis
ports:
- "8008:8008"
networks:
- backend
elasticsearch:
networks:
- backend
container_name: trains-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
bootstrap.memory_lock: "true"
cluster.name: trains
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
discovery.zen.minimum_master_nodes: "1"
http.compression_level: "7"
node.ingest: "true"
node.name: trains
reindex.remote.whitelist: '*.*'
script.inline: "true"
script.painless.regex.enabled: "true"
script.update: "true"
thread_pool.bulk.queue_size: "2000"
thread_pool.search.queue_size: "10000"
xpack.monitoring.enabled: "false"
xpack.security.enabled: "false"
ulimits:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
restart: unless-stopped
volumes:
- c:/opt/trains/data/elastic:/usr/share/elasticsearch/data
ports:
- "9200:9200"
fileserver:
networks:
- backend
command:
- fileserver
container_name: trains-fileserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/opt/trains/logs:/var/log/trains
- c:/opt/trains/data/fileserver:/mnt/fileserver
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- mongodata:/data
ports:
- "27017:27017"
redis:
networks:
- backend
container_name: trains-redis
image: redis:5.0
restart: unless-stopped
volumes:
- c:/opt/trains/data/redis:/data
ports:
- "6379:6379"
webserver:
command:
- webserver
container_name: trains-webserver
image: allegroai/trains:latest
restart: unless-stopped
volumes:
- c:/trains/logs:/var/log/trains
depends_on:
- apiserver
ports:
- "8080:80"
networks:
backend:
driver: bridge
volumes:
mongodata:

View File

@@ -1,29 +1,29 @@
version: "3.6"
services:
apiserver:
command:
- apiserver
container_name: trains-apiserver
image: allegroai/trains:latest
restart: always
restart: unless-stopped
volumes:
- type: bind
source: /opt/trains/logs
target: /var/log/trains
- type: bind
source: /opt/trains/config
target: /opt/trains/config
links:
- mongo:mongo
- elasticsearch:elasticsearch
- fileserver:fileserver
- /opt/trains/logs:/var/log/trains
- /opt/trains/config:/opt/trains/config
depends_on:
- redis
- mongo
- elasticsearch
- fileserver
environment:
ELASTIC_SERVICE_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_SERVICE_HOST: mongo
ELASTIC_SERVICE_HOST: elasticsearch
MONGODB_SERVICE_HOST: mongo
REDIS_SERVICE_HOST: redis
ports:
- "8008:8008"
networks:
- backend
elasticsearch:
networks:
- backend
@@ -49,14 +49,16 @@ services:
memlock:
soft: -1
hard: -1
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:5.6.16
restart: always
restart: unless-stopped
volumes:
- type: bind
source: /opt/trains/data/elastic
target: /usr/share/elasticsearch/data
- /opt/trains/data/elastic:/usr/share/elasticsearch/data
ports:
- "9200:9200"
fileserver:
networks:
- backend
@@ -64,44 +66,46 @@ services:
- fileserver
container_name: trains-fileserver
image: allegroai/trains:latest
restart: always
restart: unless-stopped
volumes:
- type: bind
source: /opt/trains/logs
target: /var/log/trains
- type: bind
source: /opt/trains/data/fileserver
target: /mnt/fileserver
- /opt/trains/logs:/var/log/trains
- /opt/trains/data/fileserver:/mnt/fileserver
ports:
- "8081:8081"
mongo:
networks:
- backend
container_name: trains-mongo
image: mongo:3.6.5
restart: always
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
- type: bind
source: /opt/trains/data/mongo/db
target: /data/db
- type: bind
source: /opt/trains/data/mongo/configdb
target: /data/configdb
- /opt/trains/data/mongo/db:/data/db
- /opt/trains/data/mongo/configdb:/data/configdb
ports:
- "27017:27017"
webserver:
redis:
networks:
- backend
container_name: trains-redis
image: redis:5.0
restart: unless-stopped
volumes:
- /opt/trains/data/redis:/data
ports:
- "6379:6379"
webserver:
command:
- webserver
container_name: trains-webserver
image: allegroai/trains:latest
restart: always
restart: unless-stopped
volumes:
- type: bind
source: /opt/trains/logs
target: /var/log/trains
links:
- /opt/trains/logs:/var/log/trains
depends_on:
- apiserver
ports:
- "8080:80"

19
docs/apiserver.conf Normal file
View File

@@ -0,0 +1,19 @@
auth {
# Fixed users login credetials
# No other user will be able to login
fixed_users {
enabled: true
users: [
{
username: "jane"
password: "12345678"
name: "Jane Doe"
},
{
username: "john"
password: "12345678"
name: "John Doe"
},
]
}
}

View File

@@ -1,6 +1,6 @@
# TRAINS-server: Using Docker Pre-Built Images
The pre-built Docker image for the **trains-server** is the quickest way to get started with your own **TRAINS** server.
The pre-built Docker image for the **trains-server** is the quickest way to get started with your own **TRAINS** server.
You can also build the entire **trains-server** architecture using the code available in the [trains-server](https://github.com/allegroai/trains-server) repository.
@@ -58,13 +58,15 @@ Create this directory, and set its owner and group to `uid` 1000. The data store
For example, if your data directory is `/opt/trains`, then use the following command:
```bash
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/data/fileserver
sudo mkdir -p /opt/trains/data/elastic
sudo mkdir -p /opt/trains/data/mongo/db
sudo mkdir -p /opt/trains/data/mongo/configdb
sudo mkdir -p /opt/trains/data/redis
sudo mkdir -p /opt/trains/logs
sudo mkdir -p /opt/trains/data/fileserver
sudo mkdir -p /opt/trains/config
sudo chown -R 1000:1000 /opt/trains
sudo chown -R 1000:1000 /opt/trains
```
## TRAINS-server: Manually Launching Docker Containers <a name="launch"></a>
@@ -77,24 +79,88 @@ If your data directory is not `/opt/trains`, then in the five `docker run` comma
sudo docker run -d --restart="always" --name="trains-elastic" -e "bootstrap.memory_lock=true" --ulimit memlock=-1:-1 -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
1. Launch the **trains-mongo** Docker container.
1. Launch the **trains-mongo** Docker container.
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
1. Launch the **trains-fileserver** Docker container.
1. Launch the **trains-redis** Docker container.
sudo docker run -d --restart="always" --name="trains-redis" -v /opt/trains/data/redis:/data --network="host" redis:5.0
1. Launch the **trains-fileserver** Docker container.
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
1. Launch the **trains-apiserver** Docker container.
1. Launch the **trains-apiserver** Docker container.
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/config:/opt/trains/config allegroai/trains:latest apiserver
1. Launch the **trains-webserver** Docker container.
1. Launch the **trains-webserver** Docker container.
sudo docker run -d --restart="always" --name="trains-webserver" -p 8080:80 allegroai/trains:latest webserver
1. Your server is now running on [http://localhost:8080](http://localhost:8080) and the following ports are available:
* API server on port `8008`
* Web server on port `8080`
* File server on port `8081`
## Manually Upgrading TRAINS-server Containers <a name="upgrade"></a>
We are constantly updating, improving and adding to the **trains-server**.
New releases will include new pre-built Docker images.
When we release a new version and include a new pre-built Docker image for it, upgrade as follows:
1. Shut down and remove each of your Docker instances using the following commands:
```bash
$ sudo docker stop <docker-name>
$ sudo docker rm -v <docker-name>
```
The Docker names are (see [Launching Docker Containers](#launch-docker)):
* `trains-elastic`
* `trains-mongo`
* `trains-redis`
* `trains-fileserver`
* `trains-apiserver`
* `trains-webserver`
2. We highly recommend backing up your data directory!. A simple way to do that is using `tar`:
For example, if your data directory is `/opt/trains`, use the following command:
```bash
$ sudo tar czvf ~/trains_backup.tgz /opt/trains/data
```
This backups all data to an archive in your home directory.
To restore this example backup, use the following command:
```bash
$ sudo rm -R /opt/trains/data
$ sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```
3. Pull the new **trains-server** docker image using the following command:
```bash
$ sudo docker pull allegroai/trains:latest
```
If you wish to pull a different version, replace `latest` with the required version number, for example:
```bash
$ sudo docker pull allegroai/trains:0.11.0
```
4. Launch the newly released Docker image (see [Launching Docker Containers](#trains-server-manually-launching-docker-containers-)).
#### Common Docker Upgrade Errors
* In case of a docker error: "... The container name "/trains-???" is already in use by ..."
Try removing deprecated images with:
```bash
$ docker rm -f $(docker ps -a -q)
```

View File

@@ -6,12 +6,15 @@
* [Running trains-server on Mac OS X](#mac-osx)
* [Running trains-server on Windows 10](#docker_compose_win10)
* [Installing trains-server on stand alone Linux Ubuntu systems ](#ubuntu)
* [Resolving port conflicts preventing fixed users mode authentication and login](#port-conflict)
* [Configuring trains-server for sub-domains and load balancers](#sub-domains)
### Deploying trains-server on Kubernetes clusters <a name="kubernetes"></a>
**trains-server** supports Kubernetes. See [trains-server-k8s](https://github.com/allegroai/trains-server-k8s)
@@ -33,14 +36,16 @@ To install and configure **trains-server** on Mac OS X, follow the steps below.
1. Configure [Docker](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode).
$ screen ~/Library/Containers/com.docker.docker/Data/vms/0/tty
sysctl -w vm.max_map_count=262144
$ sysctl -w vm.max_map_count=262144
1. Create local directories for the databases and storage.
$ sudo mkdir -p /opt/trains/data/elastic
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/data/redis
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/config
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo chown -R $(whoami):staff /opt/trains
@@ -57,6 +62,43 @@ To install and configure **trains-server** on Mac OS X, follow the steps below.
Your server is now running on [http://localhost:8080](http://localhost:8080)
### Running trains-server on Windows 10 <a name="docker_compose_win10"></a>
You can run **trains-server** on Windows 10 using Docker Desktop for Windows (see the Docker [System Requirements](https://docs.docker.com/docker-for-windows/install/#system-requirements)).
To run **trains-server** on Windows 10, follow the steps below.
1. Install the Docker Desktop for Windows application by either:
* Following the [Install Docker Desktop on Windows](https://docs.docker.com/docker-for-windows/install/) instructions.
* Running the Docker installation [wizard](https://hub.docker.com/?overlay=onboarding).
1. Increase the memory allocation in Docker Desktop to `4GB`.
1. In your Windows notification area (system tray), right click the Docker icon.
1. Click *Settings*, *Advanced*, and then set the memory to at least `4096`.
1. Click *Apply*.
1. Create local directories for data and logs. Open PowerShell and execute the following commands:
mkdir c:\opt\trains\logs
mkdir c:\opt\trains\config
mkdir c:\opt\trains\data
mkdir c:\opt\trains\data\elastic
mkdir c:\opt\trains\data\redis
mkdir c:\opt\trains\data\fileserver
1. Save the **trains-server** docker-compose YAML file [docker-compose-win10.yml](https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose-win10.yml) as `c:\opt\trains\docker-compose.yml`.
1. Run `docker-compose`. In PowerShell, execute the following commands:
cd c:\opt\trains\
docker-compose up
Your server is now running on [http://localhost:8080](http://localhost:8080)
### Installing trains-server on stand alone Linux Ubuntu systems <a name="ubuntu"></a>
To install **trains-server** on a stand alone Linux Ubuntu, follow the steps belows.
@@ -80,6 +122,7 @@ To install **trains-server** on a stand alone Linux Ubuntu, follow the steps bel
$ sudo mkdir -p /opt/trains/data/mongo/db
$ sudo mkdir -p /opt/trains/data/mongo/configdb
$ sudo mkdir -p /opt/trains/logs
$ sudo mkdir -p /opt/trains/config
$ sudo mkdir -p /opt/trains/data/fileserver
$ sudo chown -R 1000:1000 /opt/trains

View File

@@ -21,6 +21,27 @@ The minimum recommended instance type is **t3a.large**
In order to upgrade **trains-server** on an existing EC2 instance based on one of these AMIs, SSH into the instance and follow the [upgrade instructions](../README.md#upgrade) for **trains-server**.
### Upgrading AMI's to v0.12
**Including the automatically updated AMI**
Version 0.12 introduced an additional REDIS docker to the trains-server setup.
AMI upgrading instructions:
1. SSH to the EC2 machine running one of the `Latest Version AMI's`
2. Execute the following bash commands
```bash
sudo bash
echo "" >> /usr/bin/start_or_update_server.sh
echo "sudo mkdir -p \${datadir}/redis" >> /usr/bin/start_or_update_server.sh
echo "sudo docker stop trains-redis || true && sudo docker rm -v trains-redis || true" >> /usr/bin/start_or_update_server.sh
echo "echo never | sudo tee -a /sys/kernel/mm/transparent_hugepage/enabled" >> /usr/bin/start_or_update_server.sh
echo "sudo sysctl vm.overcommit_memory=1" >> /usr/bin/start_or_update_server.sh
echo "sudo docker run -d --restart=always --name=trains-redis -v \${datadir}/redis:/data --network=host redis:5 redis-server" >> /usr/bin/start_or_update_server.sh
```
3. Reboot the EC2 machine
## Released versions
The following sections provide a list containing AMI Image ID per region for each released **trains-server** version.
@@ -28,40 +49,76 @@ The following sections provide a list containing AMI Image ID per region for eac
### Latest Version AMI <a name="autoupdate"></a>
**For easier upgrades: The following AMI automatically update to the latest release every reboot**
* **eu-north-1** : ami-047eb12cf0b47b2d1
* **ap-south-1** : ami-0a2facc5f027ab528
* **eu-west-3** : ami-08ef18e0e4ca1e6c6
* **eu-west-2** : ami-0a7133d9a3c800bbd
* **eu-west-1** : ami-0f1cce84bb2187729
* **ap-northeast-2** : ami-0825c4e06cc194272
* **ap-northeast-1** : ami-024db084d549289f3
* **sa-east-1** : ami-04eca8d7ab944a48c
* **ca-central-1** : ami-03b7bfbb8607c9bc4
* **ap-southeast-1** : ami-0a8667b8ba3564202
* **ap-southeast-2** : ami-0866de3db64f63e15
* **eu-central-1** : ami-04898b0923493de1b
* **us-east-2** : ami-06afbbc84f5d829da
* **us-west-1** : ami-045fe6664792a00d7
* **us-west-2** : ami-0132184364da97720
* **us-east-1** : ami-08747037c11256d44
* **eu-north-1** : ami-055909c1b9471451d
* **ap-south-1** : ami-0476123cc77226faf
* **eu-west-3** : ami-01df7d35ab63cca70
* **eu-west-2** : ami-00e8004c11fd0228e
* **eu-west-1** : ami-04293fbba6d3acad1
* **ap-northeast-2** : ami-004331f9c5eb13e94
* **ap-northeast-1** : ami-08cc80e2049b30e61
* **sa-east-1** : ami-06d814a0b6ffa3153
* **ca-central-1** : ami-069210ff757e9c1b7
* **ap-southeast-1** : ami-0d12cc70d6e9c0f39
* **ap-southeast-2** : ami-0b4615aa76c055267
* **eu-central-1** : ami-06537f431e52e4763
* **us-east-2** : ami-0c3cfbcb8e72ecfc5
* **us-west-1** : ami-0d83de031b83b6880
* **us-west-2** : ami-06968633c4f7187c4
* **us-east-1** : ami-07ff2f5f7ef99e8f6
### v0.12.1
* **eu-north-1** : ami-003118a8103286d84
* **ap-south-1** : ami-02dfe86baa48e096f
* **eu-west-3** : ami-0cc1f01267d2a780d
* **eu-west-2** : ami-0e4c8332e5ce09585
* **eu-west-1** : ami-03459a2f0b0a3b1ab
* **ap-northeast-2** : ami-08f6c2aed3a53f24c
* **ap-northeast-1** : ami-0b798eab95a7c5435
* **sa-east-1** : ami-0d3ee166c09f0d1b2
* **ca-central-1** : ami-00a758c56bd63acd5
* **ap-southeast-1** : ami-0be64d4988cd03fbb
* **ap-southeast-2** : ami-02087310d43a63f31
* **eu-central-1** : ami-097bbefeac0c74225
* **us-east-2** : ami-07eda256712b90f4d
* **us-west-1** : ami-02ef2b55cbd01c7df
* **us-west-2** : ami-037c6176ef4735360
* **us-east-1** : ami-08715c20c0e3f1c15
### v0.12.0
* **eu-north-1** : ami-03ff8ab48cd43e77e
* **ap-south-1** : ami-079c1a41ff836487c
* **eu-west-3** : ami-0121ef0398ae87ab0
* **eu-west-2** : ami-09f0f97654d8c79de
* **eu-west-1** : ami-0b7ba303f757bfcd9
* **ap-northeast-2** : ami-053f416517b5f40a6
* **ap-northeast-1** : ami-056dff06c698c2d9d
* **sa-east-1** : ami-017ab655119258639
* **ca-central-1** : ami-03bf5fa1d86ac97f6
* **ap-southeast-1** : ami-0e667958002b0360c
* **ap-southeast-2** : ami-091f1b69cb43b1933
* **eu-central-1** : ami-068ec2f0e98c26541
* **us-east-2** : ami-0524bbdc1b64ff83f
* **us-west-1** : ami-0b4facd7534e393c9
* **us-west-2** : ami-0018d5a7e58966848
* **us-east-1** : ami-08f24178fc14a84d2
### v0.11.0
* **eu-north-1** : ami-0303acd0967b3df38
* **ap-south-1** : ami-0e14dc1e886344a3e
* **eu-west-3** : ami-00de3fa500c2e7ea9
* **eu-west-2** : ami-0bd68bec0c2631535
* **eu-west-1** : ami-094b8dcc9b6f9a04c
* **ap-northeast-2** : ami-0091bb348c218d4c5
* **ap-northeast-1** : ami-0e06fbc71a9e7a74d
* **sa-east-1** : ami-0e99a346d8e585f76
* **ca-central-1** : ami-09874b823457e5874
* **ap-southeast-1** : ami-0823fd4963b3d4ff4
* **ap-southeast-2** : ami-0463d77897f1c0569
* **eu-central-1** : ami-0bb5cb2f5d444f905
* **us-east-2** : ami-0b364bf4c7dc12f67
* **us-west-1** : ami-0a97c0548d53d9f1d
* **us-west-2** : ami-06588b5bde813c28c
* **us-east-1** : ami-0a43a4b03215b0144
* **eu-north-1** : ami-0cbe338f058018c97
* **ap-south-1** : ami-06d72ff894f7a5e5d
* **eu-west-3** : ami-00f2a45d67df2d2f3
* **eu-west-2** : ami-0627ae688f4533237
* **eu-west-1** : ami-00bf924ccb0354418
* **ap-northeast-2** : ami-0800edf1d1dec1da8
* **ap-northeast-1** : ami-07b2ed9709cdc4b15
* **sa-east-1** : ami-0012c1648618b812c
* **ca-central-1** : ami-02870b965d002fc8a
* **ap-southeast-1** : ami-068ec23abf2473192
* **ap-southeast-2** : ami-06664624728b5e01a
* **eu-central-1** : ami-05f2a9304f237a6f0
* **us-east-2** : ami-0ec242e6dca2b72b9
* **us-west-1** : ami-050b6577acf246ceb
* **us-west-2** : ami-0e384b6f78bf96ebe
* **us-east-1** : ami-0a7b46f907d5d9c4a
### v0.10.1
* **eu-north-1** : ami-09937ec4d18350c32

9
docs/services.conf Normal file
View File

@@ -0,0 +1,9 @@
tasks {
non_responsive_tasks_watchdog {
# In-progress tasks that haven't been updated for at least 'value' seconds will be stopped by the watchdog
threshold_sec: 7200
# Watchdog will sleep for this number of seconds after each cycle
watch_interval_sec: 900
}
}

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2018 MongoDB, Inc.
Copyright © 2019 allegro.ai, Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -1,7 +1,7 @@
Server Side Public License
VERSION 1, OCTOBER 16, 2018
Copyright © 2018 MongoDB, Inc.
Copyright © 2019 allegro.ai, Inc.
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.

View File

@@ -48,7 +48,6 @@ _error_codes = {
129: ('task_publish_in_progress', 'Task publish in progress'),
130: ('task_not_found', 'task not found'),
# Models
200: ('model_error', 'general task error'),
201: ('invalid_model_id', 'invalid model id'),
@@ -70,9 +69,26 @@ _error_codes = {
403: ('project_not_found', 'project not found'),
405: ('project_has_models', 'project has associated models'),
# Queues
701: ('invalid_queue_id', 'invalid queue id'),
702: ('queue_not_empty', 'queue is not empty'),
703: ('invalid_queue_or_task_not_queued', 'invalid queue id or task not in queue'),
704: ('removed_during_reposition', 'task was removed by another party during reposition'),
705: ('failed_adding_during_reposition', 'failed adding task back to queue during reposition'),
706: ('task_already_queued', 'failed adding task to queue since task is already queued'),
707: ('no_default_queue', 'no queue is tagged as the default queue for this company'),
708: ('multiple_default_queues', 'more than one queue is tagged as the default queue for this company'),
# Database
800: ('data_validation_error', 'data validation error'),
801: ('expected_unique_data', 'value combination already exists'),
# Workers
1001: ('invalid_worker_id', 'invalid worker id'),
1002: ('worker_registration_failed', 'worker registration failed'),
1003: ('worker_registered', 'worker is already registered'),
1004: ('worker_not_registered', 'worker is not registered'),
1005: ('worker_stats_not_found', 'worker stats not found'),
},
(401, 'unauthorized'): {

View File

@@ -0,0 +1,60 @@
from jsonmodels import validators
from jsonmodels.fields import StringField, IntField, BoolField, FloatField
from jsonmodels.models import Base
from apimodels import ListField
class GetDefaultResp(Base):
id = StringField(required=True)
name = StringField(required=True)
class CreateRequest(Base):
name = StringField(required=True)
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
class QueueRequest(Base):
queue = StringField(required=True)
class DeleteRequest(QueueRequest):
force = BoolField(default=False)
class UpdateRequest(QueueRequest):
name = StringField()
tags = ListField(items_types=[str])
system_tags = ListField(items_types=[str])
class TaskRequest(QueueRequest):
task = StringField(required=True)
class MoveTaskRequest(TaskRequest):
count = IntField(default=1)
class MoveTaskResponse(Base):
position = IntField()
class GetMetricsRequest(Base):
queue_ids = ListField([str])
from_date = FloatField(required=True, validators=validators.Min(0))
to_date = FloatField(required=True, validators=validators.Min(0))
interval = IntField(required=True, validators=validators.Min(1))
class QueueMetrics(Base):
queue = StringField()
dates = ListField(int)
avg_waiting_times = ListField([float, int])
queue_lengths = ListField(int)
class GetMetricsResponse(Base):
queues = ListField(QueueMetrics)

View File

@@ -0,0 +1,14 @@
from jsonmodels.fields import BoolField, DateTimeField, StringField
from jsonmodels.models import Base
class ReportStatsOptionRequest(Base):
enabled = BoolField(default=None, nullable=True)
class ReportStatsOptionResponse(Base):
supported = BoolField(default=True)
enabled = BoolField()
enabled_time = DateTimeField(nullable=True)
enabled_version = StringField(nullable=True)
enabled_user = StringField(nullable=True)

View File

@@ -13,8 +13,17 @@ class StartedResponse(UpdateResponse):
started = IntField()
class EnqueueResponse(UpdateResponse):
queued = IntField()
class DequeueResponse(UpdateResponse):
dequeued = IntField()
class ResetResponse(UpdateResponse):
deleted_indices = ListField(items_types=six.string_types)
dequeued = DictField()
frames = DictField()
events = DictField()
model_deleted = IntField()
@@ -30,6 +39,10 @@ class UpdateRequest(TaskRequest):
force = BoolField(default=False)
class EnqueueRequest(UpdateRequest):
queue = StringField()
class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True)
@@ -58,4 +71,4 @@ class CreateRequest(TaskData):
class PingRequest(TaskRequest):
task = StringField(required=True)
pass

183
server/apimodels/workers.py Normal file
View File

@@ -0,0 +1,183 @@
import json
from enum import Enum
import six
from jsonmodels import validators
from jsonmodels.fields import (
StringField,
EmbeddedField,
DateTimeField,
IntField,
FloatField,
BoolField,
)
from jsonmodels.models import Base
from apimodels import make_default, ListField, EnumField
DEFAULT_TIMEOUT = 10 * 60
class WorkerRequest(Base):
worker = StringField(required=True)
class RegisterRequest(WorkerRequest):
timeout = make_default(
IntField, DEFAULT_TIMEOUT
)() # registration timeout in seconds (default is 10min)
queues = ListField(six.string_types) # list of queues this worker listens to
class MachineStats(Base):
cpu_usage = ListField(six.integer_types + (float,))
cpu_temperature = ListField(six.integer_types + (float,))
gpu_usage = ListField(six.integer_types + (float,))
gpu_temperature = ListField(six.integer_types + (float,))
gpu_memory_free = ListField(six.integer_types + (float,))
gpu_memory_used = ListField(six.integer_types + (float,))
memory_used = FloatField()
memory_free = FloatField()
network_tx = FloatField()
network_rx = FloatField()
disk_free_home = FloatField()
disk_free_temp = FloatField()
disk_read = FloatField()
disk_write = FloatField()
class StatusReportRequest(WorkerRequest):
task = StringField() # task the worker is running on
queue = StringField() # queue from which task was taken
queues = ListField(
str
) # list of queues this worker listens to. if None, this will not update the worker's queues list.
timestamp = IntField(required=True)
machine_stats = EmbeddedField(MachineStats)
class IdNameEntry(Base):
id = StringField(required=True)
name = StringField()
class WorkerEntry(Base):
key = StringField() # not required due to migration issues
id = StringField(required=True)
user = EmbeddedField(IdNameEntry)
company = EmbeddedField(IdNameEntry)
ip = StringField()
task = EmbeddedField(IdNameEntry)
queue = StringField() # queue from which current task was taken
queues = ListField(str) # list of queues this worker listens to
register_time = DateTimeField(required=True)
register_timeout = IntField(required=True)
last_activity_time = DateTimeField(required=True)
last_report_time = DateTimeField()
def to_json(self):
return json.dumps(self.to_struct())
@classmethod
def from_json(cls, s):
return cls(**json.loads(s))
class CurrentTaskEntry(IdNameEntry):
running_time = IntField()
last_iteration = IntField()
class QueueEntry(IdNameEntry):
next_task = EmbeddedField(IdNameEntry)
num_tasks = IntField()
class WorkerResponseEntry(WorkerEntry):
task = EmbeddedField(CurrentTaskEntry)
queue = EmbeddedField(QueueEntry)
queues = ListField(QueueEntry)
class GetAllRequest(Base):
last_seen = IntField(default=3600)
class GetAllResponse(Base):
workers = ListField(WorkerResponseEntry)
class StatsBase(Base):
worker_ids = ListField(str)
class StatsReportBase(StatsBase):
from_date = FloatField(required=True, validators=validators.Min(0))
to_date = FloatField(required=True, validators=validators.Min(0))
interval = IntField(required=True, validators=validators.Min(1))
class AggregationType(Enum):
avg = "avg"
min = "min"
max = "max"
class StatItem(Base):
key = StringField(required=True)
aggregation = EnumField(AggregationType, default=AggregationType.avg)
class GetStatsRequest(StatsReportBase):
items = ListField(
StatItem, required=True, validators=validators.Length(minimum_value=1)
)
split_by_variant = BoolField(default=False)
class AggregationStats(Base):
aggregation = EnumField(AggregationType)
values = ListField(float)
class MetricStats(Base):
metric = StringField()
variant = StringField()
dates = ListField(int)
stats = ListField(AggregationStats)
class WorkerStatistics(Base):
worker = StringField()
metrics = ListField(MetricStats)
class GetStatsResponse(Base):
workers = ListField(WorkerStatistics)
class GetMetricKeysRequest(StatsBase):
pass
class MetricCategory(Base):
name = StringField()
metric_keys = ListField(str)
class GetMetricKeysResponse(Base):
categories = ListField(MetricCategory)
class GetActivityReportRequest(StatsReportBase):
pass
class ActivityReportSeries(Base):
dates = ListField(int)
counts = ListField(int)
class GetActivityReportResponse(Base):
total = EmbeddedField(ActivityReportSeries)
active = EmbeddedField(ActivityReportSeries)

View File

@@ -8,6 +8,7 @@ from typing import Sequence
import attr
import six
from elasticsearch import helpers
from mongoengine import Q
from nested_dict import nested_dict
import database.utils as dbutils
@@ -16,7 +17,7 @@ from apierrors import errors
from bll.event.event_metrics import EventMetrics
from bll.task import TaskBLL
from database.errors import translate_errors_context
from database.model.task.task import Task
from database.model.task.task import Task, TaskStatus
from timing_context import TimingContext
from utilities.dicts import flatten_nested_items
@@ -33,6 +34,9 @@ class EventType(Enum):
EVENT_TYPES = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
@attr.s
class TaskEventsResult(object):
events = attr.ib(type=list, default=attr.Factory(list))
@@ -51,7 +55,7 @@ class EventBLL(object):
def metrics(self) -> EventMetrics:
return self._metrics
def add_events(self, company_id, events, worker):
def add_events(self, company_id, events, worker, allow_locked_tasks=False):
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
@@ -132,11 +136,16 @@ class EventBLL(object):
if task_ids:
# verify task_ids
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
res = Task.objects(id__in=task_ids, company=company_id).only("id")
extra_msg = None
query = Q(id__in=task_ids, company=company_id)
if not allow_locked_tasks:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id")
if len(res) < len(task_ids):
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
raise errors.bad_request.InvalidTaskId(
company=company_id, ids=invalid_task_ids
extra_msg, company=company_id, ids=invalid_task_ids
)
errors_in_bulk = []
@@ -169,7 +178,7 @@ class EventBLL(object):
company_id=company_id,
task_id=task_id,
now=now,
iter=task_iteration.get(task_id),
iter_max=task_iteration.get(task_id),
last_events=task_last_events.get(task_id),
)
@@ -207,7 +216,7 @@ class EventBLL(object):
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
def _update_task(self, company_id, task_id, now, iter=None, last_events=None):
def _update_task(self, company_id, task_id, now, iter_max=None, last_events=None):
"""
Update task information in DB with aggregated results after handling event(s) related to this task.
@@ -217,8 +226,8 @@ class EventBLL(object):
"""
fields = {}
if iter is not None:
fields["last_iteration"] = iter
if iter_max is not None:
fields["last_iteration_max"] = iter_max
if last_events:
fields["last_values"] = list(
@@ -650,7 +659,19 @@ class EventBLL(object):
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
def delete_task_events(self, company_id, task_id):
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
extra_msg = None
query = Q(id=task_id, company=company_id)
if not allow_locked:
query &= Q(status__nin=LOCKED_TASK_STATUSES)
extra_msg = "or task published"
res = Task.objects(query).only("id").first()
if not res:
raise errors.bad_request.InvalidTaskId(
extra_msg, company=company_id, id=task_id
)
es_index = EventMetrics.get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):

View File

@@ -0,0 +1 @@
from .builder import Builder

View File

@@ -0,0 +1,36 @@
from typing import Optional, Sequence, Iterable, Union
from config import config
log = config.logger(__file__)
RANGE_IGNORE_VALUE = -1
class Builder:
@staticmethod
def dates_range(from_date: Union[int, float], to_date: Union[int, float]) -> dict:
return {
"range": {
"timestamp": {
"gte": int(from_date),
"lte": int(to_date),
"format": "epoch_second",
}
}
}
@staticmethod
def terms(field: str, values: Iterable[str]) -> dict:
return {"terms": {field: list(values)}}
@staticmethod
def normalize_range(
range_: Sequence[Union[int, float]],
ignore_value: Union[int, float] = RANGE_IGNORE_VALUE,
) -> Optional[Sequence[Union[int, float]]]:
if not range_ or set(range_) == {ignore_value}:
return None
if len(range_) < 2:
return [range_[0]] * 2
return range_

View File

@@ -0,0 +1 @@
from .queue_bll import QueueBLL

View File

@@ -0,0 +1,264 @@
from collections import defaultdict
from datetime import datetime
from typing import Callable, Sequence, Optional, Tuple
from elasticsearch import Elasticsearch
import database
import es_factory
from apierrors import errors
from bll.queue.queue_metrics import QueueMetrics
from bll.workers import WorkerBLL
from database.errors import translate_errors_context
from database.model.queue import Queue, Entry
class QueueBLL(object):
def __init__(self, worker_bll: WorkerBLL = None, es: Elasticsearch = None):
self.worker_bll = worker_bll or WorkerBLL()
self.es = es or es_factory.connect("workers")
self._metrics = QueueMetrics(self.es)
@property
def metrics(self) -> QueueMetrics:
return self._metrics
@staticmethod
def create(
company_id: str,
name: str,
tags: Optional[Sequence[str]] = None,
system_tags: Optional[Sequence[str]] = None,
) -> Queue:
"""Creates a queue"""
with translate_errors_context():
now = datetime.utcnow()
queue = Queue(
id=database.utils.id(),
company=company_id,
created=now,
name=name,
tags=tags or [],
system_tags=system_tags or [],
last_update=now,
)
queue.save()
return queue
def get_by_id(
self, company_id: str, queue_id: str, only: Optional[Sequence[str]] = None
) -> Queue:
"""
Get queue by id
:raise errors.bad_request.InvalidQueueId: if the queue is not found
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
qs = Queue.objects(**query)
if only:
qs = qs.only(*only)
queue = qs.first()
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
return queue
@classmethod
def get_queue_with_task(cls, company_id: str, queue_id: str, task_id: str) -> Queue:
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(entries__task=task_id, **query).first()
if not queue:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
)
return queue
def get_default(self, company_id: str) -> Queue:
"""
Get the default queue
:raise errors.bad_request.NoDefaultQueue: if the default queue not found
:raise errors.bad_request.MultipleDefaultQueues: if more than one default queue is found
"""
with translate_errors_context():
res = Queue.objects(company=company_id, system_tags="default").only(
"id", "name"
)
if not res:
raise errors.bad_request.NoDefaultQueue()
if len(res) > 1:
raise errors.bad_request.MultipleDefaultQueues(
queues=tuple(r.id for r in res)
)
return res.first()
def update(
self, company_id: str, queue_id: str, **update_fields
) -> Tuple[int, dict]:
"""
Partial update of the queue from update_fields
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:return: number of updated objects and updated fields dictionary
"""
with translate_errors_context():
# validate the queue exists
self.get_by_id(company_id=company_id, queue_id=queue_id, only=("id",))
return Queue.safe_update(company_id, queue_id, update_fields)
def delete(self, company_id: str, queue_id: str, force: bool) -> None:
"""
Delete the queue
:raise errors.bad_request.InvalidQueueId: if the queue is not found
:raise errors.bad_request.QueueNotEmpty: if the queue is not empty and 'force' not set
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if queue.entries and not force:
raise errors.bad_request.QueueNotEmpty(
"use force=true to delete", id=queue_id
)
queue.delete()
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
return Queue.get_many(
company=company_id, parameters=query_dict, query_dict=query_dict
)
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
"""
projection = Queue.get_extra_projection("entries.task.name")
with translate_errors_context():
res = Queue.get_many_with_join(
company=company_id,
query_dict=query_dict,
override_projection=projection,
)
queue_workers = defaultdict(list)
for worker in self.worker_bll.get_all(company_id):
for queue in worker.queues:
queue_workers[queue].append(worker)
for item in res:
item["workers"] = [
{
"name": w.id,
"ip": w.ip,
"task": w.task.to_struct() if w.task else None,
}
for w in queue_workers.get(item["id"], [])
]
return res
def add_task(self, company_id: str, queue_id: str, task_id: str) -> dict:
"""
Add the task to the queue and return the queue update results
:raise errors.bad_request.TaskAlreadyQueued: if the task is already in the queue
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the queue update operation failed
"""
with translate_errors_context():
queue = self.get_by_id(company_id=company_id, queue_id=queue_id)
if any(e.task == task_id for e in queue.entries):
raise errors.bad_request.TaskAlreadyQueued(task=task_id)
self.metrics.log_queue_metrics_to_es(company_id=company_id, queues=[queue])
entry = Entry(added=datetime.utcnow(), task=task_id)
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
push__entries=entry, last_update=datetime.utcnow(), upsert=False
)
if not res:
raise errors.bad_request.InvalidQueueOrTaskNotQueued(
task=task_id, **query
)
return res
def get_next_task(self, company_id: str, queue_id: str) -> Optional[Entry]:
"""
Atomically pop and return the first task from the queue (or None)
:raise errors.bad_request.InvalidQueueId: if the queue does not exist
"""
with translate_errors_context():
query = dict(id=queue_id, company=company_id)
queue = Queue.objects(**query).modify(
pop__entries=-1, last_update=datetime.utcnow(), upsert=False
)
if not queue:
raise errors.bad_request.InvalidQueueId(**query)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
if not queue.entries:
return
return queue.entries[0]
def remove_task(self, company_id: str, queue_id: str, task_id: str) -> int:
"""
Removes the task from the queue and returns the number of removed items
:raise errors.bad_request.InvalidQueueOrTaskNotQueued: if the task is not found in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
self.metrics.log_queue_metrics_to_es(company_id, queues=[queue])
entries_to_remove = [e for e in queue.entries if e.task == task_id]
query = dict(id=queue_id, company=company_id)
res = Queue.objects(entries__task=task_id, **query).update_one(
pull_all__entries=entries_to_remove, last_update=datetime.utcnow()
)
return len(entries_to_remove) if res else 0
def reposition_task(
self,
company_id: str,
queue_id: str,
task_id: str,
pos_func: Callable[[int], int],
) -> int:
"""
Moves the task in the queue to the position calculated by pos_func
Returns the updated task position in the queue
"""
with translate_errors_context():
queue = self.get_queue_with_task(
company_id=company_id, queue_id=queue_id, task_id=task_id
)
position = next(i for i, e in enumerate(queue.entries) if e.task == task_id)
new_position = pos_func(position)
if new_position != position:
entry = queue.entries[position]
query = dict(id=queue_id, company=company_id)
updated = Queue.objects(entries__task=task_id, **query).update_one(
pull__entries=entry, last_update=datetime.utcnow()
)
if not updated:
raise errors.bad_request.RemovedDuringReposition(
task=task_id, **query
)
inst = {"$push": {"entries": {"$each": [entry.to_proper_dict()]}}}
if new_position >= 0:
inst["$push"]["entries"]["$position"] = new_position
res = Queue.objects(entries__task__ne=task_id, **query).update_one(
__raw__=inst
)
if not res:
raise errors.bad_request.FailedAddingDuringReposition(
task=task_id, **query
)
return new_position

View File

@@ -0,0 +1,265 @@
from collections import defaultdict
from datetime import datetime
from typing import Sequence
import elasticsearch.helpers
from elasticsearch import Elasticsearch
import es_factory
from apierrors.errors import bad_request
from bll.query import Builder as QueryBuilder
from config import config
from database.errors import translate_errors_context
from database.model.queue import Queue, Entry
from timing_context import TimingContext
log = config.logger(__file__)
class QueueMetrics:
class EsKeys:
DOC_TYPE = "metrics"
WAITING_TIME_FIELD = "average_waiting_time"
QUEUE_LENGTH_FIELD = "queue_length"
TIMESTAMP_FIELD = "timestamp"
QUEUE_FIELD = "queue"
def __init__(self, es: Elasticsearch):
self.es = es
@staticmethod
def _queue_metrics_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"queue_metrics_{company_id}_"
@staticmethod
def _get_es_index_suffix():
"""Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m")
@staticmethod
def _calc_avg_waiting_time(entries: Sequence[Entry]) -> float:
"""
Calculate avg waiting time for the given tasks.
Return 0 if the list is empty
"""
if not entries:
return 0
now = datetime.utcnow()
total_waiting_in_secs = sum((now - e.added).total_seconds() for e in entries)
return total_waiting_in_secs / len(entries)
def log_queue_metrics_to_es(self, company_id: str, queues: Sequence[Queue]) -> bool:
"""
Calculate and write queue statistics (avg waiting time and queue length) to Elastic
:return: True if the write to es was successful, false otherwise
"""
es_index = (
self._queue_metrics_prefix_for_company(company_id)
+ self._get_es_index_suffix()
)
timestamp = es_factory.get_timestamp_millis()
def make_doc(queue: Queue) -> dict:
entries = [e for e in queue.entries if e.added]
return dict(
_index=es_index,
_type=self.EsKeys.DOC_TYPE,
_source={
self.EsKeys.TIMESTAMP_FIELD: timestamp,
self.EsKeys.QUEUE_FIELD: queue.id,
self.EsKeys.WAITING_TIME_FIELD: self._calc_avg_waiting_time(
entries
),
self.EsKeys.QUEUE_LENGTH_FIELD: len(entries),
},
)
actions = list(map(make_doc, queues))
es_res = elasticsearch.helpers.bulk(self.es, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
def _log_current_metrics(self, company_id: str, queue_ids=Sequence[str]):
query = dict(company=company_id)
if queue_ids:
query["id__in"] = list(queue_ids)
queues = Queue.objects(**query)
self.log_queue_metrics_to_es(company_id, queues=list(queues))
def _search_company_metrics(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self._queue_metrics_prefix_for_company(company_id)}*",
doc_type=self.EsKeys.DOC_TYPE,
body=es_req,
)
@classmethod
def _get_dates_agg(cls, interval) -> dict:
"""
Aggregation for building date histogram with internal grouping per queue.
We are grouping by queue inside date histogram and not vice versa so that
it will be easy to average between queue metrics inside each date bucket.
Ignore empty buckets.
"""
return {
"dates": {
"date_histogram": {
"field": cls.EsKeys.TIMESTAMP_FIELD,
"interval": f"{interval}s",
"min_doc_count": 1,
},
"aggs": {
"queues": {
"terms": {"field": cls.EsKeys.QUEUE_FIELD},
"aggs": cls._get_top_waiting_agg(),
}
},
}
}
@classmethod
def _get_top_waiting_agg(cls) -> dict:
"""
Aggregation for getting max waiting time and the corresponding queue length
inside each date->queue bucket
"""
return {
"top_avg_waiting": {
"top_hits": {
"sort": [
{cls.EsKeys.WAITING_TIME_FIELD: {"order": "desc"}},
{cls.EsKeys.QUEUE_LENGTH_FIELD: {"order": "desc"}},
],
"_source": {
"includes": [
cls.EsKeys.WAITING_TIME_FIELD,
cls.EsKeys.QUEUE_LENGTH_FIELD,
]
},
"size": 1,
}
}
}
def get_queue_metrics(
self,
company_id: str,
from_date: float,
to_date: float,
interval: int,
queue_ids: Sequence[str],
) -> dict:
"""
Get the company queue metrics in the specified time range.
Returned as date histograms of average values per queue and metric type.
The from_date is extended by 'metrics_before_from_date' seconds from
queues.conf due to possibly small amount of points. The default extension is 3600s
In case no queue ids are specified the avg across all the
company queues is calculated for each metric
"""
# self._log_current_metrics(company, queue_ids=queue_ids)
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
seconds_before = config.get("services.queues.metrics_before_from_date", 3600)
must_terms = [QueryBuilder.dates_range(from_date - seconds_before, to_date)]
if queue_ids:
must_terms.append(QueryBuilder.terms("queue", queue_ids))
es_req = {
"size": 0,
"query": {"bool": {"must": must_terms}},
"aggs": self._get_dates_agg(interval),
}
with translate_errors_context(), TimingContext("es", "get_queue_metrics"):
res = self._search_company_metrics(company_id, es_req)
if "aggregations" not in res:
return {}
date_metrics = [
dict(
timestamp=d["key"],
queue_metrics=self._extract_queue_metrics(d["queues"]["buckets"]),
)
for d in res["aggregations"]["dates"]["buckets"]
if d["doc_count"] > 0
]
if queue_ids:
return self._datetime_histogram_per_queue(date_metrics)
return self._average_datetime_histogram(date_metrics)
@classmethod
def _datetime_histogram_per_queue(cls, date_metrics: Sequence[dict]) -> dict:
"""
Build datetime histogram per queue from datetime histogram where every
bucket contains all the queues metrics
"""
queues_data = defaultdict(list)
for date_data in date_metrics:
timestamp = date_data["timestamp"]
for queue, metrics in date_data["queue_metrics"].items():
queues_data[queue].append({"date": timestamp, **metrics})
return queues_data
@classmethod
def _average_datetime_histogram(cls, date_metrics: Sequence[dict]) -> dict:
"""
Calculate weighted averages and total count for each bucket of date_metrics histogram.
If for any queue the data is missing then take it from the previous bucket
The result is returned as a dictionary with one key 'total'
"""
queues_total = []
last_values = {}
for date_data in date_metrics:
date_metrics = date_data["queue_metrics"]
queue_metrics = {
**date_metrics,
**{k: v for k, v in last_values.items() if k not in date_metrics},
}
total_length = sum(m["queue_length"] for m in queue_metrics.values())
if total_length:
total_average = sum(
m["avg_waiting_time"] * m["queue_length"] / total_length
for m in queue_metrics.values()
)
else:
total_average = 0
queues_total.append(
dict(
date=date_data["timestamp"],
avg_waiting_time=total_average,
queue_length=total_length,
)
)
for k, v in date_metrics.items():
last_values[k] = v
return dict(total=queues_total)
@classmethod
def _extract_queue_metrics(cls, queue_buckets: Sequence[dict]) -> dict:
"""
Extract ES data for single date and queue bucket
"""
queue_metrics = dict()
for queue_data in queue_buckets:
if not queue_data["doc_count"]:
continue
res = queue_data["top_avg_waiting"]["hits"]["hits"][0]["_source"]
queue_metrics[queue_data["key"]] = {
"queue_length": res[cls.EsKeys.QUEUE_LENGTH_FIELD],
"avg_waiting_time": res[cls.EsKeys.WAITING_TIME_FIELD],
}
return queue_metrics

View File

@@ -0,0 +1,87 @@
from datetime import datetime
import operator
from threading import Thread, Lock
from time import sleep
import attr
import psutil
class ResourceMonitor(Thread):
@attr.s(auto_attribs=True)
class Sample:
cpu_usage: float = 0.0
mem_used_gb: float = 0
mem_free_gb: float = 0
@classmethod
def _apply(cls, op, *samples):
return cls(
**{
field: op(*(getattr(sample, field) for sample in samples))
for field in attr.fields_dict(cls)
}
)
def min(self, sample):
return self._apply(min, self, sample)
def max(self, sample):
return self._apply(max, self, sample)
def avg(self, sample, count):
res = self._apply(lambda x: x * count, self)
res = self._apply(operator.add, res, sample)
res = self._apply(lambda x: x / (count + 1), res)
return res
def __init__(self, sample_interval_sec=5):
super(ResourceMonitor, self).__init__(daemon=True)
self.sample_interval_sec = sample_interval_sec
self._lock = Lock()
self._clear()
def _clear(self):
sample = self._get_sample()
self._avg = sample
self._min = sample
self._max = sample
self._clear_time = datetime.utcnow()
self._count = 1
@classmethod
def _get_sample(cls) -> Sample:
return cls.Sample(
cpu_usage=psutil.cpu_percent(),
mem_used_gb=psutil.virtual_memory().used / (1024 ** 3),
mem_free_gb=psutil.virtual_memory().free / (1024 ** 3),
)
def run(self):
while True:
sample = self._get_sample()
with self._lock:
self._min = self._min.min(sample)
self._max = self._max.max(sample)
self._avg = self._avg.avg(sample, self._count)
self._count += 1
sleep(self.sample_interval_sec)
def get_stats(self) -> dict:
""" Returns current resource statistics and clears internal resource statistics """
with self._lock:
min_ = attr.asdict(self._min)
max_ = attr.asdict(self._max)
avg = attr.asdict(self._avg)
res = {
"interval_sec": (datetime.utcnow() - self._clear_time).total_seconds(),
"num_cores": psutil.cpu_count(),
**{
k: {"min": v, "max": max_[k], "avg": avg[k]}
for k, v in min_.items()
}
}
self._clear()
return res

View File

@@ -0,0 +1,306 @@
import logging
import queue
import random
import time
from datetime import timedelta, datetime
from time import sleep
from typing import Sequence, Optional
import dpath
import requests
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from bll.query import Builder as QueryBuilder
from bll.util import get_server_uuid
from bll.workers import WorkerStats, WorkerBLL
from config import config
from config.info import get_deployment_type
from database.model import Company, User
from database.model.queue import Queue
from database.model.task.task import Task
from utilities import safe_get
from utilities.json import dumps
from utilities.threads_manager import ThreadsManager
from version import __version__ as current_version
from .resource_monitor import ResourceMonitor
log = config.logger(__file__)
worker_bll = WorkerBLL()
class StatisticsReporter:
threads = ThreadsManager("Statistics", resource_monitor=ResourceMonitor)
send_queue = queue.Queue()
supported = config.get("apiserver.statistics.supported", True)
@classmethod
def start(cls):
cls.start_sender()
cls.start_reporter()
@classmethod
@threads.register("reporter", daemon=True)
def start_reporter(cls):
"""
Periodically send statistics reports for companies who have opted in.
Note: in trains we usually have only a single company
"""
if not cls.supported:
return
report_interval = timedelta(
hours=config.get("apiserver.statistics.report_interval_hours", 24)
)
while True:
sleep(report_interval.total_seconds())
try:
for company in Company.objects(
defaults__stats_option__enabled=True
).only("id"):
stats = cls.get_statistics(company.id)
cls.send_queue.put(stats)
except Exception as ex:
log.exception(f"Failed collecting stats: {str(ex)}")
@classmethod
@threads.register("sender", daemon=True)
def start_sender(cls):
if not cls.supported:
return
url = config.get("apiserver.statistics.url")
retries = config.get("apiserver.statistics.max_retries", 5)
max_backoff = config.get("apiserver.statistics.max_backoff_sec", 5)
session = requests.Session()
adapter = HTTPAdapter(max_retries=Retry(retries))
session.mount("http://", adapter)
session.mount("https://", adapter)
session.headers["Content-type"] = "application/json"
WarningFilter.attach()
while True:
try:
report = cls.send_queue.get()
# Set a random backoff factor each time we send a report
adapter.max_retries.backoff_factor = random.random() * max_backoff
session.post(url, data=dumps(report))
except Exception as ex:
pass
@classmethod
def get_statistics(cls, company_id: str) -> dict:
"""
Returns a statistics report per company
"""
return {
"time": datetime.utcnow(),
"company_id": company_id,
"server": {
"version": current_version,
"deployment": get_deployment_type(),
"uuid": get_server_uuid(),
"queues": {"count": Queue.objects(company=company_id).count()},
"users": {"count": User.objects(company=company_id).count()},
"resources": cls.threads.resource_monitor.get_stats(),
"experiments": next(
iter(cls._get_experiments_stats(company_id).values()), {}
),
},
"agents": cls._get_agents_statistics(company_id),
}
@classmethod
def _get_agents_statistics(cls, company_id: str) -> Sequence[dict]:
result = cls._get_resource_stats_per_agent(company_id, key="resources")
dpath.merge(
result, cls._get_experiments_stats_per_agent(company_id, key="experiments")
)
return [{"uuid": agent_id, **data} for agent_id, data in result.items()]
@classmethod
def _get_resource_stats_per_agent(cls, company_id: str, key: str) -> dict:
agent_resource_threshold_sec = timedelta(
hours=config.get("apiserver.statistics.report_interval_hours", 24)
).total_seconds()
to_timestamp = int(time.time())
from_timestamp = to_timestamp - int(agent_resource_threshold_sec)
es_req = {
"size": 0,
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {
"categories": {
"terms": {"field": "category"},
"aggs": {"count": {"cardinality": {"field": "variant"}}},
},
"metrics": {
"terms": {"field": "metric"},
"aggs": {
"min": {"min": {"field": "value"}},
"max": {"max": {"field": "value"}},
"avg": {"avg": {"field": "value"}},
},
},
},
}
},
}
res = cls._run_worker_stats_query(company_id, es_req)
def _get_cardinality_fields(categories: Sequence[dict]) -> dict:
names = {"cpu": "num_cores"}
return {
names[c["key"]]: safe_get(c, "count/value")
for c in categories
if c["key"] in names
}
def _get_metric_fields(metrics: Sequence[dict]) -> dict:
names = {
"cpu_usage": "cpu_usage",
"memory_used": "mem_used_gb",
"memory_free": "mem_free_gb",
}
return {
names[m["key"]]: {
"min": safe_get(m, "min/value"),
"max": safe_get(m, "max/value"),
"avg": safe_get(m, "avg/value"),
}
for m in metrics
if m["key"] in names
}
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
return {
b["key"]: {
key: {
"interval_sec": agent_resource_threshold_sec,
**_get_cardinality_fields(safe_get(b, "categories/buckets", [])),
**_get_metric_fields(safe_get(b, "metrics/buckets", [])),
}
}
for b in buckets
}
@classmethod
def _get_experiments_stats_per_agent(cls, company_id: str, key: str) -> dict:
agent_relevant_threshold = timedelta(
days=config.get("apiserver.statistics.agent_relevant_threshold_days", 30)
)
to_timestamp = int(time.time())
from_timestamp = to_timestamp - int(agent_relevant_threshold.total_seconds())
workers = cls._get_active_workers(company_id, from_timestamp, to_timestamp)
if not workers:
return {}
stats = cls._get_experiments_stats(company_id, list(workers.keys()))
return {
worker_id: {key: {**workers[worker_id], **stat}}
for worker_id, stat in stats.items()
}
@classmethod
def _get_active_workers(
cls, company_id, from_timestamp: int, to_timestamp: int
) -> dict:
es_req = {
"size": 0,
"query": QueryBuilder.dates_range(from_timestamp, to_timestamp),
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {"last_activity_time": {"max": {"field": "timestamp"}}},
}
},
}
res = cls._run_worker_stats_query(company_id, es_req)
buckets = safe_get(res, "aggregations/workers/buckets", default=[])
return {
b["key"]: {"last_activity_time": b["last_activity_time"]["value"]}
for b in buckets
}
@classmethod
def _run_worker_stats_query(cls, company_id, es_req) -> dict:
return worker_bll.es_client.search(
index=f"{WorkerStats.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req,
)
@classmethod
def _get_experiments_stats(
cls, company_id, workers: Optional[Sequence] = None
) -> dict:
pipeline = [
{
"$match": {
"company": company_id,
"started": {"$exists": True, "$ne": None},
"last_update": {"$exists": True, "$ne": None},
"status": {"$nin": ["created", "queued"]},
**({"last_worker": {"$in": workers}} if workers else {}),
}
},
{
"$group": {
"_id": "$last_worker" if workers else None,
"count": {"$sum": 1},
"avg_run_time_sec": {
"$avg": {
"$divide": [
{"$subtract": ["$last_update", "$started"]},
1000,
]
}
},
"avg_iterations": {"$avg": "$last_iteration"},
}
},
{
"$project": {
"count": 1,
"avg_run_time_sec": {"$trunc": "$avg_run_time_sec"},
"avg_iterations": {"$trunc": "$avg_iterations"},
}
},
]
return {
group["_id"]: {k: v for k, v in group.items() if k != "_id"}
for group in Task.aggregate(*pipeline)
}
class WarningFilter(logging.Filter):
@classmethod
def attach(cls):
from urllib3.connectionpool import (
ConnectionPool,
) # required to make sure the logger is created
assert ConnectionPool # make sure import is not optimized out
logging.getLogger("urllib3.connectionpool").addFilter(cls())
def filter(self, record):
if (
record.levelno == logging.WARNING
and len(record.args) > 2
and record.args[2] == "/stats"
):
return False
return True

View File

@@ -29,7 +29,7 @@ from .utils import ChangeStatusRequest, validate_status_change
class TaskBLL(object):
threads = ThreadsManager()
threads = ThreadsManager("TaskBLL")
def __init__(self, events_es=None):
self.events_es = (
@@ -208,7 +208,7 @@ class TaskBLL(object):
]
with translate_errors_context():
result = Task.objects.aggregate(*pipeline)
result = Task.aggregate(*pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
@@ -376,11 +376,27 @@ class TaskBLL(object):
task = TaskBLL.get_task_with_access(
task_id,
company_id=company_id,
only=("status", "project", "tags", "system_tags", "last_update"),
only=(
"status",
"project",
"tags",
"system_tags",
"last_worker",
"last_update",
),
requires_write_access=True,
)
if TaskSystemTags.development in task.system_tags:
def is_run_by_worker(t: Task) -> bool:
"""Checks if there is an active worker running the task"""
update_timeout = config.get("apiserver.workers.task_update_timeout", 600)
return (
t.last_worker
and t.last_update
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
@@ -486,7 +502,10 @@ class TaskBLL(object):
]
with translate_errors_context():
result = next(Task.objects.aggregate(*pipeline), None)
result = next(
Task.aggregate(*pipeline),
None,
)
total = 0
remaining = 0

View File

@@ -7,7 +7,7 @@ import six
from apierrors import errors
from database.errors import translate_errors_context
from database.model.project import Project
from database.model.task.task import Task, TaskStatus
from database.model.task.task import Task, TaskStatus, TaskSystemTags
from database.utils import get_options
from timing_context import TimingContext
from utilities.attrs import typed_attrs
@@ -25,9 +25,10 @@ class ChangeStatusRequest(object):
status_message = attr.ib(type=six.string_types, default="")
force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True)
current_status_override = attr.ib(default=None)
def execute(self, **kwargs):
current_status = self.task.status
current_status = self.current_status_override or self.task.status
project_id = self.task.project
# Verify new status is allowed from current status (will throw exception if not valid)
@@ -44,6 +45,9 @@ class ChangeStatusRequest(object):
last_update=now,
)
if self.new_status == TaskStatus.queued:
fields["pull__system_tags"] = TaskSystemTags.development
def safe_mongoengine_key(key):
return f"__{key}" if key in control else key
@@ -99,7 +103,8 @@ def validate_status_change(current_status, new_status):
state_machine = {
TaskStatus.created: {TaskStatus.in_progress},
TaskStatus.created: {TaskStatus.queued, TaskStatus.in_progress},
TaskStatus.queued: {TaskStatus.created, TaskStatus.in_progress},
TaskStatus.in_progress: {
TaskStatus.stopped,
TaskStatus.failed,
@@ -129,7 +134,7 @@ state_machine = {
TaskStatus.published,
TaskStatus.in_progress,
TaskStatus.created,
}
},
}

View File

@@ -1,5 +1,9 @@
import functools
from operator import itemgetter
from typing import Sequence, Optional, Callable, Tuple
from typing import Sequence, Optional, Callable, Tuple, Dict, Any, Set
from database.model import AttributedDocument
from database.model.settings import Settings
def extract_properties_to_lists(
@@ -18,3 +22,52 @@ def extract_properties_to_lists(
"""
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
return dict(zip(key_names, map(list, value_sequences)))
class SetFieldsResolver:
"""
The class receives set fields dictionary
and for the set fields that require 'min' or 'max'
operation replace them with a simple set in case the
DB document does not have these fields set
"""
SET_MODIFIERS = ("min", "max")
def __init__(self, set_fields: Dict[str, Any]):
self.orig_fields = set_fields
self.fields = {
f: fname
for f, modifier, dunder, fname in (
(f,) + f.partition("__") for f in set_fields.keys()
)
if dunder and modifier in self.SET_MODIFIERS
}
def _get_updated_name(self, doc: AttributedDocument, name: str) -> str:
if name in self.fields and doc.get_field_value(self.fields[name]) is None:
return self.fields[name]
return name
def get_fields(self, doc: AttributedDocument):
"""
For the given document return the set fields instructions
with min/max operations replaced with a single set in case
the document does not have the field set
"""
return {
self._get_updated_name(doc, name): value
for name, value in self.orig_fields.items()
}
def get_names(self) -> Set[str]:
"""
Returns the names of the fields that had min/max modifiers
in the format suitable for projection (dot separated)
"""
return set(name.replace("__", ".") for name in self.fields.values())
@functools.lru_cache()
def get_server_uuid() -> Optional[str]:
return Settings.get_by_key("server.uuid")

View File

@@ -0,0 +1,422 @@
import itertools
from datetime import datetime, timedelta
from typing import Sequence, Set, Optional
import attr
import elasticsearch.helpers
import es_factory
from apierrors import APIError
from apierrors.errors import bad_request, server_error
from apimodels.workers import (
DEFAULT_TIMEOUT,
IdNameEntry,
WorkerEntry,
StatusReportRequest,
WorkerResponseEntry,
QueueEntry,
MachineStats,
)
from config import config
from database.errors import translate_errors_context
from database.model.auth import User
from database.model.company import Company
from database.model.queue import Queue
from database.model.task.task import Task
from redis_manager import redman
from timing_context import TimingContext
from tools import safe_get
from .stats import WorkerStats
log = config.logger(__file__)
class WorkerBLL:
def __init__(self, es=None, redis=None):
self.es_client = es if es is not None else es_factory.connect("workers")
self.redis = redis if redis is not None else redman.connection("workers")
self._stats = WorkerStats(self.es_client)
@property
def stats(self) -> WorkerStats:
return self._stats
def register_worker(
self,
company_id: str,
user_id: str,
worker: str,
ip: str = "",
queues: Sequence[str] = None,
timeout: int = 0,
) -> WorkerEntry:
"""
Register a worker
:param company_id: worker's company ID
:param user_id: user ID under which this worker is running
:param worker: worker ID
:param ip: the real ip of the worker
:param queues: queues reported as being monitored by the worker
:param timeout: registration expiration timeout in seconds
:raise bad_request.InvalidUserId: in case the calling user or company does not exist
:return: worker entry instance
"""
key = WorkerBLL._get_worker_key(company_id, user_id, worker)
timeout = timeout or DEFAULT_TIMEOUT
queues = queues or []
with translate_errors_context():
query = dict(id=user_id, company=company_id)
user = User.objects(**query).only("id", "name").first()
if not user:
raise bad_request.InvalidUserId(**query)
company = Company.objects(id=company_id).only("id", "name").first()
if not company:
raise server_error.InternalError("invalid company", company=company_id)
queue_objs = Queue.objects(company=company_id, id__in=queues).only("id")
if len(queue_objs) < len(queues):
invalid = set(queues).difference(q.id for q in queue_objs)
raise bad_request.InvalidQueueId(ids=invalid)
now = datetime.utcnow()
entry = WorkerEntry(
key=key,
id=worker,
user=user.to_proper_dict(),
company=company.to_proper_dict(),
ip=ip,
queues=queues,
register_time=now,
register_timeout=timeout,
last_activity_time=now,
)
self.redis.setex(key, timedelta(seconds=timeout), entry.to_json())
return entry
def unregister_worker(self, company_id: str, user_id: str, worker: str) -> None:
"""
Unregister a worker
:param company_id: worker's company ID
:param user_id: user ID under which this worker is running
:param worker: worker ID
:raise bad_request.WorkerNotRegistered: the worker was not previously registered
"""
with TimingContext("redis", "workers_unregister"):
res = self.redis.delete(
company_id, self._get_worker_key(company_id, user_id, worker)
)
if not res:
raise bad_request.WorkerNotRegistered(worker=worker)
def status_report(
self, company_id: str, user_id: str, ip: str, report: StatusReportRequest
) -> None:
"""
Write worker status report
:param company_id: worker's company ID
:param user_id: user_id ID under which this worker is running
:raise bad_request.InvalidTaskId: the reported task was not found
:return: worker entry instance
"""
entry = self._get_worker(company_id, user_id, report.worker)
try:
entry.ip = ip
now = datetime.utcnow()
entry.last_activity_time = now
if report.machine_stats:
self._log_stats_to_es(
company_id=company_id,
company_name=entry.company.name,
worker=report.worker,
timestamp=report.timestamp,
task=report.task,
machine_stats=report.machine_stats,
)
entry.queue = report.queue
if report.queues:
entry.queues = report.queues
if not report.task:
entry.task = None
else:
with translate_errors_context():
query = dict(id=report.task, company=company_id)
update = dict(
last_worker=report.worker,
last_worker_report=now,
last_update=now,
)
# modify(new=True, ...) returns the modified object
task = Task.objects(**query).modify(new=True, **update)
if not task:
raise bad_request.InvalidTaskId(**query)
entry.task = IdNameEntry(id=task.id, name=task.name)
entry.last_report_time = now
except APIError:
raise
except Exception as e:
msg = "Failed processing worker status report"
log.exception(msg)
raise server_error.DataError(msg, err=e.args[0])
finally:
self._save_worker(entry)
def get_all(
self, company_id: str, last_seen: Optional[int] = None
) -> Sequence[WorkerEntry]:
"""
Get all the company workers that were active during the last_seen period
:param company_id: worker's company id
:param last_seen: period in seconds to check. Min value is 1 second
:return:
"""
try:
workers = self._get(company_id)
except Exception as e:
raise server_error.DataError("failed loading worker entries", err=e.args[0])
if last_seen:
ref_time = datetime.utcnow() - timedelta(seconds=max(1, last_seen))
workers = [
w
for w in workers
if w.last_activity_time.replace(tzinfo=None) >= ref_time
]
return workers
def get_all_with_projection(
self, company_id: str, last_seen: int
) -> Sequence[WorkerResponseEntry]:
helpers = list(
map(
WorkerConversionHelper.from_worker_entry,
self.get_all(company_id=company_id, last_seen=last_seen),
)
)
task_ids = set(filter(None, (helper.task_id for helper in helpers)))
all_queues = set(
itertools.chain.from_iterable(helper.queue_ids for helper in helpers)
)
queues_info = {}
if all_queues:
projection = [
{"$match": {"_id": {"$in": list(all_queues)}}},
{
"$project": {
"name": 1,
"next_entry": {"$arrayElemAt": ["$entries", 0]},
"num_entries": {"$size": "$entries"},
}
},
]
queues_info = {
res["_id"]: res for res in Queue.objects.aggregate(*projection)
}
task_ids = task_ids.union(
filter(
None,
(
safe_get(info, "next_entry/task")
for info in queues_info.values()
),
)
)
tasks_info = {}
if task_ids:
tasks_info = {
task.id: task
for task in Task.objects(id__in=task_ids).only(
"name", "started", "last_iteration"
)
}
def update_queue_entries(*entries):
for entry in entries:
if not entry:
continue
info = queues_info.get(entry.id, None)
if not info:
continue
entry.name = info.get("name", None)
entry.num_tasks = info.get("num_entries", 0)
task_id = safe_get(info, "next_entry/task")
if task_id:
task = tasks_info.get(task_id, None)
entry.next_task = IdNameEntry(
id=task_id, name=task.name if task else None
)
for helper in helpers:
worker = helper.worker
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (
int((datetime.utcnow() - task.started).total_seconds() * 1000)
if task.started
else 0
)
worker.task.last_iteration = task.last_iteration
update_queue_entries(worker.queue)
if worker.queues:
update_queue_entries(*worker.queues)
return [helper.worker for helper in helpers]
@staticmethod
def _get_worker_key(company: str, user: str, worker_id: str) -> str:
"""Build redis key from company, user and worker_id"""
return f"worker_{company}_{user}_{worker_id}"
def _get_worker(self, company_id: str, user_id: str, worker: str) -> WorkerEntry:
"""
Get a worker entry for the provided worker ID. The entry is loaded from Redis
if it exists (i.e. worker has already been registered), otherwise the worker
is registered and its entry stored into Redis).
:param company_id: worker's company ID
:param user_id: user ID under which this worker is running
:param worker: worker ID
:raise bad_request.InvalidWorkerId: in case the worker id was not found
:return: worker entry instance
"""
key = self._get_worker_key(company_id, user_id, worker)
with TimingContext("redis", "get_worker"):
data = self.redis.get(key)
if data:
try:
entry = WorkerEntry.from_json(data)
if not entry.key:
entry.key = key
self._save_worker(entry)
return entry
except Exception as e:
msg = "Failed parsing worker entry"
log.exception(msg)
raise server_error.DataError(msg, err=e.args[0])
# Failed loading worker from Redis
if config.get("apiserver.workers.auto_register", False):
try:
return self.register_worker(company_id, user_id, worker)
except Exception:
log.error(
"Failed auto registration of {} for company {}".format(
worker, company_id
)
)
raise bad_request.InvalidWorkerId(worker=worker)
def _save_worker(self, entry: WorkerEntry) -> None:
"""Save worker entry in Redis"""
try:
self.redis.setex(
entry.key, timedelta(seconds=entry.register_timeout), entry.to_json()
)
except Exception:
msg = "Failed saving worker entry"
log.exception(msg)
def _get(
self, company: str, user: str = "*", worker_id: str = "*"
) -> Sequence[WorkerEntry]:
"""Get worker entries matching the company and user, worker patterns"""
match = self._get_worker_key(company, user, worker_id)
with TimingContext("redis", "workers_get_all"):
res = self.redis.scan_iter(match)
return [WorkerEntry.from_json(self.redis.get(r)) for r in res]
@staticmethod
def _get_es_index_suffix():
"""Get the index name suffix for storing current month data"""
return datetime.utcnow().strftime("%Y-%m")
def _log_stats_to_es(
self,
company_id: str,
company_name: str,
worker: str,
timestamp: int,
task: str,
machine_stats: MachineStats,
) -> bool:
"""
Actually writing the worker statistics to Elastic
:return: True if successful, False otherwise
"""
es_index = (
f"{self._stats.worker_stats_prefix_for_company(company_id)}"
f"{self._get_es_index_suffix()}"
)
def make_doc(category, metric, variant, value) -> dict:
return dict(
_index=es_index,
_type="stat",
_source=dict(
timestamp=timestamp,
worker=worker,
company=company_name,
task=task,
category=category,
metric=metric,
variant=variant,
value=float(value),
),
)
actions = []
for field, value in machine_stats.to_struct().items():
if not value:
continue
category = field.partition("_")[0]
metric = field
if not isinstance(value, (list, tuple)):
actions.append(make_doc(category, metric, "total", value))
else:
actions.extend(
make_doc(category, metric, str(i), val)
for i, val in enumerate(value)
)
es_res = elasticsearch.helpers.bulk(self.es_client, actions)
added, errors = es_res[:2]
return (added == len(actions)) and not errors
@attr.s(auto_attribs=True)
class WorkerConversionHelper:
worker: WorkerResponseEntry
task_id: str
queue_ids: Set[str]
@classmethod
def from_worker_entry(cls, worker: WorkerEntry):
data = worker.to_struct()
queue = data.pop("queue", None) or None
queue_ids = set(data.pop("queues", []))
queues = [QueueEntry(id=id) for id in queue_ids]
if queue:
queue = next((q for q in queues if q.id == queue), None)
return cls(
worker=WorkerResponseEntry(queues=queues, queue=queue, **data),
task_id=worker.task.id if worker.task else None,
queue_ids=queue_ids,
)

244
server/bll/workers/stats.py Normal file
View File

@@ -0,0 +1,244 @@
from operator import attrgetter
from typing import Optional, Sequence
from boltons.iterutils import bucketize
from apierrors.errors import bad_request
from apimodels.workers import AggregationType, GetStatsRequest, StatItem
from bll.query import Builder as QueryBuilder
from config import config
from database.errors import translate_errors_context
from timing_context import TimingContext
log = config.logger(__file__)
class WorkerStats:
def __init__(self, es):
self.es = es
@staticmethod
def worker_stats_prefix_for_company(company_id: str) -> str:
"""Returns the es index prefix for the company"""
return f"worker_stats_{company_id}_"
def _search_company_stats(self, company_id: str, es_req: dict) -> dict:
return self.es.search(
index=f"{self.worker_stats_prefix_for_company(company_id)}*",
doc_type="stat",
body=es_req,
)
def get_worker_stats_keys(
self, company_id: str, worker_ids: Optional[Sequence[str]]
) -> dict:
"""
Get dictionary of metric types grouped by categories
:param company_id: company id
:param worker_ids: optional list of workers to get metric types from.
If not specified them metrics for all the company workers returned
:return:
"""
es_req = {
"size": 0,
"aggs": {
"categories": {
"terms": {"field": "category"},
"aggs": {"metrics": {"terms": {"field": "metric"}}},
}
},
}
if worker_ids:
es_req["query"] = QueryBuilder.terms("worker", worker_ids)
res = self._search_company_stats(company_id, es_req)
if not res["hits"]["total"]:
raise bad_request.WorkerStatsNotFound(
f"No statistic metrics found for the company {company_id} and workers {worker_ids}"
)
return {
category["key"]: [
metric["key"] for metric in category["metrics"]["buckets"]
]
for category in res["aggregations"]["categories"]["buckets"]
}
def get_worker_stats(self, company_id: str, request: GetStatsRequest) -> dict:
"""
Get statistics for company workers metrics in the specified time range
Returned as date histograms for different aggregation types
grouped by worker, metric type (and optionally metric variant)
Buckets with no metrics are not returned
Note: all the statistics are retrieved as one ES query
"""
if request.from_date >= request.to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
def get_dates_agg() -> dict:
es_to_agg_types = (
("avg", AggregationType.avg.value),
("min", AggregationType.min.value),
("max", AggregationType.max.value),
)
return {
"dates": {
"date_histogram": {
"field": "timestamp",
"interval": f"{request.interval}s",
"min_doc_count": 1,
},
"aggs": {
agg_type: {es_agg: {"field": "value"}}
for es_agg, agg_type in es_to_agg_types
},
}
}
def get_variants_agg() -> dict:
return {
"variants": {"terms": {"field": "variant"}, "aggs": get_dates_agg()}
}
es_req = {
"size": 0,
"aggs": {
"workers": {
"terms": {"field": "worker"},
"aggs": {
"metrics": {
"terms": {"field": "metric"},
"aggs": get_variants_agg()
if request.split_by_variant
else get_dates_agg(),
}
},
}
},
}
query_terms = [
QueryBuilder.dates_range(request.from_date, request.to_date),
QueryBuilder.terms("metric", {item.key for item in request.items}),
]
if request.worker_ids:
query_terms.append(QueryBuilder.terms("worker", request.worker_ids))
es_req["query"] = {"bool": {"must": query_terms}}
with translate_errors_context(), TimingContext("es", "get_worker_stats"):
data = self._search_company_stats(company_id, es_req)
return self._extract_results(data, request.items, request.split_by_variant)
@staticmethod
def _extract_results(
data: dict, request_items: Sequence[StatItem], split_by_variant: bool
) -> dict:
"""
Clean results returned from elastic search (remove "aggregations", "buckets" etc.),
leave only aggregation types requested by the user and return a clean dictionary
and return a "clean" dictionary of
:param data: aggregation data retrieved from ES
:param request_items: aggs types requested by the user
:param split_by_variant: if False then aggregate by metric type, otherwise metric type + variant
"""
if "aggregations" not in data:
return {}
items_by_key = bucketize(request_items, key=attrgetter("key"))
aggs_per_metric = {
key: [item.aggregation for item in items]
for key, items in items_by_key.items()
}
def extract_date_stats(date: dict, metric_key) -> dict:
return {
"date": date["key"],
"count": date["doc_count"],
**{agg: date[agg]["value"] for agg in aggs_per_metric[metric_key]},
}
def extract_metric_results(
metric_or_variant: dict, metric_key: str
) -> Sequence[dict]:
return [
extract_date_stats(date, metric_key)
for date in metric_or_variant["dates"]["buckets"]
if date["doc_count"]
]
def extract_variant_results(metric: dict) -> dict:
metric_key = metric["key"]
return {
variant["key"]: extract_metric_results(variant, metric_key)
for variant in metric["variants"]["buckets"]
}
def extract_worker_results(worker: dict) -> dict:
return {
metric["key"]: extract_variant_results(metric)
if split_by_variant
else extract_metric_results(metric, metric["key"])
for metric in worker["metrics"]["buckets"]
}
return {
worker["key"]: extract_worker_results(worker)
for worker in data["aggregations"]["workers"]["buckets"]
}
def get_activity_report(
self,
company_id: str,
from_date: float,
to_date: float,
interval: int,
active_only: bool,
) -> Sequence[dict]:
"""
Get statistics for company workers metrics in the specified time range
Returned as date histograms for different aggregation types
grouped by worker, metric type (and optionally metric variant)
Note: all the statistics are retrieved using one ES query
"""
if from_date >= to_date:
raise bad_request.FieldsValueError("from_date must be less than to_date")
must = [QueryBuilder.dates_range(from_date, to_date)]
if active_only:
must.append({"exists": {"field": "task"}})
es_req = {
"size": 0,
"aggs": {
"dates": {
"date_histogram": {
"field": "timestamp",
"interval": f"{interval}s",
},
"aggs": {"workers_count": {"cardinality": {"field": "worker"}}},
}
},
"query": {"bool": {"must": must}},
}
with translate_errors_context(), TimingContext(
"es", "get_worker_activity_report"
):
data = self._search_company_stats(company_id, es_req)
if "aggregations" not in data:
return {}
ret = [
dict(date=date["key"], count=date["workers_count"]["value"])
for date in data["aggregations"]["dates"]["buckets"]
]
if ret and ret[-1]["date"] > (to_date - 0.9 * interval):
# remove last interval if it's incomplete. Allow 10% tolerance
ret.pop()
return ret

View File

@@ -30,6 +30,10 @@
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
strict: false
aggregate {
allow_disk_use: true
}
}
auth {
@@ -75,4 +79,40 @@
}
default_company: "d1bd92a3b039400cbafc60a7a5b1e52b"
workers {
# Auto-register unknown workers on status reports and other calls
auto_register: true
# Timeout in seconds on task status update. If exceeded
# then task can be stopped without communicating to the worker
task_update_timeout: 600
}
check_for_updates {
enabled: true
# Check for updates every 24 hours
check_interval_sec: 86400
url: "https://updates.trains.allegro.ai/updates"
component_name: "trains-server"
# GET request timeout
request_timeout_sec: 3.0
}
statistics {
# Note: statistics are sent ONLY if the user has actively opted-in
supported: true
url: "https://updates.trains.allegro.ai/stats"
report_interval_hours: 24
agent_relevant_threshold_days: 30
max_retries: 5
max_backoff_sec: 5
}
}

View File

@@ -9,6 +9,17 @@ elastic {
}
index_version: "1"
}
workers {
hosts: [{host:"127.0.0.1", port:9200}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
retry_on_timeout: true
}
index_version: "1"
}
}
mongo {
@@ -19,3 +30,11 @@ mongo {
host: "mongodb://127.0.0.1:27017/auth"
}
}
redis {
workers {
host: "127.0.0.1"
port: 6379
db: 4
}
}

View File

@@ -1,5 +1,6 @@
from functools import lru_cache
from pathlib import Path
from os import getenv
root = Path(__file__).parent.parent
@@ -26,3 +27,17 @@ def get_commit_number():
return (root / "COMMIT").read_text().strip()
except FileNotFoundError:
return ""
@lru_cache()
def get_deployment_type() -> str:
value = getenv("TRAINS_SERVER_DEPLOYMENT_TYPE")
if value:
return value
try:
value = (root / "DEPLOY").read_text().strip()
except FileNotFoundError:
pass
return value or "manual"

View File

@@ -1,5 +1,6 @@
from os import getenv
from boltons.iterutils import first
from furl import furl
from jsonmodels import models
from jsonmodels.errors import ValidationError
@@ -11,14 +12,16 @@ from config import config
from .defs import Database
from .utils import get_items
from boltons.iterutils import first
log = config.logger("database")
strict = config.get("apiserver.mongo.strict", True)
OVERRIDE_HOST_ENV_KEY = ("MONGODB_SERVICE_HOST", "MONGODB_SERVICE_SERVICE_HOST")
OVERRIDE_PORT_ENV_KEY = "MONGODB_SERVICE_PORT"
OVERRIDE_HOST_ENV_KEY = (
"TRAINS_MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_HOST",
"MONGODB_SERVICE_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_MONGODB_SERVICE_PORT", "MONGODB_SERVICE_PORT")
_entries = []
@@ -41,7 +44,7 @@ def initialize():
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
override_port = getenv(OVERRIDE_PORT_ENV_KEY)
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
if override_port:
log.info(f"Using override mongodb port {override_port}")

View File

@@ -1,12 +1,12 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence
from typing import Collection, Sequence, Union
from boltons.iterutils import first
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document
from six import string_types
from mongoengine import Q, Document, ListField, StringField
from pymongo.command_cursor import CommandCursor
from apierrors import errors
from config import config
@@ -16,9 +16,9 @@ from database.props import PropsMixin
from database.query import RegexQ, RegexWrapper
from database.utils import (
get_company_or_none_constraint,
get_fields_with_attr,
field_exists,
get_fields_choices,
field_does_not_exist,
field_exists,
)
log = config.logger("dbmodel")
@@ -62,6 +62,7 @@ class GetMixin(PropsMixin):
_text_score = "$text_score"
_ordering_key = "order_by"
_search_text_key = "search_text"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
@@ -221,6 +222,24 @@ class GetMixin(PropsMixin):
return get_company_or_none_constraint(company)
return Q(company=company)
@classmethod
def validate_order_by(cls, parameters, search_text) -> Sequence:
"""
Validate and extract order_by params as a list
"""
order_by = parameters.get(cls._ordering_key)
if not order_by:
return []
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
if not search_text and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
return order_by
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
@@ -267,7 +286,6 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
override_none_ordering=False,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
@@ -303,7 +321,6 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
override_none_ordering=override_none_ordering,
)
def projection_func(doc_type, projection, ids):
@@ -328,7 +345,6 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
override_none_ordering=False,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
@@ -341,8 +357,9 @@ class GetMixin(PropsMixin):
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
be raised.
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
requested, each contains only the requested projection). If False, a QuerySet object is returned
(lazy evaluated). If return_dicts is requested then the entities with the None value in order_by field
are returned last in the ordering.
:param company: Company ID (required)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
@@ -352,8 +369,6 @@ class GetMixin(PropsMixin):
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:param override_none_ordering: If True, then items with the None values in the first ordered field
are always sorted in the end
:return: A list of objects matching the query.
"""
if query_dict is not None:
@@ -367,26 +382,19 @@ class GetMixin(PropsMixin):
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
if override_none_ordering:
if return_dicts:
return cls._get_many_override_none_ordering(
query=_query,
parameters=parameters,
query_dict=query_dict,
query_options=query_options,
override_projection=override_projection,
)
return cls._get_many_no_company(
query=_query,
parameters=parameters,
override_projection=override_projection,
return_dicts=return_dicts,
query=_query, parameters=parameters, override_projection=override_projection
)
@classmethod
def _get_many_no_company(
cls, query, parameters=None, override_projection=None, return_dicts=True
):
def _get_many_no_company(cls, query, parameters=None, override_projection=None):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
@@ -395,44 +403,25 @@ class GetMixin(PropsMixin):
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
:param query: Query object (mongoengine.Q)
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
"""
parameters = parameters or {}
if not query:
raise ValueError("query or call_data must be provided")
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
order_by = parameters.get(cls._ordering_key)
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
search_text = parameters.get("search_text")
only = cls.get_projection(parameters, override_projection)
if not search_text and order_by and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
qs = cls.objects(query)
if search_text:
qs = qs.search_text(search_text)
if order_by:
# add ordering
qs = (
qs.order_by(order_by)
if isinstance(order_by, string_types)
else qs.order_by(*order_by)
)
qs = qs.order_by(*order_by)
if only:
# add projection
qs = qs.only(*only)
@@ -444,17 +433,13 @@ class GetMixin(PropsMixin):
# add paging
qs = qs.skip(page * page_size).limit(page_size)
if return_dicts:
return [obj.to_proper_dict(only=only) for obj in qs]
return qs
@classmethod
def _get_many_override_none_ordering(
cls,
cls: Union[Document, "GetMixin"],
query: Q = None,
parameters: dict = None,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
override_projection: Collection[str] = None,
) -> Sequence[dict]:
"""
@@ -467,57 +452,43 @@ class GetMixin(PropsMixin):
:param query: Query object (mongoengine.Q)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
a query. The resulting query is AND'ed with the `query` parameter (if provided).
:param query_options: query parameters options (see ParametersOptions)
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
"""
if not query:
raise ValueError("query or call_data must be provided")
parameters = parameters or {}
search_text = parameters.get("search_text")
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
only = cls.get_projection(parameters, override_projection)
query_sets = []
order_by = parameters.get(cls._ordering_key)
query_sets = [cls.objects(query)]
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
if not search_text and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
order_field = first(
field for field in order_by if not field.startswith("$")
)
if (
order_field
and not order_field.startswith("-")
and (not query_dict or order_field not in query_dict)
and "[" not in order_field
):
empty_value = None
if order_field in query_options.list_fields:
empty_value = []
elif order_field in query_options.pattern_fields:
empty_value = ""
params = {}
mongo_field = order_field.replace(".", "__")
non_empty = query & field_exists(mongo_field, empty_value=empty_value)
empty = query & field_does_not_exist(
mongo_field, empty_value=empty_value
)
if mongo_field in cls.get_field_names_for_type(of_type=ListField):
params["is_list"] = True
elif mongo_field in cls.get_field_names_for_type(of_type=StringField):
params["empty_value"] = ""
non_empty = query & field_exists(mongo_field, **params)
empty = query & field_does_not_exist(mongo_field, **params)
query_sets = [cls.objects(non_empty), cls.objects(empty)]
if not query_sets:
query_sets = [cls.objects(query)]
query_sets = [qs.order_by(*order_by) for qs in query_sets]
if search_text:
query_sets = [qs.search_text(search_text) for qs in query_sets]
if order_by:
# add ordering
query_sets = [qs.order_by(*order_by) for qs in query_sets]
only = cls.get_projection(parameters, override_projection)
if only:
# add projection
query_sets = [qs.only(*only) for qs in query_sets]
@@ -583,8 +554,8 @@ class UpdateMixin(object):
def user_set_allowed(cls):
res = getattr(cls, "__user_set_allowed_fields", None)
if res is None:
res = cls.__user_set_allowed_fields = dict(
get_fields_with_attr(cls, "user_set_allowed")
res = cls.__user_set_allowed_fields = get_fields_choices(
cls, "user_set_allowed"
)
return res
@@ -622,7 +593,24 @@ class UpdateMixin(object):
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
pass
@classmethod
def aggregate(
cls: Document, *pipeline: dict, allow_disk_use=None, **kwargs
) -> CommandCursor:
"""
Aggregate objects of this document class according to the provided pipeline.
:param pipeline: a list of dictionaries describing the pipeline stages
:param allow_disk_use: if True, allow the server to use disk space if aggregation query cannot fit in memory.
If None, default behavior will be used (see apiserver.conf/mongo/aggregate/allow_disk_use)
:param kwargs: additional keyword arguments passed to mongoengine
:return:
"""
kwargs.update(
allowDiskUse=allow_disk_use
if allow_disk_use is not None
else config.get("apiserver.mongo.aggregate.allow_disk_use", True)
)
return cls.objects.aggregate(*pipeline, **kwargs)
def validate_id(cls, company, **kwargs):

View File

@@ -1,23 +1,36 @@
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q
from mongoengine import (
Document,
EmbeddedDocument,
EmbeddedDocumentField,
StringField,
Q,
BooleanField,
DateTimeField,
)
from database import Database, strict
from database.fields import StrippedStringField
from database.model import DbModelMixin
class ReportStatsOption(EmbeddedDocument):
enabled = BooleanField(default=False) # opt-in for statistics reporting
enabled_version = StringField() # server version when enabled
enabled_time = DateTimeField() # time when enabled
enabled_user = StringField() # ID of user who enabled
class CompanyDefaults(EmbeddedDocument):
cluster = StringField()
stats_option = EmbeddedDocumentField(ReportStatsOption, default=ReportStatsOption)
class Company(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
}
meta = {"db_alias": Database.backend, "strict": strict}
id = StringField(primary_key=True)
name = StrippedStringField(unique=True, min_length=3)
defaults = EmbeddedDocumentField(CompanyDefaults)
defaults = EmbeddedDocumentField(CompanyDefaults, default=CompanyDefaults)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):

View File

@@ -0,0 +1,47 @@
from mongoengine import (
Document,
EmbeddedDocument,
StringField,
DateTimeField,
EmbeddedDocumentListField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField
from database.model import DbModelMixin
from database.model.base import ProperDictMixin, GetMixin
from database.model.company import Company
from database.model.task.task import Task
class Entry(EmbeddedDocument, ProperDictMixin):
""" Entry representing a task waiting in the queue """
task = StringField(required=True, reference_field=Task)
''' Task ID '''
added = DateTimeField(required=True)
''' Added to the queue '''
class Queue(DbModelMixin, Document):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name",),
list_fields=("tags", "system_tags", "id"),
)
meta = {
'db_alias': Database.backend,
'strict': strict,
}
id = StringField(primary_key=True)
name = StrippedStringField(
required=True, unique_with="company", min_length=3, user_set_allowed=True
)
company = StringField(required=True, reference_field=Company)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list, user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
entries = EmbeddedDocumentListField(Entry, default=list)
last_update = DateTimeField()

View File

@@ -0,0 +1,57 @@
from typing import Any, Optional, Sequence, Tuple
from mongoengine import Document, StringField, DynamicField, Q
from mongoengine.errors import NotUniqueError
from database import Database, strict
from database.model import DbModelMixin
class Settings(DbModelMixin, Document):
meta = {
"db_alias": Database.backend,
"strict": strict,
}
key = StringField(primary_key=True)
value = DynamicField()
@classmethod
def get_by_key(cls, key: str, default: Optional[Any] = None, sep: str = ".") -> Any:
key = key.strip(sep)
res = Settings.objects(key=key).first()
if not res:
return default
return res.value
@classmethod
def get_by_prefix(
cls, key_prefix: str, default: Optional[Any] = None, sep: str = "."
) -> Sequence[Tuple[str, Any]]:
key_prefix = key_prefix.strip(sep)
query = Q(key=key_prefix) | Q(key__startswith=key_prefix + sep)
res = Settings.objects(query)
if not res:
return default
return [(x.key, x.value) for x in res]
@classmethod
def set_or_add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
""" Sets a new value or adds a new key/value setting (if key does not exist) """
key = key.strip(sep)
res = Settings.objects(key=key).update(key=key, value=value, upsert=True)
# if Settings.objects(key=key).only("key"):
#
# else:
# res = Settings(key=key, value=value).save()
return bool(res)
@classmethod
def add_value(cls, key: str, value: Any, sep: str = ".") -> bool:
""" Adds a new key/value settings. Fails if key already exists. """
key = key.strip(sep)
try:
res = Settings(key=key, value=value).save(force_insert=True)
return bool(res)
except NotUniqueError:
return False

View File

@@ -29,6 +29,7 @@ DEFAULT_LAST_ITERATION = 0
class TaskStatus(object):
created = "created"
queued = "queued"
in_progress = "in_progress"
stopped = "stopped"
publishing = "publishing"
@@ -85,7 +86,7 @@ class Execution(EmbeddedDocument):
model_labels = ModelLabels()
framework = StringField()
artifacts = EmbeddedDocumentSortedListField(Artifact)
docker_cmd = StringField()
queue = StringField()
""" Queue ID where task was queued """
@@ -150,6 +151,8 @@ class Task(AttributedDocument):
tags = ListField(StringField(required=True), user_set_allowed=True)
system_tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
last_worker = StringField()
last_worker_report = DateTimeField()
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))

View File

@@ -1,17 +1,19 @@
from collections import OrderedDict
from collections import OrderedDict, defaultdict
from itertools import chain
from operator import attrgetter
from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document
from mongoengine.base import get_document, BaseField
from database.fields import (
LengthRangeEmbeddedDocumentListField,
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
)
from database.utils import get_fields, get_fields_and_attr
from database.utils import get_fields, get_fields_attr
class PropsMixin(object):
@@ -19,6 +21,7 @@ class PropsMixin(object):
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_field_names_per_type = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
@@ -29,6 +32,39 @@ class PropsMixin(object):
cls.__cached_fields = get_fields(cls)
return cls.__cached_fields
@classmethod
def get_field_names_for_type(cls, of_type=BaseField):
"""
Return field names per type including subfields
The fields of derived types are also returned
"""
assert issubclass(of_type, BaseField)
if cls.__cached_field_names_per_type is None:
fields = defaultdict(list)
for name, field in get_fields(cls, return_instance=True, subfields=True):
fields[type(field)].append(name)
for type_ in fields:
fields[type_].extend(
chain.from_iterable(
fields[other_type]
for other_type in fields
if other_type != type_ and issubclass(other_type, type_)
)
)
cls.__cached_field_names_per_type = fields
if of_type not in cls.__cached_field_names_per_type:
names = list(
chain.from_iterable(
field_names
for type_, field_names in cls.__cached_field_names_per_type.items()
if issubclass(type_, of_type)
)
)
cls.__cached_field_names_per_type[of_type] = names
return cls.__cached_field_names_per_type[of_type]
@classmethod
def get_fields_with_instance(cls, doc_cls):
if cls.__cached_fields_with_instance is None:
@@ -42,7 +78,7 @@ class PropsMixin(object):
@staticmethod
def _get_fields_with_attr(cls_, attr):
""" Get all fields with the specified attribute (supports nested fields) """
res = get_fields_and_attr(cls_, attr=attr)
res = get_fields_attr(cls_, attr=attr)
def resolve_doc(v):
if not isinstance(v, six.string_types):
@@ -122,6 +158,14 @@ class PropsMixin(object):
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@classmethod
def get_extra_projection(cls, fields: Sequence) -> tuple:
if isinstance(fields, str):
fields = [fields]
return tuple(
set(fields).union(cls.get_fields()).difference(cls.get_exclude_fields())
)
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
@@ -140,3 +184,18 @@ class PropsMixin(object):
result = separator.join(translated)
cls.__cached_dpath_computed_fields[path] = result
return cls.__cached_dpath_computed_fields[path]
def get_field_value(self, field_path: str, default=None):
"""
Return the document field_path value by the field_path name.
The path may contain '.'. If on any level the path is
not found then the default value is returned
"""
path_elements = field_path.split(".")
current = self
for name in path_elements:
current = getattr(current, name, default)
if current == default:
break
return current

View File

@@ -1,6 +1,6 @@
import hashlib
from inspect import ismethod, getmembers
from typing import Sequence, Tuple, Set, Optional
from typing import Sequence, Tuple, Set, Optional, Callable, Any
from uuid import uuid4
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
@@ -9,64 +9,58 @@ from mongoengine.base import BaseField
from .errors import translate_errors_context, ParseCallError
def get_fields(cls, of_type=BaseField, return_instance=False):
def get_fields(cls, of_type=BaseField, return_instance=False, subfields=False):
return _get_fields(
cls,
of_type=of_type,
subfields=subfields,
selector=lambda k, v: (k, v) if return_instance else k,
)
def get_fields_attr(cls, attr):
""" get field names from a class containing mongoengine fields """
res = []
for cls_ in reversed(cls.mro()):
res.extend(
[
k if not return_instance else (k, v)
for k, v in vars(cls_).items()
if isinstance(v, of_type)
]
)
return res
return dict(
_get_fields(cls, with_attr=attr, selector=lambda k, v: (k, getattr(v, attr)))
)
def get_fields_and_attr(cls, attr):
""" get field names from a class containing mongoengine fields """
res = {}
for cls_ in reversed(cls.mro()):
res.update(
{
k: getattr(v, attr)
for k, v in vars(cls_).items()
if isinstance(v, BaseField) and hasattr(v, attr)
}
)
return res
def get_fields_choices(cls, attr):
def get_choices(field_name: str, field: BaseField) -> Tuple:
if isinstance(field, ListField):
return field_name, field.field.choices
return field_name, field.choices
return dict(_get_fields(cls, with_attr=attr, subfields=True, selector=get_choices))
def _get_field_choices(name, field):
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
obj = field.document_type_obj
n, choices = _get_field_choices(field.name, obj.field)
return "%s__%s" % (name, n), choices
elif issubclass(type(field), ListField):
return name, field.field.choices
return name, field.choices
def get_fields_with_attr(cls, attr, default=False):
def _get_fields(
cls,
with_attr=None,
of_type=BaseField,
subfields=False,
selector: Optional[Callable[[str, BaseField], Any]] = None,
path: Tuple[str, ...] = (),
):
fields = []
for field_name, field in cls._fields.items():
if not getattr(field, attr, default):
continue
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
field_path = path + (field_name,)
if isinstance(field, of_type) and (not with_attr or hasattr(field, with_attr)):
full_name = "__".join(field_path)
fields.append(selector(full_name, field) if selector else full_name)
if subfields and isinstance(field, EmbeddedDocumentField):
fields.extend(
(
("%s__%s" % (field_name, name), choices)
for name, choices in get_fields_with_attr(
field.document_type, attr, default
)
_get_fields(
field.document_type,
with_attr=with_attr,
of_type=of_type,
subfields=subfields,
selector=selector,
path=field_path,
)
)
elif issubclass(type(field), ListField):
fields.append((field_name, field.field.choices))
else:
fields.append((field_name, field.choices))
return fields
@@ -151,17 +145,20 @@ def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
return query
def field_exists(field: str, empty_value=None) -> Q:
def field_exists(field: str, empty_value=None, is_list=False) -> Q:
"""
Creates a query object used for finding a field that exists and is not None or empty.
:param field: Field name
:param empty_value: The empty value to test for (None means no specific empty value will be used).
For lists pass [] for empty_value
:param empty_value: The empty value to test for (None means no specific empty value will be used)
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
the length of the array will be used (len==0 means empty)
:return:
"""
query = Q(**{f"{field}__exists": True}) & Q(
**{f"{field}__nin": {empty_value, None}}
)
if is_list:
query &= Q(**{f"{field}__not__size": 0})
return query
@@ -213,6 +210,7 @@ system_tag_names = {
"model": _names_set("active", "archived"),
"project": _names_set("archived", "public", "default"),
"task": _names_set("active", "archived", "development"),
"queue": _names_set("default"),
}
system_tag_prefixes = {"task": _names_set("annotat")}

View File

@@ -0,0 +1,27 @@
{
"template": "queue_metrics_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"metrics": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": {
"type": "date"
},
"queue": {
"type": "keyword"
},
"average_waiting_time": {
"type": "float"
},
"queue_length": {
"type": "integer"
}
}
}
}
}

View File

@@ -0,0 +1,23 @@
{
"template": "worker_stats_*",
"settings": {
"number_of_shards": 1
},
"mappings": {
"stat": {
"_source": {
"enabled": true
},
"properties": {
"timestamp": { "type": "date" },
"worker": { "type": "keyword" },
"category": { "type": "keyword" },
"metric": { "type": "keyword" },
"variant": { "type": "keyword" },
"value": { "type": "float" },
"unit": { "type": "keyword" },
"task": { "type": "keyword" }
}
}
}
}

View File

@@ -1,20 +1,25 @@
from datetime import datetime
from os import getenv
from boltons.iterutils import first
from elasticsearch import Elasticsearch, Transport
from config import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = ("ELASTIC_SERVICE_HOST", "ELASTIC_SERVICE_SERVICE_HOST")
OVERRIDE_PORT_ENV_KEY = "ELASTIC_SERVICE_PORT"
OVERRIDE_HOST_ENV_KEY = (
"TRAINS_ELASTIC_SERVICE_HOST",
"ELASTIC_SERVICE_HOST",
"ELASTIC_SERVICE_SERVICE_HOST",
)
OVERRIDE_PORT_ENV_KEY = ("TRAINS_ELASTIC_SERVICE_PORT", "ELASTIC_SERVICE_PORT")
OVERRIDE_HOST = next(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)), None)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
log.info(f"Using override elastic host {OVERRIDE_HOST}")
OVERRIDE_PORT = getenv(OVERRIDE_PORT_ENV_KEY)
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override elastic port {OVERRIDE_PORT}")
@@ -25,6 +30,7 @@ class MissingClusterConfiguration(Exception):
"""
Exception when cluster configuration is not found in config files
"""
pass
@@ -32,6 +38,7 @@ class InvalidClusterConfiguration(Exception):
"""
Exception when cluster configuration does not contain required properties
"""
pass
@@ -46,12 +53,14 @@ def connect(cluster_name):
"""
if cluster_name not in _instances:
cluster_config = get_cluster_config(cluster_name)
hosts = cluster_config.get('hosts', None)
hosts = cluster_config.get("hosts", None)
if not hosts:
raise InvalidClusterConfiguration(cluster_name)
args = cluster_config.get('args', {})
_instances[cluster_name] = Elasticsearch(hosts=hosts, transport_class=Transport, **args)
args = cluster_config.get("args", {})
_instances[cluster_name] = Elasticsearch(
hosts=hosts, transport_class=Transport, **args
)
return _instances[cluster_name]
@@ -63,13 +72,13 @@ def get_cluster_config(cluster_name):
:return: config section for the cluster
:raises MissingClusterConfiguration: in case no config section is found for the cluster
"""
cluster_key = '.'.join(('hosts.elastic', cluster_name))
cluster_key = ".".join(("hosts.elastic", cluster_name))
cluster_config = config.get(cluster_key, None)
if not cluster_config:
raise MissingClusterConfiguration(cluster_name)
def set_host_prop(key, value):
for host in cluster_config.get('hosts', []):
for host in cluster_config.get("hosts", []):
host[key] = value
if OVERRIDE_HOST:

View File

@@ -1,6 +1,7 @@
import importlib.util
from datetime import datetime
from pathlib import Path
from uuid import uuid4
import attr
from furl import furl
@@ -8,11 +9,14 @@ from mongoengine.connection import get_db
from semantic_version import Version
import database.utils
from bll.queue import QueueBLL
from config import config
from database import Database
from database.model.auth import Role
from database.model.auth import User as AuthUser, Credentials
from database.model.company import Company
from database.model.queue import Queue
from database.model.settings import Settings
from database.model.user import User
from database.model.version import Version as DatabaseVersion
from elastic.apply_mappings import apply_mappings_to_host
@@ -57,6 +61,18 @@ def _ensure_company():
return company_id
def _ensure_default_queue(company):
"""
If no queue is present for the company then
create a new one and mark it as a default
"""
queue = Queue.objects(company=company).only("id").first()
if queue:
return
QueueBLL.create(company, name="default", system_tags=["default"])
def _ensure_auth_user(user_data, company_id):
ensure_credentials = {"key", "secret"}.issubset(user_data.keys())
if ensure_credentials:
@@ -95,10 +111,7 @@ def _ensure_user(user: FixedUser, company_id: str):
data["email"] = f"{user.user_id}@example.com"
data["role"] = Role.user
_ensure_auth_user(
user_data=data,
company_id=company_id,
)
_ensure_auth_user(user_data=data, company_id=company_id)
given_name, _, family_name = user.name.partition(" ")
@@ -128,9 +141,7 @@ def _apply_migrations():
try:
new_scripts = {
ver: path
for ver, path in (
(Version(f.stem), f) for f in migration_dir.glob("*.py")
)
for ver, path in ((Version(f.stem), f) for f in migration_dir.glob("*.py"))
if ver > last_version
}
except ValueError as ex:
@@ -165,14 +176,30 @@ def _apply_migrations():
).save()
def _ensure_uuid():
Settings.add_value("server.uuid", str(uuid4()))
def init_mongo_data():
try:
_apply_migrations()
_ensure_uuid()
company_id = _ensure_company()
_ensure_default_queue(company_id)
users = [
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"},
{"name": "webserver", "role": Role.system, "email": "webserver@example.com"},
{
"name": "apiserver",
"role": Role.system,
"email": "apiserver@example.com",
},
{
"name": "webserver",
"role": Role.system,
"email": "webserver@example.com",
},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]

195
server/redis_manager.py Normal file
View File

@@ -0,0 +1,195 @@
import threading
from os import getenv
from time import sleep
from boltons.iterutils import first
from redis import StrictRedis
from redis.sentinel import Sentinel, SentinelConnectionPool
from apierrors.errors.server_error import ConfigError, GeneralError
from config import config
log = config.logger(__file__)
OVERRIDE_HOST_ENV_KEY = ("TRAINS_REDIS_SERVICE_HOST", "REDIS_SERVICE_HOST")
OVERRIDE_PORT_ENV_KEY = ("TRAINS_REDIS_SERVICE_PORT", "REDIS_SERVICE_PORT")
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
log.info(f"Using override redis host {OVERRIDE_HOST}")
OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override redis port {OVERRIDE_PORT}")
class MyPubSubWorkerThread(threading.Thread):
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
super(MyPubSubWorkerThread, self).__init__()
self.daemon = daemon
self.sentinel = sentinel
self.on_new_master = on_new_master
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.msg_sleep_time = msg_sleep_time
self._running = False
self.pubsub = None
def subscribe(self):
if self.pubsub:
try:
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
except Exception:
pass
finally:
self.pubsub = None
subscriptions = {"+switch-master": self.on_new_master}
while not self.pubsub or not self.pubsub.subscribed:
try:
self.pubsub = self.sentinel.pubsub()
self.pubsub.subscribe(**subscriptions)
except Exception as ex:
log.warn(
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
)
sleep(3)
log.info(f"Subscribed to sentinel {self.sentinel_host}")
def run(self):
if self._running:
return
self._running = True
self.subscribe()
while self.pubsub.subscribed:
try:
self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
)
except Exception as ex:
log.warn(
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
)
self.subscribe()
self.pubsub.close()
self._running = False
def stop(self):
# stopping simply unsubscribes from all channels and patterns.
# the unsubscribe responses that are generated will short circuit
# the loop in run(), calling pubsub.close() to clean up the connection
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
# todo,future - multi master clusters?
class RedisCluster(object):
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
self.service_name = service_name
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
self.master = None
self.master_host_port = None
self.reconfigure()
self.sentinel_threads = {}
self.listen()
def reconfigure(self):
try:
self.master_host_port = self.sentinel.discover_master(self.service_name)
self.master = self.sentinel.master_for(self.service_name)
log.info(f"Reconfigured master to {self.master_host_port}")
except Exception as ex:
log.error(f"Error while reconfiguring. {ex.args[0]}")
def listen(self):
def on_new_master(workerThread):
self.reconfigure()
for sentinel in self.sentinel.sentinels:
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
)
self.sentinel_threads[sentinel_host].start()
class RedisManager(object):
def __init__(self, redis_config_dict):
self.aliases = {}
for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict()
is_cluster = alias_config.get("cluster", False)
host = OVERRIDE_HOST or alias_config.get("host", None)
if host:
alias_config["host"] = host
port = OVERRIDE_PORT or alias_config.get("port", None)
if port:
alias_config["port"] = port
db = alias_config.get("db", 0)
sentinels = alias_config.get("sentinels", None)
service_name = alias_config.get("service_name", None)
if not is_cluster and sentinels:
raise ConfigError(
"Redis configuration is invalid. mixed regular and cluster mode",
alias=alias,
)
if is_cluster and (not sentinels or not service_name):
raise ConfigError(
"Redis configuration is invalid. missing sentinels or service_name",
alias=alias,
)
if not is_cluster and (not port or not host):
raise ConfigError(
"Redis configuration is invalid. missing port or host", alias=alias
)
if is_cluster:
# todo support all redis connection args via sentinel's connection_kwargs
del alias_config["sentinels"]
del alias_config["cluster"]
del alias_config["service_name"]
self.aliases[alias] = RedisCluster(
sentinels, service_name, **alias_config
)
else:
self.aliases[alias] = StrictRedis(**alias_config)
def connection(self, alias):
obj = self.aliases.get(alias)
if not obj:
raise GeneralError(f"Invalid Redis alias {alias}")
if isinstance(obj, RedisCluster):
obj.master.get("health")
return obj.master
else:
obj.get("health")
return obj
def host(self, alias):
r = self.connection(alias)
pool = r.connection_pool
if isinstance(pool, SentinelConnectionPool):
connections = pool.connection_kwargs[
"connection_pool"
]._available_connections
else:
connections = pool._available_connections
if len(connections) > 0:
return connections[0].host
else:
return None
redman = RedisManager(config.get("hosts.redis"))

View File

@@ -26,3 +26,6 @@ fastjsonschema>=2.8
boltons>=19.1.0
semantic_version>=2.6.0,<3
furl>=2.0.0
redis>=2.10.5
humanfriendly==4.18
psutil>=5.6.5

View File

@@ -276,33 +276,6 @@ revoke_credentials {
}
}
delete_user {
allow_roles = [ "system", "root", "admin" ]
internal: false
"2.1" {
description: """Delete a new user manually. Only supported in on-premises deployments. This only removes the user's auth entry so that any references to the deleted user's ID will still have valid user information"""
request {
type: object
required: [ user ]
properties {
user {
type: string
description: User ID
}
}
}
response {
type: object
properties {
deleted {
description: "True if user was successfully deleted, False otherwise"
type: boolean
}
}
}
}
}
edit_user {
internal: false
allow_roles: ["system", "root", "admin"]

View File

@@ -224,7 +224,7 @@
}
add_batch {
"2.1" {
description: "Adds a batch of events in a single call."
description: "Adds a batch of events in a single call (json-lines format, stream-friendly)"
batch_request: {
action: add
version: 1.5
@@ -251,6 +251,11 @@
type: string
description: "Task ID"
}
allow_locked {
type: boolean
description: "Allow deleting events even if the task is locked"
default: false
}
}
}
response {
@@ -452,6 +457,10 @@
type: integer
description: "Number of events to return each time"
}
event_type {
type: string
description: "Return only events of this type"
}
}
}
response {

View File

@@ -62,7 +62,7 @@
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
framework {
@@ -279,7 +279,7 @@
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
override_model_id {
@@ -346,7 +346,7 @@
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
framework {
@@ -434,7 +434,7 @@
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
framework {
@@ -516,7 +516,7 @@
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
ready {

View File

@@ -49,7 +49,7 @@ _definitions {
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
default_output_destination {
@@ -159,7 +159,7 @@ _definitions {
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
default_output_destination {
@@ -223,7 +223,7 @@ create {
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
default_output_destination {
@@ -393,7 +393,7 @@ update {
}
system_tags {
type: array
description: "System tags. This field is reserved for system use, please dont use it."
description: "System tags. This field is reserved for system use, please don't use it."
items {type: string}
}
default_output_destination {

View File

@@ -0,0 +1,568 @@
{
_description: "Provides a management API for queues of tasks waiting to be executed by workers deployed anywhere (see Workers Service)."
_definitions {
queue_metrics {
type: object
properties: {
queue: {
type: string
description: "ID of the queue"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no queue status change was recorded are omitted."
items { type: integer }
}
avg_waiting_times {
type: array
description: "List of average waiting times for tasks in the queue. The points correspond to the timestamps in the dates list. If more than one value exists for the given interval then the maximum value is taken."
items { type: number }
}
queue_lengths {
type: array
description: "List of tasks counts in the queue. The points correspond to the timestamps in the dates list. If more than one value exists for the given interval then the count that corresponds to the maximum average value is taken."
items { type: integer }
}
}
}
entry {
type: object
properties: {
task {
description: "Queued task ID"
type: string
}
added {
description: "Time this entry was added to the queue"
type: string
format: "date-time"
}
}
}
queue {
type: object
properties: {
id {
description: "Queue id"
type: string
}
name {
description: "Queue name"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Queue creation time"
type: string
format: "date-time"
}
tags {
description: "User-defined tags"
type: array
items { type: string }
}
system_tags {
description: "System tags. This field is reserved for system use, please don't use it."
type: array
items { type: string }
}
entries {
description: "List of ordered queue entries"
type: array
items { "$ref": "#/definitions/entry" }
}
}
}
}
get_by_id {
"2.4" {
description: "Gets queue information"
request {
type: object
required: [ queue ]
properties {
queue {
description: "Queue ID"
type: string
}
}
}
response {
type: object
properties {
queue {
description: "Queue info"
"$ref": "#/definitions/queue"
}
}
}
}
}
// typescript generation hack
get_all_ex {
internal: true
"2.4": ${get_all."2.4"}
}
get_all {
"2.4" {
description: "Get all queues"
request {
type: object
properties {
name {
description: "Get only queues whose name matches this pattern (python regular expression syntax)"
type: string
}
id {
description: "List of Queue IDs used to filter results"
type: array
items { type: string }
}
tags {
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
type: array
items { type: string }
}
system_tags {
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
type: array
items { type: string }
}
page {
description: "Page number, returns a specific page out of the result list of results."
type: integer
minimum: 0
}
page_size {
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
type: integer
minimum: 1
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
search_text {
description: "Free text search query"
type: string
}
only_fields {
description: "List of document field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
queues {
description: "Queues list"
type: array
items { "$ref": "#/definitions/queue"}
}
}
}
}
}
get_default {
"2.4" {
description: ""
request {
type: object
properties {}
additionalProperties: false
}
response {
type: object
properties {
id {
description: "Queue id"
type: string
}
name {
description: "Queue name"
type: string
}
}
}
}
}
create {
"2.4" {
description: "Create a new queue"
request {
type: object
required: [ name ]
properties {
name {
description: "Queue name Unique within the company."
type: string
}
tags {
description: "User-defined tags list"
type: array
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
id {
description: "New queue ID"
type: string
}
}
}
}
}
update {
"2.4" {
description: "Update queue information"
request {
type: object
required: [ queue ]
properties {
queue {
description: "Queue id"
type: string
}
name {
description: "Queue name Unique within the company."
type: string
}
tags {
description: "User-defined tags list"
type: array
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
updated {
description: "Number of queues updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
delete {
"2.4" {
description: "Deletes a queue. If the queue is not empty and force is not set to true, queue will not be deleted."
request {
type: object
required: [ queue ]
properties {
queue {
description: "Queue id"
type: string
}
force {
description: "Force delete of non-empty queue. Defaults to false"
type: boolean
default: false
}
}
}
response {
type: object
properties {
deleted {
description: "Number of queues deleted (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
}
}
}
}
add_task {
"2.4" {
description: "Adds a task entry to the queue."
request {
type: object
required: [
queue
task
]
properties {
queue {
description: "Queue id"
type: string
}
task {
description: "Task id"
type: string
}
}
}
response {
type: object
properties {
added {
description: "Number of tasks added (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
}
}
}
}
get_next_task {
"2.4" {
description: "Gets the next task from the top of the queue (FIFO). The task entry is removed from the queue."
request {
type: object
required: [ queue ]
properties {
queue {
description: "Queue id"
type: string
}
}
}
response {
type: object
properties {
entry {
description: "Entry information"
"$ref": "#/definitions/entry"
}
}
}
}
}
remove_task {
"2.4" {
description: "Removes a task entry from the queue."
request {
type: object
required: [
queue
task
]
properties {
queue {
description: "Queue id"
type: string
}
task {
description: "Task id"
type: string
}
}
}
response {
type: object
properties {
removed {
description: "Number of tasks removed (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
}
}
}
}
move_task_forward: {
"2.4" {
description: "Moves a task entry one step forward towards the top of the queue."
request {
type: object
required: [
queue
task
]
properties {
queue {
description: "Queue id"
type: string
}
task {
description: "Task id"
type: string
}
count {
description: "Number of positions in the queue to move the task forward relative to the current position. Optional, the default value is 1."
type: integer
}
}
}
response {
type: object
properties {
position {
description: "The new position of the task entry in the queue (index, -1 represents bottom of queue)"
type: integer
}
}
}
}
}
move_task_backward: {
"2.4" {
description: ""
request {
type: object
required: [
queue
task
]
properties {
queue {
description: "Queue id"
type: string
}
task {
description: "Task id"
type: string
}
count {
description: "Number of positions in the queue to move the task forward relative to the current position. Optional, the default value is 1."
type: integer
}
}
}
response {
type: object
properties {
position {
description: "The new position of the task entry in the queue (index, -1 represents bottom of queue)"
type: integer
}
}
}
}
}
move_task_to_front: {
"2.4" {
description: ""
request {
type: object
required: [
queue
task
]
properties {
queue {
description: "Queue id"
type: string
}
task {
description: "Task id"
type: string
}
}
}
response {
type: object
properties {
position {
description: "The new position of the task entry in the queue (index, -1 represents bottom of queue)"
type: integer
}
}
}
}
}
move_task_to_back: {
"2.4" {
description: ""
request {
type: object
required: [
queue
task
]
properties {
queue {
description: "Queue id"
type: string
}
task {
description: "Task id"
type: string
}
}
}
response {
type: object
properties {
position {
description: "The new position of the task entry in the queue (index, -1 represents bottom of queue)"
type: integer
}
}
}
}
}
get_queue_metrics : {
"2.4" {
description: "Returns metrics of the company queues. The metrics are avaraged in the specified interval."
request {
type: object
required: [from_date, to_date, interval]
properties: {
from_date {
description: "Starting time (in seconds from epoch) for collecting metrics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting metrics"
type: number
}
interval {
description: "Time interval in seconds for a single metrics point. The minimal value is 1"
type: integer
}
queue_ids {
description: "List of queue ids to collect metrics for. If not provided or empty then all then average metrics across all the company queues will be returned."
type: array
items { type: string }
}
}
}
response {
type: object
properties: {
queues {
type: array
description: "List of the requested queues with their metrics. If no queue ids were requested then 'all' queue is returned with the metrics averaged accross all the company queues."
items { "$ref": "#/definitions/queue_metrics" }
}
}
}
}
}
}

View File

@@ -3,6 +3,25 @@ _default {
internal: true
allow_roles: ["root", "system"]
}
get_stats {
"2.1" {
description: "Get the server collected statistics."
request {
type: object
properties {
interval {
description: "The period for statistics collection in seconds."
type: long
}
}
}
response {
type: object
properties: {
}
}
}
}
config {
"2.1" {
description: "Get server configuration. Secure section is not returned."
@@ -66,3 +85,39 @@ endpoints {
}
}
}
report_stats_option {
"2.4" {
description: "Get or set the report statistics option per-company"
request {
type: object
properties {
enabled {
description: "If provided, sets the report statistics option (true/false)"
type: boolean
}
}
}
response {
type: object
properties {
enabled {
description: "Returns the current report stats option value"
type: boolean
}
enabled_time {
description: "If enabled, returns the time at which option was enabled"
type: string
format: date-time
}
enabled_version {
description: "If enabled, returns the server version at the time option was enabled"
type: string
}
enabled_user {
description: "If enabled, returns Id of the user who enabled the option"
type: string
}
}
}
}
}

View File

@@ -193,6 +193,10 @@ _definitions {
execution {
type: object
properties {
queue {
description: "Queue ID where task was queued."
type: string
}
parameters {
description: "Json object containing the Task parameters"
type: object
@@ -219,6 +223,10 @@ _definitions {
description: """Framework related to the task. Case insensitive. Mandatory for Training tasks. """
type: string
}
docker_cmd {
description: "Command for running docker script for the execution of the task"
type: string
}
artifacts {
description: "Task artifacts"
type: array
@@ -230,6 +238,7 @@ _definitions {
type: string
enum: [
created
queued
in_progress
stopped
published
@@ -348,14 +357,14 @@ _definitions {
"$ref": "#/definitions/script"
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
items { type: string }
}
status_changed {
description: "Last status change time"
@@ -375,6 +384,15 @@ _definitions {
type: string
format: "date-time"
}
last_worker {
description: "ID of last worker that handled the task"
type: string
}
last_worker_report {
description: "Last time a worker reported while working on this task"
type: string
format: "date-time"
}
last_update {
description: "Last time this task was created, updated, changed or events for this task were reported"
type: string
@@ -467,12 +485,12 @@ get_all {
items { type: string }
}
tags {
description: "User-defined tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
description: "List of task user-defined tags. Use '-' prefix to exclude tags"
type: array
items { type: string }
}
system_tags {
description: "System tags list used to filter results. Prepend '-' to system tag name to indicate exclusion"
description: "List of task system tags. Use '-' prefix to exclude system tags"
type: array
items { type: string }
}
@@ -547,14 +565,14 @@ create {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
items { type: string }
}
type {
description: "Type of task"
@@ -612,14 +630,14 @@ validate {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
items { type: string }
}
type {
description: "Type of task"
@@ -675,14 +693,14 @@ update {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
items { type: string }
}
comment {
description: "Free text comment "
@@ -762,14 +780,14 @@ edit {
type: string
}
tags {
description: "User-defined tags list"
type: array
description: "User-defined tags"
items { type: string }
}
system_tags {
description: "System tags list. This field is reserved for system use, please don't use it."
type: array
description: "System tags. This field is reserved for system use, please dont use it."
items {type: string}
items { type: string }
}
type {
description: "Type of task"
@@ -838,6 +856,11 @@ reset {
type: array
items { type: string }
}
dequeued {
description: "Response from queues.remove_task"
type: object
additionalProperties: true
}
frames {
description: "Response from frames.rollback"
type: object
@@ -1114,7 +1137,85 @@ publish {
}
}
}
enqueue {
"1.5" {
description: """Adds a task into a queue.
Fails if task state is not 'created'.
Fails if the following parameters in the task were not filled:
* execution.script.repository
* execution.script.entrypoint
"""
request = {
type: object
required: [
task
]
properties {
queue {
description: "Queue id. If not provided, task is added to the default queue."
type: string
}
}
} ${_references.status_change_request}
response {
type: object
properties {
queued {
description: "Number of tasks queued (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
dequeue {
"1.5" {
description: """Remove a task from its queue.
Fails if task status is not queued."""
request = {
type: object
required: [
task
]
} ${_references.status_change_request}
response {
type: object
properties {
dequeued {
description: "Number of tasks dequeued (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
set_requirements {
"2.1" {
description: """Set the script requirements for a task"""

View File

@@ -0,0 +1,483 @@
{
_description: "Provides an API for worker machines, allowing workers to report status and get tasks for execution"
_definitions {
metrics_category {
type: object
properties {
name {
type: string
description: "Name of the metrics category."
}
metric_keys {
type: array
items { type: string }
description: "The names of the metrics in the category."
}
}
}
aggregation_type {
type: string
enum: [ avg, min, max ]
description: "Metric aggregation type"
}
stat_item {
type: object
properties {
key {
type: string
description: "Name of a metric"
}
category {
"$ref": "#/definitions/aggregation_type"
}
}
}
aggregation_stats {
type: object
properties {
aggregation {
"$ref": "#/definitions/aggregation_type"
}
values {
type: array
description: "List of values corresponding to the dates in metric statistics"
items { type: number }
}
}
}
metric_stats {
type: object
properties {
metric {
type: string
description: "Name of the metric ("cpu_usage", "memory_used" etc.)"
}
variant {
type: string
description: "Name of the metric component. Set only if 'split_by_variant' was set in the request"
}
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval. Timestamps where no workers activity was recorded are omitted."
items { type: integer }
}
stats {
type: array
description: "Statistics data by type"
items { "$ref": "#/definitions/aggregation_stats" }
}
}
}
worker_stats {
type: object
properties {
worker {
type: string
description: "ID of the worker"
}
metrics {
type: array
description: "List of the metrics statistics for the worker"
items { "$ref": "#/definitions/metric_stats" }
}
}
}
activity_series {
type: object
properties {
dates {
type: array
description: "List of timestamps (in seconds from epoch) in the acceding order. The timestamps are separated by the requested interval."
items {type: integer}
}
counts {
type: array
description: "List of worker counts corresponding to the timestamps in the dates list. None values are returned for the dates with no workers."
items {type: integer}
}
}
}
worker {
type: object
properties {
id {
description: "Worker ID"
type: string
}
user {
description: "Associated user (under whose credentials are used by the worker daemon)"
"$ref": "#/definitions/id_name_entry"
}
company {
description: "Associated company"
"$ref": "#/definitions/id_name_entry"
}
ip {
description: "IP of the worker"
type: string
}
register_time {
description: "Registration time"
type: string
format: "date-time"
}
last_activity_time {
description: "Last activity time (even if an error occurred)"
type: string
format: "date-time"
}
last_report_time {
description: "Last successful report time"
type: string
format: "date-time"
}
task {
description: "Task currently being run by the worker"
"$ref": "#/definitions/current_task_entry"
}
queue {
description: "Queue from which running task was taken"
"$ref": "#/definitions/queue_entry"
}
queues {
description: "List of queues on which the worker is listening"
type: array
items { "$ref": "#/definitions/queue_entry" }
}
}
}
id_name_entry {
type: object
properties {
id {
description: "Worker ID"
type: string
}
name {
description: "Worker name"
type: string
}
}
}
current_task_entry = ${_definitions.id_name_entry} {
properties {
running_time {
description: "Task running time"
type: integer
}
last_iteration {
description: "Last task iteration"
type: integer
}
}
}
queue_entry = ${_definitions.id_name_entry} {
properties {
next_task {
description: "Next task in the queue"
"$ref": "#/definitions/id_name_entry"
}
num_tasks {
description: "Number of task entries in the queue"
type: integer
}
}
}
machine_stats {
type: object
properties {
cpu_usage {
description: "Average CPU usage per core"
type: array
items { type: number }
}
gpu_usage {
description: "Average GPU usage per GPU card"
type: array
items { type: number }
}
memory_used {
description: "Used memory MBs"
type: integer
}
memory_free {
description: "Free memory MBs"
type: integer
}
gpu_memory_free {
description: "GPU free memory MBs"
type: array
items { type: integer }
}
gpu_memory_used {
description: "GPU used memory MBs"
type: array
items { type: integer }
}
network_tx {
description: "Mbytes per second"
type: integer
}
network_rx {
description: "Mbytes per second"
type: integer
}
disk_free_home {
description: "Mbytes free space of /home drive"
type: integer
}
disk_free_temp {
description: "Mbytes free space of /tmp drive"
type: integer
}
disk_read {
description: "Mbytes read per second"
type: integer
}
disk_write {
description: "Mbytes write per second"
type: integer
}
cpu_temperature {
description: "CPU temperature"
type: array
items { type: number }
}
gpu_temperature {
description: "GPU temperature"
type: array
items { type: number }
}
}
}
}
get_all {
"2.4" {
description: "Returns information on all registered workers."
request {
type: object
properties {
last_seen {
description: """Filter out workers not active for more than last_seen seconds.
A value or 0 or 'none' will disable the filter."""
type: integer
default: 3600
}
}
}
response {
type: object
properties {
workers {
type: array
items { "$ref": "#/definitions/worker" }
}
}
}
}
}
register {
"2.4" {
description: "Register a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
description: "Worker id. Must be unique in company."
type: string
}
timeout {
description: "Registration timeout in seconds. If timeout seconds have passed since the worker's last call to register or status_report, the worker is automatically removed from the list of registered workers."
type: integer
default: 600
}
queues {
description: "List of queue IDs on which the worker is listening."
type: array
items { type: string }
}
}
}
response {
type: object
properties {}
}
}
}
unregister {
"2.4" {
description: "Unregister a worker in the system. Called by the Worker Daemon."
request {
required: [ worker ]
type: object
properties {
worker {
description: "Worker id. Must be unique in company."
type: string
}
}
}
response {
type: object
properties {}
}
}
}
status_report {
"2.4" {
description: "Called periodically by the worker daemon to report machine status"
request {
required: [
worker
timestamp
]
type: object
properties {
worker {
description: "Worker id."
type: string
}
task {
description: "ID of a task currently being run by the worker. If no task is sent, the worker's task field will be cleared."
type: string
}
queue {
description: "ID of the queue from which task was received. If no queue is sent, the worker's queue field will be cleared."
type: string
}
queues {
description: "List of queue IDs on which the worker is listening. If null, the worker's queues list will not be updated."
type: array
items { type: string }
}
timestamp {
description: "UNIX time in seconds since epoch."
type: integer
}
machine_stats {
description: "The machine statistics."
"$ref": "#/definitions/machine_stats"
}
}
}
response {
type: object
properties {}
}
}
}
get_metric_keys {
"2.4" {
description: "Returns worker statistics metric keys grouped by categories."
request {
type: object
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: string }
}
}
}
response {
type: object
properties {
categories {
type: array
description: "List of unique metric categories found in the statistics of the requested workers."
items { "$ref": "#/definitions/metrics_category" }
}
}
}
}
}
get_stats {
"2.4" {
description: "Returns statistics for the selected workers and time range aggregated by date intervals."
request {
type: object
required: [ from_date, to_date, interval, items ]
properties {
worker_ids {
description: "List of worker ids to collect metrics for. If not provided or empty then all the company workers metrics are analyzed."
type: array
items { type: string }
}
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
items {
description: "List of metric keys and requested statistics"
type: array
items { "$ref": "#/definitions/stat_item" }
}
split_by_variant {
description: "If true then break statistics by hardware sub types"
type: boolean
default: false
}
}
}
response {
type: object
properties {
workers {
type: array
description: "List of the requested workers with their statistics"
items { "$ref": "#/definitions/worker_stats" }
}
}
}
}
}
get_activity_report {
"2.4" {
description: "Returns count of active company workers in the selected time range."
request {
type: object
required: [ from_date, to_date, interval ]
properties {
from_date {
description: "Starting time (in seconds from epoch) for collecting statistics"
type: number
}
to_date {
description: "Ending time (in seconds from epoch) for collecting statistics"
type: number
}
interval {
description: "Time interval in seconds for a single statistics point. The minimal value is 1"
type: integer
}
}
}
response {
type: object
properties {
total {
description: "Activity series that include all the workers that sent reports in the given time interval."
"$ref": "#/definitions/activity_series"
}
active {
description: "Activity series that include only workers that worked on a task in the given time interval."
"$ref": "#/definitions/activity_series"
}
}
}
}
}
}

View File

@@ -7,13 +7,15 @@ from werkzeug.exceptions import BadRequest
import database
from apierrors.base import BaseError
from bll.statistics.stats_reporter import StatisticsReporter
from config import config
from init_data import init_es_data, init_mongo_data
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
from timing_context import TimingContext
from updates import check_updates_thread
from utilities import json
from init_data import init_es_data, init_mongo_data
app = Flask(__name__, static_url_path="/static")
CORS(app, **config.get("apiserver.cors"))
@@ -35,6 +37,10 @@ ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
check_updates_thread.start()
StatisticsReporter.start()
@app.before_first_request
def before_app_first_request():
pass
@@ -52,7 +58,9 @@ def before_request():
content, content_type = ServiceRepo.handle_call(call)
headers = {}
if call.result.filename:
headers["Content-Disposition"] = f"attachment; filename={call.result.filename}"
headers[
"Content-Disposition"
] = f"attachment; filename={call.result.filename}"
if call.result.headers:
headers.update(call.result.headers)
@@ -66,7 +74,9 @@ def before_request():
if value is None:
response.set_cookie(key, "", expires=0)
else:
response.set_cookie(key, value, **config.get("apiserver.auth.cookies"))
response.set_cookie(
key, value, **config.get("apiserver.auth.cookies")
)
return response
except Exception as ex:

View File

@@ -7,8 +7,7 @@ from jsonmodels import models
from six import string_types
import database
import timing_context
from timing_context import TimingContext
from timing_context import TimingContext, TimingStats
from utilities import json
from .auth import Identity
from .auth import Payload as AuthPayload
@@ -256,6 +255,7 @@ class MissingIdentity(Exception):
class APICall(DataContainer):
HEADER_AUTHORIZATION = "Authorization"
HEADER_REAL_IP = "X-Real-IP"
HEADER_FORWARDED_FOR = "X-Forwarded-For"
""" Standard headers """
_transaction_headers = ("X-Trains-Trx",)
@@ -306,8 +306,6 @@ class APICall(DataContainer):
):
super(APICall, self).__init__(data=data, batched_data=batched_data)
timing_context.clear()
self._id = database.utils.id()
self._files = files # currently dic of key to flask's FileStorage)
self._start_ts = time.time()
@@ -385,8 +383,13 @@ class APICall(DataContainer):
@property
def real_ip(self):
real_ip = self.get_header(self.HEADER_REAL_IP)
return real_ip or self._remote_addr or "untrackable"
""" Obtain visitor's IP address """
return (
self.get_header(self.HEADER_FORWARDED_FOR)
or self.get_header(self.HEADER_REAL_IP)
or self._remote_addr
or "untrackable"
)
@property
def failed(self):
@@ -508,7 +511,7 @@ class APICall(DataContainer):
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
self.stats = timing_context.stats()
self.stats = TimingStats.aggregate()
def get_response(self):
def make_version_number(version):

View File

@@ -62,6 +62,8 @@ def authorize_credentials(auth_data, service, action, call_data_items):
query = Q(credentials__match=Credentials(key=access_key, secret=secret_key))
fixed_user = None
if FixedUser.enabled():
fixed_user = FixedUser.get_by_username(access_key)
if fixed_user:
@@ -74,7 +76,7 @@ def authorize_credentials(auth_data, service, action, call_data_items):
if not user:
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
if not FixedUser.enabled():
if not fixed_user:
# In case these are proper credentials, update last used time
User.objects(id=user.id, credentials__key=access_key).update(
**{"set__credentials__$__last_used": datetime.utcnow()}

View File

@@ -19,7 +19,7 @@ class FixedUser:
self.user_id = hashlib.md5(f"{self.username}:{self.password}".encode()).hexdigest()
@classmethod
def enabled(self):
def enabled(cls):
return config.get("apiserver.auth.fixed_users.enabled", False)
@classmethod

View File

@@ -34,7 +34,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.3")
_max_version = PartialVersion("2.4")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (

View File

@@ -20,18 +20,18 @@ event_bll = EventBLL()
@endpoint("events.add")
def add(call, company_id, req_model):
assert isinstance(call, APICall)
def add(call: APICall, company_id, req_model):
data = call.data.copy()
allow_locked = data.pop("allow_locked", False)
added, batch_errors = event_bll.add_events(
company_id, [call.data.copy()], call.worker
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=len(batch_errors))
call.kpis["events"] = 1
@endpoint("events.add_batch")
def add_batch(call, company_id, req_model):
assert isinstance(call, APICall)
def add_batch(call: APICall, company_id, req_model):
events = call.batched_data
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
@@ -209,8 +209,9 @@ def vector_metrics_iter_histogram(call, company_id, req_model):
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, req_model):
def get_task_events(call, company_id, _):
task_id = call.data["task"]
batch_size = call.data.get("batch_size")
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
@@ -222,6 +223,7 @@ def get_task_events(call, company_id, req_model):
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
scroll_id=scroll_id,
size=batch_size,
)
call.result.data = dict(
@@ -302,7 +304,11 @@ def multi_task_scalar_metrics_iter_histogram(
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
metrics=event_bll.metrics.compare_scalar_metrics_average_per_iter(
company_id, task_ids=task_ids, samples=req_model.samples, allow_public=True, key=req_model.key
company_id,
task_ids=task_ids,
samples=req_model.samples,
allow_public=True,
key=req_model.key,
)
)
@@ -388,7 +394,7 @@ def get_task_plots_v1_7(call, company_id, req_model):
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company_id, task_id,
# company, task_id,
# event_type="plot",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
@@ -423,7 +429,6 @@ def get_task_plots(call, company_id, req_model):
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
result = event_bll.get_task_plots(
company_id,
tasks=[task_id],
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
@@ -448,7 +453,7 @@ def get_debug_images_v1_7(call, company_id, req_model):
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company_id, task_id,
# company, task_id,
# event_type="training_debug_image",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
@@ -505,9 +510,14 @@ def get_debug_images(call, company_id, req_model):
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]
allow_locked = call.data.get("allow_locked", False)
task_bll.assert_exists(company_id, task_id)
call.result.data = dict(deleted=event_bll.delete_task_events(company_id, task_id))
call.result.data = dict(
deleted=event_bll.delete_task_events(
company_id, task_id, allow_locked=allow_locked
)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):

View File

@@ -277,6 +277,25 @@ def prepare_update_fields(call, fields):
if "task" in fields:
validate_task(call, fields)
if "labels" in fields:
labels = fields["labels"]
def find_other_types(iterable, type_):
res = [x for x in iterable if not isinstance(x, type_)]
try:
return set(res)
except TypeError:
# Un-hashable, probably
return res
invalid_keys = find_other_types(labels.keys(), str)
if invalid_keys:
raise errors.bad_request.ValidationError("labels keys must be strings", keys=invalid_keys)
invalid_values = find_other_types(labels.values(), int)
if invalid_values:
raise errors.bad_request.ValidationError("labels values must be integers", values=invalid_values)
conform_tag_fields(call, fields)
return fields

View File

@@ -202,7 +202,7 @@ def get_all_ex(call: APICall):
status_count = defaultdict(lambda: {})
key = itemgetter(EntityVisibility.archived.value)
for result in Task.objects.aggregate(*status_count_pipeline):
for result in Task.aggregate(*status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
EntityVisibility.archived if k else EntityVisibility.active
@@ -216,7 +216,7 @@ def get_all_ex(call: APICall):
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.objects.aggregate(*runtime_pipeline)
for result in Task.aggregate(*runtime_pipeline)
}
def safe_get(obj, path, default=None):

218
server/services/queues.py Normal file
View File

@@ -0,0 +1,218 @@
from apimodels.base import UpdateResponse
from apimodels.queues import (
GetDefaultResp,
CreateRequest,
DeleteRequest,
UpdateRequest,
MoveTaskRequest,
MoveTaskResponse,
TaskRequest,
QueueRequest,
GetMetricsRequest,
GetMetricsResponse,
QueueMetrics,
)
from bll.queue import QueueBLL
from bll.util import extract_properties_to_lists
from bll.workers import WorkerBLL
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags, conform_tags
worker_bll = WorkerBLL()
queue_bll = QueueBLL(worker_bll)
@endpoint("queues.get_by_id", min_version="2.4", request_data_model=QueueRequest)
def get_by_id(call: APICall, company_id, req_model: QueueRequest):
queue = queue_bll.get_by_id(company_id, req_model.queue)
queue_dict = queue.to_proper_dict()
conform_output_tags(call, queue_dict)
call.result.data = {"queue": queue_dict}
@endpoint("queues.get_default", min_version="2.4", response_data_model=GetDefaultResp)
def get_by_id(call: APICall):
queue = queue_bll.get_default(call.identity.company)
call.result.data_model = GetDefaultResp(id=queue.id, name=queue.name)
@endpoint("queues.get_all_ex", min_version="2.4")
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
queues = queue_bll.get_queue_infos(
company_id=call.identity.company, query_dict=call.data
)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
@endpoint("queues.get_all", min_version="2.4")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
def create(call: APICall, company_id, request: CreateRequest):
tags, system_tags = conform_tags(call, request.tags, request.system_tags)
queue = queue_bll.create(
company_id=company_id, name=request.name, tags=tags, system_tags=system_tags
)
call.result.data = {"id": queue.id}
@endpoint(
"queues.update",
min_version="2.4",
request_data_model=UpdateRequest,
response_data_model=UpdateResponse,
)
def update(call: APICall, company_id, req_model: UpdateRequest):
data = call.data_model_for_partial_update
conform_tag_fields(call, data)
updated, fields = queue_bll.update(
company_id=company_id, queue_id=req_model.queue, **data
)
conform_output_tags(call, fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("queues.delete", min_version="2.4", request_data_model=DeleteRequest)
def delete(call: APICall, company_id, req_model: DeleteRequest):
queue_bll.delete(
company_id=company_id, queue_id=req_model.queue, force=req_model.force
)
call.result.data = {"deleted": 1}
@endpoint("queues.add_task", min_version="2.4", request_data_model=TaskRequest)
def add_task(call: APICall, company_id, req_model: TaskRequest):
call.result.data = {
"added": queue_bll.add_task(
company_id=company_id, queue_id=req_model.queue, task_id=req_model.task
)
}
@endpoint("queues.get_next_task", min_version="2.4", request_data_model=QueueRequest)
def get_next_task(call: APICall, company_id, req_model: QueueRequest):
task = queue_bll.get_next_task(company_id=company_id, queue_id=req_model.queue)
if task:
call.result.data = {"entry": task.to_proper_dict()}
@endpoint("queues.remove_task", min_version="2.4", request_data_model=TaskRequest)
def remove_task(call: APICall, company_id, req_model: TaskRequest):
call.result.data = {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=req_model.queue, task_id=req_model.task
)
}
@endpoint(
"queues.move_task_forward",
min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_forward(call: APICall, company_id, req_model: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: max(0, p - req_model.count),
)
)
@endpoint(
"queues.move_task_backward",
min_version="2.4",
request_data_model=MoveTaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_backward(call: APICall, company_id, req_model: MoveTaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: max(0, p + req_model.count),
)
)
@endpoint(
"queues.move_task_to_front",
min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_to_front(call: APICall, company_id, req_model: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: 0,
)
)
@endpoint(
"queues.move_task_to_back",
min_version="2.4",
request_data_model=TaskRequest,
response_data_model=MoveTaskResponse,
)
def move_task_to_back(call: APICall, company_id, req_model: TaskRequest):
call.result.data_model = MoveTaskResponse(
position=queue_bll.reposition_task(
company_id=company_id,
queue_id=req_model.queue,
task_id=req_model.task,
pos_func=lambda p: -1,
)
)
@endpoint(
"queues.get_queue_metrics",
min_version="2.4",
request_data_model=GetMetricsRequest,
response_data_model=GetMetricsResponse,
)
def get_queue_metrics(
call: APICall, company_id, req_model: GetMetricsRequest
) -> GetMetricsResponse:
ret = queue_bll.metrics.get_queue_metrics(
company_id=company_id,
from_date=req_model.from_date,
to_date=req_model.to_date,
interval=req_model.interval,
queue_ids=req_model.queue_ids,
)
queue_dicts = {
queue: extract_properties_to_lists(
["date", "avg_waiting_time", "queue_length"], data
)
for queue, data in ret.items()
}
return GetMetricsResponse(
queues=[
QueueMetrics(
queue=queue,
dates=data["date"],
avg_waiting_times=data["avg_waiting_time"],
queue_lengths=data["queue_length"],
) if data else QueueMetrics(queue=queue)
for queue, data in queue_dicts.items()
]
)

View File

@@ -1,8 +1,24 @@
from datetime import datetime
from pyhocon.config_tree import NoneValue
from apierrors import errors
from apimodels.server import ReportStatsOptionRequest, ReportStatsOptionResponse
from bll.statistics.stats_reporter import StatisticsReporter
from config import config
from config.info import get_version, get_build_number, get_commit_number
from database.errors import translate_errors_context
from database.model import Company
from database.model.company import ReportStatsOption
from service_repo import ServiceRepo, APICall, endpoint
from version import __version__ as current_version
@endpoint("server.get_stats")
def get_stats(call: APICall):
call.result.data = StatisticsReporter.get_statistics(
company_id=call.identity.company
)
@endpoint("server.config")
@@ -43,3 +59,35 @@ def info(call: APICall):
"build": get_build_number(),
"commit": get_commit_number(),
}
@endpoint(
"server.report_stats_option",
request_data_model=ReportStatsOptionRequest,
response_data_model=ReportStatsOptionResponse,
)
def report_stats(call: APICall, company: str, request: ReportStatsOptionRequest):
if not StatisticsReporter.supported:
result = ReportStatsOptionResponse(supported=False)
else:
enabled = request.enabled
with translate_errors_context():
query = Company.objects(id=company)
if enabled is None:
stats_option = query.first().defaults.stats_option
else:
stats_option = ReportStatsOption(
enabled=enabled,
enabled_time=datetime.utcnow(),
enabled_version=current_version,
enabled_user=call.identity.user,
)
updated = query.update(defaults__stats_option=stats_option)
if not updated:
raise errors.server_error.InternalError(
f"Failed setting report_stats to {enabled}"
)
result = ReportStatsOptionResponse(**stats_option.to_mongo())
call.result.data_model = result

View File

@@ -11,7 +11,7 @@ from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
from apierrors import errors
from apierrors import errors, APIError
from apimodels.base import UpdateResponse
from apimodels.tasks import (
StartedResponse,
@@ -24,13 +24,24 @@ from apimodels.tasks import (
TaskRequest,
DeleteRequest,
PingRequest,
EnqueueRequest,
EnqueueResponse,
DequeueResponse,
)
from bll.event import EventBLL
from bll.queue import QueueBLL
from bll.task import TaskBLL, ChangeStatusRequest, update_project_time, split_by
from bll.util import SetFieldsResolver
from database.errors import translate_errors_context
from database.model.model import Model
from database.model.task.output import Output
from database.model.task.task import Task, TaskStatus, Script, DEFAULT_LAST_ITERATION
from database.model.task.task import (
Task,
TaskStatus,
Script,
DEFAULT_LAST_ITERATION,
Execution,
)
from database.utils import get_fields, parse_from_call
from service_repo import APICall, endpoint
from services.utils import conform_tag_fields, conform_output_tags
@@ -48,20 +59,23 @@ get_all_query_options = Task.QueryParameterOptions(
task_bll = TaskBLL()
event_bll = EventBLL()
queue_bll = QueueBLL()
TaskBLL.start_non_responsive_tasks_watchdog()
def set_task_status_from_call(
request: UpdateRequest, company_id, new_status=None, **kwargs
request: UpdateRequest, company_id, new_status=None, **set_fields
) -> dict:
fields_resolver = SetFieldsResolver(set_fields)
task = TaskBLL.get_task_with_access(
request.task,
company_id=company_id,
only=("status", "project"),
only=tuple({"status", "project"} | fields_resolver.get_names()),
requires_write_access=True,
)
status_reason = request.status_reason
status_message = request.status_message
force = request.force
@@ -71,7 +85,7 @@ def set_task_status_from_call(
status_reason=status_reason,
status_message=status_message,
force=force,
).execute(**kwargs)
).execute(**fields_resolver.get_fields(task))
@endpoint("tasks.get_by_id", request_data_model=TaskRequest)
@@ -94,7 +108,6 @@ def get_all_ex(call: APICall):
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
override_none_ordering=True,
)
conform_output_tags(call, tasks)
call.result.data = {"tasks": tasks}
@@ -111,7 +124,6 @@ def get_all(call: APICall):
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True, # required in case projection is requested for public dataset/versions
override_none_ordering=True,
)
conform_output_tags(call, tasks)
call.result.data = {"tasks": tasks}
@@ -167,7 +179,7 @@ def started(call: APICall, company_id, req_model: UpdateRequest):
req_model,
company_id,
new_status=TaskStatus.in_progress,
started=datetime.utcnow(),
min__started=datetime.utcnow(), # don't override a previous, smaller "started" field value
)
)
res.started = res.updated
@@ -443,6 +455,125 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
call.result.data_model = UpdateResponse(updated=0)
@endpoint(
"tasks.enqueue",
request_data_model=EnqueueRequest,
response_data_model=EnqueueResponse,
)
def enqueue(call: APICall, company_id, req_model: EnqueueRequest):
task_id = req_model.task
queue_id = req_model.queue
status_message = req_model.status_message
status_reason = req_model.status_reason
if not queue_id:
# try to get default queue
queue_id = queue_bll.get_default(company_id).id
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
_only=("type", "script", "execution", "status", "project", "id"), **query
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
res = EnqueueResponse(
**ChangeStatusRequest(
task=task,
new_status=TaskStatus.queued,
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
).execute()
)
try:
queue_bll.add_task(
company_id=company_id, queue_id=queue_id, task_id=task.id
)
except Exception:
# failed enqueueing, revert to previous state
ChangeStatusRequest(
task=task,
current_status_override=TaskStatus.queued,
new_status=task.status,
force=True,
status_reason="failed enqueueing",
).execute()
raise
# set the current queue ID in the task
if task.execution:
Task.objects(**query).update(execution__queue=queue_id, multi=False)
else:
Task.objects(**query).update(
execution=Execution(queue=queue_id), multi=False
)
res.queued = 1
res.fields.update(**{"execution.queue": queue_id})
call.result.data_model = res
@endpoint(
"tasks.dequeue",
request_data_model=UpdateRequest,
response_data_model=DequeueResponse,
)
def dequeue(call: APICall, company_id, req_model: UpdateRequest):
task = TaskBLL.get_task_with_access(
req_model.task,
company_id=company_id,
only=("id", "execution", "status", "project"),
requires_write_access=True,
)
if task.status not in (TaskStatus.queued,):
raise errors.bad_request.InvalidTaskId(
status=task.status, expected=TaskStatus.queued
)
_dequeue(task, company_id)
status_message = req_model.status_message
status_reason = req_model.status_reason
res = DequeueResponse(
**ChangeStatusRequest(
task=task,
new_status=TaskStatus.created,
status_reason=status_reason,
status_message=status_message,
).execute(unset__execution__queue=1)
)
res.dequeued = 1
call.result.data_model = res
def _dequeue(task: Task, company_id: str, silent_fail=False):
"""
Dequeue the task from the queue
:param task: task to dequeue
:param silent_fail: do not throw exceptions. APIError is still thrown
:raise errors.bad_request.MissingRequiredFields: if the task is not queued
:raise APIError or errors.server_error.TransactionError: if internal call to queues.remove_task fails
:return: the result of queues.remove_task call. None in case of silent failure
"""
if not task.execution or not task.execution.queue:
if silent_fail:
return
raise errors.bad_request.MissingRequiredFields(
"task has no queue value", field="execution.queue"
)
return {
"removed": queue_bll.remove_task(
company_id=company_id, queue_id=task.execution.queue, task_id=task.id
)
}
@endpoint(
"tasks.reset", request_data_model=UpdateRequest, response_data_model=ResetResponse
)
@@ -459,6 +590,16 @@ def reset(call: APICall, company_id, req_model: UpdateRequest):
api_results = {}
updates = {}
try:
dequeued = _dequeue(task, company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
else:
if dequeued:
api_results.update(dequeued=dequeued)
updates.update(unset__execution__queue=1)
cleaned_up = cleanup_task(task, force)
api_results.update(attr.asdict(cleaned_up))

View File

@@ -1,4 +1,4 @@
from typing import Union, Sequence
from typing import Union, Sequence, Tuple
from database.utils import partition_tags
from service_repo import APICall
@@ -6,6 +6,9 @@ from service_repo.base import PartialVersion
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
"""
For old clients both tags and system tags are returned in 'tags' field
"""
if call.requested_endpoint_version >= PartialVersion("2.3"):
return
if isinstance(documents, dict):
@@ -17,36 +20,44 @@ def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
def conform_tag_fields(call: APICall, document: dict):
"""
Upgrade old client tags in place
"""
if "tags" in document:
tags, system_tags = conform_tags(
call, document["tags"], document.get("system_tags")
)
if tags != document.get("tags"):
document["tags"] = tags
if system_tags != document.get("system_tags"):
document["system_tags"] = system_tags
def conform_tags(
call: APICall, tags: Sequence, system_tags: Sequence
) -> Tuple[Sequence, Sequence]:
"""
Make sure that 'tags' from the old SDK clients
are correctly split into 'tags' and 'system_tags'
Make sure that there are no duplicate tags
"""
if call.requested_endpoint_version < PartialVersion("2.3"):
tags, system_tags = _upgrade_tags(call, tags, system_tags)
return _get_unique_values(tags), _get_unique_values(system_tags)
def _upgrade_tags(call: APICall, tags: Sequence, system_tags: Sequence):
if tags is not None and not system_tags:
service_name = call.endpoint_name.partition(".")[0]
upgrade_tags(
service_name[:-1] if service_name.endswith("s") else service_name, document
)
remove_duplicate_tags(document)
entity = service_name[:-1] if service_name.endswith("s") else service_name
return partition_tags(entity, tags)
return tags, system_tags
def upgrade_tags(entity: str, document: dict):
"""
If only 'tags' is present in the fields then extract
the system tags from it to a separate field 'system_tags'
"""
tags = document.get("tags")
if tags is not None and not document.get("system_tags"):
user_tags, system_tags = partition_tags(entity, tags)
document["tags"] = user_tags
document["system_tags"] = system_tags
def _get_unique_values(values: Sequence) -> Sequence:
"""Get unique values from the given sequence"""
if not values:
return values
def remove_duplicate_tags(document: dict):
"""
Remove duplicates from 'tags' and 'system_tags' fields
"""
for name in ("tags", "system_tags"):
values = document.get(name)
if values:
document[name] = list(set(values))
return list(set(values))

202
server/services/workers.py Normal file
View File

@@ -0,0 +1,202 @@
import itertools
from operator import attrgetter
from typing import Optional, Sequence, Union
from boltons.iterutils import bucketize
from apierrors.errors import bad_request
from apimodels.workers import (
WorkerRequest,
StatusReportRequest,
GetAllRequest,
GetAllResponse,
RegisterRequest,
GetStatsRequest,
MetricCategory,
GetMetricKeysRequest,
GetMetricKeysResponse,
GetStatsResponse,
WorkerStatistics,
MetricStats,
AggregationStats,
GetActivityReportRequest,
GetActivityReportResponse,
ActivityReportSeries,
)
from bll.util import extract_properties_to_lists
from bll.workers import WorkerBLL
from config import config
from service_repo import APICall, endpoint
log = config.logger(__file__)
worker_bll = WorkerBLL()
@endpoint(
"workers.get_all",
min_version="2.4",
request_data_model=GetAllRequest,
response_data_model=GetAllResponse,
)
def get_all(call: APICall, company_id: str, request: GetAllRequest):
call.result.data_model = GetAllResponse(
workers=worker_bll.get_all_with_projection(company_id, request.last_seen)
)
@endpoint("workers.register", min_version="2.4", request_data_model=RegisterRequest)
def register(call: APICall, company_id, req_model: RegisterRequest):
worker = req_model.worker
timeout = req_model.timeout
queues = req_model.queues
if not timeout or timeout <= 0:
raise bad_request.WorkerRegistrationFailed(
"invalid timeout", timeout=timeout, worker=worker
)
worker_bll.register_worker(
company_id=company_id,
user_id=call.identity.user,
worker=worker,
ip=call.real_ip,
queues=queues,
timeout=timeout,
)
@endpoint("workers.unregister", min_version="2.4", request_data_model=WorkerRequest)
def unregister(call: APICall, company_id, req_model: WorkerRequest):
worker_bll.unregister_worker(company_id, call.identity.user, req_model.worker)
@endpoint("workers.status_report", min_version="2.4", request_data_model=StatusReportRequest)
def status_report(call: APICall, company_id, request: StatusReportRequest):
worker_bll.status_report(
company_id=company_id,
user_id=call.identity.user,
ip=call.real_ip,
report=request,
)
@endpoint(
"workers.get_metric_keys",
min_version="2.4",
request_data_model=GetMetricKeysRequest,
response_data_model=GetMetricKeysResponse,
validate_schema=True,
)
def get_metric_keys(
call: APICall, company_id, req_model: GetMetricKeysRequest
) -> GetMetricKeysResponse:
ret = worker_bll.stats.get_worker_stats_keys(
company_id, worker_ids=req_model.worker_ids
)
return GetMetricKeysResponse(
categories=[MetricCategory(name=k, metric_keys=v) for k, v in ret.items()]
)
@endpoint(
"workers.get_activity_report",
min_version="2.4",
request_data_model=GetActivityReportRequest,
response_data_model=GetActivityReportResponse,
validate_schema=True,
)
def get_activity_report(
call: APICall, company_id, req_model: GetActivityReportRequest
) -> GetActivityReportResponse:
def get_activity_series(active_only: bool = False) -> ActivityReportSeries:
ret = worker_bll.stats.get_activity_report(
company_id=company_id,
from_date=req_model.from_date,
to_date=req_model.to_date,
interval=req_model.interval,
active_only=active_only,
)
if not ret:
return ActivityReportSeries(dates=[], counts=[])
count_by_date = extract_properties_to_lists(["date", "count"], ret)
return ActivityReportSeries(
dates=count_by_date["date"], counts=count_by_date["count"]
)
return GetActivityReportResponse(
total=get_activity_series(), active=get_activity_series(active_only=True)
)
@endpoint(
"workers.get_stats",
min_version="2.4",
request_data_model=GetStatsRequest,
response_data_model=GetStatsResponse,
validate_schema=True,
)
def get_stats(call: APICall, company_id, request: GetStatsRequest):
ret = worker_bll.stats.get_worker_stats(company_id, request)
def _get_variant_metric_stats(
metric: str,
agg_names: Sequence[str],
stats: Sequence[dict],
variant: Optional[str] = None,
) -> MetricStats:
stat_by_name = extract_properties_to_lists(agg_names, stats)
return MetricStats(
metric=metric,
variant=variant,
dates=stat_by_name["date"],
stats=[
AggregationStats(aggregation=name, values=aggs)
for name, aggs in stat_by_name.items()
if name != "date"
],
)
def _get_metric_stats(
metric: str, stats: Union[dict, Sequence[dict]], agg_types: Sequence[str]
) -> Sequence[MetricStats]:
"""
Return statistics for a certain metric or a list of statistic for
metric variants if break_by_variant was requested
"""
agg_names = ["date"] + list(set(agg_types))
if not isinstance(stats, dict):
# no variants were requested
return [_get_variant_metric_stats(metric, agg_names, stats)]
return [
_get_variant_metric_stats(metric, agg_names, variant_stats, variant)
for variant, variant_stats in stats.items()
]
def _get_worker_metrics(stats: dict) -> Sequence[MetricStats]:
"""
Convert the worker statistics data from the internal format of lists of structs
to a more "compact" format for json transfer (arrays of dates and arrays of values)
"""
# removed metrics that were requested but for some reason
# do not exist in stats data
metrics = [metric for metric in request.items if metric.key in stats]
aggs_by_metric = bucketize(
metrics, key=attrgetter("key"), value_transform=attrgetter("aggregation")
)
return list(
itertools.chain.from_iterable(
_get_metric_stats(metric, metric_stats, aggs_by_metric[metric])
for metric, metric_stats in stats.items()
)
)
return GetStatsResponse(
workers=[
WorkerStatistics(worker=worker, metrics=_get_worker_metrics(stats))
for worker, stats in ret.items()
]
)

View File

@@ -1,6 +1,8 @@
import abc
import sys
from datetime import datetime, timezone
from functools import partial
from typing import Iterable
from unittest import TestCase
from tests.api_client import APIClient
@@ -88,6 +90,19 @@ class TestService(TestCase, TestServiceInterface):
log.exception(ex)
self._deferred = []
def assertEqualNoOrder(self, first: Iterable, second: Iterable):
"""Compares 2 sequences regardless of their items order"""
self.assertEqual(set(first), set(second))
def header(info, title="=" * 20):
print(title, info, title, file=sys.stderr)
def utc_now_tz_aware() -> datetime:
"""
Returns utc now with the utc time zone.
Suitable for subsequent usage with functions that
make use of tz info like 'timestamp'
"""
return datetime.now(timezone.utc)

View File

@@ -6,8 +6,8 @@ from typing import Sequence
from tests.automated import TestService
class TestTasksOrdering(TestService):
test_comment = "Task ordering test"
class TestEntityOrdering(TestService):
test_comment = "Entity ordering test"
only_fields = ["id", "started", "comment"]
def setUp(self, **kwargs):

View File

@@ -0,0 +1,203 @@
import time
from operator import itemgetter
from typing import Sequence
from future.backports.datetime import timedelta
from tests.api_client import AttrDict
from tests.automated import TestService, utc_now_tz_aware
class TestQueues(TestService):
def setUp(self, version="2.4"):
super().setUp(version=version)
def test_default_queue(self):
res = self.api.queues.get_default()
self.assertIsNotNone(res.id)
def test_create_update_delete(self):
queue = self._temp_queue("TempTest", tags=["hello", "world"])
res = self.api.queues.update(queue=queue, tags=["test"])
assert res.updated == 1
assert res.fields.tags == ["test"]
def test_queue_metrics(self):
queue_id = self._temp_queue("TestTempQueue")
task1 = self._create_temp_queued_task("temp task 1", queue_id)
time.sleep(1)
task2 = self._create_temp_queued_task("temp task 2", queue_id)
self.api.queues.get_next_task(queue=queue_id)
self.api.queues.remove_task(queue=queue_id, task=task2["id"])
to_date = utc_now_tz_aware()
from_date = to_date - timedelta(hours=1)
res = self.api.queues.get_queue_metrics(
queue_ids=[queue_id],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
interval=5,
)
self.assertMetricQueues(res["queues"], queue_id)
def test_reset_task(self):
queue = self._temp_queue("TestTempQueue")
task = self._temp_task("TempTask", is_development=True)
self.api.tasks.enqueue(task=task, queue=queue)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, [task])
res = self.api.tasks.reset(task=task)
self.assertEqual(res.dequeued.removed, 1)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, [])
def test_enqueue_dev_task(self):
queue = self._temp_queue("TestTempQueue")
task_name = "TempDevTask"
task = self._temp_task(task_name, is_development=True)
self.assertTaskTags(task, system_tags=["development"])
self.api.tasks.enqueue(task=task, queue=queue)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, [task])
self.assertTaskTags(task, system_tags=[])
def test_move_task(self):
queue = self._temp_queue("TestTempQueue")
tasks = [
self._create_temp_queued_task(t, queue)["id"]
for t in ("temp task1", "temp task2", "temp task3")
]
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
new_pos = self.api.queues.move_task_backward(
queue=queue, task=tasks[0], count=2
).position
self.assertEqual(new_pos, 2)
res = self.api.queues.get_by_id(queue=queue)
changed_tasks = tasks[1:] + tasks[:1]
self.assertQueueTasks(res.queue, changed_tasks)
new_pos = self.api.queues.move_task_forward(
queue=queue, task=tasks[0], count=2
).position
self.assertEqual(new_pos, 0)
res = self.api.queues.get_by_id(queue=queue)
self.assertQueueTasks(res.queue, tasks)
self.assertGetNextTasks(queue, tasks)
def test_get_all_ex(self):
queue_name = "TestTempQueue1"
queue_tags = ["Test1", "Test2"]
queue = self._temp_queue(queue_name, tags=queue_tags)
res = self.api.queues.get_all_ex(name="TestTempQueue*").queues
self.assertQueue(
res, queue_id=queue, name=queue_name, tags=queue_tags, tasks=[], workers=[]
)
tasks = [
self._create_temp_queued_task(t, queue)
for t in ("temp task1", "temp task2")
]
workers = [
self._create_temp_worker(w, queue) for w in ("temp worker1", "temp worker2")
]
res = self.api.queues.get_all_ex(name="TestTempQueue*").queues
self.assertQueue(
res,
queue_id=queue,
name=queue_name,
tags=queue_tags,
tasks=tasks,
workers=workers,
)
def assertMetricQueues(self, queues_data, queue_id):
self.assertEqual(len(queues_data), 1)
queue_res = queues_data[0]
self.assertEqual(queue_res.queue, queue_id)
dates_len = len(queue_res["dates"])
self.assertTrue(2 >= dates_len >= 1)
for prop in ("avg_waiting_times", "queue_lengths"):
self.assertEqual(len(queue_res[prop]), dates_len)
dates_in_sec = [d / 1000 for d in queue_res["dates"]]
self.assertGreater(
dates_in_sec[0], (utc_now_tz_aware() - timedelta(seconds=15)).timestamp()
)
if dates_len > 1:
self.assertAlmostEqual(dates_in_sec[1] - dates_in_sec[0], 5, places=0)
def assertQueue(
self,
queues: Sequence[AttrDict],
queue_id: str,
name: str,
tags: Sequence[str],
tasks: Sequence[dict],
workers: Sequence[dict],
):
queue = next(q for q in queues if q.id == queue_id)
assert queue.last_update
self.assertEqualNoOrder(queue.tags, tags)
self.assertEqual(queue.name, name)
self.assertQueueTasks(queue, tasks)
self.assertQueueWorkers(queue, workers)
def assertTaskTags(self, task, system_tags):
res = self.api.tasks.get_by_id(task=task)
self.assertSequenceEqual(res.task.system_tags, system_tags)
def assertQueueTasks(self, queue: AttrDict, tasks: Sequence):
self.assertEqual([e.task for e in queue.entries], tasks)
def assertGetNextTasks(self, queue, tasks):
for task_id in tasks:
res = self.api.queues.get_next_task(queue=queue)
self.assertEqual(res.entry.task, task_id)
assert not self.api.queues.get_next_task(queue=queue)
def assertQueueWorkers(self, queue: AttrDict, workers: Sequence[dict]):
sort_key = itemgetter("name")
self.assertEqual(
sorted(queue.workers, key=sort_key), sorted(workers, key=sort_key)
)
def _temp_queue(self, queue_name, tags=None):
return self.create_temp("queues", name=queue_name, tags=tags)
def _temp_task(self, task_name, is_testing=False, is_development=False):
task_input = dict(
name=task_name,
type="testing" if is_testing else "training",
input=dict(mapping={}, view={}),
script={"repository": "test", "entry_point": "test"},
system_tags=["development"] if is_development else None,
)
return self.create_temp("tasks", **task_input)
def _create_temp_queued_task(self, task_name, queue) -> dict:
task_id = self._temp_task(task_name)
self.api.tasks.enqueue(task=task_id, queue=queue)
return dict(id=task_id, name=task_name)
def _create_temp_running_task(self, task_name) -> dict:
task_id = self._temp_task(task_name, is_testing=True)
self.api.tasks.started(task=task_id)
return dict(id=task_id, name=task_name)
def _create_temp_worker(self, worker, queue):
self.api.workers.register(worker=worker, queues=[queue])
task = self._create_temp_running_task(f"temp task for worker {worker}")
self.api.workers.status_report(
worker=worker,
timestamp=int(utc_now_tz_aware().timestamp() * 1000),
machine_stats=dict(cpu_usage=[10, 20]),
task=task["id"],
)
return dict(name=worker, ip="127.0.0.1", task=task)

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from time import sleep
from typing import Sequence
@@ -11,7 +12,7 @@ log = config.logger(__file__)
class TestTags(TestService):
def setUp(self, version="2.3"):
def setUp(self, version="2.4"):
super().setUp(version)
def testPartition(self):
@@ -154,6 +155,12 @@ class TestTags(TestService):
# test development system tag
self.api.tasks.started(task=task_id)
self.api.workers.status_report(
worker="Test tags",
timestamp=int(datetime.utcnow().timestamp() * 1000),
machine_stats=dict(memory_used=30),
task=task_id,
)
self.api.tasks.stop(task=task_id)
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "in_progress")
@@ -162,6 +169,32 @@ class TestTags(TestService):
task = self.api.tasks.get_by_id(task=task_id).task
self.assertEqual(task.status, "stopped")
def testQueueTags(self):
q_id = self._temp_queue(system_tags=["default"])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["default"]
).queues
self.assertFound(q_id, ["default"], queues)
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertNotFound(q_id, queues)
self.api.queues.update(queue=q_id, system_tags=[])
queues = self.api.queues.get_all_ex(
name="Test tags", system_tags=["-default"]
).queues
self.assertFound(q_id, [], queues)
# test default queue
queues = self.api.queues.get_all(system_tags=["default"]).queues
if queues:
self.assertEqual(queues[0].id, self.api.queues.get_default().id)
else:
self.api.queues.update(queue=q_id, system_tags=["default"])
self.assertEqual(q_id, self.api.queues.get_default().id)
def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
@@ -174,6 +207,10 @@ class TestTags(TestService):
sleep(1)
self.api.tasks.stopped(task=task_id)
def _temp_queue(self, **kwargs):
self._update_missing(kwargs, name="Test tags")
return self.create_temp("queues", **kwargs)
def _temp_project(self, **kwargs):
self._update_missing(kwargs, name="Test tags", description="test")
return self.create_temp("projects", **kwargs)

View File

@@ -1,4 +1,3 @@
from apierrors.errors.bad_request import ModelNotReady
from config import config
from tests.automated import TestService
@@ -25,15 +24,7 @@ class TestTasksEdit(TestService):
model = self.new_model()
self.api.models.edit(model=model, ready=False)
self.assertFalse(self.api.models.get_by_id(model=model).model.ready)
with self.api.raises(ModelNotReady):
self.api.tasks.edit(task=task, execution=dict(model=model))
def test_edit_model_not_ready_force(self):
task = self.new_task()
model = self.new_model()
self.api.models.edit(model=model, ready=False)
self.assertFalse(self.api.models.get_by_id(model=model).model.ready)
self.api.tasks.edit(task=task, execution=dict(model=model), force=True)
self.api.tasks.edit(task=task, execution=dict(model=model))
def test_edit_had_model_model_not_ready(self):
ready_model = self.new_model()
@@ -42,14 +33,4 @@ class TestTasksEdit(TestService):
not_ready_model = self.new_model()
self.api.models.edit(model=not_ready_model, ready=False)
self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready)
with self.api.raises(ModelNotReady):
self.api.tasks.edit(task=task, execution=dict(model=not_ready_model))
def test_edit_had_model_model_not_ready_force(self):
ready_model = self.new_model()
self.assert_(self.api.models.get_by_id(model=ready_model).model.ready)
task = self.new_task(execution=dict(model=ready_model))
not_ready_model = self.new_model()
self.api.models.edit(model=not_ready_model, ready=False)
self.assertFalse(self.api.models.get_by_id(model=not_ready_model).model.ready)
self.api.tasks.edit(task=task, execution=dict(model=not_ready_model), force=True)
self.api.tasks.edit(task=task, execution=dict(model=not_ready_model))

View File

@@ -0,0 +1,214 @@
import time
from uuid import uuid4
from datetime import timedelta
from operator import attrgetter
from typing import Sequence
from apierrors.errors import bad_request
from tests.automated import TestService, utc_now_tz_aware
from config import config
log = config.logger(__file__)
class TestWorkersService(TestService):
def setUp(self, version="2.4"):
super().setUp(version=version)
def _check_exists(self, worker: str, exists: bool = True):
workers = self.api.workers.get_all(last_seen=100).workers
found = any(w for w in workers if w.id == worker)
assert exists == found
def test_workers_register(self):
test_worker = f"test_{uuid4().hex}"
self._check_exists(test_worker, False)
self.api.workers.register(worker=test_worker)
self._check_exists(test_worker)
self.api.workers.unregister(worker=test_worker)
self._check_exists(test_worker, False)
def test_workers_timeout(self):
test_worker = f"test_{uuid4().hex}"
self._check_exists(test_worker, False)
self.api.workers.register(worker=test_worker, timeout=3)
self._check_exists(test_worker)
time.sleep(5)
self._check_exists(test_worker, False)
def _simulate_workers(self) -> Sequence[str]:
"""
Two workers writing the same metrics. One for 4 seconds. Another one for 2
The first worker reports a task
:return: worker ids
"""
task_id = self._create_running_task(task_name="task-1")
workers = [f"test_{uuid4().hex}", f"test_{uuid4().hex}"]
workers_stats = [
(
dict(cpu_usage=[10, 20], memory_used=50),
dict(cpu_usage=[5], memory_used=30),
)
] * 4
workers_activity = [
(workers[0], workers[1]),
(workers[0], workers[1]),
(workers[0],),
(workers[0],),
]
for ws, stats in zip(workers_activity, workers_stats):
for w, s in zip(ws, stats):
data = dict(
worker=w,
timestamp=int(utc_now_tz_aware().timestamp() * 1000),
machine_stats=s,
)
if w == workers[0]:
data["task"] = task_id
self.api.workers.status_report(**data)
time.sleep(1)
return workers
def _create_running_task(self, task_name):
task_input = dict(
name=task_name, type="testing", input=dict(mapping={}, view={})
)
task_id = self.create_temp("tasks", **task_input)
self.api.tasks.started(task=task_id)
return task_id
def test_get_keys(self):
workers = self._simulate_workers()
res = self.api.workers.get_metric_keys(worker_ids=workers)
assert {"cpu", "memory"} == set(c.name for c in res["categories"])
assert all(
c.metric_keys == ["cpu_usage"] for c in res["categories"] if c.name == "cpu"
)
assert all(
c.metric_keys == ["memory_used"]
for c in res["categories"]
if c.name == "memory"
)
with self.api.raises(bad_request.WorkerStatsNotFound):
self.api.workers.get_metric_keys(worker_ids=["Non existing worker id"])
def test_get_stats(self):
workers = self._simulate_workers()
to_date = utc_now_tz_aware()
from_date = to_date - timedelta(days=1)
# no variants
res = self.api.workers.get_statistics(
items=[
dict(key="cpu_usage", aggregation="avg"),
dict(key="cpu_usage", aggregation="max"),
dict(key="memory_used", aggregation="max"),
dict(key="memory_used", aggregation="min"),
],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
# split_by_variant=True,
interval=1,
worker_ids=workers,
)
self.assertWorkersInStats(workers, res["workers"])
assert all(
{"cpu_usage", "memory_used"}
== set(map(attrgetter("metric"), worker["metrics"]))
for worker in res["workers"]
)
def _check_dates_and_stats(metric, stats, worker_id) -> bool:
return set(
map(attrgetter("aggregation"), metric["stats"])
) == stats and len(metric["dates"]) == (4 if worker_id == workers[0] else 2)
assert all(
_check_dates_and_stats(metric, metric_stats, worker["worker"])
for worker in res["workers"]
for metric, metric_stats in zip(
worker["metrics"], ({"avg", "max"}, {"max", "min"})
)
)
# split by variants
res = self.api.workers.get_statistics(
items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
split_by_variant=True,
interval=1,
worker_ids=workers,
)
self.assertWorkersInStats(workers, res["workers"])
def _check_metric_and_variants(worker):
return (
all(
_check_dates_and_stats(metric, {"avg"}, worker["worker"])
for metric in worker["metrics"]
)
and set(map(attrgetter("variant"), worker["metrics"])) == {"0", "1"}
if worker["worker"] == workers[0]
else {"0"}
)
assert all(_check_metric_and_variants(worker) for worker in res["workers"])
res = self.api.workers.get_statistics(
items=[dict(key="cpu_usage", aggregation="avg")],
from_date=from_date.timestamp(),
to_date=to_date.timestamp(),
interval=1,
worker_ids=["Non existing worker id"],
)
assert not res["workers"]
@staticmethod
def assertWorkersInStats(workers: Sequence[str], stats: dict):
assert set(workers) == set(map(attrgetter("worker"), stats))
def test_get_activity_report(self):
# test no workers data
# run on an empty es db since we have no way
# to pass non existing workers to this api
# res = self.api.workers.get_activity_report(
# from_timestamp=from_timestamp.timestamp(),
# to_timestamp=to_timestamp.timestamp(),
# interval=20,
# )
self._simulate_workers()
to_date = utc_now_tz_aware()
from_date = to_date - timedelta(minutes=10)
# no variants
res = self.api.workers.get_activity_report(
from_date=from_date.timestamp(), to_date=to_date.timestamp(), interval=20
)
self.assertWorkerSeries(res["total"], 2)
self.assertWorkerSeries(res["active"], 1)
self.assertTotalSeriesGreaterThenActive(res["total"], res["active"])
@staticmethod
def assertTotalSeriesGreaterThenActive(total_data: dict, active_data: dict):
assert total_data["dates"][-1] == active_data["dates"][-1]
assert total_data["counts"][-1] > active_data["counts"][-1]
@staticmethod
def assertWorkerSeries(series_data: dict, min_count: int):
assert len(series_data["dates"]) == len(series_data["counts"])
# check the last 20s aggregation
# there may be more workers that we created since we are not filtering by test workers here
assert series_data["counts"][-1] >= min_count

View File

@@ -1,65 +1,15 @@
import time
from config import config
log = config.logger(__file__)
_stats = dict()
def stats():
aggregate()
return _stats
def clear():
global _stats
_stats = dict()
def get_component_total(comp):
if comp not in _stats:
return 0
return _stats[comp].get("total")
# create a "total" node for each componenet
def aggregate():
grand_total = 0
for comp in _stats:
total = 0
for op in _stats[comp]:
total += _stats[comp][op]
_stats[comp]["total"] = total
grand_total += total
_stats["_all"] = dict(total=grand_total)
class TimingStats:
@classmethod
def aggregate(cls):
return {}
class TimingContext:
def __init__(self, component, operation):
self.component = component
self.operation = operation
_stats["_all"] = dict(total=0)
if component not in _stats:
_stats[component] = dict(total=0)
def __init__(self, *_, **__):
pass
def __enter__(self):
self.start = time.time()
return self
def __exit__(self, *args):
try:
self.end = time.time()
latency_ms = int((self.end - self.start) * 1000)
if self.operation in _stats.get(self.component, {}):
previous_latency = _stats[self.component][self.operation]
new_latency = int((previous_latency + latency_ms) / 2)
else:
new_latency = latency_ms
if self.component not in _stats:
_stats[self.component] = dict(total=0)
_stats[self.component][self.operation] = new_latency
except Exception as ex:
log.error("%s calculating latency: %s" % (type(ex).__name__, str(ex)))
pass

44
server/tools.py Normal file
View File

@@ -0,0 +1,44 @@
""" Command line tools for the API server """
from argparse import ArgumentParser
import dpath
from humanfriendly import parse_timespan
def setup():
from database import initialize
initialize()
def gen_token(args):
from bll.auth import AuthBLL
resp = AuthBLL.get_token_for_user(args.user_id, args.company_id, parse_timespan(args.expiration))
print('Token:\n%s' % resp.token)
def safe_get(obj, glob, default=None, separator="/"):
try:
return dpath.get(obj, glob, separator=separator)
except KeyError:
return default
if __name__ == '__main__':
top_parser = ArgumentParser(__doc__)
subparsers = top_parser.add_subparsers(title='Sections')
token = subparsers.add_parser('token')
token_commands = token.add_subparsers(title='Commands')
token_create = token_commands.add_parser('generate', description='Generate a new token')
token_create.add_argument('--user-id', '-u', help='User ID', required=True)
token_create.add_argument('--company-id', '-c', help='Company ID', required=True)
token_create.add_argument('--expiration', '-exp',
help="Token expiration (time span, shorthand suffixes are supported, default 1m)",
default=parse_timespan('1m'))
token_create.set_defaults(_func=gen_token)
args = top_parser.parse_args()
if args._func:
setup()
args._func(args)

114
server/updates.py Normal file
View File

@@ -0,0 +1,114 @@
import os
from threading import Thread
from time import sleep
from typing import Optional
import attr
import requests
from semantic_version import Version
from config import config
from database.model.settings import Settings
from version import __version__ as current_version
log = config.logger(__name__)
class CheckUpdatesThread(Thread):
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
@attr.s(auto_attribs=True)
class _VersionResponse:
version: str
patch_upgrade: bool
description: str = None
def __init__(self):
super(CheckUpdatesThread, self).__init__(
target=self._check_updates, daemon=True
)
def start(self) -> None:
if not self._enabled:
log.info("Checking for updates is disabled")
return
super(CheckUpdatesThread, self).start()
@property
def component_name(self) -> str:
return config.get("apiserver.check_for_updates.component_name", "trains-server")
def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get(
"apiserver.check_for_updates.url",
"https://updates.trains.allegro.ai/updates",
)
uid = Settings.get_by_key("server.uuid")
response = requests.get(
url,
json={"versions": {self.component_name: str(current_version)}, "uid": uid},
timeout=float(
config.get("apiserver.check_for_updates.request_timeout_sec", 3.0)
),
)
if not response.ok:
return
response = response.json().get(self.component_name)
if not response:
return
latest_version = response.get("version")
if not latest_version:
return
cur_version = Version(current_version)
latest_version = Version(latest_version)
if cur_version >= latest_version:
return
return self._VersionResponse(
version=str(latest_version),
patch_upgrade=(
latest_version.major == cur_version.major
and latest_version.minor == cur_version.minor
),
description=response.get("description").split("\r\n"),
)
def _check_updates(self):
while True:
# noinspection PyBroadException
try:
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception:
log.exception("Failed obtaining updates")
sleep(
max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec",
60 * 60 * 24,
)
),
60 * 5,
)
)
check_updates_thread = CheckUpdatesThread()

View File

@@ -1,14 +1,29 @@
from functools import wraps
from threading import Lock, Thread
import attr
@attr.s(auto_attribs=True)
class ThreadsManager:
objects = {}
lock = Lock()
def __init__(self, name=None, **threads):
super(ThreadsManager, self).__init__()
self.name = name or self.__class__.name
self.objects = {}
self.lock = Lock()
for name, thread in threads.items():
if issubclass(thread, Thread):
thread = thread()
thread.start()
elif isinstance(thread, Thread):
if not thread.is_alive():
thread.start()
else:
raise Exception(f"Expected thread or thread class ({name}): {thread}")
self.objects[name] = thread
def register(self, thread_name, daemon=True):
def decorator(f):
@wraps(f)
@@ -17,7 +32,7 @@ class ThreadsManager:
thread = self.objects.get(thread_name)
if not thread:
thread = Thread(
target=f, name=thread_name, args=args, kwargs=kwargs
target=f, name=f"{self.name}_{thread_name}", args=args, kwargs=kwargs
)
thread.daemon = daemon
thread.start()
@@ -27,3 +42,13 @@ class ThreadsManager:
return wrapper
return decorator
def __getattr__(self, item):
if item in self.objects:
return self.objects[item]
return self.__getattribute__(item)
def __getitem__(self, item):
if item in self.objects:
return self.objects[item]
raise KeyError(item)

1
server/version.py Normal file
View File

@@ -0,0 +1 @@
__version__ = "0.12.0"