mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Initial commit
This commit is contained in:
201
README.md
Normal file
201
README.md
Normal file
@@ -0,0 +1,201 @@
|
||||
# TRAINS Server
|
||||
## Magic Version Control & Experiment Manager for AI
|
||||
|
||||
## Introduction
|
||||
|
||||
The **trains-server** is the infrastructure behind [trains](https://github.com/allegroai/trains).
|
||||
|
||||
The server provides:
|
||||
|
||||
* UI (single-page webapp) for experiment management and browsing
|
||||
* REST interface for documenting and logging experiment information, statistics and results
|
||||
* REST interface for querying experiments history, logs and results
|
||||
* Locally-hosted fileserver, for storing images and models to be easily accessible from the UI
|
||||
|
||||
The server is designed to allow multiple users to collaborate and manage their experiments.
|
||||
The server’s code is freely available [here](https://github.com/allegroai/trains-server).
|
||||
We've also pre-built a docker image to allow **trains** users to quickly set up their own server.
|
||||
|
||||
## System diagram
|
||||
|
||||
<pre>
|
||||
TRAINS-server
|
||||
+--------------------------------------------------------------------+
|
||||
| |
|
||||
| Server Docker Elastic Docker Mongo Docker |
|
||||
| +-------------------------+ +---------------+ +------------+ |
|
||||
| | Pythonic Server | | | | | |
|
||||
| | +-----------------+ | | ElasticSearch | | MongoDB | |
|
||||
| | | WEB server | | | | | | |
|
||||
| | | Port 8080 | | | | | | |
|
||||
| | +--------+--------+ | | | | | |
|
||||
| | | | | | | | |
|
||||
| | +--------+--------+ | | | | | |
|
||||
| | | API server +----------------------------+ | |
|
||||
| | | Port 8008 +---------+ | | | |
|
||||
| | +-----------------+ | +-------+-------+ +-----+------+ |
|
||||
| | | | | |
|
||||
| | +-----------------+ | +---+----------------+------+ |
|
||||
| | | File Server +-------+ | Host Storage | |
|
||||
| | | Port 8081 | | +-----+ | |
|
||||
| | +-----------------+ | +---------------------------+ |
|
||||
| +------------+------------+ |
|
||||
+---------------|----------------------------------------------------+
|
||||
|HTTP
|
||||
+--------+
|
||||
GPU Machine |
|
||||
+------------------------|-------------------------------------------+
|
||||
| +------------------|--------------+ |
|
||||
| | Training | | +---------------------+ |
|
||||
| | Code +---+------------+ | | trains configuration| |
|
||||
| | | TRAINS | | | ~/trains.conf | |
|
||||
| | | +------+ | |
|
||||
| | +----------------+ | +---------------------+ |
|
||||
| +---------------------------------+ |
|
||||
+--------------------------------------------------------------------+
|
||||
</pre>
|
||||
|
||||
## Installation
|
||||
|
||||
In order to install and run the pre-built **trains-server**, you must be logged in as a user with sudo privileges.
|
||||
|
||||
### Setup
|
||||
|
||||
In order to run the pre-packaged **trains-server**, you'll need to install **docker**.
|
||||
|
||||
#### Install docker
|
||||
|
||||
```bash
|
||||
sudo apt-get install docker
|
||||
```
|
||||
|
||||
#### Setup docker daemon
|
||||
In order to run the ElasticSearch docker container, you'll need to change some of the default values in the Docker configuration file.
|
||||
|
||||
For systems with an `/etc/sysconfig/docker` file, add the options in quotes to the available arguments in `OPTIONS`:
|
||||
|
||||
```bash
|
||||
OPTIONS="--default-ulimit nofile=1024:65536 --default-ulimit memlock=-1:-1"
|
||||
```
|
||||
|
||||
For systems with an `/etc/docker/daemon.json` file, add the section in curly brackets to `default-ulimits`:
|
||||
|
||||
```json
|
||||
"default-ulimits": {
|
||||
"nofile": {
|
||||
"name": "nofile",
|
||||
"hard": 65536,
|
||||
"soft": 1024
|
||||
},
|
||||
"memlock":
|
||||
{
|
||||
"name": "memlock",
|
||||
"soft": -1,
|
||||
"hard": -1
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Following this configuration change, you will have to restart the docker daemon:
|
||||
|
||||
```bash
|
||||
sudo service docker stop
|
||||
sudo service docker start
|
||||
```
|
||||
|
||||
#### vm.max_map_count
|
||||
|
||||
The `vm.max_map_count` kernel setting must be at least 262144.
|
||||
|
||||
The following example was tested with CentOS 7, Ubuntu 16.04, Mint 18.3, Ubuntu 18.04 and Mint 19:
|
||||
|
||||
```bash
|
||||
sudo echo "vm.max_map_count=262144" > /tmp/99-trains.conf
|
||||
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
|
||||
sudo sysctl -w vm.max_map_count=262144
|
||||
```
|
||||
|
||||
For additional information about setting this parameter on other systems, see the [elastic](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode) documentation.
|
||||
|
||||
#### Choose a data folder
|
||||
|
||||
You will need to choose a directory on your system in which all data maintained by **trains-server** will be stored (among others, this includes database, uploaded files and logs).
|
||||
|
||||
The following instructions assume the directory is `/opt/trains`.
|
||||
|
||||
Issue the following commands:
|
||||
|
||||
```bash
|
||||
sudo mkdir -p /opt/trains/data/elastic && sudo chown -R 1000:1000 /opt/trains
|
||||
```
|
||||
|
||||
### Launching docker images
|
||||
|
||||
|
||||
To launch the docker images, issue the following commands:
|
||||
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-elastic" -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains allegroai/trains:latest apiserver
|
||||
```
|
||||
|
||||
```bash
|
||||
sudo docker run -d --restart="always" --name="trains-webserver" --network="host" -v /opt/trains/logs:/var/log/trains allegroai/trains:latest webserver
|
||||
```
|
||||
|
||||
Once the **trains-server** dockers are up, the following are available:
|
||||
|
||||
* API server on port `8008`
|
||||
* Web server on port `8080`
|
||||
* File server on port `8081`
|
||||
|
||||
## Upgrade
|
||||
|
||||
We are constantly updating and adding stuff.
|
||||
When we release a new version, we’ll include a new pre-built docker image.
|
||||
Once a new release is out, you can simply:
|
||||
|
||||
1. Shut down and remove your docker instances. Each instance can be shut down and removed using the following commands:
|
||||
```bash
|
||||
sudo docker stop <docker-name>
|
||||
sudo docker rm -v <docker-name>
|
||||
```
|
||||
The docker names are (see [Launching docker images](#Launching-docker-images)):
|
||||
* `trains-elastic`
|
||||
* `trains-mongo`
|
||||
* `trains-fileserver`
|
||||
* `trains-apiserver`
|
||||
* `trains-webserver`
|
||||
|
||||
2. Back up your data folder (recommended!). A simple way to do that is using this command:
|
||||
```bash
|
||||
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
|
||||
```
|
||||
Which will back up all data to an archive in your home folder. Restoring such a backup can be done using these commands:
|
||||
```bash
|
||||
sudo rm -R /opt/trains/data
|
||||
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
|
||||
```
|
||||
3. Launch the newly released docker image (see [Launching docker images](#Launching-docker-images))
|
||||
|
||||
## License
|
||||
|
||||
[Server Side Public License v1.0](https://github.com/mongodb/mongo/blob/master/LICENSE-Community.txt)
|
||||
|
||||
**trains-server** relies *heavily* on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
|
||||
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our job as a community to support the projects we love and cherish.
|
||||
We feel the cause for the license change in both cases is more than just, and chose [SSPL](https://www.mongodb.com/licensing/server-side-public-license) because it is the more restrictive of the two.
|
||||
|
||||
This is our way to say - we support you guys!
|
||||
58
fileserver/fileserver.py
Normal file
58
fileserver/fileserver.py
Normal file
@@ -0,0 +1,58 @@
|
||||
""" A Simple file server for uploading and downloading files """
|
||||
import json
|
||||
import logging.config
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
|
||||
from flask import Flask, request, send_from_directory, safe_join
|
||||
from pyhocon import ConfigFactory
|
||||
|
||||
logging.config.dictConfig(ConfigFactory.parse_file("logging.conf"))
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def upload():
|
||||
results = []
|
||||
for filename, file in request.files.items():
|
||||
if not filename:
|
||||
continue
|
||||
file_path = filename.lstrip(os.sep)
|
||||
target = Path(safe_join(app.config["UPLOAD_FOLDER"], file_path))
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
file.save(str(target))
|
||||
results.append(file_path)
|
||||
return (json.dumps(results), 200)
|
||||
|
||||
|
||||
@app.route("/<path:path>", methods=["GET"])
|
||||
def download(path):
|
||||
return send_from_directory(app.config["UPLOAD_FOLDER"], path)
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--port", "-p", type=int, default=8081, help="Port (default %(default)d)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ip", "-i", type=str, default="0.0.0.0", help="Address (default %(default)s)"
|
||||
)
|
||||
parser.add_argument("--debug", action="store_true", default=False)
|
||||
parser.add_argument(
|
||||
"--upload-folder",
|
||||
"-u",
|
||||
default="/mnt/fileserver",
|
||||
help="Upload folder (default %(default)s)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = args.upload_folder
|
||||
|
||||
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
38
fileserver/logging.conf
Normal file
38
fileserver/logging.conf
Normal file
@@ -0,0 +1,38 @@
|
||||
{
|
||||
version: 1
|
||||
disable_existing_loggers: false
|
||||
formatters: {
|
||||
default: {
|
||||
format: "[%(asctime)s] [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
|
||||
}
|
||||
}
|
||||
handlers {
|
||||
console {
|
||||
formatter: default
|
||||
class: "logging.StreamHandler"
|
||||
}
|
||||
text_file: {
|
||||
formatter: default,
|
||||
backupCount: 3
|
||||
maxBytes: 10240000,
|
||||
class: "logging.handlers.RotatingFileHandler",
|
||||
filename: "/var/log/trains/fileserver.log"
|
||||
}
|
||||
}
|
||||
root {
|
||||
handlers: [console, text_file]
|
||||
level: INFO
|
||||
}
|
||||
loggers {
|
||||
urllib3 {
|
||||
handlers: [console, text_file]
|
||||
level: WARN
|
||||
propagate: false
|
||||
}
|
||||
werkzeug {
|
||||
handlers: [console, text_file]
|
||||
level: WARN
|
||||
propagate: false
|
||||
}
|
||||
}
|
||||
}
|
||||
1
fileserver/requirements.txt
Normal file
1
fileserver/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
Flask
|
||||
116
server/apierrors/__init__.py
Normal file
116
server/apierrors/__init__.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import pathlib
|
||||
from . import autogen
|
||||
|
||||
from .apierror import APIError
|
||||
|
||||
|
||||
""" Error codes """
|
||||
_error_codes = {
|
||||
(400, 'bad_request'): {
|
||||
1: ('not_supported', 'endpoint is not supported'),
|
||||
2: ('request_path_has_invalid_version', 'request path has invalid version'),
|
||||
5: ('invalid_headers', 'invalid headers'),
|
||||
6: ('impersonation_error', 'impersonation error'),
|
||||
|
||||
10: ('invalid_id', 'invalid object id'),
|
||||
11: ('missing_required_fields', 'missing required fields'),
|
||||
12: ('validation_error', 'validation error'),
|
||||
13: ('fields_not_allowed_for_role', 'fields not allowed for role'),
|
||||
14: ('invalid fields', 'fields not defined for object'),
|
||||
15: ('fields_conflict', 'conflicting fields'),
|
||||
16: ('fields_value_error', 'invalid value for fields'),
|
||||
17: ('batch_contains_no_items', 'batch request contains no items'),
|
||||
18: ('batch_validation_error', 'batch request validation error'),
|
||||
19: ('invalid_lucene_syntax', 'malformed lucene query'),
|
||||
20: ('fields_type_error', 'invalid type for fields'),
|
||||
21: ('invalid_regex_error', 'malformed regular expression'),
|
||||
22: ('invalid_email_address', 'malformed email address'),
|
||||
23: ('invalid_domain_name', 'malformed domain name'),
|
||||
24: ('not_public_object', 'object is not public'),
|
||||
|
||||
# Tasks
|
||||
100: ('task_error', 'general task error'),
|
||||
101: ('invalid_task_id', 'invalid task id'),
|
||||
102: ('task_validation_error', 'task validation error'),
|
||||
110: ('invalid_task_status', 'invalid task status'),
|
||||
111: ('task_not_started', 'task not started (invalid task status)'),
|
||||
112: ('task_in_progress', 'task in progress (invalid task status)'),
|
||||
113: ('task_published', 'task published (invalid task status)'),
|
||||
114: ('task_status_unknown', 'task unknown (invalid task status)'),
|
||||
120: ('invalid_task_execution_progress', 'invalid task execution progress'),
|
||||
121: ('failed_changing_task_status', 'failed changing task status. probably someone changed it before you'),
|
||||
122: ('missing_task_fields', 'task is missing expected fields'),
|
||||
123: ('task_cannot_be_deleted', 'task cannot be deleted'),
|
||||
125: ('task_has_jobs_running', "task has jobs that haven't completed yet"),
|
||||
126: ('invalid_task_type', "invalid task type for this operations"),
|
||||
127: ('invalid_task_input', 'invalid task output'),
|
||||
128: ('invalid_task_output', 'invalid task output'),
|
||||
129: ('task_publish_in_progress', 'Task publish in progress'),
|
||||
130: ('task_not_found', 'task not found'),
|
||||
|
||||
|
||||
# Models
|
||||
200: ('model_error', 'general task error'),
|
||||
201: ('invalid_model_id', 'invalid model id'),
|
||||
202: ('model_not_ready', 'model is not ready'),
|
||||
203: ('model_is_ready', 'model is ready'),
|
||||
204: ('invalid_model_uri', 'invalid model URI'),
|
||||
205: ('model_in_use', 'model is used by tasks'),
|
||||
206: ('model_creating_task_exists', 'task that created this model exists'),
|
||||
|
||||
# Users
|
||||
300: ('invalid_user', 'invalid user'),
|
||||
301: ('invalid_user_id', 'invalid user id'),
|
||||
302: ('user_id_exists', 'user id already exists'),
|
||||
305: ('invalid_preferences_update', 'Malformed key and/or value'),
|
||||
|
||||
# Projects
|
||||
401: ('invalid_project_id', 'invalid project id'),
|
||||
402: ('project_has_tasks', 'project has associated tasks'),
|
||||
403: ('project_not_found', 'project not found'),
|
||||
405: ('project_has_models', 'project has associated models'),
|
||||
|
||||
# Database
|
||||
800: ('data_validation_error', 'data validation error'),
|
||||
801: ('expected_unique_data', 'value combination already exists'),
|
||||
},
|
||||
|
||||
(401, 'unauthorized'): {
|
||||
1: ('not_authorized', 'unauthorized (not authorized for endpoint)'),
|
||||
2: ('entity_not_allowed', 'unauthorized (entity not allowed)'),
|
||||
10: ('bad_auth_type', 'unauthorized (bad authentication header type)'),
|
||||
20: ('no_credentials', 'unauthorized (missing credentials)'),
|
||||
21: ('bad_credentials', 'unauthorized (malformed credentials)'),
|
||||
22: ('invalid_credentials', 'unauthorized (invalid credentials)'),
|
||||
30: ('invalid_token', 'invalid token'),
|
||||
31: ('blocked_token', 'token is blocked')
|
||||
},
|
||||
|
||||
(403, 'forbidden'): {
|
||||
10: ('routing_error', 'forbidden (routing error)'),
|
||||
11: ('missing_routing_header', 'forbidden (missing routing header)'),
|
||||
12: ('blocked_internal_endpoint', 'forbidden (blocked internal endpoint)'),
|
||||
20: ('role_not_allowed', 'forbidden (not allowed for role)'),
|
||||
21: ('no_write_permission', 'forbidden (modification not allowed)'),
|
||||
},
|
||||
|
||||
(500, 'server_error'): {
|
||||
0: ('general_error', 'general server error'),
|
||||
1: ('internal_error', 'internal server error'),
|
||||
2: ('config_error', 'configuration error'),
|
||||
3: ('build_info_error', 'build info unavailable or corrupted'),
|
||||
10: ('transaction_error', 'a transaction call has returned with an error'),
|
||||
# Database-related issues
|
||||
100: ('data_error', 'general data error'),
|
||||
101: ('inconsistent_data', 'inconsistent data encountered in document'),
|
||||
102: ('database_unavailable', 'database is temporarily unavailable'),
|
||||
|
||||
# Index-related issues
|
||||
201: ('missing_index', 'missing internal index'),
|
||||
|
||||
9999: ('not_implemented', 'action is not yet implemented'),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
autogen.generate(pathlib.Path(__file__).parent, _error_codes)
|
||||
21
server/apierrors/apierror.py
Normal file
21
server/apierrors/apierror.py
Normal file
@@ -0,0 +1,21 @@
|
||||
class APIError(Exception):
|
||||
def __init__(self, msg, code=500, subcode=0, **_):
|
||||
super(APIError, self).__init__()
|
||||
self._msg = msg
|
||||
self._code = code
|
||||
self._subcode = subcode
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return self._msg
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
return self._code
|
||||
|
||||
@property
|
||||
def subcode(self):
|
||||
return self._subcode
|
||||
|
||||
def __str__(self):
|
||||
return self.msg
|
||||
4
server/apierrors/autogen/__init__.py
Normal file
4
server/apierrors/autogen/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
def generate(path, error_codes):
|
||||
from .generator import Generator
|
||||
from pathlib import Path
|
||||
Generator(Path(path) / 'errors', format_pep8=False).make_errors(error_codes)
|
||||
6
server/apierrors/autogen/__main__.py
Normal file
6
server/apierrors/autogen/__main__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
if __name__ == '__main__':
|
||||
from pathlib import Path
|
||||
from apierrors import _error_codes
|
||||
from apierrors.autogen import generate
|
||||
|
||||
generate(Path(__file__).parent.parent, _error_codes)
|
||||
85
server/apierrors/autogen/generator.py
Normal file
85
server/apierrors/autogen/generator.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import re
|
||||
import json
|
||||
import jinja2
|
||||
import hashlib
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
env = jinja2.Environment(
|
||||
loader=jinja2.FileSystemLoader(str(Path(__file__).parent)),
|
||||
autoescape=jinja2.select_autoescape(disabled_extensions=('py',), default_for_string=False),
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True)
|
||||
|
||||
|
||||
def env_filter(name=None):
|
||||
return lambda func: env.filters.setdefault(name or func.__name__, func)
|
||||
|
||||
|
||||
@env_filter()
|
||||
def cls_name(name):
|
||||
delims = list(map(re.escape, (' ', '_')))
|
||||
parts = re.split('|'.join(delims), name)
|
||||
return ''.join(x.capitalize() for x in parts)
|
||||
|
||||
|
||||
class Generator(object):
|
||||
_base_class_name = 'BaseError'
|
||||
_base_class_module = 'apierrors.base'
|
||||
|
||||
def __init__(self, path, format_pep8=True, use_md5=True):
|
||||
self._use_md5 = use_md5
|
||||
self._format_pep8 = format_pep8
|
||||
self._path = Path(path)
|
||||
self._path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _make_init_file(self, path):
|
||||
(self._path / path / '__init__.py').write_bytes('')
|
||||
|
||||
def _do_render(self, file, template, context):
|
||||
with file.open('w') as f:
|
||||
result = template.render(
|
||||
base_class_name=self._base_class_name,
|
||||
base_class_module=self._base_class_module,
|
||||
**context)
|
||||
if self._format_pep8:
|
||||
result = autopep8.fix_code(result, options={'aggressive': 1, 'verbose': 0, 'max_line_length': 120})
|
||||
f.write(result)
|
||||
|
||||
def _make_section(self, name, code, subcodes):
|
||||
self._do_render(
|
||||
file=(self._path / name).with_suffix('.py'),
|
||||
template=env.get_template('templates/section.jinja2'),
|
||||
context=dict(code=code, subcodes=list(subcodes.items()),))
|
||||
|
||||
def _make_init(self, sections):
|
||||
self._do_render(
|
||||
file=(self._path / '__init__.py'),
|
||||
template=env.get_template('templates/init.jinja2'),
|
||||
context=dict(sections=sections,))
|
||||
|
||||
def _key_to_str(self, data):
|
||||
if isinstance(data, dict):
|
||||
return {str(k): self._key_to_str(v) for k, v in data.items()}
|
||||
return data
|
||||
|
||||
def _calc_digest(self, data):
|
||||
data = json.dumps(self._key_to_str(data), sort_keys=True)
|
||||
return hashlib.md5(data.encode('utf8')).hexdigest()
|
||||
|
||||
def make_errors(self, errors):
|
||||
digest = None
|
||||
digest_file = self._path / 'digest.md5'
|
||||
if self._use_md5:
|
||||
digest = self._calc_digest(errors)
|
||||
if digest_file.is_file():
|
||||
if digest_file.read_text() == digest:
|
||||
return
|
||||
|
||||
self._make_init(errors)
|
||||
for (code, section_name), subcodes in errors.items():
|
||||
self._make_section(section_name, code, subcodes)
|
||||
|
||||
if self._use_md5:
|
||||
digest_file.write_text(digest)
|
||||
6
server/apierrors/autogen/templates/error.jinja2
Normal file
6
server/apierrors/autogen/templates/error.jinja2
Normal file
@@ -0,0 +1,6 @@
|
||||
{% macro error_class(name, msg, code, subcode=0) %}
|
||||
class {{ name }}({{ base_class_name }}):
|
||||
_default_code = {{ code }}
|
||||
_default_subcode = {{ subcode }}
|
||||
_default_msg = "{{ msg|capitalize }}"
|
||||
{% endmacro -%}
|
||||
14
server/apierrors/autogen/templates/init.jinja2
Normal file
14
server/apierrors/autogen/templates/init.jinja2
Normal file
@@ -0,0 +1,14 @@
|
||||
{% from 'templates/error.jinja2' import error_class with context %}
|
||||
{% if sections %}
|
||||
from {{ base_class_module }} import {{ base_class_name }}
|
||||
{% endif %}
|
||||
|
||||
{% for _, name in sections %}
|
||||
from . import {{ name }}
|
||||
{% endfor %}
|
||||
|
||||
|
||||
{% for code, name in sections %}
|
||||
{{ error_class(name|cls_name, name|replace('_', ' '), code) }}
|
||||
|
||||
{% endfor %}
|
||||
9
server/apierrors/autogen/templates/section.jinja2
Normal file
9
server/apierrors/autogen/templates/section.jinja2
Normal file
@@ -0,0 +1,9 @@
|
||||
{% from 'templates/error.jinja2' import error_class with context %}
|
||||
{% if subcodes %}
|
||||
from {{ base_class_module }} import {{ base_class_name }}
|
||||
{% endif %}
|
||||
{% for subcode, (name, msg) in subcodes %}
|
||||
|
||||
|
||||
{{ error_class(name|cls_name, msg, code, subcode) -}}
|
||||
{% endfor %}
|
||||
38
server/apierrors/base.py
Normal file
38
server/apierrors/base.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import six
|
||||
from boltons.typeutils import classproperty
|
||||
from typing import Tuple
|
||||
|
||||
from .apierror import APIError
|
||||
|
||||
|
||||
class BaseError(APIError):
|
||||
_default_code = 500
|
||||
_default_subcode = 0
|
||||
_default_msg = ""
|
||||
|
||||
def __init__(self, extra_msg=None, replacement_msg=None, **kwargs):
|
||||
message = replacement_msg or self._default_msg
|
||||
if extra_msg:
|
||||
message += f" ({extra_msg})"
|
||||
if kwargs:
|
||||
kwargs_msg = ", ".join(
|
||||
f"{k}={self._format_kwarg(v)}" for k, v in kwargs.items()
|
||||
)
|
||||
message += f": {kwargs_msg}"
|
||||
params = kwargs.copy()
|
||||
params.update(
|
||||
code=self._default_code, subcode=self._default_subcode, msg=message
|
||||
)
|
||||
super(BaseError, self).__init__(**params)
|
||||
|
||||
@staticmethod
|
||||
def _format_kwarg(value):
|
||||
if isinstance(value, (tuple, list)):
|
||||
return f'({", ".join(str(v) for v in value)})'
|
||||
elif isinstance(value, six.string_types):
|
||||
return value
|
||||
return str(value)
|
||||
|
||||
@classproperty
|
||||
def codes(self) -> Tuple[int, int]:
|
||||
return self._default_code, self._default_subcode
|
||||
162
server/apimodels/__init__.py
Normal file
162
server/apimodels/__init__.py
Normal file
@@ -0,0 +1,162 @@
|
||||
from __future__ import absolute_import
|
||||
|
||||
from enum import Enum
|
||||
from typing import Union, Type, Iterable
|
||||
|
||||
import jsonmodels.errors
|
||||
import jsonmodels.validators
|
||||
import six
|
||||
import validators
|
||||
from jsonmodels import fields
|
||||
from jsonmodels.fields import _LazyType
|
||||
from jsonmodels.models import Base as ModelBase
|
||||
from jsonmodels.validators import Enum as EnumValidator
|
||||
from luqum.parser import parser, ParseError
|
||||
|
||||
from apierrors import errors
|
||||
|
||||
|
||||
def make_default(field_cls, default_value):
|
||||
class _FieldWithDefault(field_cls):
|
||||
def get_default_value(self):
|
||||
return default_value
|
||||
|
||||
return _FieldWithDefault
|
||||
|
||||
|
||||
class ListField(fields.ListField):
|
||||
def _cast_value(self, value):
|
||||
try:
|
||||
return super(ListField, self)._cast_value(value)
|
||||
except TypeError:
|
||||
return value
|
||||
|
||||
def validate_single_value(self, item):
|
||||
super(ListField, self).validate_single_value(item)
|
||||
if isinstance(item, ModelBase):
|
||||
item.validate()
|
||||
|
||||
|
||||
class DictField(fields.BaseField):
|
||||
types = (dict,)
|
||||
|
||||
def __init__(self, value_types=None, *args, **kwargs):
|
||||
self.value_types = self._assign_types(value_types)
|
||||
super(DictField, self).__init__(*args, **kwargs)
|
||||
|
||||
def get_default_value(self):
|
||||
default = super(DictField, self).get_default_value()
|
||||
if default is None and not self.required:
|
||||
return {}
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def _assign_types(value_types):
|
||||
if value_types:
|
||||
try:
|
||||
value_types = tuple(value_types)
|
||||
except TypeError:
|
||||
value_types = (value_types,)
|
||||
else:
|
||||
value_types = tuple()
|
||||
|
||||
return tuple(
|
||||
_LazyType(type_)
|
||||
if isinstance(type_, six.string_types)
|
||||
else type_
|
||||
for type_ in value_types
|
||||
)
|
||||
|
||||
def validate(self, value):
|
||||
super(DictField, self).validate(value)
|
||||
|
||||
if not self.value_types:
|
||||
return
|
||||
|
||||
for item in value.values():
|
||||
self.validate_single_value(item)
|
||||
|
||||
def validate_single_value(self, item):
|
||||
if not self.value_types:
|
||||
return
|
||||
|
||||
if not isinstance(item, self.value_types):
|
||||
raise jsonmodels.errors.ValidationError(
|
||||
"All items must be instances "
|
||||
'of "{types}", and not "{type}".'.format(
|
||||
types=", ".join([t.__name__ for t in self.value_types]),
|
||||
type=type(item).__name__,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class IntField(fields.IntField):
|
||||
def parse_value(self, value):
|
||||
try:
|
||||
return super(IntField, self).parse_value(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
|
||||
|
||||
def validate_lucene_query(value):
|
||||
if value == '':
|
||||
return
|
||||
try:
|
||||
parser.parse(value)
|
||||
except ParseError as e:
|
||||
raise errors.bad_request.InvalidLuceneSyntax(error=e)
|
||||
|
||||
|
||||
class LuceneQueryField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super(LuceneQueryField, self).validate(value)
|
||||
if value is None:
|
||||
return
|
||||
validate_lucene_query(value)
|
||||
|
||||
|
||||
class NullableEnumValidator(EnumValidator):
|
||||
"""Validator for enums that allows a None value."""
|
||||
def validate(self, value):
|
||||
if value is not None:
|
||||
super(NullableEnumValidator, self).validate(value)
|
||||
|
||||
|
||||
class EnumField(fields.StringField):
|
||||
def __init__(
|
||||
self,
|
||||
values_or_type: Union[Iterable, Type[Enum]],
|
||||
*args,
|
||||
required=False,
|
||||
default=None,
|
||||
**kwargs
|
||||
):
|
||||
choices = list(map(self.parse_value, values_or_type))
|
||||
validator_cls = EnumValidator if required else NullableEnumValidator
|
||||
kwargs.setdefault("validators", []).append(validator_cls(*choices))
|
||||
super().__init__(
|
||||
default=self.parse_value(default), required=required, *args, **kwargs
|
||||
)
|
||||
|
||||
def parse_value(self, value):
|
||||
if isinstance(value, Enum):
|
||||
return str(value.value)
|
||||
return super().parse_value(value)
|
||||
|
||||
|
||||
class EmailField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if validators.email(value) is not True:
|
||||
raise errors.bad_request.InvalidEmailAddress()
|
||||
|
||||
|
||||
class DomainField(fields.StringField):
|
||||
def validate(self, value):
|
||||
super().validate(value)
|
||||
if value is None:
|
||||
return
|
||||
if validators.domain(value) is not True:
|
||||
raise errors.bad_request.InvalidDomainName()
|
||||
116
server/apimodels/auth.py
Normal file
116
server/apimodels/auth.py
Normal file
@@ -0,0 +1,116 @@
|
||||
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Max, Enum
|
||||
|
||||
from apimodels import ListField, EnumField
|
||||
from config import config
|
||||
from database.model.auth import Role
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class GetTokenRequest(Base):
|
||||
""" User requests a token """
|
||||
|
||||
expiration_sec = IntField(
|
||||
validators=Max(config.get("apiserver.auth.max_expiration_sec")), nullable=True
|
||||
)
|
||||
""" Expiration time for token in seconds. """
|
||||
|
||||
|
||||
class GetTaskTokenRequest(GetTokenRequest):
|
||||
""" User requests a task token """
|
||||
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class GetTokenForUserRequest(GetTokenRequest):
|
||||
""" System requests a token for a user """
|
||||
|
||||
user = StringField(required=True)
|
||||
company = StringField()
|
||||
|
||||
|
||||
class GetTaskTokenForUserRequest(GetTokenForUserRequest):
|
||||
""" System requests a token for a user, for a specific task """
|
||||
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class GetTokenResponse(Base):
|
||||
token = StringField(required=True)
|
||||
|
||||
|
||||
class ValidateTokenRequest(Base):
|
||||
token = StringField(required=True)
|
||||
|
||||
|
||||
class ValidateUserRequest(Base):
|
||||
email = StringField(required=True)
|
||||
|
||||
|
||||
class ValidateResponse(Base):
|
||||
valid = BoolField(required=True)
|
||||
msg = StringField()
|
||||
user = StringField()
|
||||
company = StringField()
|
||||
|
||||
|
||||
class CreateUserRequest(Base):
|
||||
name = StringField(required=True)
|
||||
company = StringField(required=True)
|
||||
role = StringField(
|
||||
validators=Enum(*(set(get_options(Role)))),
|
||||
default=Role.user,
|
||||
)
|
||||
email = StringField(required=True)
|
||||
family_name = StringField()
|
||||
given_name = StringField()
|
||||
avatar = StringField()
|
||||
|
||||
|
||||
class CreateUserResponse(Base):
|
||||
id = StringField(required=True)
|
||||
|
||||
|
||||
class Credentials(Base):
|
||||
access_key = StringField(required=True)
|
||||
secret_key = StringField(required=True)
|
||||
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
secret_key = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
credentials = EmbeddedField(Credentials)
|
||||
|
||||
|
||||
class GetCredentialsResponse(Base):
|
||||
credentials = ListField(CredentialsResponse)
|
||||
|
||||
|
||||
class RevokeCredentialsRequest(Base):
|
||||
access_key = StringField(required=True)
|
||||
|
||||
|
||||
class RevokeCredentialsResponse(Base):
|
||||
revoked = IntField(required=True)
|
||||
|
||||
|
||||
class AddUserRequest(CreateUserRequest):
|
||||
company = StringField()
|
||||
secret_key = StringField()
|
||||
|
||||
|
||||
class AddUserResponse(CreateUserResponse):
|
||||
secret = StringField()
|
||||
|
||||
|
||||
class DeleteUserRequest(Base):
|
||||
user = StringField(required=True)
|
||||
company = StringField()
|
||||
|
||||
|
||||
class EditUserReq(Base):
|
||||
user = StringField(required=True)
|
||||
role = EnumField(Role.get_company_roles())
|
||||
60
server/apimodels/base.py
Normal file
60
server/apimodels/base.py
Normal file
@@ -0,0 +1,60 @@
|
||||
from jsonmodels import models, fields
|
||||
from mongoengine.base import BaseDocument
|
||||
|
||||
from apimodels import DictField
|
||||
|
||||
|
||||
class MongoengineFieldsDict(DictField):
|
||||
"""
|
||||
DictField representing mongoengine field names/value mapping.
|
||||
Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update
|
||||
operators.
|
||||
"""
|
||||
|
||||
mongoengine_update_operators = (
|
||||
'inc',
|
||||
'dec',
|
||||
'push',
|
||||
'push_all',
|
||||
'pop',
|
||||
'pull',
|
||||
'pull_all',
|
||||
'add_to_set',
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_mongo_value(value):
|
||||
if isinstance(value, BaseDocument):
|
||||
return value.to_mongo()
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _normalize_mongo_field_path(cls, path, value):
|
||||
parts = path.split('__')
|
||||
if len(parts) > 1:
|
||||
if parts[0] == 'set':
|
||||
parts = parts[1:]
|
||||
elif parts[0] == 'unset':
|
||||
parts = parts[1:]
|
||||
value = None
|
||||
elif parts[0] in cls.mongoengine_update_operators:
|
||||
return None, None
|
||||
return '.'.join(parts), cls._normalize_mongo_value(value)
|
||||
|
||||
def parse_value(self, value):
|
||||
value = super(MongoengineFieldsDict, self).parse_value(value)
|
||||
return {
|
||||
k: v
|
||||
for k, v in (self._normalize_mongo_field_path(*p) for p in value.items())
|
||||
if k is not None
|
||||
}
|
||||
|
||||
|
||||
class UpdateResponse(models.Base):
|
||||
updated = fields.IntField(required=True)
|
||||
fields = MongoengineFieldsDict()
|
||||
|
||||
|
||||
class PagedRequest(models.Base):
|
||||
page = fields.IntField()
|
||||
page_size = fields.IntField()
|
||||
43
server/apimodels/models.py
Normal file
43
server/apimodels/models.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from jsonmodels import models, fields
|
||||
from six import string_types
|
||||
|
||||
from apimodels import ListField, DictField
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.tasks import PublishResponse as TaskPublishResponse
|
||||
|
||||
|
||||
class CreateModelRequest(models.Base):
|
||||
name = fields.StringField(required=True)
|
||||
uri = fields.StringField(required=True)
|
||||
labels = DictField(value_types=string_types+(int,), required=True)
|
||||
tags = ListField(items_types=string_types)
|
||||
comment = fields.StringField()
|
||||
public = fields.BoolField(default=False)
|
||||
project = fields.StringField()
|
||||
parent = fields.StringField()
|
||||
framework = fields.StringField()
|
||||
design = DictField()
|
||||
ready = fields.BoolField(default=True)
|
||||
ui_cache = DictField()
|
||||
task = fields.StringField()
|
||||
|
||||
|
||||
class CreateModelResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
created = fields.BoolField(required=True)
|
||||
|
||||
|
||||
class PublishModelRequest(models.Base):
|
||||
model = fields.StringField(required=True)
|
||||
force_publish_task = fields.BoolField(default=False)
|
||||
publish_task = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class ModelTaskPublishResponse(models.Base):
|
||||
id = fields.StringField(required=True)
|
||||
data = fields.EmbeddedField(TaskPublishResponse)
|
||||
|
||||
|
||||
class PublishModelResponse(UpdateResponse):
|
||||
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
|
||||
updated = fields.IntField()
|
||||
57
server/apimodels/tasks.py
Normal file
57
server/apimodels/tasks.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import six
|
||||
from jsonmodels import models
|
||||
from jsonmodels.fields import StringField, BoolField, IntField
|
||||
from jsonmodels.validators import Enum
|
||||
|
||||
from apimodels import DictField, ListField
|
||||
from apimodels.base import UpdateResponse
|
||||
from database.model.task.task import TaskType
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class StartedResponse(UpdateResponse):
|
||||
started = IntField()
|
||||
|
||||
|
||||
class ResetResponse(UpdateResponse):
|
||||
deleted_indices = ListField(items_types=six.string_types)
|
||||
frames = DictField()
|
||||
events = DictField()
|
||||
model_deleted = IntField()
|
||||
|
||||
|
||||
class TaskRequest(models.Base):
|
||||
task = StringField(required=True)
|
||||
|
||||
|
||||
class UpdateRequest(TaskRequest):
|
||||
status_reason = StringField(default="")
|
||||
status_message = StringField(default="")
|
||||
force = BoolField(default=False)
|
||||
|
||||
|
||||
class DeleteRequest(UpdateRequest):
|
||||
move_to_trash = BoolField(default=True)
|
||||
|
||||
|
||||
class SetRequirementsRequest(TaskRequest):
|
||||
requirements = DictField(required=True)
|
||||
|
||||
|
||||
class PublishRequest(UpdateRequest):
|
||||
publish_model = BoolField(default=True)
|
||||
|
||||
|
||||
class PublishResponse(UpdateResponse):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(models.Base):
|
||||
"""
|
||||
This is a partial description of task can be updated incrementally
|
||||
"""
|
||||
|
||||
|
||||
class CreateRequest(TaskData):
|
||||
name = StringField(required=True)
|
||||
type = StringField(required=True, validators=Enum(*get_options(TaskType)))
|
||||
17
server/apimodels/users.py
Normal file
17
server/apimodels/users.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from jsonmodels.fields import StringField
|
||||
from jsonmodels.models import Base
|
||||
|
||||
from apimodels import DictField
|
||||
|
||||
|
||||
class CreateRequest(Base):
|
||||
id = StringField(required=True)
|
||||
name = StringField(required=True)
|
||||
company = StringField(required=True)
|
||||
family_name = StringField()
|
||||
given_name = StringField()
|
||||
avatar = StringField()
|
||||
|
||||
|
||||
class SetPreferencesRequest(Base):
|
||||
preferences = DictField(required=True)
|
||||
0
server/bll/__init__.py
Normal file
0
server/bll/__init__.py
Normal file
167
server/bll/auth/__init__.py
Normal file
167
server/bll/auth/__init__.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from datetime import datetime
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
CreateUserRequest,
|
||||
Credentials as CredModel,
|
||||
)
|
||||
from apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from bll.user import UserBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import User, Role, Credentials
|
||||
from database.model.company import Company
|
||||
from service_repo import APICall
|
||||
from service_repo.auth import (
|
||||
Identity,
|
||||
Token,
|
||||
get_client_id,
|
||||
get_secret_key,
|
||||
)
|
||||
|
||||
log = config.logger("AuthBLL")
|
||||
|
||||
|
||||
class AuthBLL:
|
||||
@staticmethod
|
||||
def get_token_for_user(
|
||||
user_id: str,
|
||||
company_id: str = None,
|
||||
expiration_sec: int = None,
|
||||
entities: dict = None,
|
||||
) -> GetTokenResponse:
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=user_id)
|
||||
|
||||
if company_id:
|
||||
query.update(company=company_id)
|
||||
|
||||
user = User.objects(**query).first()
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
company_id = company_id or user.company
|
||||
company = Company.objects(id=company_id).only("id", "name").first()
|
||||
if not company:
|
||||
raise errors.bad_request.InvalidId(
|
||||
"invalid company associated with user", company=company
|
||||
)
|
||||
|
||||
identity = Identity(
|
||||
user=user_id,
|
||||
company=company_id,
|
||||
role=user.role,
|
||||
user_name=user.name,
|
||||
company_name=company.name,
|
||||
)
|
||||
|
||||
token = Token.create_encoded_token(
|
||||
identity=identity,
|
||||
entities=entities,
|
||||
expiration_sec=expiration_sec,
|
||||
)
|
||||
|
||||
return GetTokenResponse(token=token.decode("ascii"))
|
||||
|
||||
@staticmethod
|
||||
def create_user(request: CreateUserRequest, call: APICall = None) -> str:
|
||||
"""
|
||||
Create a new user in both the auth database and the backend database
|
||||
:param request: New user details
|
||||
:param call: API call that triggered this call. If not None, new backend user creation
|
||||
will be performed using a new call in the same transaction.
|
||||
:return: The new user's ID
|
||||
"""
|
||||
with translate_errors_context():
|
||||
if not Company.objects(id=request.company).only("id"):
|
||||
raise errors.bad_request.InvalidId(company=request.company)
|
||||
|
||||
user = User(
|
||||
id=database.utils.id(),
|
||||
name=request.name,
|
||||
company=request.company,
|
||||
role=request.role or Role.user,
|
||||
email=request.email,
|
||||
created=datetime.utcnow(),
|
||||
)
|
||||
|
||||
user.save()
|
||||
|
||||
users_create_request = Users_CreateRequest(
|
||||
id=user.id,
|
||||
name=request.name,
|
||||
company=request.company,
|
||||
family_name=request.family_name,
|
||||
given_name=request.given_name,
|
||||
avatar=request.avatar,
|
||||
)
|
||||
|
||||
try:
|
||||
UserBLL.create(users_create_request)
|
||||
except Exception as ex:
|
||||
user.delete()
|
||||
raise errors.server_error.GeneralError(
|
||||
"failed adding new user", ex=str(ex)
|
||||
)
|
||||
|
||||
return user.id
|
||||
|
||||
@staticmethod
|
||||
def delete_user(
|
||||
identity: Identity, user_id: str, company_id: str = None, call: APICall = None
|
||||
):
|
||||
"""
|
||||
Delete an existing user from both the auth database and the backend database
|
||||
:param identity: Calling user identity
|
||||
:param user_id: ID of user to delete
|
||||
:param company_id: Company of user to delete
|
||||
:param call: API call that triggered this call. If not None, backend user deletion
|
||||
will be performed using a new call in the same transaction.
|
||||
"""
|
||||
if user_id == identity.user:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"cannot delete yourself", user=user_id
|
||||
)
|
||||
|
||||
if not company_id:
|
||||
company_id = identity.company
|
||||
|
||||
if (
|
||||
identity.role not in Role.get_system_roles()
|
||||
and company_id != identity.company
|
||||
):
|
||||
raise errors.bad_request.FieldsNotAllowedForRole(
|
||||
"must be empty or your own company", role=identity.role, field="company"
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=user_id, company=company_id)
|
||||
res = User.objects(**query).delete()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
try:
|
||||
UserBLL.delete(user_id)
|
||||
except Exception as ex:
|
||||
log.error(f"Exception calling users.delete: {str(ex)}")
|
||||
|
||||
@classmethod
|
||||
def create_credentials(
|
||||
cls, user_id: str, company_id: str, role: str = None
|
||||
) -> CredModel:
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=user_id, company=company_id)
|
||||
user = User.objects(**query).first()
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
)
|
||||
user.save()
|
||||
|
||||
return cred
|
||||
1
server/bll/event/__init__.py
Normal file
1
server/bll/event/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .event_bll import EventBLL
|
||||
700
server/bll/event/event_bll.py
Normal file
700
server/bll/event/event_bll.py
Normal file
@@ -0,0 +1,700 @@
|
||||
from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
|
||||
import attr
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import Q
|
||||
from nested_dict import nested_dict
|
||||
|
||||
import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from bll.task import TaskBLL
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.task.task import Task
|
||||
from database.model.task.metrics import MetricEvent
|
||||
from timing_context import TimingContext
|
||||
|
||||
|
||||
class EventType(Enum):
|
||||
metrics_scalar = "training_stats_scalar"
|
||||
metrics_vector = "training_stats_vector"
|
||||
metrics_image = "training_debug_image"
|
||||
metrics_plot = "plot"
|
||||
task_log = "log"
|
||||
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES = set(map(attrgetter("value"), EventType))
|
||||
|
||||
|
||||
@attr.s
|
||||
class TaskEventsResult(object):
|
||||
events = attr.ib(type=list, default=attr.Factory(list))
|
||||
total_events = attr.ib(type=int, default=0)
|
||||
next_scroll_id = attr.ib(type=str, default=None)
|
||||
|
||||
|
||||
class EventBLL(object):
|
||||
id_fields = ["task", "iter", "metric", "variant", "key"]
|
||||
|
||||
def __init__(self, events_es=None):
|
||||
self.es = events_es if events_es is not None else es_factory.connect("events")
|
||||
|
||||
def add_events(self, company_id, events, worker):
|
||||
actions = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_events = nested_dict(
|
||||
3, dict
|
||||
) # task_id -> metric_hash -> variant_hash -> MetricEvent
|
||||
|
||||
for event in events:
|
||||
# remove spaces from event type
|
||||
if "type" not in event:
|
||||
raise errors.BadRequest("Event must have a 'type' field", event=event)
|
||||
|
||||
event_type = event["type"].replace(" ", "_")
|
||||
if event_type not in EVENT_TYPES:
|
||||
raise errors.BadRequest(
|
||||
"Invalid event type {}".format(event_type),
|
||||
event=event,
|
||||
types=EVENT_TYPES,
|
||||
)
|
||||
|
||||
event["type"] = event_type
|
||||
|
||||
# @timestamp indicates the time the event is written, not when it happened
|
||||
event["@timestamp"] = es_factory.get_es_timestamp_str()
|
||||
|
||||
# for backward bomba-tavili-tea
|
||||
if "ts" in event:
|
||||
event["timestamp"] = event.pop("ts")
|
||||
|
||||
# set timestamp and worker if not sent
|
||||
if "timestamp" not in event:
|
||||
event["timestamp"] = es_factory.get_timestamp_millis()
|
||||
|
||||
if "worker" not in event:
|
||||
event["worker"] = worker
|
||||
|
||||
# force iter to be a long int
|
||||
iter = event.get("iter")
|
||||
if iter is not None:
|
||||
iter = int(iter)
|
||||
event["iter"] = iter
|
||||
|
||||
# used to have "values" to indicate array. no need anymore
|
||||
if "values" in event:
|
||||
event["value"] = event["values"]
|
||||
del event["values"]
|
||||
|
||||
index_name = EventBLL.get_index_name(company_id, event_type)
|
||||
es_action = {
|
||||
"_op_type": "index", # overwrite if exists with same ID
|
||||
"_index": index_name,
|
||||
"_type": "event",
|
||||
"_source": event,
|
||||
}
|
||||
|
||||
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
|
||||
if event_type != "log":
|
||||
es_action["_id"] = self._get_event_id(event)
|
||||
else:
|
||||
es_action["_id"] = dbutils.id()
|
||||
|
||||
task_id = event.get("task")
|
||||
if task_id is not None:
|
||||
es_action["_routing"] = task_id
|
||||
task_ids.add(task_id)
|
||||
if iter is not None:
|
||||
task_iteration[task_id] = max(iter, task_iteration[task_id])
|
||||
|
||||
if event_type == EventType.metrics_scalar.value:
|
||||
self._update_last_metric_event_for_task(
|
||||
task_last_events=task_last_events, task_id=task_id, event=event
|
||||
)
|
||||
else:
|
||||
es_action["_routing"] = task_id
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
if task_ids:
|
||||
# verify task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
|
||||
res = Task.objects(id__in=task_ids, company=company_id).only("id")
|
||||
if len(res) < len(task_ids):
|
||||
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
|
||||
raise errors.bad_request.InvalidTaskId(
|
||||
company=company_id, ids=invalid_task_ids
|
||||
)
|
||||
|
||||
errors_in_bulk = []
|
||||
added = 0
|
||||
chunk_size = 500
|
||||
with translate_errors_context(), TimingContext("es", "events_add_batch"):
|
||||
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
|
||||
with closing(
|
||||
helpers.streaming_bulk(
|
||||
self.es,
|
||||
actions,
|
||||
chunk_size=chunk_size,
|
||||
# thread_count=8,
|
||||
refresh=True,
|
||||
)
|
||||
) as it:
|
||||
for success, info in it:
|
||||
if success:
|
||||
added += chunk_size
|
||||
else:
|
||||
errors_in_bulk.append(info)
|
||||
|
||||
last_metrics = {
|
||||
t.id: t.to_proper_dict().get("last_metrics", {})
|
||||
for t in Task.objects(id__in=task_ids, company=company_id).only(
|
||||
"last_metrics"
|
||||
)
|
||||
}
|
||||
|
||||
remaining_tasks = set()
|
||||
now = datetime.utcnow()
|
||||
for task_id in task_ids:
|
||||
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
|
||||
# who's events were successful
|
||||
|
||||
updated = self._update_task(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
now=now,
|
||||
iter=task_iteration.get(task_id),
|
||||
last_events=task_last_events.get(task_id),
|
||||
last_metrics=last_metrics.get(task_id),
|
||||
)
|
||||
|
||||
if not updated:
|
||||
remaining_tasks.add(task_id)
|
||||
continue
|
||||
|
||||
if remaining_tasks:
|
||||
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
|
||||
|
||||
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
|
||||
added = min(added, len(actions))
|
||||
|
||||
return added, errors_in_bulk
|
||||
|
||||
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
|
||||
"""
|
||||
Update task_last_events structure for the provided task_id with the provided event details if this event is more
|
||||
recent than the currently stored event for its metric/variant combination.
|
||||
|
||||
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
|
||||
key conflicts due to invalid characters and/or long field names.
|
||||
"""
|
||||
metric = event.get("metric")
|
||||
variant = event.get("variant")
|
||||
if not (metric and variant):
|
||||
return
|
||||
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
|
||||
last_events = task_last_events[task_id]
|
||||
|
||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
|
||||
def _update_task(
|
||||
self, company_id, task_id, now, iter=None, last_events=None, last_metrics=None
|
||||
):
|
||||
"""
|
||||
Update task information in DB with aggregated results after handling event(s) related to this task.
|
||||
|
||||
This updates the task with the highest iteration value encountered during the last events update, as well
|
||||
as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
|
||||
update time.
|
||||
"""
|
||||
fields = {}
|
||||
|
||||
if iter is not None:
|
||||
fields["last_iteration"] = iter
|
||||
|
||||
if last_events:
|
||||
|
||||
def get_metric_event(ev):
|
||||
me = MetricEvent.from_dict(**ev)
|
||||
if "timestamp" in ev:
|
||||
me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] / 1000)
|
||||
return me
|
||||
|
||||
new_last_metrics = nested_dict(2, MetricEvent)
|
||||
new_last_metrics.update(last_metrics)
|
||||
|
||||
for metric_hash, variants in last_events.items():
|
||||
for variant_hash, event in variants.items():
|
||||
new_last_metrics[metric_hash][variant_hash] = get_metric_event(
|
||||
event
|
||||
)
|
||||
|
||||
fields["last_metrics"] = new_last_metrics.to_dict()
|
||||
|
||||
if not fields:
|
||||
return False
|
||||
|
||||
return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
|
||||
|
||||
def _get_event_id(self, event):
|
||||
id_values = (str(event[field]) for field in self.id_fields if field in event)
|
||||
return "-".join(id_values)
|
||||
|
||||
def scroll_task_events(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
order,
|
||||
event_type=None,
|
||||
batch_size=10000,
|
||||
scroll_id=None,
|
||||
):
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "task_log_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
size = min(batch_size, 10000)
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [], None, 0
|
||||
|
||||
es_req = {
|
||||
"size": size,
|
||||
"sort": {"timestamp": {"order": order}},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
|
||||
|
||||
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
return events, next_scroll_id, total_events
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id,
|
||||
task_id,
|
||||
event_type=None,
|
||||
metric=None,
|
||||
variant=None,
|
||||
last_iter_count=None,
|
||||
sort=None,
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||
if event_type is None:
|
||||
event_type = "*"
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
if not self.es.indices.exists(es_index):
|
||||
return TaskEventsResult()
|
||||
|
||||
query = {"bool": defaultdict(list)}
|
||||
|
||||
if metric or variant:
|
||||
must = query["bool"]["must"]
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
if variant:
|
||||
must.append({"term": {"variant": variant}})
|
||||
|
||||
if last_iter_count is None:
|
||||
must = query["bool"]["must"]
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = query["bool"]["should"]
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
es_index, task_id, event_type, last_iter_count
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
}
|
||||
)
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
|
||||
if sort is None:
|
||||
sort = [{"timestamp": {"order": "asc"}}]
|
||||
|
||||
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
|
||||
|
||||
routing = ",".join(task_ids)
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.search(
|
||||
index=es_index,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
routing=routing,
|
||||
scroll="1h",
|
||||
)
|
||||
|
||||
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
|
||||
next_scroll_id = es_res["_scroll_id"]
|
||||
total_events = es_res["hits"]["total"]
|
||||
|
||||
return TaskEventsResult(
|
||||
events=events, next_scroll_id=next_scroll_id, total_events=total_events
|
||||
)
|
||||
|
||||
def get_metrics_and_variants(self, company_id, task_id, event_type):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, event_type)
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {"field": "metric", "size": 200},
|
||||
"aggs": {"variants": {"terms": {"field": "variant", "size": 200}}},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
|
||||
metrics = {}
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
metric = metric_bucket["key"]
|
||||
metrics[metric] = [
|
||||
b["key"] for b in metric_bucket["variants"].get("buckets")
|
||||
]
|
||||
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id, task_id):
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 1000,
|
||||
"order": {"_term": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": 1000,
|
||||
"order": {"_term": "asc"},
|
||||
},
|
||||
"aggs": {
|
||||
"last_value": {
|
||||
"top_hits": {
|
||||
"docvalue_fields": ["value"],
|
||||
"_source": "value",
|
||||
"size": 1,
|
||||
"sort": [{"iter": {"order": "desc"}}],
|
||||
}
|
||||
},
|
||||
"last_timestamp": {"max": {"field": "@timestamp"}},
|
||||
"last_10_value": {
|
||||
"top_hits": {
|
||||
"docvalue_fields": ["value"],
|
||||
"_source": "value",
|
||||
"size": 10,
|
||||
"sort": [{"iter": {"order": "desc"}}],
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"_source": {"excludes": []},
|
||||
}
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "events_get_metrics_and_variants"
|
||||
):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
|
||||
metrics = []
|
||||
max_timestamp = 0
|
||||
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
|
||||
metric_summary = dict(name=metric_bucket["key"], variants=[])
|
||||
for variant_bucket in metric_bucket["variants"].get("buckets"):
|
||||
variant_name = variant_bucket["key"]
|
||||
last_value = variant_bucket["last_value"]["hits"]["hits"][0]["fields"][
|
||||
"value"
|
||||
][0]
|
||||
last_10_value = variant_bucket["last_10_value"]["hits"]["hits"][0][
|
||||
"fields"
|
||||
]["value"][0]
|
||||
timestamp = variant_bucket["last_timestamp"]["value"]
|
||||
max_timestamp = max(timestamp, max_timestamp)
|
||||
metric_summary["variants"].append(
|
||||
dict(
|
||||
name=variant_name,
|
||||
last_value=last_value,
|
||||
last_10_value=last_10_value,
|
||||
)
|
||||
)
|
||||
metrics.append(metric_summary)
|
||||
return metrics, max_timestamp
|
||||
|
||||
def compare_scalar_metrics_average_per_iter(
|
||||
self, company_id, task_ids, allow_public=True
|
||||
):
|
||||
assert isinstance(task_ids, list)
|
||||
|
||||
task_name_by_id = {}
|
||||
with translate_errors_context():
|
||||
task_objs = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=task_ids),
|
||||
allow_public=allow_public,
|
||||
override_projection=("id", "name"),
|
||||
return_dicts=False,
|
||||
)
|
||||
if len(task_objs) < len(task_ids):
|
||||
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
|
||||
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
|
||||
|
||||
task_name_by_id = {t.id: t.name for t in task_objs}
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"terms": {"task": task_ids}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
|
||||
"aggs": {
|
||||
"metric_and_variant": {
|
||||
"terms": {
|
||||
"script": "doc['metric'].value +'/'+ doc['variant'].value",
|
||||
"size": 10000,
|
||||
},
|
||||
"aggs": {
|
||||
"tasks": {
|
||||
"terms": {"field": "task"},
|
||||
"aggs": {"avg_val": {"avg": {"field": "value"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_comparison"):
|
||||
es_res = self.es.search(index=es_index, body=es_req)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return
|
||||
|
||||
metrics = {}
|
||||
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
|
||||
iteration = int(iter_bucket["key"])
|
||||
for metric_bucket in iter_bucket["metric_and_variant"]["buckets"]:
|
||||
metric_name = metric_bucket["key"]
|
||||
if metrics.get(metric_name) is None:
|
||||
metrics[metric_name] = {}
|
||||
|
||||
metric_data = metrics[metric_name]
|
||||
for task_bucket in metric_bucket["tasks"]["buckets"]:
|
||||
task_id = task_bucket["key"]
|
||||
value = task_bucket["avg_val"]["value"]
|
||||
if metric_data.get(task_id) is None:
|
||||
metric_data[task_id] = {
|
||||
"x": [],
|
||||
"y": [],
|
||||
"name": task_name_by_id[
|
||||
task_id
|
||||
], # todo: lookup task name from id
|
||||
}
|
||||
metric_data[task_id]["x"].append(iteration)
|
||||
metric_data[task_id]["y"].append(value)
|
||||
|
||||
return metrics
|
||||
|
||||
def get_scalar_metrics_average_per_iter(self, company_id, task_id):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return {}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"_source": {"excludes": []},
|
||||
"query": {"term": {"task": task_id}},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
"field": "metric",
|
||||
"size": 200,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": {
|
||||
"variants": {
|
||||
"terms": {
|
||||
"field": "variant",
|
||||
"size": 500,
|
||||
"order": {"_term": "desc"},
|
||||
},
|
||||
"aggs": {"avg_val": {"avg": {"field": "value"}}},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"version": True,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
|
||||
metrics = {}
|
||||
if "aggregations" in es_res:
|
||||
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
|
||||
iteration = int(iter_bucket["key"])
|
||||
for metric_bucket in iter_bucket["metrics"]["buckets"]:
|
||||
metric_name = metric_bucket["key"]
|
||||
if metrics.get(metric_name) is None:
|
||||
metrics[metric_name] = {}
|
||||
|
||||
metric_data = metrics[metric_name]
|
||||
for variant_bucket in metric_bucket["variants"]["buckets"]:
|
||||
variant = variant_bucket["key"]
|
||||
value = variant_bucket["avg_val"]["value"]
|
||||
if metric_data.get(variant) is None:
|
||||
metric_data[variant] = {"x": [], "y": [], "name": variant}
|
||||
metric_data[variant]["x"].append(iteration)
|
||||
metric_data[variant]["y"].append(value)
|
||||
return metrics
|
||||
|
||||
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
|
||||
|
||||
es_index = EventBLL.get_index_name(company_id, "training_stats_vector")
|
||||
if not self.es.indices.exists(es_index):
|
||||
return [], []
|
||||
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"metric": metric}},
|
||||
{"term": {"variant": variant}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"_source": ["iter", "value"],
|
||||
"sort": ["iter"],
|
||||
}
|
||||
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
|
||||
vectors = []
|
||||
iterations = []
|
||||
for hit in es_res["hits"]["hits"]:
|
||||
vectors.append(hit["_source"]["value"])
|
||||
iterations.append(hit["_source"]["iter"])
|
||||
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters(self, es_index, task_id, event_type, iters):
|
||||
if not self.es.indices.exists(es_index):
|
||||
return []
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_term": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
}
|
||||
if event_type:
|
||||
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
|
||||
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
|
||||
|
||||
def delete_task_events(self, company_id, task_id):
|
||||
es_index = EventBLL.get_index_name(company_id, "*")
|
||||
es_req = {"query": {"term": {"task": task_id}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_task_events"):
|
||||
es_res = self.es.delete_by_query(
|
||||
index=es_index, body=es_req, routing=task_id, refresh=True
|
||||
)
|
||||
|
||||
return es_res.get("deleted", 0)
|
||||
|
||||
@staticmethod
|
||||
def get_index_name(company_id, event_type):
|
||||
event_type = event_type.lower().replace(" ", "_")
|
||||
return "events-%s-%s" % (event_type, company_id)
|
||||
7
server/bll/task/__init__.py
Normal file
7
server/bll/task/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .task_bll import TaskBLL
|
||||
from .utils import (
|
||||
ChangeStatusRequest,
|
||||
update_project_time,
|
||||
validate_status_change,
|
||||
split_by,
|
||||
)
|
||||
393
server/bll/task/task_bll.py
Normal file
393
server/bll/task/task_bll.py
Normal file
@@ -0,0 +1,393 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime
|
||||
from typing import Mapping, Collection
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from six import string_types
|
||||
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
from database.fields import OutputDestinationField
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.metrics import MetricEvent
|
||||
from database.model.task.output import Output
|
||||
from database.model.task.task import Task, TaskStatus, TaskStatusMessage, TaskTags
|
||||
from database.utils import get_company_or_none_constraint, id as create_id
|
||||
from service_repo import APICall
|
||||
from timing_context import TimingContext
|
||||
from .utils import ChangeStatusRequest, validate_status_change
|
||||
|
||||
|
||||
class TaskBLL(object):
|
||||
def __init__(self, events_es=None):
|
||||
self.events_es = (
|
||||
events_es if events_es is not None else es_factory.connect("events")
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_task_with_access(
|
||||
task_id, company_id, only=None, allow_public=False, requires_write_access=False
|
||||
) -> Task:
|
||||
"""
|
||||
Gets a task that has a required write access
|
||||
:except errors.bad_request.InvalidTaskId: if the task is not found
|
||||
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
|
||||
"""
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=company_id)
|
||||
with TimingContext("mongo", "task_with_access"):
|
||||
if requires_write_access:
|
||||
task = Task.get_for_writing(_only=only, **query)
|
||||
else:
|
||||
task = Task.get(_only=only, **query, include_public=allow_public)
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def get_by_id(
|
||||
company_id,
|
||||
task_id,
|
||||
required_status=None,
|
||||
required_dataset=None,
|
||||
only_fields=None,
|
||||
):
|
||||
|
||||
with TimingContext("mongo", "task_by_id_all"):
|
||||
qs = Task.objects(id=task_id, company=company_id)
|
||||
if only_fields:
|
||||
qs = (
|
||||
qs.only(only_fields)
|
||||
if isinstance(only_fields, string_types)
|
||||
else qs.only(*only_fields)
|
||||
)
|
||||
qs = qs.only(
|
||||
"status", "input"
|
||||
) # make sure all fields we rely on here are also returned
|
||||
task = qs.first()
|
||||
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
if required_status and not task.status == required_status:
|
||||
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
|
||||
|
||||
if required_dataset and required_dataset not in (
|
||||
entry.dataset for entry in task.input.view.entries
|
||||
):
|
||||
raise errors.bad_request.InvalidId(
|
||||
"not in input view", dataset=required_dataset
|
||||
)
|
||||
|
||||
return task
|
||||
|
||||
@staticmethod
|
||||
def assert_exists(company_id, task_ids, only=None, allow_public=False):
|
||||
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
|
||||
with translate_errors_context(), TimingContext("mongo", "task_exists"):
|
||||
ids = set(task_ids)
|
||||
q = Task.get_many(
|
||||
company=company_id,
|
||||
query=Q(id__in=ids),
|
||||
allow_public=allow_public,
|
||||
return_dicts=False,
|
||||
)
|
||||
if only:
|
||||
res = q.only(*only)
|
||||
count = len(res)
|
||||
else:
|
||||
count = q.count()
|
||||
res = q.first()
|
||||
if count != len(ids):
|
||||
raise errors.bad_request.InvalidTaskId(ids=task_ids)
|
||||
return res
|
||||
|
||||
@staticmethod
|
||||
def create(call: APICall, fields: dict):
|
||||
identity = call.identity
|
||||
now = datetime.utcnow()
|
||||
return Task(
|
||||
id=create_id(),
|
||||
user=identity.user,
|
||||
company=identity.company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def validate_execution_model(task, allow_only_public=False):
|
||||
if not task.execution or not task.execution.model:
|
||||
return
|
||||
|
||||
company = None if allow_only_public else task.company
|
||||
model_id = task.execution.model
|
||||
model = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(company)
|
||||
).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(model=model_id)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def validate(cls, task: Task, force=False):
|
||||
assert isinstance(task, Task)
|
||||
|
||||
if task.parent and not Task.get(
|
||||
company=task.company, id=task.parent, _only=("id",), include_public=True
|
||||
):
|
||||
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
|
||||
|
||||
if task.project:
|
||||
Project.get_for_writing(company=task.company, id=task.project)
|
||||
|
||||
model = cls.validate_execution_model(task)
|
||||
if model and not force and not model.ready:
|
||||
raise errors.bad_request.ModelNotReady(
|
||||
"can't be used in a task", model=model.id
|
||||
)
|
||||
|
||||
if task.execution:
|
||||
if task.execution.parameters:
|
||||
cls._validate_execution_parameters(task.execution.parameters)
|
||||
|
||||
if task.output and task.output.destination:
|
||||
parsed_url = urlparse(task.output.destination)
|
||||
if parsed_url.scheme not in OutputDestinationField.schemes:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"unsupported scheme for output destination",
|
||||
dest=task.output.destination,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _validate_execution_parameters(parameters):
|
||||
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
|
||||
if invalid_keys:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"execution.parameters keys contain whitespace", keys=invalid_keys
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(company_id, project_ids=None):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company=company_id,
|
||||
**({"project": {"$in": project_ids}} if project_ids else {}),
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.objects.aggregate(*pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str], company_id: str, last_update: datetime
|
||||
):
|
||||
return Task.objects(id__in=task_ids, company=company_id).update(
|
||||
upsert=False, last_update=last_update
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def update_statistics(
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
last_update: datetime = None,
|
||||
last_iteration: int = None,
|
||||
last_iteration_max: int = None,
|
||||
last_metrics: Mapping[str, Mapping[str, MetricEvent]] = None,
|
||||
**extra_updates,
|
||||
):
|
||||
"""
|
||||
Update task statistics
|
||||
:param task_id: Task's ID.
|
||||
:param company_id: Task's company ID.
|
||||
:param last_update: Last update time. If not provided, defaults to datetime.utcnow().
|
||||
:param last_iteration: Last reported iteration. Use this to set a value regardless of current
|
||||
task's last iteration value.
|
||||
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
|
||||
if the current task's last iteration value is smaller than the provided value.
|
||||
:param last_metrics: Last reported metrics summary.
|
||||
:param extra_updates: Extra task updates to include in this update call.
|
||||
:return:
|
||||
"""
|
||||
last_update = last_update or datetime.utcnow()
|
||||
|
||||
if last_iteration is not None:
|
||||
extra_updates.update(last_iteration=last_iteration)
|
||||
elif last_iteration_max is not None:
|
||||
extra_updates.update(max__last_iteration=last_iteration_max)
|
||||
|
||||
if last_metrics is not None:
|
||||
extra_updates.update(last_metrics=last_metrics)
|
||||
|
||||
return Task.objects(id=task_id, company=company_id).update(
|
||||
upsert=False, last_update=last_update, **extra_updates
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def model_set_ready(
|
||||
cls,
|
||||
model_id: str,
|
||||
company_id: str,
|
||||
publish_task: bool,
|
||||
force_publish_task: bool = False,
|
||||
) -> tuple:
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
elif model.ready:
|
||||
raise errors.bad_request.ModelIsReady(**query)
|
||||
|
||||
published_task_data = {}
|
||||
if model.task and publish_task:
|
||||
task = (
|
||||
Task.objects(id=model.task, company=company_id)
|
||||
.only("id", "status")
|
||||
.first()
|
||||
)
|
||||
if task and task.status != TaskStatus.published:
|
||||
published_task_data["data"] = cls.publish_task(
|
||||
task_id=model.task,
|
||||
company_id=company_id,
|
||||
publish_model=False,
|
||||
force=force_publish_task,
|
||||
)
|
||||
published_task_data["id"] = model.task
|
||||
|
||||
updated = model.update(upsert=False, ready=True)
|
||||
return updated, published_task_data
|
||||
|
||||
@classmethod
|
||||
def publish_task(
|
||||
cls,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
publish_model: bool,
|
||||
force: bool,
|
||||
status_reason: str = "",
|
||||
status_message: str = "",
|
||||
) -> dict:
|
||||
task = cls.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
)
|
||||
if not force:
|
||||
validate_status_change(task.status, TaskStatus.published)
|
||||
|
||||
previous_task_status = task.status
|
||||
output = task.output or Output()
|
||||
publish_failed = False
|
||||
|
||||
try:
|
||||
# set state to publishing
|
||||
task.status = TaskStatus.publishing
|
||||
task.save()
|
||||
|
||||
# publish task models
|
||||
if task.output.model and publish_model:
|
||||
output_model = (
|
||||
Model.objects(id=task.output.model)
|
||||
.only("id", "task", "ready")
|
||||
.first()
|
||||
)
|
||||
if output_model and not output_model.ready:
|
||||
cls.model_set_ready(
|
||||
model_id=task.output.model,
|
||||
company_id=company_id,
|
||||
publish_task=False,
|
||||
)
|
||||
|
||||
# set task status to published, and update (or set) it's new output (view and models)
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=TaskStatus.published,
|
||||
force=force,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
).execute(published=datetime.utcnow(), output=output)
|
||||
|
||||
except Exception as ex:
|
||||
publish_failed = True
|
||||
raise ex
|
||||
finally:
|
||||
if publish_failed:
|
||||
task.status = previous_task_status
|
||||
task.save()
|
||||
|
||||
@classmethod
|
||||
def stop_task(
|
||||
cls,
|
||||
task_id: str,
|
||||
company_id: str,
|
||||
user_name: str,
|
||||
status_reason: str,
|
||||
force: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Stop a running task. Requires task status 'in_progress' and
|
||||
execution_progress 'running', or force=True. Development task or
|
||||
task that has no associated worker is stopped immediately.
|
||||
For a non-development task with worker only the status message
|
||||
is set to 'stopping' to allow the worker to stop the task and report by itself
|
||||
:return: updated task fields
|
||||
"""
|
||||
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id,
|
||||
company_id=company_id,
|
||||
only=("status", "project", "tags", "last_update"),
|
||||
requires_write_access=True,
|
||||
)
|
||||
|
||||
if TaskTags.development in task.tags:
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
new_status = task.status
|
||||
status_message = TaskStatusMessage.stopping
|
||||
|
||||
return ChangeStatusRequest(
|
||||
task=task,
|
||||
new_status=new_status,
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
force=force,
|
||||
).execute()
|
||||
151
server/bll/task/utils.py
Normal file
151
server/bll/task/utils.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from datetime import datetime
|
||||
from typing import TypeVar, Callable, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
import six
|
||||
|
||||
from apierrors import errors
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from database.utils import get_options
|
||||
from timing_context import TimingContext
|
||||
from utilities.attrs import typed_attrs
|
||||
|
||||
valid_statuses = get_options(TaskStatus)
|
||||
|
||||
|
||||
@typed_attrs
|
||||
class ChangeStatusRequest(object):
|
||||
task = attr.ib(type=Task)
|
||||
new_status = attr.ib(
|
||||
type=six.string_types, validator=attr.validators.in_(valid_statuses)
|
||||
)
|
||||
status_reason = attr.ib(type=six.string_types, default="")
|
||||
status_message = attr.ib(type=six.string_types, default="")
|
||||
force = attr.ib(type=bool, default=False)
|
||||
allow_same_state_transition = attr.ib(type=bool, default=True)
|
||||
|
||||
def execute(self, **kwargs):
|
||||
current_status = self.task.status
|
||||
project_id = self.task.project
|
||||
|
||||
# Verify new status is allowed from current status (will throw exception if not valid)
|
||||
self.validate_transition(current_status)
|
||||
|
||||
control = dict(upsert=False, multi=False, write_concern=None, full_result=False)
|
||||
|
||||
now = datetime.utcnow()
|
||||
fields = dict(
|
||||
status=self.new_status,
|
||||
status_reason=self.status_reason,
|
||||
status_message=self.status_message,
|
||||
status_changed=now,
|
||||
last_update=now,
|
||||
)
|
||||
|
||||
def safe_mongoengine_key(key):
|
||||
return f"__{key}" if key in control else key
|
||||
|
||||
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "task_status"):
|
||||
# atomic change of task status by querying the task with the EXPECTED status before modifying it
|
||||
params = fields.copy()
|
||||
params.update(control)
|
||||
updated = Task.objects(id=self.task.id, status=current_status).update(
|
||||
**params
|
||||
)
|
||||
|
||||
if not updated:
|
||||
# failed to change status (someone else beat us to it?)
|
||||
raise errors.bad_request.FailedChangingTaskStatus(
|
||||
task_id=self.task.id,
|
||||
current_status=current_status,
|
||||
new_status=self.new_status,
|
||||
)
|
||||
|
||||
update_project_time(project_id)
|
||||
return dict(updated=updated, fields=fields)
|
||||
|
||||
def validate_transition(self, current_status):
|
||||
if self.force:
|
||||
return
|
||||
if self.new_status != current_status:
|
||||
validate_status_change(current_status, self.new_status)
|
||||
elif not self.allow_same_state_transition:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
"Task already in requested status",
|
||||
current_status=current_status,
|
||||
new_status=self.new_status,
|
||||
)
|
||||
|
||||
|
||||
def validate_status_change(current_status, new_status):
|
||||
assert current_status in valid_statuses
|
||||
assert new_status in valid_statuses
|
||||
|
||||
allowed_statuses = get_possible_status_changes(current_status)
|
||||
if new_status not in allowed_statuses:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
"Invalid status change",
|
||||
current_status=current_status,
|
||||
new_status=new_status,
|
||||
)
|
||||
|
||||
|
||||
state_machine = {
|
||||
TaskStatus.created: {TaskStatus.in_progress},
|
||||
TaskStatus.in_progress: {TaskStatus.stopped, TaskStatus.failed, TaskStatus.created},
|
||||
TaskStatus.stopped: {
|
||||
TaskStatus.closed,
|
||||
TaskStatus.created,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.in_progress,
|
||||
TaskStatus.published,
|
||||
TaskStatus.publishing,
|
||||
},
|
||||
TaskStatus.closed: {
|
||||
TaskStatus.created,
|
||||
TaskStatus.failed,
|
||||
TaskStatus.published,
|
||||
TaskStatus.publishing,
|
||||
TaskStatus.stopped,
|
||||
},
|
||||
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
|
||||
TaskStatus.publishing: {TaskStatus.published},
|
||||
TaskStatus.published: set(),
|
||||
}
|
||||
|
||||
|
||||
def get_possible_status_changes(current_status):
|
||||
"""
|
||||
:param current_status:
|
||||
:return possible states from current state
|
||||
"""
|
||||
possible = state_machine.get(current_status)
|
||||
assert (
|
||||
possible is not None
|
||||
), f"Current status {current_status} not supported by state machine"
|
||||
return possible
|
||||
|
||||
|
||||
def update_project_time(project_id):
|
||||
if project_id:
|
||||
Project.objects(id=project_id).update(last_update=datetime.utcnow())
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def split_by(
|
||||
condition: Callable[[T], bool], items: Sequence[T]
|
||||
) -> Tuple[Sequence[T], Sequence[T]]:
|
||||
"""
|
||||
split "items" to two lists by "condition"
|
||||
"""
|
||||
applied = zip(map(condition, items), items)
|
||||
return (
|
||||
[item for cond, item in applied if cond],
|
||||
[item for cond, item in applied if not cond],
|
||||
)
|
||||
23
server/bll/user/__init__.py
Normal file
23
server/bll/user/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from apierrors import errors
|
||||
from apimodels.users import CreateRequest
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.user import User
|
||||
|
||||
|
||||
class UserBLL:
|
||||
@staticmethod
|
||||
def create(request: CreateRequest):
|
||||
user_id = request.id
|
||||
with translate_errors_context("creating user"):
|
||||
if user_id and User.objects(id=user_id).only("id"):
|
||||
raise errors.bad_request.UserIdExists(id=user_id)
|
||||
|
||||
user = User(**request.to_struct())
|
||||
user.save(force_insert=True)
|
||||
|
||||
@staticmethod
|
||||
def delete(user_id: str):
|
||||
with translate_errors_context("deleting user"):
|
||||
res = User.objects(id=user_id).delete()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidUserId(id=user_id)
|
||||
8
server/config/__init__.py
Normal file
8
server/config/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
import logging.config
|
||||
from pathlib import Path
|
||||
|
||||
from .basic import BasicConfig
|
||||
|
||||
config = BasicConfig(Path(__file__).with_name("default"))
|
||||
|
||||
logging.config.dictConfig(config.get("logging"))
|
||||
89
server/config/basic.py
Normal file
89
server/config/basic.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
from pyparsing import (
|
||||
ParseFatalException,
|
||||
ParseException,
|
||||
RecursiveGrammarException,
|
||||
ParseSyntaxException,
|
||||
)
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
|
||||
def __init__(self, folder):
|
||||
self.folder = Path(folder)
|
||||
if not self.folder.is_dir():
|
||||
raise ValueError("Invalid configuration folder")
|
||||
|
||||
self.prefix = "trains"
|
||||
|
||||
self._load()
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._config[key]
|
||||
|
||||
def get(self, key, default=NotSet):
|
||||
value = self._config.get(key, default)
|
||||
if value is self.NotSet and not default:
|
||||
raise KeyError(
|
||||
f"Unable to find value for key '{key}' and default value was not provided."
|
||||
)
|
||||
return value
|
||||
|
||||
def logger(self, name):
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _load(self, verbose=True):
|
||||
self._config = self._read_recursive(self.folder, verbose=verbose)
|
||||
|
||||
def _read_recursive(self, conf_root, verbose=True):
|
||||
conf = ConfigTree()
|
||||
|
||||
if not conf_root:
|
||||
return conf
|
||||
|
||||
if not conf_root.is_dir():
|
||||
if verbose:
|
||||
if not conf_root.exists():
|
||||
print(f"No config in {conf_root}")
|
||||
else:
|
||||
print(f"Not a directory: {conf_root}")
|
||||
return conf
|
||||
|
||||
if verbose:
|
||||
print("Loading config from {conf_root}")
|
||||
|
||||
for file in conf_root.rglob("*.conf"):
|
||||
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
|
||||
conf.put(key, self._read_single_file(file, verbose=verbose))
|
||||
|
||||
return conf
|
||||
|
||||
@staticmethod
|
||||
def _read_single_file(file_path, verbose=True):
|
||||
if verbose:
|
||||
print(f"Loading config from file {file_path}")
|
||||
|
||||
try:
|
||||
return ConfigFactory.parse_file(file_path)
|
||||
except ParseSyntaxException as ex:
|
||||
msg = f"Failed parsing {file_path} ({ex.__class__.__name__}): (at char {ex.loc}, line:{ex.lineno}, col:{ex.column})"
|
||||
raise ConfigurationError(msg, file_path=file_path) from ex
|
||||
except (ParseException, ParseFatalException, RecursiveGrammarException) as ex:
|
||||
msg = f"Failed parsing {file_path} ({ex.__class__.__name__}): {ex}"
|
||||
raise ConfigurationError(msg) from ex
|
||||
except Exception as ex:
|
||||
print(f"Failed loading {file_path}: {ex}")
|
||||
raise
|
||||
|
||||
|
||||
class ConfigurationError(Exception):
|
||||
def __init__(self, msg, file_path=None, *args):
|
||||
super(ConfigurationError, self).__init__(msg, *args)
|
||||
self.file_path = file_path
|
||||
51
server/config/default/apiserver.conf
Normal file
51
server/config/default/apiserver.conf
Normal file
@@ -0,0 +1,51 @@
|
||||
{
|
||||
watch: false # Watch for changes (dev only)
|
||||
debug: false # Debug mode
|
||||
pretty_json: false # prettify json response
|
||||
return_stack: true # return stack trace on error
|
||||
log_calls: true # Log API Calls
|
||||
|
||||
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
|
||||
# valid values are:
|
||||
# - an integer number, specifying a status code
|
||||
# - a tuple of (code, subcode or list of subcodes)
|
||||
return_stack_on_code: [
|
||||
[500, 0] # raise on internal server error with no subcode
|
||||
]
|
||||
|
||||
listen {
|
||||
ip : "0.0.0.0"
|
||||
port: 8008
|
||||
}
|
||||
|
||||
version {
|
||||
required: false
|
||||
default: 1.0
|
||||
}
|
||||
|
||||
mongo {
|
||||
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
|
||||
# but not declared in a data model
|
||||
strict: false
|
||||
}
|
||||
|
||||
auth {
|
||||
# verify user tokens
|
||||
verify_user_tokens: false
|
||||
|
||||
# max token expiration timeout in seconds (1 year)
|
||||
max_expiration_sec: 31536000
|
||||
|
||||
# default token expiration timeout in seconds (30 days)
|
||||
default_expiration_sec: 2592000
|
||||
|
||||
# cookie containing auth token, for requests arriving from a web-browser
|
||||
session_auth_cookie_name: "trains_token_basic"
|
||||
}
|
||||
|
||||
cors {
|
||||
origins: "*"
|
||||
}
|
||||
|
||||
default_company: "d1bd92a3b039400cbafc60a7a5b1e52b"
|
||||
}
|
||||
21
server/config/default/hosts.conf
Normal file
21
server/config/default/hosts.conf
Normal file
@@ -0,0 +1,21 @@
|
||||
elastic {
|
||||
events {
|
||||
hosts: [{host: "127.0.0.1", port: 9200}]
|
||||
args {
|
||||
timeout: 60
|
||||
dead_timeout: 10
|
||||
max_retries: 5
|
||||
retry_on_timeout: true
|
||||
}
|
||||
index_version: "1"
|
||||
}
|
||||
}
|
||||
|
||||
mongo {
|
||||
backend {
|
||||
host: "mongodb://127.0.0.1:27017/backend"
|
||||
}
|
||||
auth {
|
||||
host: "mongodb://127.0.0.1:27017/auth"
|
||||
}
|
||||
}
|
||||
49
server/config/default/logging.conf
Normal file
49
server/config/default/logging.conf
Normal file
@@ -0,0 +1,49 @@
|
||||
{
|
||||
version: 1
|
||||
disable_existing_loggers: false
|
||||
formatters: {
|
||||
standard: {
|
||||
format: "[%(asctime)s] [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
|
||||
}
|
||||
}
|
||||
handlers {
|
||||
console {
|
||||
formatter: standard
|
||||
class: "logging.StreamHandler"
|
||||
}
|
||||
text_file: {
|
||||
formatter: standard,
|
||||
backupCount: 3
|
||||
maxBytes: 10240000,
|
||||
class: "logging.handlers.RotatingFileHandler",
|
||||
filename: "/var/log/trains/apiserver.log"
|
||||
}
|
||||
}
|
||||
root {
|
||||
handlers: [console, text_file]
|
||||
level: INFO
|
||||
}
|
||||
loggers {
|
||||
urllib3 {
|
||||
handlers: [console, text_file]
|
||||
level: WARN
|
||||
propagate: false
|
||||
}
|
||||
werkzeug {
|
||||
handlers: [console, text_file]
|
||||
level: WARN
|
||||
propagate: false
|
||||
}
|
||||
elasticsearch {
|
||||
handlers: [console, text_file]
|
||||
level: WARN
|
||||
propagate: false
|
||||
}
|
||||
# disable pep8 auto-gen logging (at least part of it)
|
||||
RefactoringTool {
|
||||
handlers: []
|
||||
level: ERROR
|
||||
propagate: false
|
||||
}
|
||||
}
|
||||
}
|
||||
29
server/config/default/secure.conf
Normal file
29
server/config/default/secure.conf
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
http {
|
||||
session_secret {
|
||||
apiserver: "Gx*gB-L2U8!Naqzd#8=7A4&+=In4H(da424H33ZTDQRGF6=FWw"
|
||||
}
|
||||
}
|
||||
|
||||
auth {
|
||||
# token sign secret
|
||||
token_secret: "7E1ua3xP9GT2(cIQOfhjp+gwN6spBeCAmN-XuugYle00I=Wc+u"
|
||||
}
|
||||
|
||||
credentials {
|
||||
# system credentials as they appear in the auth DB, used for intra-service communications
|
||||
apiserver {
|
||||
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
|
||||
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
|
||||
}
|
||||
webserver {
|
||||
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
|
||||
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
|
||||
}
|
||||
tests {
|
||||
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
|
||||
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
3
server/config/default/services/events.conf
Normal file
3
server/config/default/services/events.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
{
|
||||
es_index_prefix:"events"
|
||||
}
|
||||
58
server/database/__init__.py
Normal file
58
server/database/__init__.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import ValidationError
|
||||
from jsonmodels.fields import StringField
|
||||
from mongoengine import register_connection
|
||||
from mongoengine.connection import get_connection
|
||||
|
||||
from config import config
|
||||
from .defs import Database
|
||||
from .utils import get_items
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
strict = config.get('apiserver.mongo.strict', True)
|
||||
|
||||
_entries = []
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
alias = StringField()
|
||||
|
||||
@property
|
||||
def health_alias(self):
|
||||
return '__health__' + self.alias
|
||||
|
||||
|
||||
def initialize():
|
||||
db_entries = config.get('hosts.mongo', {})
|
||||
missing = []
|
||||
log.info('Initializing database connections')
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
missing.append(key)
|
||||
continue
|
||||
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
|
||||
try:
|
||||
entry.validate()
|
||||
log.info('Registering connection to %(alias)s (%(host)s)' % entry.to_struct())
|
||||
register_connection(alias=alias, host=entry.host)
|
||||
|
||||
_entries.append(entry)
|
||||
except ValidationError as ex:
|
||||
raise Exception('Invalid database entry `%s`: %s' % (key, ex.args[0]))
|
||||
if missing:
|
||||
raise ValueError('Missing database configuration for %s' % ', '.join(missing))
|
||||
|
||||
|
||||
def get_entries():
|
||||
return _entries
|
||||
|
||||
|
||||
def get_aliases():
|
||||
return [entry.alias for entry in get_entries()]
|
||||
|
||||
|
||||
def reconnect():
|
||||
for entry in get_entries():
|
||||
get_connection(entry.alias, reconnect=True)
|
||||
10
server/database/defs.py
Normal file
10
server/database/defs.py
Normal file
@@ -0,0 +1,10 @@
|
||||
|
||||
|
||||
class Database(object):
|
||||
""" Database names for our different DB instances """
|
||||
|
||||
backend = 'backend-db'
|
||||
''' Used for all backend objects (tasks, models etc.) '''
|
||||
|
||||
auth = 'auth-db'
|
||||
''' Used for all authentication and permission objects '''
|
||||
189
server/database/errors.py
Normal file
189
server/database/errors.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import re
|
||||
from contextlib import contextmanager
|
||||
from functools import wraps
|
||||
|
||||
import dpath
|
||||
from dpath.exceptions import InvalidKeyName
|
||||
from elasticsearch import ElasticsearchException
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from jsonmodels.errors import ValidationError as JsonschemaValidationError
|
||||
from mongoengine.errors import (
|
||||
ValidationError,
|
||||
NotUniqueError,
|
||||
FieldDoesNotExist,
|
||||
InvalidDocumentError,
|
||||
LookUpError,
|
||||
InvalidQueryError,
|
||||
)
|
||||
from pymongo.errors import PyMongoError, NotMasterError
|
||||
|
||||
from apierrors import errors
|
||||
|
||||
|
||||
class MakeGetAllQueryError(Exception):
|
||||
def __init__(self, error, field):
|
||||
super(MakeGetAllQueryError, self).__init__(f"{error}: field={field}")
|
||||
self.error = error
|
||||
self.field = field
|
||||
|
||||
|
||||
class ParseCallError(Exception):
|
||||
def __init__(self, msg, **kwargs):
|
||||
super(ParseCallError, self).__init__(msg)
|
||||
self.params = kwargs
|
||||
|
||||
|
||||
def throws_default_error(err_cls):
|
||||
"""
|
||||
Used to make functions (Exception, str) -> Optional[str] searching for specialized error messages raise those
|
||||
messages in ``err_cls``. If the decorated function does not find a suitable error message,
|
||||
the underlying exception is returned.
|
||||
:param err_cls: Error class (generated by apierrors)
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, e, message, **kwargs):
|
||||
extra_info = func(self, e, message, **kwargs)
|
||||
raise err_cls(message, err=e, extra_info=extra_info)
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ElasticErrorsHandler(object):
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def bulk_error(cls, e, _, **__):
|
||||
if not e.errors:
|
||||
return
|
||||
|
||||
# Else try returning a better error string
|
||||
for _, reason in dpath.search(e.errors[0], "*/error/reason", yielded=True):
|
||||
return reason
|
||||
|
||||
|
||||
class MongoEngineErrorsHandler(object):
|
||||
# NotUniqueError
|
||||
__not_unique_regex = re.compile(
|
||||
r"collection:\s(?P<collection>[\w.]+)\sindex:\s(?P<index>\w+)\sdup\skey:\s{(?P<values>[^\}]+)\}"
|
||||
)
|
||||
__not_unique_value_regex = re.compile(r':\s"(?P<value>[^"]+)"')
|
||||
__id_index = "_id_"
|
||||
__index_sep_regex = re.compile(r"_[0-9]+_?")
|
||||
|
||||
# FieldDoesNotExist
|
||||
__not_exist_fields_regex = re.compile(r'"{(?P<fields>.+?)}".+?"(?P<document>.+?)"')
|
||||
__not_exist_field_regex = re.compile(r"'(?P<field>\w+)'")
|
||||
|
||||
@classmethod
|
||||
def validation_error(cls, e: ValidationError, message, **_):
|
||||
# Thrown when a document is validated. Documents are validated by default on save and on update
|
||||
err_dict = e.errors or {e.field_name: e.message}
|
||||
raise errors.bad_request.DataValidationError(message, **err_dict)
|
||||
|
||||
@classmethod
|
||||
def not_unique_error(cls, e, message, **_):
|
||||
# Thrown when a save/update violates a unique index constraint
|
||||
m = cls.__not_unique_regex.search(str(e))
|
||||
if not m:
|
||||
raise errors.bad_request.ExpectedUniqueData(message, err=str(e))
|
||||
values = cls.__not_unique_value_regex.findall(m.group("values"))
|
||||
index = m.group("index")
|
||||
if index == cls.__id_index:
|
||||
fields = "id"
|
||||
else:
|
||||
fields = cls.__index_sep_regex.split(index)[:-1]
|
||||
raise errors.bad_request.ExpectedUniqueData(
|
||||
message, **dict(zip(fields, values))
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def field_does_not_exist(cls, e, message, **kwargs):
|
||||
# Strict mode. Unknown fields encountered in loaded document(s)
|
||||
field_does_not_exist_cls = kwargs.get(
|
||||
"field_does_not_exist_cls", errors.server_error.InconsistentData
|
||||
)
|
||||
m = cls.__not_exist_fields_regex.search(str(e))
|
||||
params = {}
|
||||
if m:
|
||||
params["document"] = m.group("document")
|
||||
fields = cls.__not_exist_field_regex.findall(m.group("fields"))
|
||||
if fields:
|
||||
if len(fields) > 1:
|
||||
params["fields"] = "(%s)" % ", ".join(fields)
|
||||
else:
|
||||
params["field"] = fields[0]
|
||||
raise field_does_not_exist_cls(message, **params)
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.DataError)
|
||||
def invalid_document_error(cls, e, message, **_):
|
||||
# Reverse_delete_rule used in reference field
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def lookup_error(cls, e, message, **_):
|
||||
raise errors.bad_request.InvalidFields(
|
||||
"probably an invalid field name or unsupported nested field",
|
||||
replacement_msg="Lookup error",
|
||||
err=str(e),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.bad_request.InvalidRegexError)
|
||||
def invalid_regex_error(cls, e, _, **__):
|
||||
if e.args and e.args[0] == "unexpected end of regular expression":
|
||||
raise errors.bad_request.InvalidRegexError(e.args[0])
|
||||
|
||||
@classmethod
|
||||
@throws_default_error(errors.server_error.InternalError)
|
||||
def invalid_query_error(cls, e, message, **_):
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def translate_errors_context(message=None, **kwargs):
|
||||
"""
|
||||
A context manager that translates MongoEngine's and Elastic thrown errors into our apierrors classes,
|
||||
with an appropriate message.
|
||||
"""
|
||||
try:
|
||||
if message:
|
||||
message = "while " + message
|
||||
yield True
|
||||
except ValidationError as e:
|
||||
MongoEngineErrorsHandler.validation_error(e, message, **kwargs)
|
||||
except NotUniqueError as e:
|
||||
MongoEngineErrorsHandler.not_unique_error(e, message, **kwargs)
|
||||
except FieldDoesNotExist as e:
|
||||
MongoEngineErrorsHandler.field_does_not_exist(e, message, **kwargs)
|
||||
except InvalidDocumentError as e:
|
||||
MongoEngineErrorsHandler.invalid_document_error(e, message, **kwargs)
|
||||
except LookUpError as e:
|
||||
MongoEngineErrorsHandler.lookup_error(e, message, **kwargs)
|
||||
except re.error as e:
|
||||
MongoEngineErrorsHandler.invalid_regex_error(e, message, **kwargs)
|
||||
except InvalidQueryError as e:
|
||||
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
|
||||
except PyMongoError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except NotMasterError as e:
|
||||
raise errors.server_error.InternalError(message, err=str(e))
|
||||
except MakeGetAllQueryError as e:
|
||||
raise errors.bad_request.ValidationError(e.error, field=e.field)
|
||||
except ParseCallError as e:
|
||||
raise errors.bad_request.FieldsValueError(e.args[0], **e.params)
|
||||
except JsonschemaValidationError as e:
|
||||
if len(e.args) >= 2:
|
||||
raise errors.bad_request.ValidationError(e.args[0], reason=e.args[1])
|
||||
raise errors.bad_request.ValidationError(e.args[0])
|
||||
except BulkIndexError as e:
|
||||
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
|
||||
except ElasticsearchException as e:
|
||||
raise errors.server_error.DataError(e, message, **kwargs)
|
||||
except InvalidKeyName:
|
||||
raise errors.server_error.DataError("invalid empty key encountered in data")
|
||||
except Exception as ex:
|
||||
raise
|
||||
237
server/database/fields.py
Normal file
237
server/database/fields.py
Normal file
@@ -0,0 +1,237 @@
|
||||
import re
|
||||
from sys import maxsize
|
||||
|
||||
import six
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentListField,
|
||||
ListField,
|
||||
FloatField,
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
SortedListField,
|
||||
MapField,
|
||||
DictField,
|
||||
)
|
||||
|
||||
|
||||
class LengthRangeListField(ListField):
|
||||
def __init__(self, field=None, max_length=maxsize, min_length=0, **kwargs):
|
||||
self.__min_length = min_length
|
||||
self.__max_length = max_length
|
||||
super(LengthRangeListField, self).__init__(field, **kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
min, val, max = self.__min_length, len(value), self.__max_length
|
||||
if not min <= val <= max:
|
||||
self.error("Item count %d exceeds range [%d, %d]" % (val, min, max))
|
||||
super(LengthRangeListField, self).validate(value)
|
||||
|
||||
|
||||
class LengthRangeEmbeddedDocumentListField(LengthRangeListField):
|
||||
def __init__(self, field=None, *args, **kwargs):
|
||||
super(LengthRangeEmbeddedDocumentListField, self).__init__(
|
||||
EmbeddedDocumentField(field), *args, **kwargs
|
||||
)
|
||||
|
||||
|
||||
class UniqueEmbeddedDocumentListField(EmbeddedDocumentListField):
|
||||
def __init__(self, document_type, key, **kwargs):
|
||||
"""
|
||||
Create a unique embedded document list field for a document type with a unique comparison key func/property
|
||||
:param document_type: The type of :class:`~mongoengine.EmbeddedDocument` the list will hold.
|
||||
:param key: A callable to extract a key from each item
|
||||
"""
|
||||
if not callable(key):
|
||||
raise KeyError("key must be callable")
|
||||
self.__key = key
|
||||
super(UniqueEmbeddedDocumentListField, self).__init__(document_type)
|
||||
|
||||
def validate(self, value):
|
||||
if len({self.__key(i) for i in value}) != len(value):
|
||||
self.error("Items with duplicate key exist in the list")
|
||||
super(UniqueEmbeddedDocumentListField, self).validate(value)
|
||||
|
||||
|
||||
def object_to_key_value_pairs(obj):
|
||||
if isinstance(obj, dict):
|
||||
return [(key, object_to_key_value_pairs(value)) for key, value in obj.items()]
|
||||
if isinstance(obj, list):
|
||||
return list(map(object_to_key_value_pairs, obj))
|
||||
return obj
|
||||
|
||||
|
||||
class EmbeddedDocumentSortedListField(EmbeddedDocumentListField):
|
||||
"""
|
||||
A sorted list of embedded documents
|
||||
"""
|
||||
|
||||
def to_mongo(self, value, use_db_field=True, fields=None):
|
||||
value = super(EmbeddedDocumentSortedListField, self).to_mongo(
|
||||
value, use_db_field, fields
|
||||
)
|
||||
return sorted(value, key=object_to_key_value_pairs)
|
||||
|
||||
|
||||
class LengthRangeSortedListField(LengthRangeListField, SortedListField):
|
||||
pass
|
||||
|
||||
|
||||
class CustomFloatField(FloatField):
|
||||
def __init__(self, greater_than=None, **kwargs):
|
||||
self.greater_than = greater_than
|
||||
super(CustomFloatField, self).__init__(**kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
super(CustomFloatField, self).validate(value)
|
||||
|
||||
if self.greater_than is not None and value <= self.greater_than:
|
||||
self.error("Float value must be greater than %s" % str(self.greater_than))
|
||||
|
||||
|
||||
# TODO: bucket name should be at most 63 characters....
|
||||
aws_s3_bucket_only_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
)
|
||||
|
||||
aws_s3_url_with_bucket_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
|
||||
)
|
||||
|
||||
non_aws_s3_regex = (
|
||||
r"^s3://"
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
|
||||
)
|
||||
|
||||
google_gs_bucket_only_regex = (
|
||||
r"^gs://"
|
||||
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
|
||||
)
|
||||
|
||||
file_regex = r"^file://"
|
||||
|
||||
generic_url_regex = (
|
||||
r"^%s://" # scheme placeholder
|
||||
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
|
||||
r"localhost|" # localhost...
|
||||
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
|
||||
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
|
||||
r"(?::\d+)?" # optional port
|
||||
)
|
||||
|
||||
path_suffix = r"(?:/?|[/?]\S+)$"
|
||||
file_path_suffix = r"(?:/\S*[^/]+)$"
|
||||
|
||||
|
||||
class _RegexURLField(StringField):
|
||||
_regex = []
|
||||
|
||||
def __init__(self, regex, **kwargs):
|
||||
super(_RegexURLField, self).__init__(**kwargs)
|
||||
regex = regex if isinstance(regex, (tuple, list)) else [regex]
|
||||
self._regex = [
|
||||
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
|
||||
for e in regex
|
||||
]
|
||||
|
||||
def validate(self, value):
|
||||
# Check first if the scheme is valid
|
||||
if not any(regex for regex in self._regex if regex.match(value)):
|
||||
self.error("Invalid URL: {}".format(value))
|
||||
return
|
||||
|
||||
|
||||
class OutputDestinationField(_RegexURLField):
|
||||
""" A field representing task output URL """
|
||||
|
||||
schemes = ["s3", "gs", "file"]
|
||||
_expressions = (
|
||||
aws_s3_bucket_only_regex + path_suffix,
|
||||
aws_s3_url_with_bucket_regex + path_suffix,
|
||||
non_aws_s3_regex + path_suffix,
|
||||
google_gs_bucket_only_regex + path_suffix,
|
||||
file_regex + path_suffix,
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class SupportedURLField(_RegexURLField):
|
||||
""" A field representing a model URL """
|
||||
|
||||
schemes = ["s3", "gs", "file", "http", "https"]
|
||||
|
||||
_expressions = tuple(
|
||||
pattern + file_path_suffix
|
||||
for pattern in (
|
||||
aws_s3_bucket_only_regex,
|
||||
aws_s3_url_with_bucket_regex,
|
||||
non_aws_s3_regex,
|
||||
google_gs_bucket_only_regex,
|
||||
file_regex,
|
||||
(generic_url_regex % "http"),
|
||||
(generic_url_regex % "https"),
|
||||
)
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
|
||||
|
||||
|
||||
class StrippedStringField(StringField):
|
||||
def __init__(
|
||||
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
|
||||
):
|
||||
super(StrippedStringField, self).__init__(
|
||||
regex, max_length, min_length, **kwargs
|
||||
)
|
||||
self._strip_chars = strip_chars
|
||||
|
||||
def __set__(self, instance, value):
|
||||
if value is not None:
|
||||
try:
|
||||
value = value.strip(self._strip_chars)
|
||||
except AttributeError:
|
||||
pass
|
||||
super(StrippedStringField, self).__set__(instance, value)
|
||||
|
||||
def prepare_query_value(self, op, value):
|
||||
if not isinstance(op, six.string_types):
|
||||
return value
|
||||
if value is not None:
|
||||
value = value.strip(self._strip_chars)
|
||||
return super(StrippedStringField, self).prepare_query_value(op, value)
|
||||
|
||||
|
||||
def contains_empty_key(d):
|
||||
"""
|
||||
Helper function to recursively determine if any key in a
|
||||
dictionary is empty (based on mongoengine.fields.key_not_string)
|
||||
"""
|
||||
for k, v in list(d.items()):
|
||||
if not k or (isinstance(v, dict) and contains_empty_key(v)):
|
||||
return True
|
||||
|
||||
|
||||
class SafeMapField(MapField):
|
||||
def validate(self, value):
|
||||
super(SafeMapField, self).validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a MapField")
|
||||
|
||||
|
||||
class SafeDictField(DictField):
|
||||
def validate(self, value):
|
||||
super(SafeDictField, self).validate(value)
|
||||
|
||||
if contains_empty_key(value):
|
||||
self.error("Empty keys are not allowed in a DictField")
|
||||
56
server/database/model/__init__.py
Normal file
56
server/database/model/__init__.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from apierrors import errors
|
||||
from database.model.base import DbModelMixin, ABSTRACT_FLAG
|
||||
from database.model.company import Company
|
||||
from database.model.user import User
|
||||
|
||||
|
||||
class AttributedDocument(DbModelMixin, Document):
|
||||
"""
|
||||
Represents objects which are attributed to a company and a user or to "no one".
|
||||
Company must be required since it can be used as unique field.
|
||||
"""
|
||||
meta = ABSTRACT_FLAG
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
user = StringField(reference_field=User)
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return bool(self.company)
|
||||
|
||||
|
||||
class PrivateDocument(AttributedDocument):
|
||||
"""
|
||||
Represents documents which always belong to a single company
|
||||
"""
|
||||
meta = ABSTRACT_FLAG
|
||||
# can not have an empty string as this is the "public" marker
|
||||
company = StringField(required=True, reference_field=Company, min_length=1)
|
||||
user = StringField(reference_field=User, required=True)
|
||||
|
||||
def is_public(self) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
Validate existence of objects with certain IDs. within company.
|
||||
:param cls: Model class to search in
|
||||
:param company: Company to search in
|
||||
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
|
||||
it will be reported along with the name it was assigned to.
|
||||
:return:
|
||||
"""
|
||||
ids = set(kwargs.values())
|
||||
objs = list(cls.objects(company=company, id__in=ids).only('id'))
|
||||
missing = ids - set(x.id for x in objs)
|
||||
if not missing:
|
||||
return
|
||||
id_to_name = {}
|
||||
for name, obj_id in kwargs.items():
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
'Invalid {} ids'.format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
|
||||
72
server/database/model/auth.py
Normal file
72
server/database/model/auth.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocument,
|
||||
EmbeddedDocumentListField,
|
||||
EmailField,
|
||||
DateTimeField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.model import DbModelMixin
|
||||
from database.model.base import AuthDocument
|
||||
from database.utils import get_options
|
||||
|
||||
|
||||
class Entities(object):
|
||||
company = "company"
|
||||
task = "task"
|
||||
user = "user"
|
||||
model = "model"
|
||||
|
||||
|
||||
class Role(object):
|
||||
system = "system"
|
||||
""" Internal system component """
|
||||
root = "root"
|
||||
""" Root admin (person) """
|
||||
admin = "admin"
|
||||
""" Company administrator """
|
||||
superuser = "superuser"
|
||||
""" Company super user """
|
||||
user = "user"
|
||||
""" Company user """
|
||||
annotator = "annotator"
|
||||
""" Annotator with limited access"""
|
||||
|
||||
@classmethod
|
||||
def get_system_roles(cls) -> set:
|
||||
return {cls.system, cls.root}
|
||||
|
||||
@classmethod
|
||||
def get_company_roles(cls) -> set:
|
||||
return set(get_options(cls)) - cls.get_system_roles()
|
||||
|
||||
|
||||
class Credentials(EmbeddedDocument):
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
|
||||
|
||||
class User(DbModelMixin, AuthDocument):
|
||||
meta = {"db_alias": Database.auth, "strict": strict}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StringField(unique_with="company")
|
||||
|
||||
created = DateTimeField()
|
||||
""" User auth entry creation time """
|
||||
|
||||
validated = DateTimeField()
|
||||
""" Last validation (login) time """
|
||||
|
||||
role = StringField(required=True, choices=get_options(Role), default=Role.user)
|
||||
""" User role """
|
||||
|
||||
company = StringField(required=True)
|
||||
""" Company this user belongs to """
|
||||
|
||||
credentials = EmbeddedDocumentListField(Credentials, default=list)
|
||||
""" Credentials generated for this user """
|
||||
|
||||
email = EmailField(unique=True, required=True)
|
||||
""" Email uniquely identifying the user """
|
||||
529
server/database/model/base.py
Normal file
529
server/database/model/base.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document
|
||||
from six import string_types
|
||||
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from database.errors import MakeGetAllQueryError
|
||||
from database.projection import project_dict, ProjectionHelper
|
||||
from database.props import PropsMixin
|
||||
from database.query import RegexQ, RegexWrapper
|
||||
from database.utils import get_company_or_none_constraint, get_fields_with_attr
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
|
||||
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
|
||||
|
||||
ABSTRACT_FLAG = {"abstract": True}
|
||||
|
||||
|
||||
class AuthDocument(Document):
|
||||
meta = ABSTRACT_FLAG
|
||||
|
||||
|
||||
class ProperDictMixin(object):
|
||||
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
|
||||
return self.properize_dict(
|
||||
self.to_mongo(use_db_field=False).to_dict(),
|
||||
strip_private=strip_private,
|
||||
only=only,
|
||||
extra_dict=extra_dict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def properize_dict(
|
||||
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
|
||||
):
|
||||
res = d
|
||||
if normalize_id and "_id" in res:
|
||||
res["id"] = res.pop("_id")
|
||||
if strip_private:
|
||||
res = {k: v for k, v in res.items() if k[0] != "_"}
|
||||
if only:
|
||||
res = project_dict(res, only)
|
||||
if extra_dict:
|
||||
res.update(extra_dict)
|
||||
return res
|
||||
|
||||
|
||||
class GetMixin(PropsMixin):
|
||||
_text_score = "$text_score"
|
||||
|
||||
_ordering_key = "order_by"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
("_all_", "_and_"): lambda a, b: a & b,
|
||||
}
|
||||
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
|
||||
|
||||
class QueryParameterOptions(object):
|
||||
def __init__(
|
||||
self,
|
||||
pattern_fields=("name",),
|
||||
list_fields=("tags", "id"),
|
||||
datetime_fields=None,
|
||||
fields=None,
|
||||
):
|
||||
"""
|
||||
:param pattern_fields: Fields for which a "string contains" condition should be generated
|
||||
:param list_fields: Fields for which a "list contains" condition should be generated
|
||||
:param datetime_fields: Fields for which datetime condition should be generated (see ACCESS_MODIFIER)
|
||||
:param fields: Fields which which a simple equality condition should be generated (basically filters out all
|
||||
other unsupported query fields)
|
||||
"""
|
||||
self.fields = fields
|
||||
self.datetime_fields = datetime_fields
|
||||
self.list_fields = list_fields
|
||||
self.pattern_fields = pattern_fields
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls, company, id, *, _only=None, include_public=False, **kwargs
|
||||
) -> "GetMixin":
|
||||
q = cls.objects(
|
||||
cls._prepare_perm_query(company, allow_public=include_public)
|
||||
& Q(id=id, **kwargs)
|
||||
)
|
||||
if _only:
|
||||
q = q.only(*_only)
|
||||
return q.first()
|
||||
|
||||
@classmethod
|
||||
def prepare_query(
|
||||
cls,
|
||||
company: str,
|
||||
parameters: dict = None,
|
||||
parameters_options: QueryParameterOptions = None,
|
||||
allow_public=False,
|
||||
):
|
||||
"""
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param company: Company ID (required)
|
||||
:param allow_public: Allow results from public objects
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
||||
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
||||
provided fields match the provided pattern.
|
||||
:return: mongoengine.Q query object
|
||||
"""
|
||||
return cls._prepare_query_no_company(
|
||||
parameters, parameters_options
|
||||
) & cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
|
||||
@classmethod
|
||||
def _prepare_query_no_company(
|
||||
cls, parameters=None, parameters_options=QueryParameterOptions()
|
||||
):
|
||||
"""
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
||||
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
|
||||
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
|
||||
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
|
||||
provided fields match the provided pattern.
|
||||
:return: mongoengine.Q query object
|
||||
"""
|
||||
parameters_options = parameters_options or cls.get_all_query_options
|
||||
dict_query = {}
|
||||
query = RegexQ()
|
||||
if parameters:
|
||||
parameters = parameters.copy()
|
||||
opts = parameters_options
|
||||
for field in opts.pattern_fields:
|
||||
pattern = parameters.pop(field, None)
|
||||
if pattern:
|
||||
dict_query[field] = RegexWrapper(pattern)
|
||||
|
||||
for field in tuple(opts.list_fields or ()):
|
||||
data = parameters.pop(field, None)
|
||||
if data:
|
||||
if not isinstance(data, (list, tuple)):
|
||||
raise MakeGetAllQueryError("expected list", field)
|
||||
exclude = [t for t in data if t.startswith("-")]
|
||||
include = list(set(data).difference(exclude))
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
if include:
|
||||
dict_query[f"{mongoengine_field}__in"] = include
|
||||
if exclude:
|
||||
dict_query[f"{mongoengine_field}__nin"] = [
|
||||
t[1:] for t in exclude
|
||||
]
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
dict_query[field] = data
|
||||
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
if not isinstance(data, list):
|
||||
data = [data]
|
||||
for d in data: # type: str
|
||||
m = ACCESS_REGEX.match(d)
|
||||
if not m:
|
||||
continue
|
||||
try:
|
||||
value = parse_datetime(m.group("value"))
|
||||
prefix = m.group("prefix")
|
||||
modifier = ACCESS_MODIFIER.get(prefix)
|
||||
f = field if not modifier else "__".join((field, modifier))
|
||||
dict_query[f] = value
|
||||
except (ValueError, OverflowError):
|
||||
pass
|
||||
|
||||
for field, value in parameters.items():
|
||||
for keys, func in cls._multi_field_param_prefix.items():
|
||||
if field not in keys:
|
||||
continue
|
||||
try:
|
||||
data = cls.MultiFieldParameters(**value)
|
||||
except Exception:
|
||||
raise MakeGetAllQueryError("incorrect field format", field)
|
||||
if not data.fields:
|
||||
break
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
if allow_public:
|
||||
return get_company_or_none_constraint(company)
|
||||
return Q(company=company)
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
""" Extract a projection list from the provided dictionary. Supports an override projection. """
|
||||
if override_projection is not None:
|
||||
return override_projection
|
||||
if not parameters:
|
||||
return []
|
||||
return parameters.get("projection") or parameters.get("only_fields", [])
|
||||
|
||||
@classmethod
|
||||
def set_default_ordering(cls, parameters, value):
|
||||
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
company,
|
||||
query_dict=None,
|
||||
query_options=None,
|
||||
query=None,
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
requested projection. See get_many() for more info.
|
||||
:param expand_reference_ids: If True, reference fields that contain just an ID string are expanded into
|
||||
a sub-document in the format {_id: <ID>}. Otherwise, field values are left as a string.
|
||||
"""
|
||||
if issubclass(cls, AuthDocument):
|
||||
# Refuse projection (join) for auth documents (auth.User etc.) to avoid inadvertently disclosing
|
||||
# auth-related secrets and prevent security leaks
|
||||
log.error(
|
||||
f"Attempted projection of {cls.__name__} auth document (ignored)",
|
||||
stack_info=True,
|
||||
)
|
||||
return []
|
||||
|
||||
override_projection = cls.get_projection(
|
||||
parameters=query_dict, override_projection=override_projection
|
||||
)
|
||||
|
||||
helper = ProjectionHelper(
|
||||
doc_cls=cls,
|
||||
projection=override_projection,
|
||||
expand_reference_ids=expand_reference_ids,
|
||||
)
|
||||
|
||||
# Make the main query
|
||||
results = cls.get_many(
|
||||
override_projection=helper.doc_projection,
|
||||
company=company,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
return doc_type.get_many_with_join(
|
||||
company=company,
|
||||
override_projection=projection,
|
||||
query=Q(id__in=ids),
|
||||
expand_reference_ids=expand_reference_ids,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
|
||||
return helper.project(results, projection_func)
|
||||
|
||||
@classmethod
|
||||
def get_many(
|
||||
cls,
|
||||
company,
|
||||
parameters: dict = None,
|
||||
query_dict: dict = None,
|
||||
query_options: QueryParameterOptions = None,
|
||||
query: Q = None,
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
(aside from those provided by the parameters):
|
||||
- Ordering: using query field `order_by` which can contain a string or a list of strings corresponding to
|
||||
field names. Using field names not defined in the document will cause an error.
|
||||
- Paging: using query fields page and page_size. page must be larger than or equal to 0, page_size must be
|
||||
larger than 0 and is required when specifying a page.
|
||||
- Text search: using query field `search_text`. If used, text score can be used in the ordering, using the
|
||||
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
|
||||
be raised.
|
||||
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
||||
requested, each contains only the requested projection).
|
||||
If False, a QuerySet object is returned (lazy evaluated)
|
||||
:param company: Company ID (required)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
|
||||
a query. The resulting query is AND'ed with the `query` parameter (if provided).
|
||||
:param query_options: query parameters options (see ParametersOptions)
|
||||
:param query: Optional query object (mongoengine.Q)
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
:param allow_public: If True, objects marked as public (no associated company) are also queried.
|
||||
:return: A list of objects matching the query.
|
||||
"""
|
||||
if query_dict is not None:
|
||||
q = cls.prepare_query(
|
||||
parameters=query_dict,
|
||||
company=company,
|
||||
parameters_options=query_options,
|
||||
allow_public=allow_public,
|
||||
)
|
||||
else:
|
||||
q = cls._prepare_perm_query(company, allow_public=allow_public)
|
||||
_query = (q & query) if query else q
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
return_dicts=return_dicts,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_many_no_company(
|
||||
cls, query, parameters=None, override_projection=None, return_dicts=True
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query.
|
||||
This is a company-less version for internal uses. We assume the caller has either added any necessary
|
||||
constraints to the query or that no constraints are required.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
|
||||
|
||||
:param query: Query object (mongoengine.Q)
|
||||
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
|
||||
requested, each contains only the requested projection).
|
||||
If False, a QuerySet object is returned (lazy evaluated)
|
||||
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
|
||||
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
|
||||
argument
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
if not query:
|
||||
raise ValueError("query or call_data must be provided")
|
||||
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
|
||||
order_by = parameters.get(cls._ordering_key)
|
||||
if order_by:
|
||||
order_by = order_by if isinstance(order_by, list) else [order_by]
|
||||
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
|
||||
|
||||
search_text = parameters.get("search_text")
|
||||
|
||||
only = cls.get_projection(parameters, override_projection)
|
||||
|
||||
if not search_text and order_by and cls._text_score in order_by:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
"text score cannot be used in order_by when search text is not used"
|
||||
)
|
||||
|
||||
qs = cls.objects(query)
|
||||
if search_text:
|
||||
qs = qs.search_text(search_text)
|
||||
if order_by:
|
||||
# add ordering
|
||||
qs = (
|
||||
qs.order_by(order_by)
|
||||
if isinstance(order_by, string_types)
|
||||
else qs.order_by(*order_by)
|
||||
)
|
||||
if only:
|
||||
# add projection
|
||||
qs = qs.only(*only)
|
||||
else:
|
||||
exclude = set(cls.get_exclude_fields()).difference(only)
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
if page is not None and page_size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
|
||||
if return_dicts:
|
||||
return [obj.to_proper_dict(only=only) for obj in qs]
|
||||
return qs
|
||||
|
||||
@classmethod
|
||||
def get_for_writing(
|
||||
cls, *args, _only: Collection[str] = None, **kwargs
|
||||
) -> "GetMixin":
|
||||
if _only and "company" not in _only:
|
||||
_only = list(set(_only) | {"company"})
|
||||
result = cls.get(*args, _only=_only, include_public=True, **kwargs)
|
||||
if result and not result.company:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={(result.id,)}"
|
||||
)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def get_many_for_writing(cls, company, *args, **kwargs):
|
||||
result = cls.get_many(
|
||||
company=company,
|
||||
*args,
|
||||
**dict(return_dicts=False, **kwargs),
|
||||
allow_public=True,
|
||||
)
|
||||
forbidden_objects = {obj.id for obj in result if not obj.company}
|
||||
if forbidden_objects:
|
||||
object_name = cls.__name__.lower()
|
||||
raise errors.forbidden.NoWritePermission(
|
||||
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class UpdateMixin(object):
|
||||
@classmethod
|
||||
def user_set_allowed(cls):
|
||||
res = getattr(cls, "__user_set_allowed_fields", None)
|
||||
if res is None:
|
||||
res = cls.__user_set_allowed_fields = dict(
|
||||
get_fields_with_attr(cls, "user_set_allowed")
|
||||
)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_safe_update_dict(cls, fields):
|
||||
if not fields:
|
||||
return {}
|
||||
valid_fields = cls.user_set_allowed()
|
||||
fields = [(k, v, fields[k]) for k, v in valid_fields.items() if k in fields]
|
||||
update_dict = {
|
||||
field: value
|
||||
for field, allowed, value in fields
|
||||
if allowed is None
|
||||
or (
|
||||
(value in allowed)
|
||||
if not isinstance(value, list)
|
||||
else all(v in allowed for v in value)
|
||||
)
|
||||
}
|
||||
return update_dict
|
||||
|
||||
@classmethod
|
||||
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
|
||||
update_dict = cls.get_safe_update_dict(partial_update_dict)
|
||||
if not update_dict:
|
||||
return 0, {}
|
||||
if injected_update:
|
||||
update_dict.update(injected_update)
|
||||
update_count = cls.objects(id=id, company=company_id).update(
|
||||
upsert=False, **update_dict
|
||||
)
|
||||
return update_count, update_dict
|
||||
|
||||
|
||||
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
|
||||
""" Provide convenience methods for a subclass of mongoengine.Document """
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def validate_id(cls, company, **kwargs):
|
||||
"""
|
||||
Validate existence of objects with certain IDs. within company.
|
||||
:param cls: Model class to search in
|
||||
:param company: Company to search in
|
||||
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
|
||||
it will be reported along with the name it was assigned to.
|
||||
:return:
|
||||
"""
|
||||
ids = set(kwargs.values())
|
||||
objs = list(cls.objects(company=company, id__in=ids).only("id"))
|
||||
missing = ids - set(x.id for x in objs)
|
||||
if not missing:
|
||||
return
|
||||
id_to_name = {}
|
||||
for name, obj_id in kwargs.items():
|
||||
id_to_name.setdefault(obj_id, []).append(name)
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Invalid {} ids".format(cls.__name__.lower()),
|
||||
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
|
||||
)
|
||||
25
server/database/model/company.py
Normal file
25
server/database/model/company.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField
|
||||
from database.model import DbModelMixin
|
||||
|
||||
|
||||
class CompanyDefaults(EmbeddedDocument):
|
||||
cluster = StringField()
|
||||
|
||||
|
||||
class Company(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(unique=True, min_length=3)
|
||||
defaults = EmbeddedDocumentField(CompanyDefaults)
|
||||
|
||||
@classmethod
|
||||
def _prepare_perm_query(cls, company, allow_public=False):
|
||||
""" Override default behavior since a 'company' constraint is not supported for this document... """
|
||||
return Q()
|
||||
56
server/database/model/model.py
Normal file
56
server/database/model/model.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import SupportedURLField, StrippedStringField, SafeDictField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.company import Company
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task
|
||||
from database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
{
|
||||
'name': '%s.model.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$parent',
|
||||
'$task',
|
||||
'$project',
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'parent': 5,
|
||||
'task': 3,
|
||||
'project': 3,
|
||||
}
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field='Model', required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
comment = StringField(user_set_allowed=True)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
uri = SupportedURLField(default='', user_set_allowed=True)
|
||||
framework = StringField()
|
||||
design = SafeDictField()
|
||||
labels = ModelLabels()
|
||||
ready = BooleanField(required=True)
|
||||
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)
|
||||
11
server/database/model/model_labels.py
Normal file
11
server/database/model/model_labels.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from mongoengine import MapField, IntField
|
||||
|
||||
|
||||
class ModelLabels(MapField):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
|
||||
|
||||
def validate(self, value):
|
||||
super(ModelLabels, self).validate(value)
|
||||
if value and (len(set(value.values())) < len(value)):
|
||||
self.error("Same label id appears more than once in model labels")
|
||||
39
server/database/model/project.py
Normal file
39
server/database/model/project.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from mongoengine import StringField, DateTimeField, ListField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import OutputDestinationField, StrippedStringField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.base import GetMixin
|
||||
|
||||
|
||||
class Project(AttributedDocument):
|
||||
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "id")
|
||||
)
|
||||
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
"indexes": [
|
||||
{
|
||||
"name": "%s.project.main_text_index" % Database.backend,
|
||||
"fields": ["$name", "$id", "$description"],
|
||||
"default_language": "english",
|
||||
"weights": {"name": 10, "id": 10, "description": 10},
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True,
|
||||
unique_with=AttributedDocument.company.name,
|
||||
min_length=3,
|
||||
sparse=True,
|
||||
)
|
||||
description = StringField(required=True)
|
||||
created = DateTimeField(required=True)
|
||||
tags = ListField(StringField(required=True), default=list)
|
||||
default_output_destination = OutputDestinationField()
|
||||
last_update = DateTimeField()
|
||||
14
server/database/model/task/metrics.py
Normal file
14
server/database/model/task/metrics.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from mongoengine import EmbeddedDocument, StringField, DateTimeField, LongField, DynamicField
|
||||
|
||||
|
||||
class MetricEvent(EmbeddedDocument):
|
||||
metric = StringField(required=True, )
|
||||
variant = StringField(required=True)
|
||||
type = StringField(required=True)
|
||||
timestamp = DateTimeField(default=0, required=True)
|
||||
iter = LongField()
|
||||
value = DynamicField(required=True)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, **kwargs):
|
||||
return cls(**{k: v for k, v in kwargs.items() if k in cls._fields})
|
||||
16
server/database/model/task/output.py
Normal file
16
server/database/model/task/output.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from mongoengine import EmbeddedDocument, StringField
|
||||
from database.utils import get_options
|
||||
|
||||
from database.fields import OutputDestinationField
|
||||
|
||||
|
||||
class Result(object):
|
||||
success = 'success'
|
||||
failure = 'failure'
|
||||
|
||||
|
||||
class Output(EmbeddedDocument):
|
||||
destination = OutputDestinationField()
|
||||
model = StringField(reference_field='Model')
|
||||
error = StringField(user_set_allowed=True)
|
||||
result = StringField(choices=get_options(Result))
|
||||
132
server/database/model/task/task.py
Normal file
132
server/database/model/task/task.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from enum import Enum
|
||||
|
||||
from mongoengine import (
|
||||
StringField,
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocument,
|
||||
DateTimeField,
|
||||
IntField,
|
||||
ListField,
|
||||
)
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import StrippedStringField, SafeMapField, SafeDictField
|
||||
from database.model import AttributedDocument
|
||||
from database.model.model_labels import ModelLabels
|
||||
from database.model.project import Project
|
||||
from database.utils import get_options
|
||||
from .metrics import MetricEvent
|
||||
from .output import Output
|
||||
|
||||
DEFAULT_LAST_ITERATION = 0
|
||||
|
||||
|
||||
class TaskStatus(object):
|
||||
created = 'created'
|
||||
in_progress = 'in_progress'
|
||||
stopped = 'stopped'
|
||||
publishing = 'publishing'
|
||||
published = 'published'
|
||||
closed = 'closed'
|
||||
failed = 'failed'
|
||||
unknown = 'unknown'
|
||||
|
||||
|
||||
class TaskStatusMessage(object):
|
||||
stopping = 'stopping'
|
||||
|
||||
|
||||
class TaskTags(object):
|
||||
development = 'development'
|
||||
|
||||
|
||||
class Script(EmbeddedDocument):
|
||||
binary = StringField(default='python')
|
||||
repository = StringField(required=True)
|
||||
tag = StringField()
|
||||
branch = StringField()
|
||||
version_num = StringField()
|
||||
entry_point = StringField(required=True)
|
||||
working_dir = StringField()
|
||||
requirements = SafeDictField()
|
||||
|
||||
|
||||
class Execution(EmbeddedDocument):
|
||||
test_split = IntField(default=0)
|
||||
parameters = SafeDictField(default=dict)
|
||||
model = StringField(reference_field='Model')
|
||||
model_desc = SafeMapField(StringField(default=''))
|
||||
model_labels = ModelLabels()
|
||||
framework = StringField()
|
||||
|
||||
queue = StringField()
|
||||
''' Queue ID where task was queued '''
|
||||
|
||||
|
||||
class TaskType(object):
|
||||
training = 'training'
|
||||
testing = 'testing'
|
||||
|
||||
|
||||
class Task(AttributedDocument):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
'indexes': [
|
||||
'created',
|
||||
'started',
|
||||
'completed',
|
||||
{
|
||||
'name': '%s.task.main_text_index' % Database.backend,
|
||||
'fields': [
|
||||
'$name',
|
||||
'$id',
|
||||
'$comment',
|
||||
'$execution.model',
|
||||
'$output.model',
|
||||
'$script.repository',
|
||||
'$script.entry_point',
|
||||
],
|
||||
'default_language': 'english',
|
||||
'weights': {
|
||||
'name': 10,
|
||||
'id': 10,
|
||||
'comment': 10,
|
||||
'execution.model': 2,
|
||||
'output.model': 2,
|
||||
'script.repository': 1,
|
||||
'script.entry_point': 1,
|
||||
},
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(
|
||||
required=True, user_set_allowed=True, sparse=False, min_length=3
|
||||
)
|
||||
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
started = DateTimeField()
|
||||
completed = DateTimeField()
|
||||
published = DateTimeField()
|
||||
parent = StringField()
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
output = EmbeddedDocumentField(Output, default=Output)
|
||||
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
|
||||
tags = ListField(StringField(required=True), user_set_allowed=True)
|
||||
script = EmbeddedDocumentField(Script)
|
||||
last_update = DateTimeField()
|
||||
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
|
||||
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
|
||||
|
||||
|
||||
class TaskVisibility(Enum):
|
||||
active = 'active'
|
||||
archived = 'archived'
|
||||
21
server/database/model/user.py
Normal file
21
server/database/model/user.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from mongoengine import Document, StringField
|
||||
|
||||
from database import Database, strict
|
||||
from database.fields import SafeDictField
|
||||
from database.model import DbModelMixin
|
||||
from database.model.company import Company
|
||||
|
||||
|
||||
class User(DbModelMixin, Document):
|
||||
meta = {
|
||||
'db_alias': Database.backend,
|
||||
'strict': strict,
|
||||
}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
name = StringField(required=True, user_set_allowed=True)
|
||||
family_name = StringField(user_set_allowed=True)
|
||||
given_name = StringField(user_set_allowed=True)
|
||||
avatar = StringField()
|
||||
preferences = SafeDictField(default=dict, exclude_by_default=True)
|
||||
269
server/database/projection.py
Normal file
269
server/database/projection.py
Normal file
@@ -0,0 +1,269 @@
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from itertools import groupby, chain
|
||||
|
||||
import dpath
|
||||
|
||||
from apierrors import errors
|
||||
from database.props import PropsMixin
|
||||
|
||||
|
||||
def project_dict(data, projection, separator='.'):
|
||||
"""
|
||||
Project partial data from a dictionary into a new dictionary
|
||||
:param data: Input dictionary
|
||||
:param projection: List of dictionary paths (each a string with field names separated using a separator)
|
||||
:param separator: Separator (default is '.')
|
||||
:return: A new dictionary containing only the projected parts from the original dictionary
|
||||
"""
|
||||
assert isinstance(data, dict)
|
||||
result = {}
|
||||
|
||||
def copy_path(path_parts, source, destination):
|
||||
src, dst = source, destination
|
||||
try:
|
||||
for depth, path_part in enumerate(path_parts[:-1]):
|
||||
src_part = src[path_part]
|
||||
if isinstance(src_part, dict):
|
||||
src = src_part
|
||||
dst = dst.setdefault(path_part, {})
|
||||
elif isinstance(src_part, (list, tuple)):
|
||||
if path_part not in dst:
|
||||
dst[path_part] = [{} for _ in range(len(src_part))]
|
||||
elif not isinstance(dst[path_part], (list, tuple)):
|
||||
raise TypeError('Incompatible destination type %s for %s (list expected)'
|
||||
% (type(dst), separator.join(path_parts[:depth + 1])))
|
||||
elif not len(dst[path_part]) == len(src_part):
|
||||
raise ValueError('Destination list length differs from source length for %s'
|
||||
% separator.join(path_parts[:depth + 1]))
|
||||
|
||||
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
|
||||
for s, d in zip(src_part, dst[path_part])]
|
||||
|
||||
return destination
|
||||
else:
|
||||
raise TypeError('Unsupported projection type %s for %s'
|
||||
% (type(src), separator.join(path_parts[:depth + 1])))
|
||||
|
||||
last_part = path_parts[-1]
|
||||
dst[last_part] = src[last_part]
|
||||
except KeyError:
|
||||
# Projection field not in source, no biggie.
|
||||
pass
|
||||
return destination
|
||||
|
||||
for projection_path in sorted(projection):
|
||||
copy_path(
|
||||
path_parts=projection_path.split(separator),
|
||||
source=data,
|
||||
destination=result)
|
||||
return result
|
||||
|
||||
|
||||
class ProjectionHelper(object):
|
||||
pool = ThreadPoolExecutor()
|
||||
|
||||
@property
|
||||
def doc_projection(self):
|
||||
return self._doc_projection
|
||||
|
||||
def __init__(self, doc_cls, projection, expand_reference_ids=False):
|
||||
super(ProjectionHelper, self).__init__()
|
||||
self._should_expand_reference_ids = expand_reference_ids
|
||||
self._doc_cls = doc_cls
|
||||
self._doc_projection = None
|
||||
self._ref_projection = None
|
||||
self._parse_projection(projection)
|
||||
|
||||
def _collect_projection_fields(self, doc_cls, projection):
|
||||
"""
|
||||
Collect projection for the given document into immediate document projection and reference documents projection
|
||||
:param doc_cls: Document class
|
||||
:param projection: List of projection fields
|
||||
:return: A tuple of document projection and reference fields information
|
||||
"""
|
||||
doc_projection = set() # Projection fields for this class (used in the main query)
|
||||
ref_projection_info = [] # Projection information for reference fields (used in join queries)
|
||||
for field in projection:
|
||||
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
|
||||
if not field.startswith(ref_field):
|
||||
# Doesn't start with a reference field
|
||||
continue
|
||||
if field == ref_field:
|
||||
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
|
||||
# use '<reference field name>.*')
|
||||
continue
|
||||
subfield = field[len(ref_field):]
|
||||
if not subfield.startswith('.'):
|
||||
# Starts with something that looks like a reference field, but isn't
|
||||
continue
|
||||
|
||||
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
|
||||
break
|
||||
else:
|
||||
# Not a reference field, just add to the top-level projection
|
||||
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
|
||||
orig_field = field
|
||||
if field.endswith('.*'):
|
||||
field = field[:-2]
|
||||
if not field:
|
||||
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
|
||||
doc_projection.add(field)
|
||||
return doc_projection, ref_projection_info
|
||||
|
||||
def _parse_projection(self, projection):
|
||||
"""
|
||||
Prepare the projection data structures for get_many_with_join().
|
||||
:param projection: A list of field names that should be returned by the query. Sub-fields can be specified
|
||||
using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
|
||||
should be returned (only relevant for fields that represent sub-documents or referenced documents)
|
||||
:type projection: list of strings
|
||||
:returns A tuple of (class fields projection, reference fields projection)
|
||||
"""
|
||||
doc_cls = self._doc_cls
|
||||
assert issubclass(doc_cls, PropsMixin)
|
||||
if not projection:
|
||||
return [], {}
|
||||
|
||||
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
|
||||
|
||||
def normalize_cls_projection(cls_, fields):
|
||||
""" Normalize projection for this class and group (expand *, for once) """
|
||||
if '*' in fields:
|
||||
return list(fields.difference('*').union(cls_.get_fields()))
|
||||
return list(fields)
|
||||
|
||||
def compute_ref_cls_projection(cls_, group):
|
||||
""" Compute inner projection for this class and group """
|
||||
subfields = set([x[2] for x in group if x[2]])
|
||||
return normalize_cls_projection(cls_, subfields)
|
||||
|
||||
def sort_key(proj_info):
|
||||
return proj_info[:2]
|
||||
|
||||
# Aggregate by reference field. We'll leave out '*' from the projected items since
|
||||
ref_projection = {
|
||||
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
|
||||
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
|
||||
}
|
||||
|
||||
# Make sure this doesn't contain any reference field we'll join anyway
|
||||
# (i.e. in case only_fields=[project, project.name])
|
||||
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
|
||||
|
||||
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
|
||||
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
|
||||
# won't return some of the data we need.
|
||||
# This way, we make sure to use the most inclusive field that contains all requested subfields.
|
||||
projection_set = set(doc_projection)
|
||||
doc_projection = [
|
||||
field
|
||||
for field in doc_projection
|
||||
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
|
||||
]
|
||||
|
||||
# Make sure we didn't get any invalid projection fields for this class
|
||||
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
|
||||
if invalid_fields:
|
||||
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - use both normal projection fields and top-level reference fields
|
||||
doc_projection = set(doc_projection)
|
||||
for field in set(ref_projection).difference(doc_projection):
|
||||
if any(f for f in doc_projection if field.startswith(f)):
|
||||
continue
|
||||
doc_projection.add(field)
|
||||
doc_projection = list(doc_projection)
|
||||
|
||||
self._doc_projection = doc_projection
|
||||
self._ref_projection = ref_projection
|
||||
|
||||
@staticmethod
|
||||
def _search(doc_cls, obj, path, only_values=True):
|
||||
""" Call dpath.search with yielded=True, collect result values """
|
||||
norm_path = doc_cls.get_dpath_translated_path(path)
|
||||
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
|
||||
|
||||
def project(self, results, projection_func):
|
||||
"""
|
||||
Perform projection on query results, using the provided projection func.
|
||||
:param results: A list of results dictionaries on which projection should be performed
|
||||
:param projection_func: A callable that receives a document type, list of ids and projection and returns query
|
||||
results. This callable is used in order to perform sub-queries during projection
|
||||
:return: Modified results (in-place)
|
||||
"""
|
||||
cls = self._doc_cls
|
||||
ref_projection = self._ref_projection
|
||||
|
||||
if ref_projection:
|
||||
# Join mode - get results for each reference fields projection required (this is the join step)
|
||||
# Note: this is a recursive step, so we support nested reference fields
|
||||
|
||||
def do_projection(item):
|
||||
ref_field_name, data = item
|
||||
res = {}
|
||||
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
|
||||
for res in results))))
|
||||
if ids:
|
||||
doc_type = data['cls']
|
||||
doc_only = list(filter(None, data['only']))
|
||||
doc_only = list({'id'} | set(doc_only)) if doc_only else None
|
||||
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
|
||||
data['res'] = res
|
||||
|
||||
items = list(ref_projection.items())
|
||||
if len(ref_projection) == 1:
|
||||
do_projection(items[0])
|
||||
else:
|
||||
for _ in self.pool.map(do_projection, items):
|
||||
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
|
||||
# will be raised when its value is retrieved from the map() iterator
|
||||
pass
|
||||
|
||||
def do_expand_reference_ids(result, skip_fields=None):
|
||||
ref_fields = cls.get_reference_fields()
|
||||
if skip_fields:
|
||||
ref_fields = set(ref_fields) - set(skip_fields)
|
||||
self._expand_reference_fields(cls, result, ref_fields)
|
||||
|
||||
def merge_projection_result(result):
|
||||
for ref_field_name, data in ref_projection.items():
|
||||
res = data.get('res')
|
||||
if not res:
|
||||
self._expand_reference_fields(cls, result, [ref_field_name])
|
||||
continue
|
||||
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
obj = res.get(value) or {'id': value}
|
||||
dpath.new(result, path, obj, separator='.')
|
||||
|
||||
# any reference field not projected should be expanded
|
||||
do_expand_reference_ids(result, skip_fields=list(ref_projection))
|
||||
|
||||
update_func = merge_projection_result if ref_projection else \
|
||||
do_expand_reference_ids if self._should_expand_reference_ids else None
|
||||
|
||||
if update_func:
|
||||
for result in results:
|
||||
update_func(result)
|
||||
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
def _expand_reference_fields(cls, doc_cls, result, fields):
|
||||
for ref_field_name in fields:
|
||||
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
|
||||
if not ref_ids:
|
||||
continue
|
||||
for path, value in ref_ids:
|
||||
dpath.set(
|
||||
result,
|
||||
path,
|
||||
{'id': value} if value else {},
|
||||
separator='.')
|
||||
|
||||
@classmethod
|
||||
def expand_reference_ids(cls, doc_cls, result):
|
||||
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())
|
||||
142
server/database/props.py
Normal file
142
server/database/props.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from collections import OrderedDict
|
||||
from operator import attrgetter
|
||||
from threading import Lock
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from database.fields import (
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
)
|
||||
from database.utils import get_fields, get_fields_and_attr
|
||||
|
||||
|
||||
class PropsMixin(object):
|
||||
__cached_fields = None
|
||||
__cached_reference_fields = None
|
||||
__cached_exclude_fields = None
|
||||
__cached_fields_with_instance = None
|
||||
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
cls.__cached_fields = get_fields(cls)
|
||||
return cls.__cached_fields
|
||||
|
||||
@classmethod
|
||||
def get_fields_with_instance(cls, doc_cls):
|
||||
if cls.__cached_fields_with_instance is None:
|
||||
cls.__cached_fields_with_instance = {}
|
||||
if doc_cls not in cls.__cached_fields_with_instance:
|
||||
cls.__cached_fields_with_instance[doc_cls] = get_fields(
|
||||
doc_cls, return_instance=True
|
||||
)
|
||||
return cls.__cached_fields_with_instance[doc_cls]
|
||||
|
||||
@staticmethod
|
||||
def _get_fields_with_attr(cls_, attr):
|
||||
""" Get all fields with the specified attribute (supports nested fields) """
|
||||
res = get_fields_and_attr(cls_, attr=attr)
|
||||
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
return cls_.owner_document
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
|
||||
def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
|
||||
for field, embedded_doc_field in get_fields(
|
||||
cls_, of_type=doc_cls, return_instance=True
|
||||
):
|
||||
embedded_doc_cls = embedded_doc_field_getter(
|
||||
embedded_doc_field
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
}
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
|
||||
return fields
|
||||
|
||||
@classmethod
|
||||
def _translate_fields_path(cls, parts):
|
||||
current_cls = cls
|
||||
translated_parts = []
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
(k, v)
|
||||
for k, v in cls.get_fields_with_instance(current_cls)
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
if isinstance(field, EmbeddedDocumentField):
|
||||
current_cls = field.document_type
|
||||
elif isinstance(
|
||||
field,
|
||||
(
|
||||
EmbeddedDocumentListField,
|
||||
LengthRangeEmbeddedDocumentListField,
|
||||
UniqueEmbeddedDocumentListField,
|
||||
EmbeddedDocumentSortedListField,
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
return translated_parts
|
||||
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
with cls.__cached_dpath_computed_fields_lock:
|
||||
parts = path.split(separator)
|
||||
translated = cls._translate_fields_path(parts)
|
||||
result = separator.join(translated)
|
||||
cls.__cached_dpath_computed_fields[path] = result
|
||||
return cls.__cached_dpath_computed_fields[path]
|
||||
63
server/database/query.py
Normal file
63
server/database/query.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import copy
|
||||
import re
|
||||
|
||||
from mongoengine import Q
|
||||
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
|
||||
|
||||
|
||||
class RegexWrapper(object):
|
||||
def __init__(self, pattern, flags=None):
|
||||
super(RegexWrapper, self).__init__()
|
||||
self.pattern = pattern
|
||||
self.flags = flags
|
||||
|
||||
@property
|
||||
def regex(self):
|
||||
return re.compile(self.pattern, self.flags if self.flags is not None else 0)
|
||||
|
||||
|
||||
class RegexMixin(object):
|
||||
|
||||
def to_query(self, document):
|
||||
query = self.accept(SimplificationVisitor())
|
||||
query = query.accept(RegexQueryCompilerVisitor(document))
|
||||
return query
|
||||
|
||||
def _combine(self, other, operation):
|
||||
"""Combine this node with another node into a QCombination
|
||||
object.
|
||||
"""
|
||||
if getattr(other, 'empty', True):
|
||||
return self
|
||||
|
||||
if self.empty:
|
||||
return other
|
||||
|
||||
return RegexQCombination(operation, [self, other])
|
||||
|
||||
|
||||
class RegexQCombination(RegexMixin, QCombination):
|
||||
pass
|
||||
|
||||
|
||||
class RegexQ(RegexMixin, Q):
|
||||
pass
|
||||
|
||||
|
||||
class RegexQueryCompilerVisitor(QueryCompilerVisitor):
|
||||
"""
|
||||
Improved mongoengine complied queries visitor class that supports compiled regex expressions as part of the query.
|
||||
|
||||
We need this class since mongoengine's Q (QNode) class uses copy.deepcopy() as part of the tree simplification
|
||||
stage, which does not support re.compiled objects (since Python 2.5).
|
||||
This class allows users to provide regex strings wrapped in QueryRegex instances, which are lazily evaluated to
|
||||
to re.compile instances just before being visited for compilation (this is done after the simplification stage)
|
||||
"""
|
||||
|
||||
def visit_query(self, query):
|
||||
query = copy.deepcopy(query)
|
||||
query.query = self._transform_query(query.query)
|
||||
return super(RegexQueryCompilerVisitor, self).visit_query(query)
|
||||
|
||||
def _transform_query(self, query):
|
||||
return {k: v.regex if isinstance(v, RegexWrapper) else v for k, v in query.items()}
|
||||
160
server/database/utils.py
Normal file
160
server/database/utils.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import hashlib
|
||||
from inspect import ismethod, getmembers
|
||||
from uuid import uuid4
|
||||
|
||||
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
|
||||
from mongoengine.base import BaseField
|
||||
|
||||
from .errors import translate_errors_context, ParseCallError
|
||||
|
||||
|
||||
def get_fields(cls, of_type=BaseField, return_instance=False):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = []
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.extend([k if not return_instance else (k, v)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, of_type)])
|
||||
return res
|
||||
|
||||
|
||||
def get_fields_and_attr(cls, attr):
|
||||
""" get field names from a class containing mongoengine fields """
|
||||
res = {}
|
||||
for cls_ in reversed(cls.mro()):
|
||||
res.update({k: getattr(v, attr)
|
||||
for k, v in vars(cls_).items()
|
||||
if isinstance(v, BaseField) and hasattr(v, attr)})
|
||||
return res
|
||||
|
||||
|
||||
def _get_field_choices(name, field):
|
||||
field_t = type(field)
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
obj = field.document_type_obj
|
||||
n, choices = _get_field_choices(field.name, obj.field)
|
||||
return '%s__%s' % (name, n), choices
|
||||
elif issubclass(type(field), ListField):
|
||||
return name, field.field.choices
|
||||
return name, field.choices
|
||||
|
||||
|
||||
def get_fields_with_attr(cls, attr, default=False):
|
||||
fields = []
|
||||
for field_name, field in cls._fields.items():
|
||||
if not getattr(field, attr, default):
|
||||
continue
|
||||
field_t = type(field)
|
||||
if issubclass(field_t, EmbeddedDocumentField):
|
||||
fields.extend((('%s__%s' % (field_name, name), choices)
|
||||
for name, choices in get_fields_with_attr(field.document_type, attr, default)))
|
||||
elif issubclass(type(field), ListField):
|
||||
fields.append((field_name, field.field.choices))
|
||||
else:
|
||||
fields.append((field_name, field.choices))
|
||||
return fields
|
||||
|
||||
|
||||
def get_items(cls):
|
||||
""" get key/value items from an enum-like class (members represent enumeration key/value) """
|
||||
|
||||
res = {
|
||||
k: v
|
||||
for k, v in getmembers(cls)
|
||||
if not (k.startswith("_") or ismethod(v))
|
||||
}
|
||||
return res
|
||||
|
||||
|
||||
def get_options(cls):
|
||||
""" get options from an enum-like class (members represent enumeration key/value) """
|
||||
return list(get_items(cls).values())
|
||||
|
||||
|
||||
# return a dictionary of items which:
|
||||
# 1. are in the call_data
|
||||
# 2. are in the fields dictionary, and their value in the call_data matches the type in fields
|
||||
# 3. are in the cls_fields
|
||||
def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
|
||||
if not isinstance(fields, dict):
|
||||
# fields should be key=>type dict
|
||||
fields = {k: None for k in fields}
|
||||
fields = {k: v for k, v in fields.items() if k in cls_fields}
|
||||
res = {}
|
||||
with translate_errors_context('parsing call data'):
|
||||
for field, desc in fields.items():
|
||||
value = call_data.get(field)
|
||||
if value is None:
|
||||
if not discard_none_values and field in call_data:
|
||||
# we'll keep the None value in case the field actually exists in the call data
|
||||
res[field] = None
|
||||
continue
|
||||
if desc:
|
||||
if callable(desc):
|
||||
desc(value)
|
||||
else:
|
||||
if issubclass(desc, (list, tuple, dict)) and not isinstance(value, desc):
|
||||
raise ParseCallError('expecting %s' % desc.__name__, field=field)
|
||||
if issubclass(desc, Document) and not desc.objects(id=value).only('id'):
|
||||
raise ParseCallError('expecting %s id' % desc.__name__, id=value, field=field)
|
||||
res[field] = value
|
||||
return res
|
||||
|
||||
|
||||
def init_cls_from_base(cls, instance):
|
||||
return cls(**{k: v for k, v in instance.to_mongo(use_db_field=False).to_dict().items() if k[0] != '_'})
|
||||
|
||||
|
||||
def get_company_or_none_constraint(company=None):
|
||||
return Q(company__in=(company, None, '')) | Q(company__exists=False)
|
||||
|
||||
|
||||
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
|
||||
"""
|
||||
Creates a query object used for finding a field that doesn't exist, or has None or an empty value.
|
||||
:param field: Field name
|
||||
:param empty_value: The empty value to test for (None means no specific empty value will be used)
|
||||
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
|
||||
the length of the array will be used (len==0 means empty)
|
||||
:return:
|
||||
"""
|
||||
query = (Q(**{f"{field}__exists": False}) |
|
||||
Q(**{f"{field}__in": {empty_value, None}}))
|
||||
if is_list:
|
||||
query |= Q(**{f"{field}__size": 0})
|
||||
return query
|
||||
|
||||
|
||||
def get_subkey(d, key_path, default=None):
|
||||
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
|
||||
the nested dictionary.
|
||||
"""
|
||||
keys = key_path.split('.')
|
||||
for i, key in enumerate(keys):
|
||||
if not isinstance(d, dict):
|
||||
raise KeyError('Expecting a dict (%s)' % ('.'.join(keys[:i]) if i else 'bad input'))
|
||||
d = d.get(key)
|
||||
if key is None:
|
||||
return default
|
||||
return d
|
||||
|
||||
|
||||
def id():
|
||||
return str(uuid4()).replace("-", "")
|
||||
|
||||
|
||||
def hash_field_name(s):
|
||||
""" Hash field name into a unique safe string """
|
||||
return hashlib.md5(s.encode()).hexdigest()
|
||||
|
||||
|
||||
def merge_dicts(*dicts):
|
||||
base = {}
|
||||
for dct in dicts:
|
||||
base.update(dct)
|
||||
return base
|
||||
|
||||
|
||||
def filter_fields(cls, fields):
|
||||
"""From the fields dictionary return only the fields that match cls fields"""
|
||||
return {key: fields[key] for key in fields if key in get_fields(cls)}
|
||||
47
server/elastic/apply_mappings.py
Executable file
47
server/elastic/apply_mappings.py
Executable file
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Apply elasticsearch mappings to given hosts.
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
import requests
|
||||
from pathlib import Path
|
||||
|
||||
HERE = Path(__file__).parent
|
||||
|
||||
|
||||
def apply_mappings_to_host(host: str):
|
||||
def _send_mapping(f):
|
||||
with f.open() as json_data:
|
||||
data = json.load(json_data)
|
||||
es_server = host
|
||||
url = f"{es_server}/_template/{f.stem}"
|
||||
requests.delete(url)
|
||||
r = requests.post(
|
||||
url, headers={"Content-Type": "application/json"}, data=json.dumps(data)
|
||||
)
|
||||
return {"mapping": f.stem, "result": r.text}
|
||||
|
||||
p = HERE / "mappings"
|
||||
return [
|
||||
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
|
||||
]
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
||||
)
|
||||
parser.add_argument("hosts", nargs="+")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
for host in parse_args().hosts:
|
||||
print(">>>>> Applying mapping to " + host)
|
||||
res = apply_mappings_to_host(host)
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
27
server/elastic/mappings/events.json
Normal file
27
server/elastic/mappings/events.json
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
"template": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 5
|
||||
},
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"_source": {
|
||||
"enabled": true
|
||||
},
|
||||
"_routing": {
|
||||
"required": true
|
||||
},
|
||||
"properties": {
|
||||
"@timestamp": { "type": "date" },
|
||||
"task": { "type": "keyword" },
|
||||
"type": { "type": "keyword" },
|
||||
"worker": { "type": "keyword" },
|
||||
"timestamp": { "type": "date" },
|
||||
"iter": { "type": "long" },
|
||||
"metric": { "type": "keyword" },
|
||||
"variant": { "type": "keyword" },
|
||||
"value": { "type": "float" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
server/elastic/mappings/events_log.json
Normal file
12
server/elastic/mappings/events_log.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"template": "events-log-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"msg": { "type":"text", "index": false },
|
||||
"level": { "type":"keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
11
server/elastic/mappings/events_plot.json
Normal file
11
server/elastic/mappings/events_plot.json
Normal file
@@ -0,0 +1,11 @@
|
||||
{
|
||||
"template": "events-plot-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"plot_str": { "type":"text", "index": false }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
12
server/elastic/mappings/events_training_debug_image.json
Normal file
12
server/elastic/mappings/events_training_debug_image.json
Normal file
@@ -0,0 +1,12 @@
|
||||
{
|
||||
"template": "events-training_debug_image-*",
|
||||
"order" : 1,
|
||||
"mappings": {
|
||||
"_default_": {
|
||||
"properties": {
|
||||
"key": { "type": "keyword" },
|
||||
"url": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1
server/elastic/requirements.txt
Normal file
1
server/elastic/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
requests>=2.21.0
|
||||
85
server/es_factory.py
Normal file
85
server/es_factory.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
|
||||
from elasticsearch import Elasticsearch, Transport
|
||||
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
class MissingClusterConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class InvalidClusterConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration does not contain required properties
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def connect(cluster_name):
|
||||
"""
|
||||
Returns the es client for the cluster.
|
||||
Connects to the cluster if did not connect previously
|
||||
:param cluster_name: Dot separated cluster path in the configuration file
|
||||
:return: es client
|
||||
:raises MissingClusterConfiguration: in case no config section is found for the cluster
|
||||
:raises InvalidClusterConfiguration: in case cluster config section misses needed properties
|
||||
"""
|
||||
if cluster_name not in _instances:
|
||||
cluster_config = _get_cluster_config(cluster_name)
|
||||
hosts = cluster_config.get('hosts', None)
|
||||
if not hosts:
|
||||
raise InvalidClusterConfiguration(cluster_name)
|
||||
args = cluster_config.get('args', {})
|
||||
_instances[cluster_name] = Elasticsearch(hosts=hosts, transport_class=Transport, **args)
|
||||
|
||||
return _instances[cluster_name]
|
||||
|
||||
|
||||
def _get_cluster_config(cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
:param cluster_name: Dot separated cluster path in the configuration file
|
||||
:return: config section for the cluster
|
||||
:raises MissingClusterConfiguration: in case no config section is found for the cluster
|
||||
"""
|
||||
cluster_key = '.'.join(('hosts.elastic', cluster_name))
|
||||
cluster_config = config.get(cluster_key, None)
|
||||
if not cluster_config:
|
||||
raise MissingClusterConfiguration(cluster_name)
|
||||
|
||||
return cluster_config
|
||||
|
||||
|
||||
def connect_all():
|
||||
clusters = config.get("hosts.elastic").as_plain_ordered_dict()
|
||||
for name in clusters:
|
||||
connect(name)
|
||||
|
||||
|
||||
def instances():
|
||||
return _instances
|
||||
|
||||
|
||||
def timestamp_str_to_millis(ts_str):
|
||||
epoch = datetime.utcfromtimestamp(0)
|
||||
current_date = datetime.strptime(ts_str, "%Y-%m-%dT%H:%M:%S.%fZ")
|
||||
return int((current_date - epoch).total_seconds() * 1000.0)
|
||||
|
||||
|
||||
def get_timestamp_millis():
|
||||
now = datetime.utcnow()
|
||||
epoch = datetime.utcfromtimestamp(0)
|
||||
return int((now - epoch).total_seconds() * 1000.0)
|
||||
|
||||
|
||||
def get_es_timestamp_str():
|
||||
now = datetime.utcnow()
|
||||
return now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"
|
||||
82
server/init_data.py
Normal file
82
server/init_data.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from datetime import datetime
|
||||
from furl import furl
|
||||
|
||||
from database.model.auth import User, Credentials
|
||||
from config import config
|
||||
from database.model.auth import Role
|
||||
from database.model.company import Company
|
||||
from elastic.apply_mappings import apply_mappings_to_host
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class MissingElasticConfiguration(Exception):
|
||||
"""
|
||||
Exception when cluster configuration is not found in config files
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def init_es_data():
|
||||
hosts_key = "hosts.elastic.events.hosts"
|
||||
hosts_config = config.get(hosts_key, None)
|
||||
if not hosts_config:
|
||||
raise MissingElasticConfiguration(hosts_key)
|
||||
|
||||
for conf in hosts_config:
|
||||
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
|
||||
log.info(f"Applying mappings to host: {host}")
|
||||
res = apply_mappings_to_host(host)
|
||||
log.info(res)
|
||||
|
||||
|
||||
def _ensure_company():
|
||||
company_id = config.get("apiserver.default_company")
|
||||
company = Company.objects(id=company_id).only("id").first()
|
||||
if company:
|
||||
return company_id
|
||||
|
||||
company_name = "trains"
|
||||
log.info(f"Creating company: {company_name}")
|
||||
company = Company(id=company_id, name=company_name)
|
||||
company.save()
|
||||
return company_id
|
||||
|
||||
|
||||
def _ensure_user(user_data, company_id):
|
||||
user = User.objects(
|
||||
credentials__match=Credentials(key=user_data["key"], secret=user_data["secret"])
|
||||
).first()
|
||||
if user:
|
||||
return user.id
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
user = User(
|
||||
id=f"__{user_data['name']}__",
|
||||
name=user_data["name"],
|
||||
company=company_id,
|
||||
role=user_data["role"],
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])],
|
||||
)
|
||||
|
||||
user.save()
|
||||
|
||||
return user.id
|
||||
|
||||
|
||||
def init_mongo_data():
|
||||
company_id = _ensure_company()
|
||||
users = [
|
||||
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"},
|
||||
{"name": "webserver", "role": Role.system, "email": "webserver@example.com"},
|
||||
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
|
||||
]
|
||||
|
||||
for user in users:
|
||||
credentials = config.get(f"secure.credentials.{user['name']}")
|
||||
user["key"] = credentials.user_key
|
||||
user["secret"] = credentials.user_secret
|
||||
_ensure_user(user, company_id)
|
||||
27
server/requirements.txt
Normal file
27
server/requirements.txt
Normal file
@@ -0,0 +1,27 @@
|
||||
six
|
||||
Flask>=0.12.2
|
||||
elasticsearch>=5.0.0,<6.0.0
|
||||
pyhocon>=0.3.35
|
||||
requests>=2.13.0
|
||||
arrow>=0.10.0
|
||||
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
|
||||
Flask-Cors>=3.0.5
|
||||
Flask-Compress>=1.4.0
|
||||
mongoengine==0.16.2
|
||||
jsonmodels>=2.3
|
||||
pyjwt>=1.3.0
|
||||
gunicorn>=19.7.1
|
||||
Jinja2==2.10
|
||||
python-rapidjson>=0.6.3
|
||||
jsonschema>=2.6.0
|
||||
dpath>=1.4.2
|
||||
funcsigs==1.0.2
|
||||
luqum>=0.7.2
|
||||
typing>=3.6.4
|
||||
attrs>=19.1.0
|
||||
nested_dict>=1.61
|
||||
related>=0.7.2
|
||||
validators>=0.12.4
|
||||
fastjsonschema>=2.8
|
||||
boltons>=19.1.0
|
||||
semantic_version>=2.6.0,<3
|
||||
267
server/schema.py
Normal file
267
server/schema.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""
|
||||
Objects representing schema entities
|
||||
"""
|
||||
import json
|
||||
import re
|
||||
from operator import attrgetter
|
||||
from pathlib import Path
|
||||
from typing import Mapping, Sequence
|
||||
|
||||
import attr
|
||||
from boltons.dictutils import subdict
|
||||
from pyhocon import ConfigFactory
|
||||
|
||||
from config import config
|
||||
from service_repo.base import PartialVersion
|
||||
|
||||
HERE = Path(__file__)
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
ALL_ROLES = "*"
|
||||
|
||||
|
||||
class EndpointSchema:
|
||||
REQUEST_KEY = "request"
|
||||
RESPONSE_KEY = "response"
|
||||
BATCH_REQUEST_KEY = "batch_request"
|
||||
DEFINITIONS_KEY = "definitions"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
action_name: str,
|
||||
version: PartialVersion,
|
||||
schema: dict,
|
||||
definitions: dict = None,
|
||||
):
|
||||
"""
|
||||
Class for interacting with the schema of a single endpoint
|
||||
:param service_name: name of containing service
|
||||
:param action_name: name of action
|
||||
:param version: endpoint version
|
||||
:param schema: endpoint schema
|
||||
:param definitions: service definitions
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.action_name = action_name
|
||||
self.full_name = f"{service_name}.{action_name}"
|
||||
self.version = version
|
||||
self.definitions = definitions
|
||||
self.request_schema = None
|
||||
self.batch_request_schema = None
|
||||
if self.REQUEST_KEY in schema:
|
||||
self.request_schema = {
|
||||
**schema[self.REQUEST_KEY],
|
||||
self.DEFINITIONS_KEY: self.definitions,
|
||||
}
|
||||
elif self.BATCH_REQUEST_KEY in schema:
|
||||
self.batch_request_schema = {
|
||||
**schema[self.BATCH_REQUEST_KEY],
|
||||
self.DEFINITIONS_KEY: self.definitions,
|
||||
}
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"endpoint {self.full_name} version {self.version} "
|
||||
"has no request or batch_request schema",
|
||||
schema,
|
||||
)
|
||||
self.response_schema = {
|
||||
**schema[self.RESPONSE_KEY],
|
||||
"definitions": self.definitions,
|
||||
}
|
||||
|
||||
|
||||
class EndpointVersionsGroup:
|
||||
|
||||
endpoints: Sequence[EndpointSchema]
|
||||
allow_roles: Sequence[str]
|
||||
internal: bool
|
||||
authorize: bool
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"{type(self).__name__}<{self.full_name}, "
|
||||
f"versions={tuple(e.version for e in self.endpoints)}>"
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
service_name: str,
|
||||
action_name: str,
|
||||
conf: dict,
|
||||
definitions: dict = None,
|
||||
defaults: dict = None,
|
||||
):
|
||||
"""
|
||||
Represents multiple implementations of a single endpoint, discriminated by API version
|
||||
:param service_name: name of containing service
|
||||
:param action_name: name of action
|
||||
:param conf: mapping between minimum version to endpoint schema
|
||||
:param definitions: service definitions
|
||||
:param defaults: service defaults
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.action_name = action_name
|
||||
self.full_name = f"{service_name}.{action_name}"
|
||||
self.definitions = definitions or {}
|
||||
self.defaults = defaults or {}
|
||||
self.internal = self._pop_attr_with_default(conf, "internal")
|
||||
self.allow_roles = self._pop_attr_with_default(conf, "allow_roles")
|
||||
self.authorize = self._pop_attr_with_default(conf, "authorize")
|
||||
|
||||
def parse_version(version):
|
||||
if not re.match(r"^\d+\.\d+$", version):
|
||||
raise ValueError(
|
||||
f"Encountered unrecognized key {version!r} in {self.service_name}.{self.action_name}"
|
||||
)
|
||||
return PartialVersion(version)
|
||||
|
||||
self.endpoints = sorted(
|
||||
(
|
||||
EndpointSchema(
|
||||
service_name=self.service_name,
|
||||
action_name=self.action_name,
|
||||
version=parse_version(version),
|
||||
schema=endpoint_conf,
|
||||
definitions=self.definitions,
|
||||
)
|
||||
for version, endpoint_conf in conf.items()
|
||||
),
|
||||
key=attrgetter("version"),
|
||||
)
|
||||
|
||||
def allows(self, role):
|
||||
return ALL_ROLES in self.allow_roles or role in self.allow_roles
|
||||
|
||||
def _pop_attr_with_default(self, conf, attr):
|
||||
return conf.pop(attr, self.defaults[attr])
|
||||
|
||||
def get_for_version(self, min_version: PartialVersion):
|
||||
"""
|
||||
Return endpoint schema for version
|
||||
"""
|
||||
if not self.endpoints:
|
||||
raise ValueError(f"endpoint group {self} has no versions")
|
||||
for endpoint in self.endpoints:
|
||||
if min_version <= endpoint.version:
|
||||
return endpoint
|
||||
raise ValueError(
|
||||
f"min_version {min_version} is higher than highest version in group {self}"
|
||||
)
|
||||
|
||||
|
||||
class Service:
|
||||
|
||||
endpoint_groups: Mapping[str, EndpointVersionsGroup]
|
||||
|
||||
def __init__(self, name: str, conf: dict, api_defaults: dict):
|
||||
"""
|
||||
Represents schema of one service
|
||||
:param name: name of service
|
||||
:param conf: service configuration, containing endpoint groups and other details
|
||||
:param api_defaults: API-wide endpoint attributes default values
|
||||
"""
|
||||
self.name = name
|
||||
conf = subdict(conf, drop=("_description", "_references"))
|
||||
self.defaults = {**api_defaults, **conf.pop("_default", {})}
|
||||
self.definitions = conf.pop("_definitions", None)
|
||||
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
|
||||
endpoint_name: EndpointVersionsGroup(
|
||||
service_name=self.name,
|
||||
action_name=endpoint_name,
|
||||
conf=endpoint_conf,
|
||||
defaults=self.defaults,
|
||||
definitions=self.definitions,
|
||||
)
|
||||
for endpoint_name, endpoint_conf in conf.items()
|
||||
}
|
||||
|
||||
|
||||
@attr.s()
|
||||
class SchemaReader:
|
||||
root: Path = attr.ib(default=HERE.parent / "schema/services", converter=Path)
|
||||
cache_path: Path = attr.ib(default=None)
|
||||
|
||||
def __attrs_post_init__(self):
|
||||
if not self.cache_path:
|
||||
self.cache_path = self.root / "_cache.json"
|
||||
|
||||
@staticmethod
|
||||
def mod_time(path):
|
||||
"""
|
||||
return file modification time
|
||||
"""
|
||||
return path.stat().st_mtime
|
||||
|
||||
@staticmethod
|
||||
def read_file(path):
|
||||
return ConfigFactory.parse_file(path).as_plain_ordered_dict()
|
||||
|
||||
def get_schema(self):
|
||||
"""
|
||||
Parse the API schema to schema object.
|
||||
Load from config files and write to cache file if possible.
|
||||
"""
|
||||
services = [
|
||||
service
|
||||
for service in self.root.glob("*.conf")
|
||||
if not service.name.startswith("_")
|
||||
]
|
||||
|
||||
current_services_names = {path.stem for path in services}
|
||||
|
||||
try:
|
||||
if self.mod_time(self.cache_path) >= max(map(self.mod_time, services)):
|
||||
log.info("loading schema from cache")
|
||||
result = json.loads(self.cache_path.read_text())
|
||||
cached_services_names = set(result.pop("services_names", []))
|
||||
if cached_services_names == current_services_names:
|
||||
return Schema(**result)
|
||||
else:
|
||||
log.info(
|
||||
f"found services files changed: "
|
||||
f"added: {list(current_services_names - cached_services_names)}, "
|
||||
f"removed: {list(cached_services_names - current_services_names)}"
|
||||
)
|
||||
except (IOError, KeyError, TypeError, ValueError, AttributeError) as ex:
|
||||
log.warning(f"failed loading cache: {ex}")
|
||||
|
||||
log.info("regenerating schema cache")
|
||||
services = {path.stem: self.read_file(path) for path in services}
|
||||
api_defaults = self.read_file(self.root / "_api_defaults.conf")
|
||||
|
||||
try:
|
||||
self.cache_path.write_text(
|
||||
json.dumps(
|
||||
dict(
|
||||
services_names=list(current_services_names),
|
||||
services=services,
|
||||
api_defaults=api_defaults,
|
||||
)
|
||||
)
|
||||
)
|
||||
except IOError:
|
||||
log.exception(f"failed cache file to {self.cache_path}")
|
||||
|
||||
return Schema(services, api_defaults)
|
||||
|
||||
|
||||
class Schema:
|
||||
services: Mapping[str, Service]
|
||||
|
||||
def __init__(self, services: dict, api_defaults: dict):
|
||||
"""
|
||||
Represents the entire API schema
|
||||
:param services: services schema
|
||||
:param api_defaults: default values of service configuration
|
||||
"""
|
||||
self.api_defaults = api_defaults
|
||||
self.services = {
|
||||
name: Service(name, conf, api_defaults=self.api_defaults)
|
||||
for name, conf in services.items()
|
||||
}
|
||||
|
||||
|
||||
schema = SchemaReader().get_schema()
|
||||
0
server/schema/meta/__init__.py
Normal file
0
server/schema/meta/__init__.py
Normal file
199
server/schema/meta/meta.conf
Normal file
199
server/schema/meta/meta.conf
Normal file
@@ -0,0 +1,199 @@
|
||||
// some definitions
|
||||
definitions {
|
||||
// description of a reference
|
||||
reference {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required: [ "$ref" ]
|
||||
properties {
|
||||
"$ref" { type: string }
|
||||
}
|
||||
}
|
||||
// description of an "additionalProperties" section
|
||||
additional_properties {
|
||||
oneOf: [
|
||||
{ type: object, additionalProperties: true },
|
||||
{ type: boolean }
|
||||
]
|
||||
}
|
||||
// each endpoint is a mapping of versions to actions
|
||||
endpoint {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
properties {
|
||||
|
||||
// whether endpoint is internal
|
||||
internal { type: boolean }
|
||||
// whether endpoint requires authorization
|
||||
authorize { type: boolean }
|
||||
// list of roles allowed to access endpoint
|
||||
allow_roles {
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
patternProperties {
|
||||
"^\d\.\d+$" { "$ref": "#/definitions/action" }
|
||||
}
|
||||
}
|
||||
// an action describes request and response
|
||||
action {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
// must have response
|
||||
required: [ response, description ]
|
||||
// must have either request or batch_request
|
||||
oneOf: [
|
||||
{ required: [ request ] }
|
||||
{ required: [ batch_request ] }
|
||||
]
|
||||
properties {
|
||||
method { const: post }
|
||||
description { type: string }
|
||||
request {
|
||||
oneOf: [
|
||||
{ "$ref": "#/definitions/reference" }
|
||||
{ "$ref": "#/definitions/request" }
|
||||
{ "$ref": "#/definitions/multi_request" }
|
||||
]
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
oneOf: [
|
||||
{ "$ref": "#/definitions/response" }
|
||||
// { "$ref": "#/definitions/reference" }
|
||||
{
|
||||
type: object
|
||||
properties {
|
||||
type { const: string }
|
||||
}
|
||||
additionalProperties: false
|
||||
}
|
||||
]
|
||||
}
|
||||
batch_request {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
// required: [ action, version, description ]
|
||||
required: [ action, version ]
|
||||
properties {
|
||||
action { type: string }
|
||||
version { type: number }
|
||||
// description { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// describes request to server
|
||||
request {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required: [
|
||||
type
|
||||
]
|
||||
properties {
|
||||
// says it's an object
|
||||
type { const: object }
|
||||
// required fields
|
||||
required {
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
// request fields,
|
||||
// an object that can have anything
|
||||
properties {
|
||||
type: object
|
||||
additionalProperties {
|
||||
type: object
|
||||
required: [ description ]
|
||||
properties {
|
||||
description { type: string }
|
||||
}
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
dependencies { type: object }
|
||||
// can have an "additionalProperties" section
|
||||
additionalProperties { "$ref": "#/definitions/additional_properties" }
|
||||
}
|
||||
}
|
||||
multi_request {
|
||||
type: object
|
||||
required: [ type ]
|
||||
additionalProperties: false
|
||||
oneOf: [
|
||||
{
|
||||
required: [ anyOf ]
|
||||
}
|
||||
{
|
||||
required: [ oneOf ]
|
||||
}
|
||||
]
|
||||
properties {
|
||||
type { const: object }
|
||||
anyOf {
|
||||
type: array
|
||||
items { "$ref": "#/definitions/reference" }
|
||||
}
|
||||
oneOf {
|
||||
type: array
|
||||
items { "$ref": "#/definitions/reference" }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
required: [
|
||||
type
|
||||
]
|
||||
properties {
|
||||
// says it's an object
|
||||
type { const: object }
|
||||
// nothing is required
|
||||
// can have anything
|
||||
properties {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
// can have an "additionalProperties" section
|
||||
additionalProperties { "$ref": "#/definitions/additional_properties" }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// schema starts here!
|
||||
type: object
|
||||
required: [ _description ]
|
||||
properties {
|
||||
_description { type: string }
|
||||
// definitions for generator
|
||||
_definitions {
|
||||
type: object
|
||||
additionalProperties {
|
||||
required: [ type ]
|
||||
properties {
|
||||
type { type: string }
|
||||
}
|
||||
// can have anything
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
_references = ${properties._definitions}
|
||||
// default values for actions
|
||||
// can have anything
|
||||
_default {
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
// describing each endpoint
|
||||
additionalProperties {
|
||||
type: object
|
||||
// can be:
|
||||
oneOf: [
|
||||
// a reference
|
||||
{ "$ref": "#/definitions/reference" }
|
||||
// or a mapping from versions to actions
|
||||
{ "$ref": "#/definitions/endpoint" }
|
||||
]
|
||||
}
|
||||
303
server/schema/meta/validate.py
Executable file
303
server/schema/meta/validate.py
Executable file
@@ -0,0 +1,303 @@
|
||||
#!/usr/bin/env python
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
|
||||
import pyhocon
|
||||
import six
|
||||
import yaml
|
||||
from colors import color
|
||||
from jsonschema import validate, ValidationError as JSONSchemaValidationError
|
||||
from jsonschema.validators import validator_for
|
||||
from pathlib import Path
|
||||
from pyparsing import ParseBaseException
|
||||
|
||||
LINTER_URL = "https://www.jsonschemavalidator.net/"
|
||||
|
||||
|
||||
class LocalStorage(object):
|
||||
def __init__(self, driver):
|
||||
self.driver = driver
|
||||
|
||||
def __len__(self):
|
||||
return self.driver.execute_script("return window.localStorage.length;")
|
||||
|
||||
def items(self):
|
||||
return self.driver.execute_script(
|
||||
"""
|
||||
var ls = window.localStorage, items = {};
|
||||
for (var i = 0, k; i < ls.length; ++i)
|
||||
items[k = ls.key(i)] = ls.getItem(k);
|
||||
return items;
|
||||
"""
|
||||
)
|
||||
|
||||
def keys(self):
|
||||
return self.driver.execute_script(
|
||||
"""
|
||||
var ls = window.localStorage, keys = [];
|
||||
for (var i = 0; i < ls.length; ++i)
|
||||
keys[i] = ls.key(i);
|
||||
return keys;
|
||||
"""
|
||||
)
|
||||
|
||||
def get(self, key):
|
||||
return self.driver.execute_script(
|
||||
"return window.localStorage.getItem(arguments[0]);", key
|
||||
)
|
||||
|
||||
def remove(self, key):
|
||||
self.driver.execute_script("window.localStorage.removeItem(arguments[0]);", key)
|
||||
|
||||
def clear(self):
|
||||
self.driver.execute_script("window.localStorage.clear();")
|
||||
|
||||
def __getitem__(self, key):
|
||||
value = self.get(key)
|
||||
if value is None:
|
||||
raise KeyError(key)
|
||||
return value
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.driver.execute_script(
|
||||
"window.localStorage.setItem(arguments[0], arguments[1]);", key, value
|
||||
)
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.keys()
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.keys())
|
||||
|
||||
def __repr__(self):
|
||||
return repr(self.items())
|
||||
|
||||
|
||||
class ValidationError(Exception):
|
||||
|
||||
def __init__(self, *args):
|
||||
super(ValidationError, self).__init__(*args)
|
||||
self.message = self.args[0]
|
||||
|
||||
def report(self, schema_file):
|
||||
message = color(schema_file, fg='red')
|
||||
if self.message:
|
||||
message += ": {}".format(self.message)
|
||||
print(message)
|
||||
|
||||
|
||||
class InvalidFile(ValidationError):
|
||||
"""
|
||||
InvalidFile
|
||||
Wraps other exceptions that occur in file validation
|
||||
|
||||
:param message: message to display
|
||||
"""
|
||||
|
||||
def __init__(self, message):
|
||||
super(InvalidFile, self).__init__(message)
|
||||
exc_type, _, _ = self.exc_info = sys.exc_info()
|
||||
if exc_type:
|
||||
self.message = "{}: {}".format(exc_type.__name__, message)
|
||||
|
||||
def raise_original(self):
|
||||
six.reraise(*self.exc_info)
|
||||
|
||||
|
||||
def load_hocon(name):
|
||||
"""
|
||||
load_hocon
|
||||
load configuration from file
|
||||
|
||||
:param name: file path
|
||||
"""
|
||||
return pyhocon.ConfigFactory.parse_file(name).as_plain_ordered_dict()
|
||||
|
||||
|
||||
def validate_ascii_only(name):
|
||||
invalid_char = next(
|
||||
(
|
||||
(line_num, column, char)
|
||||
for line_num, line in enumerate(Path(name).read_text().splitlines())
|
||||
for column, char in enumerate(line)
|
||||
if ord(char) not in range(128)
|
||||
),
|
||||
None,
|
||||
)
|
||||
if invalid_char:
|
||||
line, column, char = invalid_char
|
||||
raise ValidationError(
|
||||
"file contains non-ascii character {!r} in line {} pos {}".format(
|
||||
char, line, column
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def validate_file(meta, name):
|
||||
"""
|
||||
validate_file
|
||||
validate file according to meta-scheme
|
||||
|
||||
:param meta: meta-scheme
|
||||
:param name: file path
|
||||
"""
|
||||
validate_ascii_only(name)
|
||||
try:
|
||||
schema = load_hocon(name)
|
||||
except ParseBaseException as e:
|
||||
raise InvalidFile(repr(e))
|
||||
|
||||
try:
|
||||
validate(schema, meta)
|
||||
return schema
|
||||
except JSONSchemaValidationError as e:
|
||||
path = "->".join(e.absolute_path)
|
||||
message = "{}: {}".format(path, e.args[0])
|
||||
raise InvalidFile(message)
|
||||
except Exception as e:
|
||||
raise InvalidFile(str(e))
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("files", nargs="+")
|
||||
parser.add_argument(
|
||||
"--linter", "-l", action="store_true", help="open jsonschema linter in browser"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--raise",
|
||||
"-r",
|
||||
action="store_true",
|
||||
dest="raise_",
|
||||
help="raise first exception encountered and print traceback",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--detect-collisions",
|
||||
action="store_true",
|
||||
help="detect objects with the same name in different modules",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def open_linter(driver, meta, schema):
|
||||
driver.maximize_window()
|
||||
driver.get(LINTER_URL)
|
||||
storage = LocalStorage(driver)
|
||||
storage["jsonText"] = json.dumps(schema, indent=4)
|
||||
storage["schemaText"] = json.dumps(meta, indent=4)
|
||||
driver.refresh()
|
||||
|
||||
|
||||
class LazyDriver(object):
|
||||
def __init__(self):
|
||||
self._driver = None
|
||||
try:
|
||||
from selenium import webdriver, common
|
||||
except ImportError:
|
||||
webdriver = None
|
||||
common = None
|
||||
self.webdriver = webdriver
|
||||
self.common = common
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.driver, item)
|
||||
|
||||
@property
|
||||
def driver(self):
|
||||
if self._driver:
|
||||
return self._driver
|
||||
if not (self.webdriver and self.common):
|
||||
print("selenium not installed: linter unavailable")
|
||||
return None
|
||||
|
||||
for driver_type in self.webdriver.Chrome, self.webdriver.Firefox:
|
||||
try:
|
||||
self._driver = driver_type()
|
||||
break
|
||||
except self.common.exceptions.WebDriverException:
|
||||
pass
|
||||
else:
|
||||
print("No webdriver is found for chrome or firefox")
|
||||
|
||||
return self._driver
|
||||
|
||||
def wait(self):
|
||||
if not self._driver:
|
||||
return
|
||||
try:
|
||||
while True:
|
||||
self._driver.title
|
||||
time.sleep(0.5)
|
||||
except self.common.exceptions.WebDriverException:
|
||||
pass
|
||||
|
||||
|
||||
def remove_description(dct):
|
||||
dct.pop("description", None)
|
||||
for value in dct.values():
|
||||
try:
|
||||
remove_description(value)
|
||||
except (TypeError, AttributeError):
|
||||
pass
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
meta = load_hocon(os.path.dirname(__file__) + "/meta.conf")
|
||||
validator_for(meta).check_schema(meta)
|
||||
|
||||
driver = LazyDriver()
|
||||
|
||||
collisions = {}
|
||||
|
||||
for schema_file in args.files:
|
||||
|
||||
if Path(schema_file).name.startswith("_"):
|
||||
continue
|
||||
|
||||
try:
|
||||
schema = validate_file(meta, schema_file)
|
||||
except InvalidFile as e:
|
||||
if args.linter and driver.driver:
|
||||
open_linter(driver, meta, load_hocon(schema_file))
|
||||
elif args.raise_:
|
||||
e.raise_original()
|
||||
|
||||
e.report(schema_file)
|
||||
except ValidationError as e:
|
||||
e.report(schema_file)
|
||||
else:
|
||||
for def_name, value in schema.get("_definitions", {}).items():
|
||||
service_name = str(Path(schema_file).stem)
|
||||
remove_description(value)
|
||||
collisions.setdefault(def_name, {})[service_name] = value
|
||||
|
||||
warning = color("warning", fg="red")
|
||||
|
||||
if args.detect_collisions:
|
||||
for name, values in collisions.items():
|
||||
if len(values) <= 1:
|
||||
continue
|
||||
groups = [
|
||||
[service for (service, _) in pairs]
|
||||
for _, pairs in groupby(values.items(), itemgetter(1))
|
||||
]
|
||||
if not groups:
|
||||
raise RuntimeError("Unknown error")
|
||||
print(
|
||||
"{}: collision for {}:\n{}".format(warning, name, yaml.dump(groups)),
|
||||
end="",
|
||||
)
|
||||
|
||||
driver.wait()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
62
server/schema/services/README.md
Normal file
62
server/schema/services/README.md
Normal file
@@ -0,0 +1,62 @@
|
||||
# Writing descriptions
|
||||
There are two options for writing parameters descriptions. Mixing between the two
|
||||
will result in output which is not Sphinx friendly.
|
||||
Whatever you choose, lines are subject to wrapping.
|
||||
|
||||
- non-strict whitespace - Break the string however you like.
|
||||
Newlines and sequences of tabs/spaces are replaced by one space.
|
||||
Example:
|
||||
```
|
||||
get_all {
|
||||
"1.5" {
|
||||
description: """This will all appear
|
||||
as one long
|
||||
sentence.
|
||||
Break lines wherever you
|
||||
like.
|
||||
"""
|
||||
}
|
||||
}
|
||||
```
|
||||
Becomes:
|
||||
```
|
||||
class GetAllRequest(...):
|
||||
"""
|
||||
This will all appear as one long sentence. Break lines wherever you
|
||||
like.
|
||||
"""
|
||||
```
|
||||
- strict whitespace - Single newlines will be replaced by spaces.
|
||||
Double newlines become a single newline WITH INDENTATION PRESERVED,
|
||||
so if uniform indentation is requried for all lines you MUST start new lines
|
||||
at the first column.
|
||||
Example:
|
||||
```
|
||||
get_all {
|
||||
"1.5" {
|
||||
description: """
|
||||
Some general sentence.
|
||||
|
||||
- separate lines must have double newlines between them
|
||||
|
||||
- must begin at first column even though the "description" key is indented
|
||||
|
||||
- you can use single newlines, the lines will be
|
||||
joined
|
||||
|
||||
-- sub bullet: this line's leading spaces are preserved
|
||||
"""
|
||||
}
|
||||
}
|
||||
```
|
||||
Becomes:
|
||||
```
|
||||
class GetAllRequest(...):
|
||||
"""
|
||||
Some general sentence.
|
||||
- separate lines must have double newlines between them
|
||||
- must begin at first column even though the "description" key is indented
|
||||
- you can use single newlines, the lines will be joined
|
||||
-- sub bullet: this line's leading spaces are preserved
|
||||
"""
|
||||
```
|
||||
3
server/schema/services/_api_defaults.conf
Normal file
3
server/schema/services/_api_defaults.conf
Normal file
@@ -0,0 +1,3 @@
|
||||
internal: false
|
||||
allow_roles: ["*"]
|
||||
authorize: true
|
||||
13
server/schema/services/_common.conf
Normal file
13
server/schema/services/_common.conf
Normal file
@@ -0,0 +1,13 @@
|
||||
credentials {
|
||||
type: object
|
||||
properties {
|
||||
access_key {
|
||||
type: string
|
||||
description: Credentials access key
|
||||
}
|
||||
secret_key {
|
||||
type: string
|
||||
description: Credentials secret key
|
||||
}
|
||||
}
|
||||
}
|
||||
320
server/schema/services/auth.conf
Normal file
320
server/schema/services/auth.conf
Normal file
@@ -0,0 +1,320 @@
|
||||
|
||||
_description: """This service provides authentication management and authorization
|
||||
validation for the entire system."""
|
||||
_default {
|
||||
internal: true
|
||||
allow_roles: ["system", "root"]
|
||||
}
|
||||
|
||||
_definitions {
|
||||
include "_common.conf"
|
||||
credential_key {
|
||||
type: object
|
||||
properties {
|
||||
access_key {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
role {
|
||||
type: string
|
||||
enum: [ admin, superuser, user, annotator ]
|
||||
}
|
||||
}
|
||||
|
||||
login {
|
||||
internal: false
|
||||
allow_roles = [ "*" ]
|
||||
"2.1" {
|
||||
description: """Get a token based on supplied credentials (key/secret).
|
||||
Intended for use by users with key/secret credentials that wish to obtain a token
|
||||
for use with other services."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
expiration_sec {
|
||||
type: integer
|
||||
description: """Requested token expiration time in seconds.
|
||||
Not guaranteed, might be overridden by the service"""
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
token {
|
||||
type: string
|
||||
description: Token string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_token_for_user {
|
||||
"2.1" {
|
||||
description: """Get a token for the specified user. Intended for internal use."""
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
user
|
||||
]
|
||||
properties {
|
||||
user {
|
||||
type: string
|
||||
description: User ID
|
||||
}
|
||||
company {
|
||||
type: string
|
||||
description: Company ID
|
||||
}
|
||||
expiration_sec {
|
||||
type: integer
|
||||
description: """Requested token expiration time in seconds.
|
||||
Not guaranteed, might be overridden by the service"""
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
token {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validate_token {
|
||||
"2.1" {
|
||||
description: """Validate a token and return user identity if valid.
|
||||
Intended for internal use. """
|
||||
request {
|
||||
type: object
|
||||
required: [ token ]
|
||||
properties {
|
||||
token {
|
||||
type: string
|
||||
description: Token string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
valid {
|
||||
type: boolean
|
||||
description: Boolean indicating if the token is valid
|
||||
}
|
||||
user {
|
||||
type: string
|
||||
description: Associated user ID
|
||||
}
|
||||
company {
|
||||
type: string
|
||||
description: Associated company ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create_user {
|
||||
"2.1" {
|
||||
description: """Creates a new user auth entry. Intended for internal use. """
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
name
|
||||
company
|
||||
email
|
||||
]
|
||||
properties {
|
||||
name {
|
||||
type: string
|
||||
description: User name (makes the auth entry more readable)
|
||||
}
|
||||
company {
|
||||
type: string
|
||||
description: Associated company ID
|
||||
}
|
||||
email {
|
||||
type: string
|
||||
description: Email address uniquely identifying the user
|
||||
}
|
||||
role {
|
||||
description: User role
|
||||
default: user
|
||||
"$ref": "#/definitions/role"
|
||||
}
|
||||
given_name {
|
||||
type: string
|
||||
description: Given name
|
||||
}
|
||||
family_name {
|
||||
type: string
|
||||
description: Family name
|
||||
}
|
||||
avatar {
|
||||
type: string
|
||||
description: Avatar URL
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
type: string
|
||||
description: New user ID
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create_credentials {
|
||||
allow_roles = [ "*" ]
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: """Creates a new set of credentials for the authenticated user.
|
||||
New key/secret is returned.
|
||||
Note: Secret will never be returned in any other API call.
|
||||
If a secret is lost or compromised, the key should be revoked
|
||||
and a new set of credentials can be created."""
|
||||
request {
|
||||
type: object
|
||||
properties {}
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
credentials {
|
||||
"$ref": "#/definitions/credentials"
|
||||
description: Created credentials
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_credentials {
|
||||
allow_roles = [ "*" ]
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: """Returns all existing credential keys for the authenticated user.
|
||||
Note: Only credential keys are returned."""
|
||||
request {
|
||||
type: object
|
||||
properties {}
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
credentials {
|
||||
description: "List of credentials, each with an empty secret field."
|
||||
type: array
|
||||
items { "$ref": "#/definitions/credential_key" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
revoke_credentials {
|
||||
allow_roles = [ "*" ]
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: """Revokes (and deletes) a set (key, secret) of credentials for
|
||||
the authenticated user."""
|
||||
request {
|
||||
type: object
|
||||
required: [ key_id ]
|
||||
properties {
|
||||
access_key {
|
||||
type: string
|
||||
description: Credentials key
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
revoked {
|
||||
description: "Number of credentials revoked"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete_user {
|
||||
allow_roles = [ "system", "root", "admin" ]
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: """Delete a new user manually. Only supported in on-premises deployments. This only removes the user's auth entry so that any references to the deleted user's ID will still have valid user information"""
|
||||
request {
|
||||
type: object
|
||||
required: [ user ]
|
||||
properties {
|
||||
user {
|
||||
type: string
|
||||
description: User ID
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "True if user was successfully deleted, False otherwise"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
edit_user {
|
||||
internal: false
|
||||
allow_roles: ["system", "root", "admin"]
|
||||
"2.1" {
|
||||
description: """ Edit a users' auth data properties"""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
user {
|
||||
description: "User ID"
|
||||
type: string
|
||||
}
|
||||
role {
|
||||
description: "The new user's role within the company"
|
||||
type: string
|
||||
enum: [admin, superuser, user, annotator]
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of users updated (0 or 1)"
|
||||
type: number
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
849
server/schema/services/events.conf
Normal file
849
server/schema/services/events.conf
Normal file
@@ -0,0 +1,849 @@
|
||||
{
|
||||
_description : "Provides an API for running tasks to report events collected by the system."
|
||||
_definitions {
|
||||
metrics_scalar_event {
|
||||
description: "Used for reporting scalar metrics during training task"
|
||||
type: object
|
||||
required: [ task, type ]
|
||||
properties {
|
||||
timestamp {
|
||||
description: "Epoch milliseconds UTC, will be set by the server if not set."
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: "training_stats_vector"
|
||||
const: "training_stats_scalar"
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
type: string
|
||||
}
|
||||
iter {
|
||||
description: "Iteration"
|
||||
type: integer
|
||||
}
|
||||
metric {
|
||||
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "E.g. 'class_1', 'total', 'average"
|
||||
type: string
|
||||
}
|
||||
value {
|
||||
description: ""
|
||||
type: number
|
||||
}
|
||||
}
|
||||
}
|
||||
metrics_vector_event {
|
||||
description: "Used for reporting vector metrics during training task"
|
||||
type: object
|
||||
required: [ task ]
|
||||
properties {
|
||||
timestamp {
|
||||
description: "Epoch milliseconds UTC, will be set by the server if not set."
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: "training_stats_vector"
|
||||
const: "training_stats_vector"
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
type: string
|
||||
}
|
||||
iter {
|
||||
description: "Iteration"
|
||||
type: integer
|
||||
}
|
||||
metric {
|
||||
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "E.g. 'class_1', 'total', 'average"
|
||||
type: string
|
||||
}
|
||||
values {
|
||||
description: "vector of float values"
|
||||
type: array
|
||||
items { type: number }
|
||||
}
|
||||
}
|
||||
}
|
||||
metrics_image_event {
|
||||
description: "An image or video was dumped to storage for debugging"
|
||||
type: object
|
||||
required: [ task, type ]
|
||||
properties {
|
||||
timestamp {
|
||||
description: "Epoch milliseconds UTC, will be set by the server if not set."
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: ""
|
||||
const: "training_debug_image"
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
type: string
|
||||
}
|
||||
iter {
|
||||
description: "Iteration"
|
||||
type: integer
|
||||
}
|
||||
metric {
|
||||
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "E.g. 'class_1', 'total', 'average"
|
||||
type: string
|
||||
}
|
||||
key {
|
||||
description: "File key"
|
||||
type: string
|
||||
}
|
||||
url {
|
||||
description: "File URL"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
metrics_plot_event {
|
||||
description: """ An entire plot (not single datapoint) and it's layout.
|
||||
Used for plotting ROC curves, confidence matrices, etc. when evaluating the net."""
|
||||
type: object
|
||||
required: [ task, type ]
|
||||
properties {
|
||||
timestamp {
|
||||
description: "Epoch milliseconds UTC, will be set by the server if not set."
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: "'plot'"
|
||||
const: "plot"
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
type: string
|
||||
}
|
||||
iter {
|
||||
description: "Iteration"
|
||||
type: integer
|
||||
}
|
||||
metric {
|
||||
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "E.g. 'class_1', 'total', 'average"
|
||||
type: string
|
||||
}
|
||||
plot_str {
|
||||
description: """An entire plot (not single datapoint) and it's layout.
|
||||
Used for plotting ROC curves, confidence matrices, etc. when evaluating the net.
|
||||
"""
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
log_level_enum {
|
||||
type: string
|
||||
enum: [
|
||||
notset
|
||||
debug
|
||||
verbose
|
||||
info
|
||||
warn
|
||||
warning
|
||||
error
|
||||
fatal
|
||||
critical
|
||||
]
|
||||
}
|
||||
task_log_event {
|
||||
description: """A log event associated with a task."""
|
||||
type: object
|
||||
required: [ task, type ]
|
||||
properties {
|
||||
timestamp {
|
||||
description: "Epoch milliseconds UTC, will be set by the server if not set."
|
||||
type: number
|
||||
}
|
||||
type {
|
||||
description: "'log'"
|
||||
const: "log"
|
||||
}
|
||||
task {
|
||||
description: "Task ID (required)"
|
||||
type: string
|
||||
}
|
||||
level {
|
||||
description: "Log level."
|
||||
"$ref": "#/definitions/log_level_enum"
|
||||
}
|
||||
worker {
|
||||
description: "Name of machine running the task."
|
||||
type: string
|
||||
}
|
||||
msg {
|
||||
description: "Log message."
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
add {
|
||||
"2.1" {
|
||||
description: "Adds a single event"
|
||||
request {
|
||||
type: object
|
||||
anyOf: [
|
||||
{ "$ref": "#/definitions/metrics_scalar_event" }
|
||||
{ "$ref": "#/definitions/metrics_vector_event" }
|
||||
{ "$ref": "#/definitions/metrics_image_event" }
|
||||
{ "$ref": "#/definitions/metrics_plot_event" }
|
||||
{ "$ref": "#/definitions/task_log_event" }
|
||||
]
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
add_batch {
|
||||
"2.1" {
|
||||
description: "Adds a batch of events in a single call."
|
||||
batch_request: {
|
||||
action: add
|
||||
version: 1.5
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
added { type: integer }
|
||||
errors { type: integer }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete_for_task {
|
||||
"2.1" {
|
||||
description: "Delete all task event. *This cannot be undone!*"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
type: boolean
|
||||
description: "Number of deleted events"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
debug_images {
|
||||
"2.1" {
|
||||
description: "Get all debug images of a task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
images {
|
||||
type: array
|
||||
items { type: object }
|
||||
description: "Images list"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of results returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_log {
|
||||
"1.5" {
|
||||
description: "Get all 'log' events for this task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
order {
|
||||
type: string
|
||||
description: "Timestamp order in which log events will be returned (defaults to ascending)"
|
||||
enum: [
|
||||
asc
|
||||
desc
|
||||
]
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
batch_size {
|
||||
type: integer
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
events {
|
||||
type: array
|
||||
# TODO: items: log event
|
||||
items { type: object }
|
||||
}
|
||||
returned { type: integer }
|
||||
total { type: integer }
|
||||
scroll_id { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
"1.7" {
|
||||
description: "Get all 'log' events for this task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
order {
|
||||
type: string
|
||||
description: "Timestamp order in which log events will be returned (defaults to ascending)"
|
||||
enum: [
|
||||
asc
|
||||
desc
|
||||
]
|
||||
}
|
||||
from {
|
||||
type: string
|
||||
description: "Where will the log entries be taken from (default to the head of the log)"
|
||||
enum: [
|
||||
head
|
||||
tail
|
||||
]
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
batch_size {
|
||||
type: integer
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
events {
|
||||
type: array
|
||||
# TODO: items: log event
|
||||
items { type: object }
|
||||
description: "Log items list"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of results returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_events {
|
||||
"2.1" {
|
||||
description: "Scroll through task events, sorted by timestamp"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
order {
|
||||
type:string
|
||||
description: "'asc' (default) or 'desc'."
|
||||
enum: [
|
||||
asc
|
||||
desc
|
||||
]
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Pass this value on next call to get next page"
|
||||
}
|
||||
batch_size {
|
||||
type: integer
|
||||
description: "Number of events to return each time"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
events {
|
||||
type: array
|
||||
items { type: object }
|
||||
description: "Events list"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of results returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
download_task_log {
|
||||
"2.1" {
|
||||
description: "Get an attachment containing the task's log"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
description: "Task ID"
|
||||
type: string
|
||||
}
|
||||
line_type {
|
||||
description: "Line format type"
|
||||
type: string
|
||||
enum: [
|
||||
json
|
||||
text
|
||||
]
|
||||
}
|
||||
line_format {
|
||||
type: string
|
||||
description: "Line string format. Used if the line type is 'text'"
|
||||
default: "{asctime} {worker} {level} {msg}"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_plots {
|
||||
"2.1" {
|
||||
description: "Get all 'plot' events for this task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
plots {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
}
|
||||
description: "Plots list"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of results returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_multi_task_plots {
|
||||
"2.1" {
|
||||
description: "Get 'plot' events for the given tasks"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
tasks
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task IDs"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
}
|
||||
iters {
|
||||
type: integer
|
||||
description: "Max number of latest iterations for which to return debug images"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
plots {
|
||||
type: object
|
||||
description: "Plots mapping (keyed by task name)"
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "Number of results returned"
|
||||
}
|
||||
total {
|
||||
type: number
|
||||
description: "Total number of results available for this query"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID for getting more results"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_vector_metrics_and_variants {
|
||||
"2.1" {
|
||||
description: ""
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
description: ""
|
||||
type: array
|
||||
items: { type: object }
|
||||
# TODO: items: ???
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
vector_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
description: "Get histogram data of all the scalar metrics and variants in the task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
metric
|
||||
variant
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
metric {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
variant {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
images {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
description: "Get histogram data of all the vector metrics and variants in the task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
images {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
multi_task_scalar_metrics_iter_histogram {
|
||||
"2.1" {
|
||||
description: "Used to compare scalar stats histogram of multiple tasks"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
tasks
|
||||
]
|
||||
properties {
|
||||
tasks {
|
||||
description: "List of task Task IDs"
|
||||
type: array
|
||||
items {
|
||||
type: string
|
||||
description: "List of task Task IDs"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
// properties {}
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_latest_scalar_values {
|
||||
"2.1" {
|
||||
description: "Get the tasks's latest scalar values"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: "Task ID"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
type: string
|
||||
description: "Metric name"
|
||||
}
|
||||
variants {
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
type: string
|
||||
description: "Variant name"
|
||||
}
|
||||
last_value {
|
||||
type: number
|
||||
description: "Last reported value"
|
||||
}
|
||||
last_100_value {
|
||||
type: number
|
||||
description: "Average of 100 last reported values"
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_scalar_metrics_and_variants {
|
||||
"2.1" {
|
||||
description: get task scalar metrics and variants
|
||||
request {
|
||||
type: object
|
||||
required: [ task ]
|
||||
properties {
|
||||
task {
|
||||
description: task ID
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_scalar_metric_data {
|
||||
"2.1" {
|
||||
description: "get scalar metric data for task"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
type: string
|
||||
description: task ID
|
||||
}
|
||||
metric {
|
||||
type: string
|
||||
description: type of metric
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
events {
|
||||
description: "task scalar metric events"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
}
|
||||
}
|
||||
returned {
|
||||
type: integer
|
||||
description: "amount of events returned"
|
||||
}
|
||||
total {
|
||||
type: integer
|
||||
description: "amount of events in task"
|
||||
}
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID of previous call (used for getting more results)"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
639
server/schema/services/models.conf
Normal file
639
server/schema/services/models.conf
Normal file
@@ -0,0 +1,639 @@
|
||||
{
|
||||
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
|
||||
_definitions {
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
model {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Model id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Model name"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company id"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Model creation time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
task {
|
||||
description: "Task ID of task in which the model was created"
|
||||
type: string
|
||||
}
|
||||
parent {
|
||||
description: "Parent model ID"
|
||||
type: string
|
||||
}
|
||||
project {
|
||||
description: "Associated project ID"
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "Model comment"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "Tags"
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
|
||||
type: string
|
||||
}
|
||||
design {
|
||||
description: "Json object representing the model design. Should be identical to the network design of the task which created the model"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
labels {
|
||||
description: "Json object representing the ids of the labels in the model. The keys are the layers' names and the values are the ids."
|
||||
type: object
|
||||
additionalProperties { type: integer }
|
||||
}
|
||||
uri {
|
||||
description: "URI for the model, pointing to the destination storage."
|
||||
type: string
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks"
|
||||
type: boolean
|
||||
}
|
||||
ui_cache {
|
||||
description: "UI cache for this model"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_by_id {
|
||||
"2.1" {
|
||||
description: "Gets model information"
|
||||
request {
|
||||
type: object
|
||||
required: [ model ]
|
||||
properties {
|
||||
model {
|
||||
description: "Model id"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
model {
|
||||
description: "Model info"
|
||||
"$ref": "#/definitions/model"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_by_task_id {
|
||||
"2.1" {
|
||||
description: "Gets model information"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
task {
|
||||
description: "Task id"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
model {
|
||||
description: "Model info"
|
||||
"$ref": "#/definitions/model"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
description: "Get all models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
description: "Get only models whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
ready {
|
||||
description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects."
|
||||
type: boolean
|
||||
}
|
||||
tags {
|
||||
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
only_fields {
|
||||
description: "List of model field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of models"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
page_size {
|
||||
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
|
||||
type: integer
|
||||
minimum: 1
|
||||
}
|
||||
project {
|
||||
description: "List of associated project IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
task {
|
||||
description: "List of associated task IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
id {
|
||||
description: "List of model IDs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
search_text {
|
||||
description: "Free text search query"
|
||||
type: string
|
||||
}
|
||||
framework {
|
||||
description: "List of frameworks"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
uri {
|
||||
description: "List of model URIs"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
_any_ {
|
||||
description: "Multi-field pattern condition (any field matches pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
dependencies {
|
||||
page: [ page_size ]
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
models: {
|
||||
description: "Models list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/model" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update_for_task {
|
||||
"2.1" {
|
||||
description: "Create or update a new model for a task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
task
|
||||
]
|
||||
properties {
|
||||
task {
|
||||
description: "Task id"
|
||||
type: string
|
||||
}
|
||||
uri {
|
||||
description: "URI for the model"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Model name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "Model comment"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
override_model_id {
|
||||
description: "Override model ID. If provided, this model is updated in the task."
|
||||
type: string
|
||||
}
|
||||
iteration {
|
||||
description: "Iteration (used to update task statistics)"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "ID of the model"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Was the model created"
|
||||
type: boolean
|
||||
}
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
create {
|
||||
"2.1" {
|
||||
description: "Create a new model not associated with a task"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
uri
|
||||
name
|
||||
labels
|
||||
]
|
||||
properties {
|
||||
uri {
|
||||
description: "URI for the model"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Model name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "Model comment"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
type: string
|
||||
}
|
||||
design {
|
||||
description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
labels {
|
||||
description: "Json object"
|
||||
type: object
|
||||
additionalProperties { type: integer }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
public {
|
||||
description: "Create a public model Default is false."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
project {
|
||||
description: "Project to which to model belongs"
|
||||
type: string
|
||||
}
|
||||
parent {
|
||||
description: "Parent model"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Associated task ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "ID of the model"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Was the model created"
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
edit {
|
||||
"2.1" {
|
||||
description: "Edit an existing model"
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
model
|
||||
]
|
||||
properties {
|
||||
model {
|
||||
description: "Model ID"
|
||||
type: string
|
||||
}
|
||||
uri {
|
||||
description: "URI for the model"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Model name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "Model comment"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
framework {
|
||||
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
|
||||
type: string
|
||||
}
|
||||
design {
|
||||
description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
labels {
|
||||
description: "Json object"
|
||||
type: object
|
||||
additionalProperties { type: integer }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks"
|
||||
type: boolean
|
||||
}
|
||||
project {
|
||||
description: "Project to which to model belongs"
|
||||
type: string
|
||||
}
|
||||
parent {
|
||||
description: "Parent model"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Associated task ID"
|
||||
type: string
|
||||
}
|
||||
iteration {
|
||||
description: "Iteration (used to update task statistics)"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
description: "Update a model"
|
||||
request {
|
||||
type: object
|
||||
required: [ model ]
|
||||
properties {
|
||||
model {
|
||||
description: "Model id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Model name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
comment {
|
||||
description: "Model comment"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
ready {
|
||||
description: "Indication if the model is final and can be used by other tasks Default is false."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
created {
|
||||
description: "Model creation time (UTC) "
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
ui_cache {
|
||||
description: "UI cache for this model"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
project {
|
||||
description: "Project to which to model belongs"
|
||||
type: string
|
||||
}
|
||||
task {
|
||||
description: "Associated task ID"
|
||||
type: "string"
|
||||
}
|
||||
iteration {
|
||||
description: "Iteration (used to update task statistics if an associated task is reported)"
|
||||
type: integer
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
set_ready {
|
||||
"2.1" {
|
||||
description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task."
|
||||
request {
|
||||
type: object
|
||||
required: [ model ]
|
||||
properties {
|
||||
model {
|
||||
description: "Model id"
|
||||
type: string
|
||||
}
|
||||
force_publish_task {
|
||||
description: "Publish the associated task (if exists) even if it is not in the 'stopped' state. Optional, the default value is False."
|
||||
type: boolean
|
||||
}
|
||||
publish_task {
|
||||
description: "Indicates that the associated task (if exists) should be published. Optional, the default value is True."
|
||||
type: boolean
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of models updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [0, 1]
|
||||
}
|
||||
published_task {
|
||||
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Task id"
|
||||
type: string
|
||||
}
|
||||
data {
|
||||
description: "Data returned from the task publishing operation."
|
||||
type: object
|
||||
properties {
|
||||
committed_versions_results {
|
||||
description: "Committed versions results"
|
||||
type: array
|
||||
items {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
updated {
|
||||
description: "Number of tasks updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Delete a model."
|
||||
request {
|
||||
required: [
|
||||
model
|
||||
]
|
||||
type: object
|
||||
properties {
|
||||
model {
|
||||
description: "Model ID"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
|
||||
"""
|
||||
type: boolean
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Indicates whether the model was deleted"
|
||||
type: boolean
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
463
server/schema/services/projects.conf
Normal file
463
server/schema/services/projects.conf
Normal file
@@ -0,0 +1,463 @@
|
||||
{
|
||||
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
|
||||
_definitions {
|
||||
multi_field_pattern_data {
|
||||
type: object
|
||||
properties {
|
||||
pattern {
|
||||
description: "Pattern string (regex)"
|
||||
type: string
|
||||
}
|
||||
fields {
|
||||
description: "List of field names"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
project_tags_enum {
|
||||
type: string
|
||||
enum: [ archived, public, default ]
|
||||
}
|
||||
project {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company id"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Creation time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
description: "Tags"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/project_tags_enum" }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
last_update {
|
||||
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
}
|
||||
}
|
||||
stats_status_count {
|
||||
type: object
|
||||
properties {
|
||||
total_runtime {
|
||||
description: "Total run time of all tasks in project (in seconds)"
|
||||
type: integer
|
||||
}
|
||||
status_count {
|
||||
description: "Status counts"
|
||||
type: object
|
||||
properties {
|
||||
created {
|
||||
description: "Number of 'created' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
queued {
|
||||
description: "Number of 'queued' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
in_progress {
|
||||
description: "Number of 'in_progress' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
stopped {
|
||||
description: "Number of 'stopped' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
published {
|
||||
description: "Number of 'published' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
closed {
|
||||
description: "Number of 'closed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
failed {
|
||||
description: "Number of 'failed' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
unknown {
|
||||
description: "Number of 'unknown' tasks in project"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
stats {
|
||||
type: object
|
||||
properties {
|
||||
active {
|
||||
description: "Stats for active tasks"
|
||||
"$ref": "#/definitions/stats_status_count"
|
||||
}
|
||||
archived {
|
||||
description: "Stats for archived tasks"
|
||||
"$ref": "#/definitions/stats_status_count"
|
||||
}
|
||||
}
|
||||
}
|
||||
projects_get_all_response_single {
|
||||
// copy-paste from project definition
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
user {
|
||||
description: "Associated user id"
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: "Company id"
|
||||
type: string
|
||||
}
|
||||
created {
|
||||
description: "Creation time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
description: "Tags"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/project_tags_enum" }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
// extra properties
|
||||
stats: {
|
||||
description: "Additional project stats"
|
||||
"$ref": "#/definitions/stats"
|
||||
}
|
||||
}
|
||||
}
|
||||
metric_variant_result {
|
||||
type: object
|
||||
properties {
|
||||
metric {
|
||||
description: "Metric name"
|
||||
type: string
|
||||
}
|
||||
metric_hash {
|
||||
description: """Metric name hash. Used instead of the metric name when categorizing
|
||||
last metrics events in task objects."""
|
||||
type: string
|
||||
}
|
||||
variant {
|
||||
description: "Variant name"
|
||||
type: string
|
||||
}
|
||||
variant_hash {
|
||||
description: """Variant name hash. Used instead of the variant name when categorizing
|
||||
last metrics events in task objects."""
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
"2.1" {
|
||||
description: "Create a new project"
|
||||
request {
|
||||
type: object
|
||||
required :[
|
||||
name
|
||||
description
|
||||
]
|
||||
properties {
|
||||
name {
|
||||
description: "Project name Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/project_tags_enum" }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_by_id {
|
||||
"2.1" {
|
||||
description: ""
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
project {
|
||||
description: "Project info"
|
||||
"$ref": "#/definitions/project"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
description: "Get all the company's projects and all public projects"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: "List of IDs to filter by"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
name {
|
||||
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of dataviews"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
page_size {
|
||||
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
|
||||
type: integer
|
||||
minimum: 1
|
||||
}
|
||||
search_text {
|
||||
description: "Free text search query"
|
||||
type: string
|
||||
}
|
||||
only_fields {
|
||||
description: "List of document's field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
_all_ {
|
||||
description: "Multi-field pattern condition (all fields match pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
_any_ {
|
||||
description: "Multi-field pattern condition (any field matches pattern)"
|
||||
"$ref": "#/definitions/multi_field_pattern_data"
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
projects {
|
||||
description: "Projects list"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/projects_get_all_response_single" }
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
include_stats {
|
||||
description: "If true, include project statistic in response."
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
stats_for_state {
|
||||
description: "Report stats include only statistics for tasks in the specified state. If Null is provided, stats for all task states will be returned."
|
||||
type: string
|
||||
enum: [ active, archived ]
|
||||
default: active
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
description: "Update project information"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: "Project name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
}
|
||||
tags {
|
||||
description: "Tags list"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
default_output_destination {
|
||||
description: "The default output destination URL for new tasks under this project"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: "Number of projects updated (0 or 1)"
|
||||
type: integer
|
||||
enum: [ 0, 1 ]
|
||||
}
|
||||
fields {
|
||||
description: "Updated fields names and values"
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
description: """If not true, fails if project has tasks.
|
||||
If true, and project has tasks, they will be unassigned"""
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
deleted {
|
||||
description: "Number of projects deleted (0 or 1)"
|
||||
type: integer
|
||||
}
|
||||
disassociated_tasks {
|
||||
description: "Number of tasks disassociated from the deleted project"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
get_unique_metric_variants {
|
||||
"2.1" {
|
||||
description: """Get all metric/variant pairs reported for tasks in a specific project.
|
||||
If no project is specified, metrics/variant paris reported for all tasks will be returned.
|
||||
If the project does not exist, an empty list will be returned."""
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
metrics {
|
||||
description: "A list of metric variants reported for tasks in this project"
|
||||
type: array
|
||||
items { "$ref": "#/definitions/metric_variant_result" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
1048
server/schema/services/tasks.conf
Normal file
1048
server/schema/services/tasks.conf
Normal file
File diff suppressed because it is too large
Load Diff
366
server/schema/services/users.conf
Normal file
366
server/schema/services/users.conf
Normal file
@@ -0,0 +1,366 @@
|
||||
_description: """This service provides a management interface to users information
|
||||
and new users login restrictions."""
|
||||
_default {
|
||||
internal: true
|
||||
}
|
||||
_definitions {
|
||||
user {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
description: User ID
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: Full name
|
||||
type: string
|
||||
}
|
||||
given_name {
|
||||
description: Given name
|
||||
type: string
|
||||
}
|
||||
family_name {
|
||||
description: Family name
|
||||
type: string
|
||||
}
|
||||
avatar {
|
||||
description: Avatar URL
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: Company ID
|
||||
type: string
|
||||
}
|
||||
# Admin only fields
|
||||
role {
|
||||
description: """User's role (admin only)"""
|
||||
type: string
|
||||
}
|
||||
providers {
|
||||
description: """Providers uses has logged-in with"""
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
created {
|
||||
description: User creation date
|
||||
type: string
|
||||
format: date-time
|
||||
}
|
||||
email {
|
||||
description: User email
|
||||
type: string
|
||||
format: email
|
||||
}
|
||||
}
|
||||
}
|
||||
get_current_user_response_user_object {
|
||||
type: object
|
||||
description: "like user, but returns company object instead of ID"
|
||||
properties {
|
||||
id {
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
type: string
|
||||
}
|
||||
given_name {
|
||||
type: string
|
||||
}
|
||||
family_name {
|
||||
type: string
|
||||
}
|
||||
role {
|
||||
type: string
|
||||
}
|
||||
avatar {
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
type: object
|
||||
properties {
|
||||
id {
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
preferences {
|
||||
description: User preferences
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_by_id {
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: Gets user information
|
||||
request {
|
||||
type: object
|
||||
required: [ user ]
|
||||
properties {
|
||||
user {
|
||||
description: User ID
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
user {
|
||||
description: User info
|
||||
"$ref": "#/definitions/user"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_current_user {
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: """Gets current user information, based on the authenticated user making the call."""
|
||||
request {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
user {
|
||||
description: "User info"
|
||||
"$ref": "#/definitions/get_current_user_response_user_object"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.1": ${get_all."2.1"} {
|
||||
}
|
||||
}
|
||||
|
||||
get_all {
|
||||
"2.1" {
|
||||
description: Get all user objects
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
name {
|
||||
description: "Get only users whose name matches this pattern (python regular expression syntax)"
|
||||
type: string
|
||||
}
|
||||
id {
|
||||
description: "List of user IDs used to filter results"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
only_fields {
|
||||
description: "List of user field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of users"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
page_size {
|
||||
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
|
||||
type: integer
|
||||
minimum: 1
|
||||
}
|
||||
order_by {
|
||||
description: "List of field names to order by. Use '-' prefix to specify descending order. Optional, recommended when using page"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
users {
|
||||
description: User list
|
||||
type: array
|
||||
items { "$ref": "#/definitions/user" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
internal: true
|
||||
allow_roles: [ "system", "root" ]
|
||||
"2.1" {
|
||||
description: Delete user
|
||||
description: Delete a user
|
||||
request {
|
||||
type: object
|
||||
required: [ user ]
|
||||
properties {
|
||||
user {
|
||||
description: ID of user to delete
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
additionalProperties: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
allow_roles: [ "system", "root" ]
|
||||
"2.1" {
|
||||
description: Create a new user object. Reserved for internal use.
|
||||
request {
|
||||
type: object
|
||||
required: [
|
||||
company
|
||||
id
|
||||
name
|
||||
]
|
||||
properties {
|
||||
id {
|
||||
description: User ID
|
||||
type: string
|
||||
}
|
||||
company {
|
||||
description: Company ID
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: Full name
|
||||
type: string
|
||||
}
|
||||
given_name {
|
||||
description: Given name
|
||||
type: string
|
||||
}
|
||||
family_name {
|
||||
description: Family name
|
||||
type: string
|
||||
}
|
||||
avatar {
|
||||
description: Avatar URL
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {}
|
||||
additionalProperties: false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
update {
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: Update a user object
|
||||
request {
|
||||
type: object
|
||||
required: [ user ]
|
||||
properties {
|
||||
user {
|
||||
description: User ID
|
||||
type: string
|
||||
}
|
||||
name {
|
||||
description: Full name
|
||||
type: string
|
||||
}
|
||||
given_name {
|
||||
description: Given name
|
||||
type: string
|
||||
}
|
||||
family_name {
|
||||
description: Family name
|
||||
type: string
|
||||
}
|
||||
avatar {
|
||||
description: Avatar URL
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: Number of updated user objects (0 or 1)
|
||||
type: integer
|
||||
}
|
||||
fields {
|
||||
description: Updated fields names and values
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_preferences {
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: Get user preferences
|
||||
request {
|
||||
type: object
|
||||
properties {}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
preferences {
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
set_preferences {
|
||||
internal: false
|
||||
"2.1" {
|
||||
description: Set user preferences
|
||||
request {
|
||||
type: object
|
||||
required: [ preferences ]
|
||||
properties {
|
||||
preferences {
|
||||
description: """Updates to user preferences. A mapping from keys in dot notation to values.
|
||||
For example, `{"a.b": 0}` will set the key "b" in object "a" to 0."""
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
updated {
|
||||
description: Number of updated user objects (0 or 1)
|
||||
type: integer
|
||||
}
|
||||
fields {
|
||||
description: Updated fields names and values
|
||||
type: object
|
||||
additionalProperties: true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
184
server/server.py
Normal file
184
server/server.py
Normal file
@@ -0,0 +1,184 @@
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from flask import Flask, request, Response
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
import database
|
||||
from apierrors.base import BaseError
|
||||
from config import config
|
||||
from service_repo import ServiceRepo, APICall
|
||||
from service_repo.auth import AuthType
|
||||
from service_repo.errors import PathParsingError
|
||||
from timing_context import TimingContext
|
||||
from utilities import json
|
||||
from init_data import init_es_data, init_mongo_data
|
||||
|
||||
app = Flask(__name__, static_url_path="/static")
|
||||
CORS(app, supports_credentials=True, **config.get("apiserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
log.info("################ API Server initializing #####################")
|
||||
|
||||
app.config["SECRET_KEY"] = config.get("secure.http.session_secret.apiserver")
|
||||
app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
|
||||
|
||||
database.initialize()
|
||||
|
||||
init_es_data()
|
||||
init_mongo_data()
|
||||
|
||||
ServiceRepo.load("services")
|
||||
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
|
||||
|
||||
|
||||
@app.before_first_request
|
||||
def before_app_first_request():
|
||||
pass
|
||||
|
||||
|
||||
@app.before_request
|
||||
def before_request():
|
||||
if request.method == "OPTIONS":
|
||||
return "", 200
|
||||
if "/static/" in request.path:
|
||||
return
|
||||
|
||||
try:
|
||||
call = create_api_call(request)
|
||||
content, content_type = ServiceRepo.handle_call(call)
|
||||
headers = None
|
||||
if call.result.filename:
|
||||
headers = {
|
||||
"Content-Disposition": f"attachment; filename={call.result.filename}"
|
||||
}
|
||||
|
||||
return Response(
|
||||
content, mimetype=content_type, status=call.result.code, headers=headers
|
||||
)
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
|
||||
def update_call_data(call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
if req.content_type == "application/json-lines":
|
||||
items = []
|
||||
for i, line in enumerate(req.data.splitlines()):
|
||||
try:
|
||||
event = json.loads(line)
|
||||
if not isinstance(event, dict):
|
||||
raise BadRequest(
|
||||
f"json lines must contain objects, found: {type(event).__name__}"
|
||||
)
|
||||
items.append(event)
|
||||
except ValueError as e:
|
||||
msg = f"{e} in batch item #{i}"
|
||||
req.on_json_loading_failed(msg)
|
||||
call.batched_data = items
|
||||
else:
|
||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
||||
# merge form and args
|
||||
form = req.form.copy()
|
||||
form.update(req.args)
|
||||
form = form.to_dict()
|
||||
# convert string numbers to floats
|
||||
for key in form:
|
||||
if form[key].replace(".", "", 1).isdigit():
|
||||
if "." in form[key]:
|
||||
form[key] = float(form[key])
|
||||
else:
|
||||
form[key] = int(form[key])
|
||||
elif form[key].lower() == "true":
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
# NOTE: dict() form data to make sure we won't pass along a MultiDict or some other nasty crap
|
||||
call.data = json_body or form or {}
|
||||
|
||||
|
||||
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
|
||||
call = call or APICall(
|
||||
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
|
||||
)
|
||||
call.set_error_result(msg=msg, code=code, subcode=subcode)
|
||||
return call
|
||||
|
||||
|
||||
def create_api_call(req):
|
||||
call = None
|
||||
try:
|
||||
# Parse the request path
|
||||
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
|
||||
|
||||
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
|
||||
# in any case, request headers always take precedence.
|
||||
auth_cookie = req.cookies.get(
|
||||
config.get("apiserver.auth.session_auth_cookie_name")
|
||||
)
|
||||
headers = (
|
||||
{}
|
||||
if not auth_cookie
|
||||
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
|
||||
)
|
||||
headers.update(
|
||||
list(req.headers.items())
|
||||
) # add (possibly override with) the headers
|
||||
|
||||
# Construct call instance
|
||||
call = APICall(
|
||||
endpoint_name=endpoint_name,
|
||||
remote_addr=req.remote_addr,
|
||||
endpoint_version=endpoint_version,
|
||||
headers=headers,
|
||||
files=req.files,
|
||||
)
|
||||
|
||||
# Update call data from request
|
||||
with TimingContext("preprocess", "update_call_data"):
|
||||
update_call_data(call, req)
|
||||
|
||||
except PathParsingError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
|
||||
call.log_api = False
|
||||
except BadRequest as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.description, 400)
|
||||
except BaseError as ex:
|
||||
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
|
||||
except Exception as ex:
|
||||
log.exception("Error creating call")
|
||||
call = _call_or_empty_with_error(
|
||||
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
|
||||
)
|
||||
|
||||
return call
|
||||
|
||||
|
||||
# =================== MAIN =======================
|
||||
if __name__ == "__main__":
|
||||
p = ArgumentParser(description=__doc__)
|
||||
p.add_argument(
|
||||
"--port", "-p", type=int, default=config.get("apiserver.listen.port")
|
||||
)
|
||||
p.add_argument("--ip", "-i", type=str, default=config.get("apiserver.listen.ip"))
|
||||
p.add_argument(
|
||||
"--debug", action="store_true", default=config.get("apiserver.debug")
|
||||
)
|
||||
p.add_argument(
|
||||
"--watch", action="store_true", default=config.get("apiserver.watch")
|
||||
)
|
||||
args = p.parse_args()
|
||||
|
||||
# logging.info("Starting API Server at %s:%s and env '%s'" % (args.ip, args.port, config.env))
|
||||
|
||||
app.run(
|
||||
debug=args.debug,
|
||||
host=args.ip,
|
||||
port=args.port,
|
||||
threaded=True,
|
||||
use_reloader=args.watch,
|
||||
)
|
||||
52
server/service_repo/__init__.py
Normal file
52
server/service_repo/__init__.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from typing import Text, Sequence, Callable, Union
|
||||
|
||||
from funcsigs import signature
|
||||
from jsonmodels import models
|
||||
|
||||
from .apicall import APICall, APICallResult
|
||||
from .endpoint import EndpointFunc, Endpoint
|
||||
from .service_repo import ServiceRepo
|
||||
|
||||
|
||||
__all__ = ["endpoint"]
|
||||
|
||||
|
||||
LegacyEndpointFunc = Callable[[APICall], None]
|
||||
|
||||
|
||||
def endpoint(
|
||||
name: Text,
|
||||
min_version: Text = "1.0",
|
||||
required_fields: Sequence[Text] = None,
|
||||
request_data_model: models.Base = None,
|
||||
response_data_model: models.Base = None,
|
||||
validate_schema=False,
|
||||
):
|
||||
""" Endpoint decorator, used to declare a method as an endpoint handler """
|
||||
|
||||
def decorator(f: Union[EndpointFunc, LegacyEndpointFunc]) -> EndpointFunc:
|
||||
# Backwards compatibility: support endpoints with both old-style signature (call) and new-style signature
|
||||
# (call, company, request_model)
|
||||
func = f
|
||||
if len(signature(f).parameters) == 1:
|
||||
# old-style
|
||||
def adapter(call, *_, **__):
|
||||
return f(call)
|
||||
|
||||
func = adapter
|
||||
|
||||
ServiceRepo.register(
|
||||
Endpoint(
|
||||
name=name,
|
||||
func=func,
|
||||
min_version=min_version,
|
||||
required_fields=required_fields,
|
||||
request_data_model=request_data_model,
|
||||
response_data_model=response_data_model,
|
||||
validate_schema=validate_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return func
|
||||
|
||||
return decorator
|
||||
562
server/service_repo/apicall.py
Normal file
562
server/service_repo/apicall.py
Normal file
@@ -0,0 +1,562 @@
|
||||
import time
|
||||
import types
|
||||
from traceback import format_exc
|
||||
from typing import Type, Optional
|
||||
|
||||
from jsonmodels import models
|
||||
from six import string_types
|
||||
|
||||
import database
|
||||
import timing_context
|
||||
from timing_context import TimingContext
|
||||
from utilities import json
|
||||
from .auth import Identity
|
||||
from .auth import Payload as AuthPayload
|
||||
from .base import PartialVersion
|
||||
from .errors import CallParsingError
|
||||
from .schema_validator import SchemaValidator
|
||||
|
||||
JSON_CONTENT_TYPE = "application/json"
|
||||
|
||||
|
||||
class DataContainer(object):
|
||||
""" Data container that supports raw data (dict or a list of batched dicts) and a data model """
|
||||
|
||||
def __init__(self, data=None, batched_data=None):
|
||||
if data and batched_data:
|
||||
raise ValueError("data and batched data are not supported simultaneously")
|
||||
self._batched_data = None
|
||||
self._data = None
|
||||
self._data_model = None
|
||||
self._data_model_cls = None
|
||||
self._schema_validator: SchemaValidator = SchemaValidator(None)
|
||||
# use setter to properly initialize data
|
||||
self.data = data
|
||||
self.batched_data = batched_data
|
||||
self._raw_data = None
|
||||
self._content_type = JSON_CONTENT_TYPE
|
||||
|
||||
@property
|
||||
def schema_validator(self):
|
||||
return self._schema_validator
|
||||
|
||||
@schema_validator.setter
|
||||
def schema_validator(self, value):
|
||||
self._schema_validator = value
|
||||
self._update_data_model()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self._data or {}
|
||||
|
||||
@data.setter
|
||||
def data(self, value):
|
||||
""" Set the data using a raw dict. If a model cls is defined, validate the raw data """
|
||||
if value is not None:
|
||||
assert isinstance(value, dict), "Data should be a dict"
|
||||
self._data = value
|
||||
self._update_data_model()
|
||||
|
||||
@property
|
||||
def batched_data(self):
|
||||
if self._batched_data is not None:
|
||||
return self._batched_data
|
||||
elif self.data != {}:
|
||||
return [self.data]
|
||||
else:
|
||||
return []
|
||||
|
||||
@batched_data.setter
|
||||
def batched_data(self, value):
|
||||
if not value:
|
||||
return
|
||||
assert isinstance(value, (tuple, list)), "Batched data should be a list"
|
||||
self._batched_data = value
|
||||
self._update_data_model()
|
||||
|
||||
@property
|
||||
def raw_data(self):
|
||||
return self._raw_data
|
||||
|
||||
@raw_data.setter
|
||||
def raw_data(self, value):
|
||||
assert isinstance(
|
||||
value, string_types + (types.GeneratorType,)
|
||||
), "Raw data must be a string type or generator"
|
||||
self._raw_data = value
|
||||
|
||||
@property
|
||||
def content_type(self):
|
||||
return self._content_type
|
||||
|
||||
@content_type.setter
|
||||
def content_type(self, value):
|
||||
self._content_type = value
|
||||
|
||||
def _update_data_model(self):
|
||||
self.schema_validator.detailed_validate(self._data)
|
||||
|
||||
cls = self.data_model_cls
|
||||
if not cls or (self._data is None and self._batched_data is None):
|
||||
return
|
||||
|
||||
# handle batched items
|
||||
if self._batched_data:
|
||||
try:
|
||||
data_model = [cls(**item) for item in self._batched_data]
|
||||
except TypeError as ex:
|
||||
raise CallParsingError(str(ex))
|
||||
|
||||
for m in data_model:
|
||||
m.validate()
|
||||
else:
|
||||
try:
|
||||
data_model = cls(**self.data)
|
||||
except TypeError as ex:
|
||||
raise CallParsingError(str(ex))
|
||||
|
||||
if not self.schema_validator.enabled:
|
||||
data_model.validate()
|
||||
self._data_model = data_model
|
||||
|
||||
@property
|
||||
def data_model(self):
|
||||
return self._data_model
|
||||
|
||||
# @property
|
||||
# def get_partial_update(self, data_model_class):
|
||||
# return {k: v for k, v in self.data_model.to_struct().iteritems() if k in self.data}
|
||||
@property
|
||||
def data_model_for_partial_update(self):
|
||||
"""
|
||||
Return only data model fields that we actually passed by the user
|
||||
:return:
|
||||
"""
|
||||
return {k: v for k, v in self.data_model.to_struct().items() if k in self.data}
|
||||
|
||||
@data_model.setter
|
||||
def data_model(self, value):
|
||||
""" Set the data using a model instance. NOTE: batched_data is never updated. """
|
||||
cls = self.data_model_cls
|
||||
if not cls:
|
||||
raise ValueError("Data model is not defined")
|
||||
if isinstance(value, cls):
|
||||
# instance of the data model class - just take it
|
||||
self._data_model = value
|
||||
elif issubclass(cls, type(value)):
|
||||
# instance of a subclass of the data model class - create the expected class instance and use the instance
|
||||
# we received to initialize it
|
||||
self._data_model = cls(**value.to_struct())
|
||||
else:
|
||||
raise ValueError(f"Invalid data model (expecting {cls} or super classes)")
|
||||
self._data = value.to_struct()
|
||||
|
||||
@property
|
||||
def data_model_cls(self) -> Optional[Type[models.Base]]:
|
||||
return self._data_model_cls
|
||||
|
||||
@data_model_cls.setter
|
||||
def data_model_cls(self, value: Type[models.Base]):
|
||||
assert issubclass(value, models.Base)
|
||||
self._data_model_cls = value
|
||||
self._update_data_model()
|
||||
|
||||
|
||||
class APICallResult(DataContainer):
|
||||
def __init__(self, data=None, code=200, subcode=0, msg="OK", traceback=""):
|
||||
super(APICallResult, self).__init__(data)
|
||||
self._code = code
|
||||
self._subcode = subcode
|
||||
self._msg = msg
|
||||
self._traceback = traceback
|
||||
self._extra = None
|
||||
self._filename = None
|
||||
|
||||
def get_log_entry(self):
|
||||
res = dict(
|
||||
msg=self.msg,
|
||||
code=self.code,
|
||||
subcode=self.subcode,
|
||||
traceback=self._traceback,
|
||||
extra=self._extra,
|
||||
)
|
||||
if self.log_data:
|
||||
res["data"] = self.data
|
||||
return res
|
||||
|
||||
def copy_from(self, result):
|
||||
self._code = result.code
|
||||
self._subcode = result.subcode
|
||||
self._msg = result.code
|
||||
self._traceback = result.traceback
|
||||
self._extra = result.extra_log
|
||||
|
||||
@property
|
||||
def msg(self):
|
||||
return self._msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
self._msg = value
|
||||
|
||||
@property
|
||||
def code(self):
|
||||
return self._code
|
||||
|
||||
@code.setter
|
||||
def code(self, value):
|
||||
self._code = value
|
||||
|
||||
@property
|
||||
def subcode(self):
|
||||
return self._subcode
|
||||
|
||||
@subcode.setter
|
||||
def subcode(self, value):
|
||||
self._subcode = value
|
||||
|
||||
@property
|
||||
def traceback(self):
|
||||
return self._traceback
|
||||
|
||||
@traceback.setter
|
||||
def traceback(self, value):
|
||||
self._traceback = value
|
||||
|
||||
@property
|
||||
def extra_log(self):
|
||||
""" Extra data to be logged into ES """
|
||||
return self._extra
|
||||
|
||||
@extra_log.setter
|
||||
def extra_log(self, value):
|
||||
self._extra = value
|
||||
|
||||
@property
|
||||
def filename(self):
|
||||
return self._filename
|
||||
|
||||
@filename.setter
|
||||
def filename(self, value):
|
||||
self._filename = value
|
||||
|
||||
|
||||
class MissingIdentity(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class APICall(DataContainer):
|
||||
HEADER_AUTHORIZATION = "Authorization"
|
||||
HEADER_REAL_IP = "X-Real-IP"
|
||||
""" Standard headers """
|
||||
|
||||
_transaction_headers = ("X-Trains-Trx",)
|
||||
""" Transaction ID """
|
||||
|
||||
@property
|
||||
def HEADER_TRANSACTION(self):
|
||||
return self._transaction_headers[0]
|
||||
|
||||
_worker_headers = ("X-Trains-Worker",)
|
||||
""" Worker (machine) ID """
|
||||
|
||||
@property
|
||||
def HEADER_WORKER(self):
|
||||
return self._worker_headers[0]
|
||||
|
||||
_impersonate_as_headers = ("X-Trains-Impersonate-As",)
|
||||
""" Impersonate as someone else (using his identity and permissions) """
|
||||
|
||||
@property
|
||||
def HEADER_IMPERSONATE_AS(self):
|
||||
return self._impersonate_as_headers[0]
|
||||
|
||||
_act_as_headers = ("X-Trains-Act-As",)
|
||||
""" Act as someone else (using his identity, but with your own role and permissions) """
|
||||
|
||||
@property
|
||||
def HEADER_ACT_AS(self):
|
||||
return self._act_as_headers[0]
|
||||
|
||||
_async_headers = ("X-Trains-Async",)
|
||||
""" Specifies that this call should be done asynchronously """
|
||||
|
||||
@property
|
||||
def HEADER_ASYNC(self):
|
||||
return self._async_headers[0]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint_name,
|
||||
remote_addr=None,
|
||||
endpoint_version: PartialVersion = PartialVersion("1.0"),
|
||||
data=None,
|
||||
batched_data=None,
|
||||
headers=None,
|
||||
files=None,
|
||||
trx=None,
|
||||
):
|
||||
super(APICall, self).__init__(data=data, batched_data=batched_data)
|
||||
|
||||
timing_context.clear()
|
||||
|
||||
self._id = database.utils.id()
|
||||
self._files = files # currently dic of key to flask's FileStorage)
|
||||
self._start_ts = time.time()
|
||||
self._end_ts = 0
|
||||
self._duration = 0
|
||||
self._endpoint_name = endpoint_name
|
||||
self._remote_addr = remote_addr
|
||||
assert isinstance(endpoint_version, PartialVersion), endpoint_version
|
||||
self._requested_endpoint_version = endpoint_version
|
||||
self._actual_endpoint_version = None
|
||||
self._headers = {}
|
||||
self._kpis = {}
|
||||
self._log_api = True
|
||||
if headers:
|
||||
self._headers.update(headers)
|
||||
self._result = APICallResult()
|
||||
self._auth = None
|
||||
self._impersonation = None
|
||||
if trx:
|
||||
self.set_header(self._transaction_headers, trx)
|
||||
self._requires_authorization = True
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def requires_authorization(self):
|
||||
return self._requires_authorization
|
||||
|
||||
@requires_authorization.setter
|
||||
def requires_authorization(self, value):
|
||||
self._requires_authorization = value
|
||||
|
||||
@property
|
||||
def log_api(self):
|
||||
return self._log_api
|
||||
|
||||
@log_api.setter
|
||||
def log_api(self, value):
|
||||
self._log_api = value
|
||||
|
||||
def assign_new_id(self):
|
||||
self._id = database.utils.id()
|
||||
|
||||
def get_header(self, header, default=None):
|
||||
"""
|
||||
Get header value
|
||||
:param header: Header name options (more than on supported, listed by priority)
|
||||
:param default: Default value if no such headers were found
|
||||
"""
|
||||
for option in header if isinstance(header, (tuple, list)) else (header,):
|
||||
if option in self._headers:
|
||||
return self._headers[option]
|
||||
return default
|
||||
|
||||
def clear_header(self, header):
|
||||
"""
|
||||
Clear header value
|
||||
:param header: Header name options (more than on supported, all will be cleared)
|
||||
"""
|
||||
for value in header if isinstance(header, (tuple, list)) else (header,):
|
||||
self.headers.pop(value, None)
|
||||
|
||||
def set_header(self, header, value):
|
||||
"""
|
||||
Set header value
|
||||
:param header: header name (if a list is provided, first item is used)
|
||||
:param value: Value to set
|
||||
:return:
|
||||
"""
|
||||
self._headers[
|
||||
header[0] if isinstance(header, (tuple, list)) else header
|
||||
] = value
|
||||
|
||||
@property
|
||||
def real_ip(self):
|
||||
real_ip = self.get_header(self.HEADER_REAL_IP)
|
||||
return real_ip or self._remote_addr or "untrackable"
|
||||
|
||||
@property
|
||||
def failed(self):
|
||||
return self.result and self.result.code != 200
|
||||
|
||||
@property
|
||||
def duration(self):
|
||||
return self._duration
|
||||
|
||||
@property
|
||||
def endpoint_name(self):
|
||||
return self._endpoint_name
|
||||
|
||||
@property
|
||||
def requested_endpoint_version(self) -> PartialVersion:
|
||||
return self._requested_endpoint_version
|
||||
|
||||
@property
|
||||
def auth(self):
|
||||
""" Authenticated payload (Token or Basic) """
|
||||
return self._auth
|
||||
|
||||
@auth.setter
|
||||
def auth(self, value):
|
||||
if value:
|
||||
assert isinstance(value, AuthPayload)
|
||||
self._auth = value
|
||||
|
||||
@property
|
||||
def impersonation_headers(self):
|
||||
return {
|
||||
k: v
|
||||
for k, v in self._headers.items()
|
||||
if k in (self._impersonate_as_headers + self._act_as_headers)
|
||||
}
|
||||
|
||||
@property
|
||||
def impersonate_as(self):
|
||||
return self.get_header(self._impersonate_as_headers)
|
||||
|
||||
@property
|
||||
def act_as(self):
|
||||
return self.get_header(self._act_as_headers)
|
||||
|
||||
@property
|
||||
def impersonation(self):
|
||||
return self._impersonation
|
||||
|
||||
@impersonation.setter
|
||||
def impersonation(self, value):
|
||||
if value:
|
||||
assert isinstance(value, AuthPayload)
|
||||
self._impersonation = value
|
||||
|
||||
@property
|
||||
def identity(self) -> Identity:
|
||||
if self.impersonation:
|
||||
if not self.impersonation.identity:
|
||||
raise Exception("Missing impersonate identity")
|
||||
return self.impersonation.identity
|
||||
if self.auth:
|
||||
if not self.auth.identity:
|
||||
raise Exception("Missing authorized identity (not authorized?)")
|
||||
return self.auth.identity
|
||||
raise MissingIdentity("Missing identity")
|
||||
|
||||
@property
|
||||
def actual_endpoint_version(self):
|
||||
return self._actual_endpoint_version
|
||||
|
||||
@actual_endpoint_version.setter
|
||||
def actual_endpoint_version(self, value):
|
||||
self._actual_endpoint_version = value
|
||||
|
||||
@property
|
||||
def headers(self):
|
||||
return self._headers
|
||||
|
||||
@property
|
||||
def kpis(self):
|
||||
"""
|
||||
Key Performance Indicators, holding things like number of returned frames/rois, etc.
|
||||
:return:
|
||||
"""
|
||||
return self._kpis
|
||||
|
||||
@property
|
||||
def trx(self):
|
||||
return self.get_header(self._transaction_headers, self.id)
|
||||
|
||||
@trx.setter
|
||||
def trx(self, value):
|
||||
self.set_header(self._transaction_headers, value)
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
return self.get_header(self._worker_headers, "<unknown>")
|
||||
|
||||
@property
|
||||
def authorization(self):
|
||||
""" Call authorization data used to authenticate the call """
|
||||
return self.get_header(self.HEADER_AUTHORIZATION)
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return self._result
|
||||
|
||||
@property
|
||||
def exec_async(self):
|
||||
return self.get_header(self._async_headers) is not None
|
||||
|
||||
@exec_async.setter
|
||||
def exec_async(self, value):
|
||||
if value:
|
||||
self.set_header(self._async_headers, "1")
|
||||
else:
|
||||
self.clear_header(self._async_headers)
|
||||
|
||||
def mark_end(self):
|
||||
self._end_ts = time.time()
|
||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||
self.stats = timing_context.stats()
|
||||
|
||||
def get_response(self):
|
||||
def make_version_number(version):
|
||||
"""
|
||||
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
|
||||
"""
|
||||
if version is None:
|
||||
return None
|
||||
if self.requested_endpoint_version < PartialVersion("2.1"):
|
||||
return float(str(version))
|
||||
return str(version)
|
||||
|
||||
if self.result.raw_data and not self.failed:
|
||||
# endpoint returned raw data and no error was detected, return raw data, no fancy dicts
|
||||
return self.result.raw_data, self.result.content_type
|
||||
|
||||
else:
|
||||
res = {
|
||||
"meta": {
|
||||
"id": self.id,
|
||||
"trx": self.trx,
|
||||
"endpoint": {
|
||||
"name": self.endpoint_name,
|
||||
"requested_version": make_version_number(
|
||||
self.requested_endpoint_version
|
||||
),
|
||||
"actual_version": make_version_number(
|
||||
self.actual_endpoint_version
|
||||
),
|
||||
},
|
||||
"result_code": self.result.code,
|
||||
"result_subcode": self.result.subcode,
|
||||
"result_msg": self.result.msg,
|
||||
"error_stack": self.result.traceback,
|
||||
},
|
||||
"data": self.result.data,
|
||||
}
|
||||
if self.content_type.lower() == JSON_CONTENT_TYPE:
|
||||
with TimingContext("json", "serialization"):
|
||||
try:
|
||||
res = json.dumps(res)
|
||||
except Exception as ex:
|
||||
# JSON serialization may fail, probably problem with data so pop it and try again
|
||||
if not self.result.data:
|
||||
raise
|
||||
self.result.data = None
|
||||
msg = "Error serializing response data: " + str(ex)
|
||||
self.set_error_result(
|
||||
code=500, subcode=0, msg=msg, include_stack=False
|
||||
)
|
||||
return self.get_response()
|
||||
|
||||
return res, self.content_type
|
||||
|
||||
def set_error_result(self, msg, code=500, subcode=0, include_stack=False):
|
||||
tb = format_exc() if include_stack else None
|
||||
self._result = APICallResult(
|
||||
data=self._result.data, code=code, subcode=subcode, msg=msg, traceback=tb
|
||||
)
|
||||
4
server/service_repo/auth/__init__.py
Normal file
4
server/service_repo/auth/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .auth import get_auth_func, authorize_impersonation
|
||||
from .payload import Token, Basic, AuthType, Payload
|
||||
from .identity import Identity
|
||||
from .utils import get_client_id, get_secret_key
|
||||
86
server/service_repo/auth/auth.py
Normal file
86
server/service_repo/auth/auth.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import base64
|
||||
import jwt
|
||||
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.company import Company
|
||||
from database.utils import get_options
|
||||
from database.model.auth import User, Entities, Credentials
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from timing_context import TimingContext
|
||||
|
||||
from .payload import Payload, Token, Basic, AuthType
|
||||
from .identity import Identity
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
entity_keys = set(get_options(Entities))
|
||||
|
||||
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
|
||||
|
||||
|
||||
def get_auth_func(auth_type):
|
||||
if auth_type == AuthType.bearer_token:
|
||||
return authorize_token
|
||||
elif auth_type == AuthType.basic:
|
||||
return authorize_credentials
|
||||
raise errors.unauthorized.BadAuthType()
|
||||
|
||||
|
||||
def authorize_token(jwt_token, *_, **__):
|
||||
""" Validate token against service/endpoint and requests data (dicts).
|
||||
Returns a parsed token object (auth payload)
|
||||
"""
|
||||
try:
|
||||
return Token.from_encoded_token(jwt_token)
|
||||
|
||||
except jwt.exceptions.InvalidKeyError as ex:
|
||||
raise errors.unauthorized.InvalidToken('jwt invalid key error', reason=ex.args[0])
|
||||
except jwt.InvalidTokenError as ex:
|
||||
raise errors.unauthorized.InvalidToken('invalid jwt token', reason=ex.args[0])
|
||||
except ValueError as ex:
|
||||
log.exception('Failed while processing token: %s' % ex.args[0])
|
||||
raise errors.unauthorized.InvalidToken('failed processing token', reason=ex.args[0])
|
||||
|
||||
|
||||
def authorize_credentials(auth_data, service, action, call_data_items):
|
||||
""" Validate credentials against service/action and request data (dicts).
|
||||
Returns a new basic object (auth payload)
|
||||
"""
|
||||
try:
|
||||
access_key, _, secret_key = base64.b64decode(auth_data.encode()).decode('latin-1').partition(':')
|
||||
except Exception as e:
|
||||
log.exception('malformed credentials')
|
||||
raise errors.unauthorized.BadCredentials(str(e))
|
||||
|
||||
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
|
||||
user = User.objects(credentials__match=Credentials(key=access_key, secret=secret_key)).first()
|
||||
|
||||
if not user:
|
||||
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
|
||||
|
||||
with TimingContext("mongo", "company_by_id"):
|
||||
company = Company.objects(id=user.company).only('id', 'name').first()
|
||||
|
||||
if not company:
|
||||
raise errors.unauthorized.InvalidCredentials('invalid user company')
|
||||
|
||||
identity = Identity(user=user.id, company=user.company, role=user.role,
|
||||
user_name=user.name, company_name=company.name)
|
||||
|
||||
basic = Basic(user_key=access_key, identity=identity)
|
||||
|
||||
return basic
|
||||
|
||||
|
||||
def authorize_impersonation(user, identity, service, action, call_data_items):
|
||||
""" Returns a new basic object (auth payload)"""
|
||||
if not user:
|
||||
raise ValueError('missing user')
|
||||
|
||||
company = Company.objects(id=user.company).only('id', 'name').first()
|
||||
if not company:
|
||||
raise errors.unauthorized.InvalidCredentials('invalid user company')
|
||||
|
||||
return Payload(auth_type=None, identity=identity)
|
||||
24
server/service_repo/auth/dictable.py
Normal file
24
server/service_repo/auth/dictable.py
Normal file
@@ -0,0 +1,24 @@
|
||||
class Dictable(object):
|
||||
_cached_props = None
|
||||
|
||||
@classmethod
|
||||
def _get_cached_props(cls):
|
||||
if cls._cached_props is None:
|
||||
props = set()
|
||||
for c in cls.mro():
|
||||
props.update(k for k, v in vars(c).items() if isinstance(v, property) and not k.startswith('_'))
|
||||
cls._cached_props = list(props)
|
||||
return cls._cached_props
|
||||
|
||||
def to_dict(self, **extra):
|
||||
props = self._get_cached_props()
|
||||
d = {k: getattr(self, k) for k in props if getattr(self, k)}
|
||||
res = {k: (v.to_dict() if isinstance(v, Dictable) else v) for k, v in d.items()}
|
||||
if extra:
|
||||
# add the extra items to our result, make sure not to overwrite existing properties (claims etc)
|
||||
res.update({k: v for k, v in extra.items() if k not in props})
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d):
|
||||
return cls(**d)
|
||||
30
server/service_repo/auth/identity.py
Normal file
30
server/service_repo/auth/identity.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from .dictable import Dictable
|
||||
|
||||
|
||||
class Identity(Dictable):
|
||||
def __init__(self, user, company, role, user_name=None, company_name=None):
|
||||
self._user = user
|
||||
self._company = company
|
||||
self._role = role
|
||||
self._user_name = user_name
|
||||
self._company_name = company_name
|
||||
|
||||
@property
|
||||
def user(self):
|
||||
return self._user
|
||||
|
||||
@property
|
||||
def company(self):
|
||||
return self._company
|
||||
|
||||
@property
|
||||
def role(self):
|
||||
return self._role
|
||||
|
||||
@property
|
||||
def user_name(self):
|
||||
return self._user_name
|
||||
|
||||
@property
|
||||
def company_name(self):
|
||||
return self._company_name
|
||||
4
server/service_repo/auth/payload/__init__.py
Normal file
4
server/service_repo/auth/payload/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .auth_type import AuthType
|
||||
from .payload import Payload
|
||||
from .basic import Basic
|
||||
from .token import Token
|
||||
3
server/service_repo/auth/payload/auth_type.py
Normal file
3
server/service_repo/auth/payload/auth_type.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class AuthType(object):
|
||||
basic = 'Basic'
|
||||
bearer_token = 'Bearer'
|
||||
18
server/service_repo/auth/payload/basic.py
Normal file
18
server/service_repo/auth/payload/basic.py
Normal file
@@ -0,0 +1,18 @@
|
||||
from .payload import Payload
|
||||
from .auth_type import AuthType
|
||||
|
||||
|
||||
class Basic(Payload):
|
||||
def __init__(self, user_key, identity=None, entities=None, **_):
|
||||
super(Basic, self).__init__(
|
||||
AuthType.basic, identity=identity, entities=entities)
|
||||
self._user_key = user_key
|
||||
|
||||
@property
|
||||
def user_key(self):
|
||||
return self._user_key
|
||||
|
||||
def get_log_entry(self):
|
||||
d = super(Basic, self).get_log_entry()
|
||||
d.update(user_key=self.user_key)
|
||||
return d
|
||||
52
server/service_repo/auth/payload/payload.py
Normal file
52
server/service_repo/auth/payload/payload.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from apierrors import errors
|
||||
|
||||
from ..identity import Identity
|
||||
from ..dictable import Dictable
|
||||
|
||||
|
||||
class Payload(Dictable):
|
||||
def __init__(self, auth_type, identity=None, entities=None):
|
||||
self._auth_type = auth_type
|
||||
self.identity = identity
|
||||
self.entities = entities or {}
|
||||
|
||||
@property
|
||||
def auth_type(self):
|
||||
return self._auth_type
|
||||
|
||||
@property
|
||||
def identity(self):
|
||||
return self._identity
|
||||
|
||||
@identity.setter
|
||||
def identity(self, value):
|
||||
if isinstance(value, dict):
|
||||
value = Identity(**value)
|
||||
else:
|
||||
assert isinstance(value, Identity)
|
||||
self._identity = value
|
||||
|
||||
@property
|
||||
def entities(self):
|
||||
return self._entities
|
||||
|
||||
@entities.setter
|
||||
def entities(self, value):
|
||||
self._entities = value
|
||||
|
||||
def get_log_entry(self):
|
||||
return {
|
||||
"type": self.auth_type,
|
||||
"identity": self.identity.to_dict(),
|
||||
"entities": self.entities,
|
||||
}
|
||||
|
||||
def validate_entities(self, **entities):
|
||||
""" Validate entities. key/value represents entity_name/entity_id(s) """
|
||||
if not self.entities:
|
||||
return
|
||||
for entity_name, entity_id in entities.items():
|
||||
constraints = self.entities.get(entity_name)
|
||||
ids = set(entity_id if isinstance(entity_id, (tuple, list)) else (entity_id,))
|
||||
if constraints and not ids.issubset(constraints):
|
||||
raise errors.unauthorized.EntityNotAllowed(entity=entity_name)
|
||||
96
server/service_repo/auth/payload/token.py
Normal file
96
server/service_repo/auth/payload/token.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import jwt
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
from database.model.auth import Role
|
||||
|
||||
from .auth_type import AuthType
|
||||
from .payload import Payload
|
||||
|
||||
token_secret = config.get('secure.auth.token_secret')
|
||||
|
||||
|
||||
class Token(Payload):
|
||||
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
|
||||
|
||||
def __init__(self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_):
|
||||
super(Token, self).__init__(
|
||||
AuthType.bearer_token, identity=identity, entities=entities)
|
||||
self.exp = exp
|
||||
self.iat = iat
|
||||
self.nbf = nbf
|
||||
self._env = env or config.get('env', '<unknown>')
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self._env
|
||||
|
||||
@property
|
||||
def exp(self):
|
||||
return self._exp
|
||||
|
||||
@exp.setter
|
||||
def exp(self, value):
|
||||
self._exp = value
|
||||
|
||||
@property
|
||||
def iat(self):
|
||||
return self._iat
|
||||
|
||||
@iat.setter
|
||||
def iat(self, value):
|
||||
self._iat = value
|
||||
|
||||
@property
|
||||
def nbf(self):
|
||||
return self._nbf
|
||||
|
||||
@nbf.setter
|
||||
def nbf(self, value):
|
||||
self._nbf = value
|
||||
|
||||
def get_log_entry(self):
|
||||
d = super(Token, self).get_log_entry()
|
||||
d.update(iat=self.iat, exp=self.exp, env=self.env)
|
||||
return d
|
||||
|
||||
def encode(self, **extra_payload):
|
||||
payload = self.to_dict(**extra_payload)
|
||||
return jwt.encode(payload, token_secret)
|
||||
|
||||
@classmethod
|
||||
def decode(cls, encoded_token, verify=True):
|
||||
return jwt.decode(encoded_token, token_secret, verify=verify)
|
||||
|
||||
@classmethod
|
||||
def from_encoded_token(cls, encoded_token, verify=True):
|
||||
decoded = cls.decode(encoded_token, verify=verify)
|
||||
try:
|
||||
token = Token.from_dict(decoded)
|
||||
assert isinstance(token, Token)
|
||||
if not token.identity:
|
||||
raise errors.unauthorized.InvalidToken('token missing identity')
|
||||
return token
|
||||
except Exception as e:
|
||||
raise errors.unauthorized.InvalidToken('failed parsing token, %s' % e.args[0])
|
||||
|
||||
@classmethod
|
||||
def create_encoded_token(cls, identity, expiration_sec=None, entities=None, **extra_payload):
|
||||
if identity.role not in (Role.system,):
|
||||
# limit expiration time for all roles but an internal service
|
||||
expiration_sec = expiration_sec or cls.default_expiration_sec
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
token = cls(
|
||||
identity=identity,
|
||||
entities=entities,
|
||||
iat=now)
|
||||
|
||||
if expiration_sec:
|
||||
# add 'expiration' claim
|
||||
token.exp = now + timedelta(seconds=expiration_sec)
|
||||
|
||||
return token.encode(**extra_payload)
|
||||
35
server/service_repo/auth/utils.py
Normal file
35
server/service_repo/auth/utils.py
Normal file
@@ -0,0 +1,35 @@
|
||||
import random
|
||||
sys_random = random.SystemRandom()
|
||||
|
||||
|
||||
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
|
||||
"""
|
||||
Returns a securely generated random string.
|
||||
|
||||
The default length of 12 with the a-z, A-Z, 0-9 character set returns
|
||||
a 71-bit value. log_2((26+26+10)^12) =~ 71 bits.
|
||||
|
||||
Taken from the django.utils.crypto module.
|
||||
"""
|
||||
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
|
||||
|
||||
|
||||
def get_client_id(length=20):
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
Taken from the Django project.
|
||||
"""
|
||||
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
||||
return get_random_string(length, chars)
|
||||
|
||||
|
||||
def get_secret_key(length=50):
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
Taken from the Django project.
|
||||
"""
|
||||
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*(-_=+)'
|
||||
return get_random_string(length, chars)
|
||||
7
server/service_repo/base.py
Normal file
7
server/service_repo/base.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from semantic_version import Version
|
||||
|
||||
|
||||
class PartialVersion(Version):
|
||||
def __init__(self, version_string: str):
|
||||
assert isinstance(version_string, str)
|
||||
super().__init__(version_string, partial=True)
|
||||
118
server/service_repo/endpoint.py
Normal file
118
server/service_repo/endpoint.py
Normal file
@@ -0,0 +1,118 @@
|
||||
from typing import Callable, Sequence, Text
|
||||
|
||||
from jsonmodels import models
|
||||
from jsonmodels.errors import FieldNotSupported
|
||||
|
||||
from schema import schema
|
||||
from .apicall import APICall
|
||||
from .base import PartialVersion
|
||||
from .schema_validator import SchemaValidator
|
||||
|
||||
EndpointFunc = Callable[[APICall, Text, models.Base], None]
|
||||
|
||||
|
||||
class Endpoint(object):
|
||||
_endpoint_config_cache = {}
|
||||
"""
|
||||
Endpoints configuration cache, in the format of {full endpoint name: dict}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: Text,
|
||||
func: EndpointFunc,
|
||||
min_version: Text = "1.0",
|
||||
required_fields: Sequence[Text] = None,
|
||||
request_data_model: models.Base = None,
|
||||
response_data_model: models.Base = None,
|
||||
validate_schema: bool = False,
|
||||
):
|
||||
"""
|
||||
Endpoint configuration
|
||||
:param name: full endpoint name
|
||||
:param func: endpoint implementation
|
||||
:param min_version: minimum supported version
|
||||
:param required_fields: required request fields, can not be used with validate_schema
|
||||
:param request_data_model: request jsonschema model, will be validated if validate_schema=False
|
||||
:param response_data_model: response jsonschema model, will be validated if validate_schema=False
|
||||
:param validate_schema: whether request and response schema should be validated
|
||||
"""
|
||||
super(Endpoint, self).__init__()
|
||||
self.name = name
|
||||
self.min_version = PartialVersion(min_version)
|
||||
self.func = func
|
||||
self.required_fields = required_fields
|
||||
self.request_data_model = request_data_model
|
||||
self.response_data_model = response_data_model
|
||||
service, _, endpoint_name = self.name.partition(".")
|
||||
try:
|
||||
self.endpoint_group = schema.services[service].endpoint_groups[
|
||||
endpoint_name
|
||||
]
|
||||
except KeyError:
|
||||
raise RuntimeError(
|
||||
f"schema for endpoint {service}.{endpoint_name} not found"
|
||||
)
|
||||
if validate_schema:
|
||||
if self.required_fields:
|
||||
raise ValueError(
|
||||
f"endpoint {self.name}: can not use 'required_fields' with 'validate_schema'"
|
||||
)
|
||||
endpoint = self.endpoint_group.get_for_version(self.min_version)
|
||||
request_schema = endpoint.request_schema
|
||||
response_schema = endpoint.response_schema
|
||||
else:
|
||||
request_schema = None
|
||||
response_schema = None
|
||||
self.request_schema_validator = SchemaValidator(request_schema)
|
||||
self.response_schema_validator = SchemaValidator(response_schema)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{type(self).__name__}<{self.name}>"
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Used by `server.endpoints` endpoint.
|
||||
Provided endpoints and their schemas on a best-effort basis.
|
||||
"""
|
||||
d = {
|
||||
"min_version": self.min_version,
|
||||
"required_fields": self.required_fields,
|
||||
"request_data_model": None,
|
||||
"response_data_model": None,
|
||||
}
|
||||
|
||||
def safe_to_json_schema(data_model: models.Base):
|
||||
"""
|
||||
Provided data_model schema if available
|
||||
"""
|
||||
try:
|
||||
return data_model.to_json_schema()
|
||||
except (FieldNotSupported, TypeError):
|
||||
return str(data_model.__name__)
|
||||
|
||||
if self.request_data_model:
|
||||
d["request_data_model"] = safe_to_json_schema(self.request_data_model)
|
||||
if self.response_data_model:
|
||||
d["response_data_model"] = safe_to_json_schema(self.response_data_model)
|
||||
if self.request_schema_validator.enabled:
|
||||
d["request_schema"] = self.request_schema_validator.schema
|
||||
if self.response_schema_validator.enabled:
|
||||
d["response_schema"] = self.response_schema_validator.schema
|
||||
|
||||
return d
|
||||
|
||||
@property
|
||||
def authorize(self):
|
||||
return self.endpoint_group.authorize
|
||||
|
||||
@property
|
||||
def allow_roles(self):
|
||||
return self.endpoint_group.allow_roles
|
||||
|
||||
def allows(self, role):
|
||||
return self.endpoint_group.allows(role)
|
||||
|
||||
@property
|
||||
def is_internal(self):
|
||||
return self.endpoint_group.internal
|
||||
19
server/service_repo/errors.py
Normal file
19
server/service_repo/errors.py
Normal file
@@ -0,0 +1,19 @@
|
||||
class PathParsingError(Exception):
|
||||
def __init__(self, msg):
|
||||
super(PathParsingError, self).__init__(msg)
|
||||
|
||||
|
||||
class MalformedPathError(PathParsingError):
|
||||
pass
|
||||
|
||||
|
||||
class InvalidVersionError(PathParsingError):
|
||||
pass
|
||||
|
||||
|
||||
class CallParsingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class CallFailedError(Exception):
|
||||
pass
|
||||
72
server/service_repo/schema_validator.py
Normal file
72
server/service_repo/schema_validator.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import sys
|
||||
from typing import Optional, Callable
|
||||
|
||||
import attr
|
||||
import fastjsonschema
|
||||
import jsonschema
|
||||
from boltons.iterutils import remap
|
||||
|
||||
from apierrors import errors
|
||||
from config import config
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, auto_exc=True)
|
||||
class FastValidationError(Exception):
|
||||
error: fastjsonschema.JsonSchemaException
|
||||
data: dict
|
||||
|
||||
|
||||
class SchemaValidator:
|
||||
def __init__(self, schema: Optional[dict]):
|
||||
"""
|
||||
Utility for different schema validation strategies
|
||||
:param schema: jsonschema to validate against
|
||||
"""
|
||||
self.schema = schema
|
||||
self.validator: Callable = schema and fastjsonschema.compile(schema)
|
||||
|
||||
@property
|
||||
def enabled(self) -> bool:
|
||||
return self.schema is not None
|
||||
|
||||
def fast_validate(self, data: dict) -> None:
|
||||
"""
|
||||
Perform a quick validate with laconic error messages
|
||||
:param data: data to validate
|
||||
:raises: fastjsonschema.JsonSchemaException
|
||||
"""
|
||||
if self.enabled and data is not None:
|
||||
data = remap(data, lambda path, key, value: value is not None)
|
||||
try:
|
||||
self.validator(data)
|
||||
except fastjsonschema.JsonSchemaException as e:
|
||||
raise FastValidationError(e, data) from e
|
||||
|
||||
def detailed_validate(self, data: dict) -> None:
|
||||
"""
|
||||
Perform a slow validate with detailed error messages
|
||||
:param data: data to validate
|
||||
:raises: errors.bad_request.ValidationError
|
||||
"""
|
||||
try:
|
||||
self.fast_validate(data)
|
||||
except FastValidationError as error:
|
||||
_, _, traceback = sys.exc_info()
|
||||
try:
|
||||
jsonschema.validate(error.data, self.schema)
|
||||
except jsonschema.exceptions.ValidationError as detailed_error:
|
||||
raise errors.bad_request.ValidationError(
|
||||
message=detailed_error.message,
|
||||
path=list(detailed_error.path),
|
||||
context=detailed_error.context,
|
||||
cause=detailed_error.cause,
|
||||
validator=detailed_error.validator,
|
||||
validator_value=detailed_error.validator_value,
|
||||
instance=detailed_error.instance,
|
||||
parent=detailed_error.parent,
|
||||
)
|
||||
else:
|
||||
log.error("fast validation failed while detailed validation succeeded")
|
||||
raise error.error.with_traceback(traceback)
|
||||
289
server/service_repo/service_repo.py
Normal file
289
server/service_repo/service_repo.py
Normal file
@@ -0,0 +1,289 @@
|
||||
import re
|
||||
from importlib import import_module
|
||||
from itertools import chain
|
||||
from typing import cast, Iterable, List, MutableMapping
|
||||
|
||||
import jsonmodels.models
|
||||
from pathlib import Path
|
||||
|
||||
import timing_context
|
||||
from apierrors import APIError
|
||||
from apierrors.errors.bad_request import RequestPathHasInvalidVersion
|
||||
from config import config
|
||||
from service_repo.base import PartialVersion
|
||||
from .apicall import APICall
|
||||
from .endpoint import Endpoint
|
||||
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
|
||||
from .util import parse_return_stack_on_code
|
||||
from .validators import validate_all
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ServiceRepo(object):
|
||||
_endpoints: MutableMapping[str, List[Endpoint]] = {}
|
||||
"""
|
||||
Registered endpoints, in the format of {endpoint_name: Endpoint)}
|
||||
the list of endpoints is sorted by min_version
|
||||
"""
|
||||
|
||||
_version_required = config.get("apiserver.version.required")
|
||||
""" If version is required, parsing will fail for endpoint paths that do not contain a valid version """
|
||||
|
||||
_max_version = PartialVersion("2.1")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
re.compile(
|
||||
r"^/?v(?P<endpoint_version>\d+\.?\d+)/(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
|
||||
)
|
||||
if config.get("apiserver.version.required")
|
||||
else re.compile(
|
||||
r"^/?(v(?P<endpoint_version>\d+\.?\d+)/)?(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
|
||||
)
|
||||
)
|
||||
"""
|
||||
Endpoint structure expressions. We have two expressions, one with optional version part.
|
||||
Constraints for the first (strict) expression:
|
||||
1. May start with a leading '/'
|
||||
2. Followed by a version number (int or float) preceded by a leading 'v'
|
||||
3. Followed by a '/'
|
||||
4. Followed by a service name, which must start with an english letter (lower or upper case) or underscore,
|
||||
and followed by any number of alphanumeric or underscore characters
|
||||
5. Followed by a '.'
|
||||
6. Followed by an action name, which must start with an english letter (lower or upper case) or underscore,
|
||||
and followed by any number of alphanumeric or underscore characters
|
||||
7. May end with a leading '/'
|
||||
|
||||
The second (optional version) expression does not require steps 2 and 3.
|
||||
"""
|
||||
|
||||
_return_stack = config.get("apiserver.return_stack")
|
||||
""" return stack trace on error """
|
||||
|
||||
_return_stack_on_code = parse_return_stack_on_code(
|
||||
config.get("apiserver.return_stack_on_code", {})
|
||||
)
|
||||
""" if 'return_stack' is true and error contains a return code, return stack trace only for these error codes """
|
||||
|
||||
_credentials = config["secure.credentials.apiserver"]
|
||||
""" Api Server credentials used for intra-service communication """
|
||||
|
||||
_token = None
|
||||
""" Token for internal calls """
|
||||
|
||||
@classmethod
|
||||
def load(cls, root_module="services"):
|
||||
root_module = Path(root_module)
|
||||
sub_module = None
|
||||
for sub_module in root_module.glob("*"):
|
||||
if (
|
||||
sub_module.is_file()
|
||||
and sub_module.suffix == ".py"
|
||||
and not sub_module.stem == "__init__"
|
||||
):
|
||||
import_module(f"{root_module.stem}.{sub_module.stem}")
|
||||
if sub_module.is_dir():
|
||||
import_module(f"{root_module.stem}.{sub_module.stem}")
|
||||
# leave no trace of the 'sub_module' local
|
||||
del sub_module
|
||||
|
||||
cls._max_version = max(
|
||||
cls._max_version,
|
||||
max(
|
||||
ep.min_version
|
||||
for ep in cast(Iterable[Endpoint], chain(*cls._endpoints.values()))
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def register(cls, endpoint):
|
||||
assert isinstance(endpoint, Endpoint)
|
||||
if cls._endpoints.get(endpoint.name):
|
||||
if any(
|
||||
ep.min_version == endpoint.min_version
|
||||
for ep in cls._endpoints[endpoint.name]
|
||||
):
|
||||
raise Exception(
|
||||
f"Trying to register an existing endpoint. name={endpoint.name}, version={endpoint.min_version}"
|
||||
)
|
||||
else:
|
||||
cls._endpoints[endpoint.name].append(endpoint)
|
||||
else:
|
||||
cls._endpoints[endpoint.name] = [endpoint]
|
||||
|
||||
cls._endpoints[endpoint.name].sort(key=lambda ep: ep.min_version, reverse=True)
|
||||
|
||||
@classmethod
|
||||
def endpoint_names(cls):
|
||||
return sorted(cls._endpoints.keys())
|
||||
|
||||
@classmethod
|
||||
def endpoints_summary(cls):
|
||||
return {
|
||||
"endpoints": {
|
||||
name: list(map(Endpoint.to_dict, eps))
|
||||
for name, eps in cls._endpoints.items()
|
||||
},
|
||||
"models": {},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def max_endpoint_version(cls) -> PartialVersion:
|
||||
return cls._max_version
|
||||
|
||||
@classmethod
|
||||
def _get_endpoint(cls, name, version):
|
||||
versions = cls._endpoints.get(name)
|
||||
if not versions:
|
||||
return None
|
||||
try:
|
||||
return next(ep for ep in versions if ep.min_version <= version)
|
||||
except StopIteration:
|
||||
# no appropriate version found
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _resolve_endpoint_from_call(cls, call):
|
||||
assert isinstance(call, APICall)
|
||||
endpoint = cls._get_endpoint(
|
||||
call.endpoint_name, call.requested_endpoint_version
|
||||
)
|
||||
if endpoint is None:
|
||||
call.log_api = False
|
||||
call.set_error_result(
|
||||
msg=(
|
||||
f"Unable to find endpoint for name {call.endpoint_name} "
|
||||
f"and version {call.requested_endpoint_version}"
|
||||
),
|
||||
code=404,
|
||||
subcode=0,
|
||||
)
|
||||
return
|
||||
|
||||
assert isinstance(endpoint, Endpoint)
|
||||
call.actual_endpoint_version: PartialVersion = endpoint.min_version
|
||||
call.requires_authorization = endpoint.authorize
|
||||
return endpoint
|
||||
|
||||
@classmethod
|
||||
def parse_endpoint_path(cls, path):
|
||||
""" Parse endpoint version, service and action from request path. """
|
||||
m = cls._endpoint_exp.match(path)
|
||||
if not m:
|
||||
raise MalformedPathError("Invalid request path %s" % path)
|
||||
endpoint_name = m.group("endpoint_name")
|
||||
version = m.group("endpoint_version")
|
||||
if version is None:
|
||||
# If endpoint is available, use the max version
|
||||
version = cls._max_version
|
||||
else:
|
||||
try:
|
||||
version = PartialVersion(version)
|
||||
except ValueError as e:
|
||||
raise RequestPathHasInvalidVersion(version=version, reason=e)
|
||||
if version > cls._max_version:
|
||||
raise InvalidVersionError(
|
||||
f"Invalid API version (max. supported version is {cls._max_version})"
|
||||
)
|
||||
return version, endpoint_name
|
||||
|
||||
@classmethod
|
||||
def _should_return_stack(cls, code, subcode):
|
||||
if not cls._return_stack or code not in cls._return_stack_on_code:
|
||||
return False
|
||||
if subcode is None:
|
||||
# Code in dict, but no subcode. We'll allow it.
|
||||
return True
|
||||
subcode_list = cls._return_stack_on_code.get(code)
|
||||
if subcode_list is None:
|
||||
# if the code is there but we don't have any subcode list, always return stack
|
||||
return True
|
||||
return subcode in subcode_list
|
||||
|
||||
@classmethod
|
||||
def _validate_call(cls, call):
|
||||
endpoint = cls._resolve_endpoint_from_call(call)
|
||||
if call.failed:
|
||||
return
|
||||
validate_all(call, endpoint)
|
||||
return endpoint
|
||||
|
||||
@classmethod
|
||||
def validate_call(cls, call):
|
||||
cls._validate_call(call)
|
||||
|
||||
@classmethod
|
||||
def _get_company(cls, call, endpoint=None, ignore_error=False):
|
||||
authorize = endpoint and endpoint.authorize
|
||||
if ignore_error or not authorize:
|
||||
try:
|
||||
return call.identity.company
|
||||
except Exception:
|
||||
return None
|
||||
return call.identity.company
|
||||
|
||||
@classmethod
|
||||
def handle_call(cls, call):
|
||||
try:
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
endpoint = cls._resolve_endpoint_from_call(call)
|
||||
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
with timing_context.TimingContext("service_repo", "validate_call"):
|
||||
validate_all(call, endpoint)
|
||||
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
|
||||
# In case call does not require authorization, parsing the identity.company might raise an exception
|
||||
company = cls._get_company(call, endpoint)
|
||||
|
||||
ret = endpoint.func(call, company, call.data_model)
|
||||
|
||||
# allow endpoints to return dict or model (instead of setting them explicitly on the call)
|
||||
if ret is not None:
|
||||
if isinstance(ret, jsonmodels.models.Base):
|
||||
call.result.data_model = ret
|
||||
elif isinstance(ret, dict):
|
||||
call.result.data = ret
|
||||
|
||||
except APIError as ex:
|
||||
# report stack trace only for gene
|
||||
include_stack = cls._return_stack and cls._should_return_stack(
|
||||
ex.code, ex.subcode
|
||||
)
|
||||
call.set_error_result(
|
||||
code=ex.code,
|
||||
subcode=ex.subcode,
|
||||
msg=str(ex),
|
||||
include_stack=include_stack,
|
||||
)
|
||||
except CallFailedError:
|
||||
# Do nothing, let 'finally' wrap up
|
||||
pass
|
||||
except Exception as ex:
|
||||
log.exception(ex)
|
||||
call.set_error_result(
|
||||
code=500, subcode=0, msg=str(ex), include_stack=cls._return_stack
|
||||
)
|
||||
finally:
|
||||
content, content_type = call.get_response()
|
||||
call.mark_end()
|
||||
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
|
||||
if call.result.code < 300:
|
||||
log.info(console_msg)
|
||||
else:
|
||||
console_msg = f"{console_msg}, msg={call.result.msg}"
|
||||
if call.result.code < 500:
|
||||
log.warn(console_msg)
|
||||
else:
|
||||
log.error(console_msg)
|
||||
|
||||
return content, content_type
|
||||
47
server/service_repo/util.py
Normal file
47
server/service_repo/util.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import socket
|
||||
|
||||
import six
|
||||
|
||||
|
||||
def get_local_addr():
|
||||
""" Get the local IP address (that isn't localhost) """
|
||||
_, _, ipaddrlist = socket.gethostbyname_ex(socket.gethostname())
|
||||
try:
|
||||
return next(ip for ip in ipaddrlist if ip not in ('127.0.0.1',))
|
||||
except StopIteration:
|
||||
raise ValueError('Cannot find non-loopback ip address for this server (received %s)' % ', '.join(ipaddrlist))
|
||||
|
||||
|
||||
def resolve_addr(addr):
|
||||
""" Resolve address (IP string of host name) into an IP string. """
|
||||
try:
|
||||
socket.inet_aton(addr)
|
||||
return addr
|
||||
except socket.error:
|
||||
try:
|
||||
return socket.gethostbyname(addr)
|
||||
except socket.error:
|
||||
pass
|
||||
|
||||
|
||||
def parse_return_stack_on_code(codes):
|
||||
assert isinstance(codes, list), "return_stack_on_code must be a list"
|
||||
|
||||
def parse(e):
|
||||
if isinstance(e, six.integer_types):
|
||||
code, subcodes = e, None
|
||||
elif isinstance(e, (list, tuple)):
|
||||
code, subcodes = e[:2]
|
||||
assert isinstance(code, six.integer_types), "return_stack_on_code/code must be int"
|
||||
if isinstance(subcodes, six.integer_types):
|
||||
subcodes = [subcodes]
|
||||
if isinstance(subcodes, (list, tuple)):
|
||||
assert all(isinstance(x, six.integer_types) for x in subcodes),\
|
||||
"return_stack_on_code/subcode must be list(int)"
|
||||
else:
|
||||
raise ValueError("invalid return_stack_on_code/subcode(s): %s" % subcodes)
|
||||
else:
|
||||
raise ValueError("invalid return_stack_on_code/subcode(s): %s" % e)
|
||||
return code, subcodes
|
||||
|
||||
return dict(map(parse, codes))
|
||||
171
server/service_repo/validators.py
Normal file
171
server/service_repo/validators.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import fastjsonschema
|
||||
import jsonmodels.errors
|
||||
|
||||
from apierrors import errors, APIError
|
||||
from config import config
|
||||
from database.model import Company
|
||||
from database.model.auth import Role, User
|
||||
from service_repo import APICall
|
||||
from service_repo.apicall import MissingIdentity
|
||||
from service_repo.endpoint import Endpoint
|
||||
from .auth import get_auth_func, Identity, authorize_impersonation, Payload
|
||||
from .errors import CallParsingError
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
def validate_all(call: APICall, endpoint: Endpoint):
|
||||
""" Perform all required call/endpoint validation, update call result appropriately """
|
||||
try:
|
||||
validate_auth(endpoint, call)
|
||||
|
||||
validate_role(endpoint, call)
|
||||
|
||||
if validate_impersonation(endpoint, call):
|
||||
# if impersonating, validate role again
|
||||
validate_role(endpoint, call)
|
||||
|
||||
# todo: remove vaildate_required_fields once all endpoints have json schema
|
||||
validate_required_fields(endpoint, call)
|
||||
|
||||
# set models. models will be validated automatically
|
||||
call.schema_validator = endpoint.request_schema_validator
|
||||
if endpoint.request_data_model:
|
||||
call.data_model_cls = endpoint.request_data_model
|
||||
|
||||
call.result.schema_validator = endpoint.response_schema_validator
|
||||
if endpoint.response_data_model:
|
||||
call.result.data_model_cls = endpoint.response_data_model
|
||||
|
||||
return True
|
||||
|
||||
except CallParsingError as ex:
|
||||
raise errors.bad_request.ValidationError(str(ex))
|
||||
except jsonmodels.errors.ValidationError as ex:
|
||||
raise errors.bad_request.ValidationError(
|
||||
" ".join(map(str.lower, map(str, ex.args)))
|
||||
)
|
||||
except fastjsonschema.exceptions.JsonSchemaException as ex:
|
||||
log.exception(f"{endpoint.name}: fastjsonschema exception")
|
||||
raise errors.bad_request.ValidationError(ex.args[0])
|
||||
|
||||
|
||||
def validate_role(endpoint, call):
|
||||
try:
|
||||
if not endpoint.allows(call.identity.role):
|
||||
raise errors.forbidden.RoleNotAllowed(role=call.identity.role, allowed=endpoint.allow_roles)
|
||||
except MissingIdentity:
|
||||
pass
|
||||
|
||||
|
||||
def validate_auth(endpoint, call):
|
||||
""" Validate authorization for this endpoint and call.
|
||||
If authentication has occurred, the call is updated with the authentication results.
|
||||
"""
|
||||
if not call.authorization:
|
||||
# No auth data. Invalid if we need to authorize and valid otherwise
|
||||
if endpoint.authorize:
|
||||
raise errors.unauthorized.NoCredentials()
|
||||
return
|
||||
|
||||
# prepare arguments for validation
|
||||
service, _, action = endpoint.name.partition(".")
|
||||
|
||||
# If we have auth data, we'll try to validate anyway (just so we'll have auth-based permissions whenever possible,
|
||||
# even if endpoint did not require authorization)
|
||||
try:
|
||||
auth = call.authorization or ""
|
||||
auth_type, _, auth_data = auth.partition(" ")
|
||||
authorize_func = get_auth_func(auth_type)
|
||||
call.auth = authorize_func(auth_data, service, action, call.batched_data)
|
||||
except Exception as e:
|
||||
if endpoint.authorize:
|
||||
# if endpoint requires authorization, re-raise exception
|
||||
raise
|
||||
|
||||
|
||||
def validate_impersonation(endpoint, call):
|
||||
""" Validate impersonation headers and set impersonated identity and authorization data accordingly.
|
||||
:returns True is impersonating, False otherwise
|
||||
"""
|
||||
try:
|
||||
act_as = call.act_as
|
||||
impersonate_as = call.impersonate_as
|
||||
if not impersonate_as and not act_as:
|
||||
return
|
||||
elif impersonate_as and act_as:
|
||||
raise errors.bad_request.InvalidHeaders(
|
||||
"only one allowed", headers=tuple(call.impersonation_headers.keys())
|
||||
)
|
||||
|
||||
identity = call.auth.identity
|
||||
|
||||
# verify this user is allowed to impersonate at all
|
||||
if identity.role not in Role.get_system_roles() | {Role.admin}:
|
||||
raise errors.bad_request.ImpersonationError(
|
||||
"impersonation not allowed", role=identity.role
|
||||
)
|
||||
|
||||
# get the impersonated user's info
|
||||
user_id = act_as or impersonate_as
|
||||
if identity.role in [Role.root]:
|
||||
# only root is allowed to impersonate users in other companies
|
||||
query = dict(id=user_id)
|
||||
else:
|
||||
query = dict(id=user_id, company=identity.company)
|
||||
user = User.objects(**query).first()
|
||||
if not user:
|
||||
raise errors.bad_request.ImpersonationError("unknown user", **query)
|
||||
|
||||
company = Company.objects(id=user.company).only("name").first()
|
||||
if not company:
|
||||
query.update(company=user.company)
|
||||
raise errors.bad_request.ImpersonationError("unknown company for user", **query)
|
||||
|
||||
# create impersonation payload
|
||||
if act_as:
|
||||
# act as a user, using your own role and permissions
|
||||
call.impersonation = Payload(
|
||||
auth_type=None,
|
||||
identity=Identity(
|
||||
user=user.id,
|
||||
company=user.company,
|
||||
role=identity.role,
|
||||
user_name=f"{identity.user_name} (acting as {user.name})",
|
||||
company_name=company.name,
|
||||
),
|
||||
)
|
||||
elif impersonate_as:
|
||||
# impersonate as a user, using his own identity and permissions (required additional validation to verify
|
||||
# impersonated user is allowed to access the endpoint)
|
||||
service, _, action = endpoint.name.partition(".")
|
||||
call.impersonation = authorize_impersonation(
|
||||
user=user,
|
||||
identity=Identity(
|
||||
user=user.id,
|
||||
company=user.company,
|
||||
role=user.role,
|
||||
user_name=f"{user.name} (impersonated by {identity.user_name})",
|
||||
company_name=company.name,
|
||||
),
|
||||
service=service,
|
||||
action=action,
|
||||
call_data_items=call.batched_data,
|
||||
)
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
except APIError:
|
||||
raise
|
||||
except Exception:
|
||||
raise errors.server_error.InternalError("validating impersonation")
|
||||
|
||||
|
||||
def validate_required_fields(endpoint, call):
|
||||
if endpoint.required_fields is None:
|
||||
return
|
||||
|
||||
missing = [val for val in endpoint.required_fields if val not in call.data]
|
||||
if missing:
|
||||
raise errors.bad_request.MissingRequiredFields(missing=missing)
|
||||
0
server/services/__init__.py
Normal file
0
server/services/__init__.py
Normal file
166
server/services/auth.py
Normal file
166
server/services/auth.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from apierrors import errors
|
||||
from apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
GetTokenForUserRequest,
|
||||
GetTokenRequest,
|
||||
ValidateTokenRequest,
|
||||
ValidateResponse,
|
||||
CreateUserRequest,
|
||||
CreateUserResponse,
|
||||
CreateCredentialsResponse,
|
||||
GetCredentialsResponse,
|
||||
RevokeCredentialsResponse,
|
||||
CredentialsResponse,
|
||||
RevokeCredentialsRequest,
|
||||
EditUserReq,
|
||||
)
|
||||
from apimodels.base import UpdateResponse
|
||||
from bll.auth import AuthBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.auth import (
|
||||
User,
|
||||
)
|
||||
from service_repo import APICall, endpoint
|
||||
from service_repo.auth import Token
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"auth.login",
|
||||
request_data_model=GetTokenRequest,
|
||||
response_data_model=GetTokenResponse,
|
||||
)
|
||||
def login(call):
|
||||
""" Generates a token based on the authenticated user (intended for use with credentials) """
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
call.result.data_model = AuthBLL.get_token_for_user(
|
||||
user_id=call.identity.user,
|
||||
company_id=call.identity.company,
|
||||
expiration_sec=call.data_model.expiration_sec,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"auth.get_token_for_user",
|
||||
request_data_model=GetTokenForUserRequest,
|
||||
response_data_model=GetTokenResponse,
|
||||
)
|
||||
def get_token_for_user(call):
|
||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
||||
assert isinstance(call, APICall)
|
||||
call.result.data_model = AuthBLL.get_token_for_user(
|
||||
user_id=call.data_model.user,
|
||||
company_id=call.data_model.company,
|
||||
expiration_sec=call.data_model.expiration_sec,
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"auth.validate_token",
|
||||
request_data_model=ValidateTokenRequest,
|
||||
response_data_model=ValidateResponse,
|
||||
)
|
||||
def validate_token_endpoint(call):
|
||||
""" Validate a token and return identity if valid. INTERNAL. """
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
try:
|
||||
# if invalid, decoding will fail
|
||||
token = Token.from_encoded_token(call.data_model.token)
|
||||
call.result.data_model = ValidateResponse(
|
||||
valid=True, user=token.identity.user, company=token.identity.company
|
||||
)
|
||||
except Exception as e:
|
||||
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
|
||||
|
||||
|
||||
@endpoint(
|
||||
"auth.create_user",
|
||||
request_data_model=CreateUserRequest,
|
||||
response_data_model=CreateUserResponse,
|
||||
)
|
||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
""" Create a user from. INTERNAL. """
|
||||
user_id = AuthBLL.create_user(request=request, call=call)
|
||||
call.result.data_model = CreateUserResponse(id=user_id)
|
||||
|
||||
|
||||
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
|
||||
def create_credentials(call: APICall, _, __):
|
||||
credentials = AuthBLL.create_credentials(
|
||||
user_id=call.identity.user,
|
||||
company_id=call.identity.company,
|
||||
role=call.identity.role,
|
||||
)
|
||||
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"auth.revoke_credentials",
|
||||
request_data_model=RevokeCredentialsRequest,
|
||||
response_data_model=RevokeCredentialsResponse,
|
||||
)
|
||||
def revoke_credentials(call):
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
access_key = call.data_model.access_key
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(
|
||||
id=identity.user, company=identity.company, credentials__key=access_key
|
||||
)
|
||||
updated = User.objects(**query).update_one(pull__credentials__key=access_key)
|
||||
if not updated:
|
||||
raise errors.bad_request.InvalidUser(
|
||||
"invalid user or invalid access key", **query
|
||||
)
|
||||
|
||||
call.result.data_model = RevokeCredentialsResponse(revoked=updated)
|
||||
|
||||
|
||||
@endpoint("auth.get_credentials", response_data_model=GetCredentialsResponse)
|
||||
def get_credentials(call):
|
||||
""" Validate a user by his email. INTERNAL. """
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=identity.user, company=identity.company)
|
||||
user = User.objects(**query).first()
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
|
||||
call.result.data_model = GetCredentialsResponse(
|
||||
credentials=[
|
||||
CredentialsResponse(access_key=c.key) for c in user.credentials
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"auth.edit_user", request_data_model=EditUserReq, response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call, company_id, _):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
fields = {
|
||||
k: v
|
||||
for k, v in call.data_model.to_struct().items()
|
||||
if k != "user" and v is not None
|
||||
}
|
||||
|
||||
with translate_errors_context():
|
||||
result = User.objects(company=company_id, id=call.data_model.user).update(
|
||||
**fields, full_result=True, upsert=False
|
||||
)
|
||||
|
||||
if not result.matched_count:
|
||||
raise errors.bad_request.InvalidUserId()
|
||||
|
||||
call.result.data_model = UpdateResponse(
|
||||
updated=result.modified_count, fields=fields
|
||||
)
|
||||
509
server/services/events.py
Normal file
509
server/services/events.py
Normal file
@@ -0,0 +1,509 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
|
||||
import six
|
||||
|
||||
from apierrors import errors
|
||||
from bll.event import EventBLL
|
||||
from bll.task import TaskBLL
|
||||
from service_repo import APICall, endpoint
|
||||
from utilities import json
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
|
||||
|
||||
@endpoint("events.add")
|
||||
def add(call, company_id, req_model):
|
||||
assert isinstance(call, APICall)
|
||||
added, batch_errors = event_bll.add_events(company_id, [call.data.copy()], call.worker)
|
||||
call.result.data = dict(
|
||||
added=added,
|
||||
errors=len(batch_errors)
|
||||
)
|
||||
call.kpis["events"] = 1
|
||||
|
||||
|
||||
@endpoint("events.add_batch")
|
||||
def add_batch(call, company_id, req_model):
|
||||
assert isinstance(call, APICall)
|
||||
events = call.batched_data
|
||||
if events is None or len(events) == 0:
|
||||
raise errors.bad_request.BatchContainsNoItems()
|
||||
|
||||
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
|
||||
call.result.data = dict(
|
||||
added=added,
|
||||
errors=len(batch_errors)
|
||||
)
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
def get_task_log(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
order = call.data.get("order") or "desc"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
company_id, task_id, order,
|
||||
event_type="log",
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id)
|
||||
call.result.data = dict(
|
||||
events=events,
|
||||
returned=len(events),
|
||||
total=total_events,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
|
||||
def get_task_log_v1_7(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
|
||||
order = call.data.get("order") or "desc"
|
||||
from_ = call.data.get("from") or "head"
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
batch_size = int(call.data.get("batch_size") or 500)
|
||||
|
||||
scroll_order = 'asc' if (from_ == 'head') else 'desc'
|
||||
|
||||
events, scroll_id, total_events = event_bll.scroll_task_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
order=scroll_order,
|
||||
event_type="log",
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
if scroll_order != order:
|
||||
events = events[::-1]
|
||||
|
||||
call.result.data = dict(
|
||||
events=events,
|
||||
returned=len(events),
|
||||
total=total_events,
|
||||
scroll_id=scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint('events.download_task_log', required_fields=['task'])
|
||||
def download_task_log(call, company_id, req_model):
|
||||
task_id = call.data['task']
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
|
||||
line_type = call.data.get('line_type', 'json').lower()
|
||||
line_format = str(call.data.get('line_format', '{asctime} {worker} {level} {msg}'))
|
||||
|
||||
is_json = (line_type == 'json')
|
||||
if not is_json:
|
||||
if not line_format:
|
||||
raise errors.bad_request.MissingRequiredFields('line_format is required for plain text lines')
|
||||
|
||||
# validate line format placeholders
|
||||
valid_task_log_fields = {'asctime', 'timestamp', 'level', 'worker', 'msg'}
|
||||
|
||||
invalid_placeholders = set()
|
||||
while True:
|
||||
try:
|
||||
line_format.format(**dict.fromkeys(valid_task_log_fields | invalid_placeholders))
|
||||
break
|
||||
except KeyError as e:
|
||||
invalid_placeholders.add(e.args[0])
|
||||
except Exception as e:
|
||||
raise errors.bad_request.FieldsValueError('invalid line format', error=e.args[0])
|
||||
|
||||
if invalid_placeholders:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
'undefined placeholders in line format',
|
||||
placeholders=invalid_placeholders
|
||||
)
|
||||
|
||||
# make sure line_format has a trailing newline
|
||||
line_format = line_format.rstrip('\n') + '\n'
|
||||
|
||||
def generate():
|
||||
scroll_id = None
|
||||
batch_size = 1000
|
||||
while True:
|
||||
log_events, scroll_id, _ = event_bll.scroll_task_events(
|
||||
company_id,
|
||||
task_id,
|
||||
order="asc",
|
||||
event_type="log",
|
||||
batch_size=batch_size,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
if not log_events:
|
||||
break
|
||||
for ev in log_events:
|
||||
ev['asctime'] = ev.pop('@timestamp')
|
||||
if is_json:
|
||||
ev.pop('type')
|
||||
ev.pop('task')
|
||||
yield json.dumps(ev) + '\n'
|
||||
else:
|
||||
try:
|
||||
yield line_format.format(**ev)
|
||||
except KeyError as ex:
|
||||
raise errors.bad_request.FieldsValueError(
|
||||
'undefined placeholders in line format',
|
||||
placeholders=[str(ex)]
|
||||
)
|
||||
|
||||
if len(log_events) < batch_size:
|
||||
break
|
||||
|
||||
call.result.filename = 'task_%s.log' % task_id
|
||||
call.result.content_type = 'text/plain'
|
||||
call.result.raw_data = generate()
|
||||
|
||||
|
||||
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
|
||||
def get_vector_metrics_and_variants(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_vector")
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
|
||||
def get_scalar_metrics_and_variants(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_scalar")
|
||||
)
|
||||
|
||||
|
||||
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
|
||||
@endpoint("events.vector_metrics_iter_histogram", required_fields=["task", "metric", "variant"])
|
||||
def vector_metrics_iter_histogram(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
metric = call.data["metric"]
|
||||
variant = call.data["variant"]
|
||||
iterations, vectors = event_bll.get_vector_metrics_per_iter(company_id, task_id, metric, variant)
|
||||
call.result.data = dict(
|
||||
metric=metric,
|
||||
variant=variant,
|
||||
vectors=vectors,
|
||||
iterations=iterations
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", required_fields=["task"])
|
||||
def get_task_events(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
event_type = call.data.get("event_type")
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
order = call.data.get("order") or "asc"
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=event_type,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
events=result.events,
|
||||
returned=len(result.events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
|
||||
def get_scalar_metric_data(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
event_type="training_stats_scalar",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
events=result.events,
|
||||
returned=len(result.events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
|
||||
def get_task_latest_scalar_values(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
|
||||
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(company_id, task_id)
|
||||
es_index = EventBLL.get_index_name(company_id, "*")
|
||||
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
last_iter=last_iters[0] if last_iters else 0,
|
||||
name=task.name,
|
||||
status=task.status,
|
||||
last_timestamp=last_timestamp
|
||||
)
|
||||
|
||||
|
||||
# todo: should not repeat iter (x-axis) for each metric/variant, JS client should get raw data and fill gaps if needed
|
||||
@endpoint("events.scalar_metrics_iter_histogram", required_fields=["task"])
|
||||
def scalar_metrics_iter_histogram(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
metrics = event_bll.get_scalar_metrics_average_per_iter(company_id, task_id)
|
||||
call.result.data = metrics
|
||||
|
||||
|
||||
@endpoint("events.multi_task_scalar_metrics_iter_histogram", required_fields=["tasks"])
|
||||
def multi_task_scalar_metrics_iter_histogram(call, company_id, req_model):
|
||||
task_ids = call.data["tasks"]
|
||||
if isinstance(task_ids, six.string_types):
|
||||
task_ids = [s.strip() for s in task_ids.split(",")]
|
||||
# Note, bll already validates task ids as it needs their names
|
||||
call.result.data = dict(
|
||||
metrics=event_bll.compare_scalar_metrics_average_per_iter(company_id, task_ids, allow_public=True)
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
|
||||
def get_multi_task_plots_v1_7(call, company_id, req_model):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
|
||||
)
|
||||
|
||||
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_ids,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
|
||||
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
|
||||
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
returned=len(return_events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
|
||||
def get_multi_task_plots(call, company_id, req_model):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
|
||||
)
|
||||
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_ids,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
|
||||
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
|
||||
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
returned=len(return_events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", required_fields=["task"])
|
||||
def get_task_plots_v1_7(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
# events, next_scroll_id, total_events = event_bll.get_task_events(
|
||||
# company_id, task_id,
|
||||
# event_type="plot",
|
||||
# sort=[{"iter": {"order": "desc"}}],
|
||||
# last_iter_count=iters,
|
||||
# scroll_id=scroll_id)
|
||||
|
||||
# the following is a hack for Bosch, requested by Moshik
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
|
||||
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
returned=len(return_events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
|
||||
def get_task_plots(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
event_type="plot",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
|
||||
call.result.data = dict(
|
||||
plots=return_events,
|
||||
returned=len(return_events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.debug_images", required_fields=["task"])
|
||||
def get_debug_images_v1_7(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
# events, next_scroll_id, total_events = event_bll.get_task_events(
|
||||
# company_id, task_id,
|
||||
# event_type="training_debug_image",
|
||||
# sort=[{"iter": {"order": "desc"}}],
|
||||
# last_iter_count=iters,
|
||||
# scroll_id=scroll_id)
|
||||
|
||||
# the following is a hack for Bosch, requested by Moshik
|
||||
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
size=10000,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
|
||||
|
||||
call.result.data = dict(
|
||||
task=task_id,
|
||||
images=return_events,
|
||||
returned=len(return_events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
|
||||
def get_debug_images(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters") or 1
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
|
||||
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
|
||||
result = event_bll.get_task_events(
|
||||
company_id, task_id,
|
||||
event_type="training_debug_image",
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
|
||||
call.result.data = dict(
|
||||
task=task_id,
|
||||
images=return_events,
|
||||
returned=len(return_events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.delete_for_task", required_fields=["task"])
|
||||
def delete_for_task(call, company_id, req_model):
|
||||
task_id = call.data["task"]
|
||||
|
||||
task_bll.assert_exists(company_id, task_id)
|
||||
call.result.data = dict(
|
||||
deleted=event_bll.delete_task_events(company_id, task_id)
|
||||
)
|
||||
|
||||
|
||||
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
|
||||
key = itemgetter('metric', 'variant', 'task', 'iter')
|
||||
|
||||
unique_events = itertools.chain.from_iterable(
|
||||
itertools.islice(group, max_iters)
|
||||
for _, group in itertools.groupby(sorted(events, key=key, reverse=True), key=key))
|
||||
|
||||
def collect(evs, fields):
|
||||
if not fields:
|
||||
evs = list(evs)
|
||||
return {
|
||||
'name': tasks.get(evs[0].get('task')),
|
||||
'plots': evs
|
||||
}
|
||||
return {
|
||||
str(k): collect(group, fields[1:])
|
||||
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
|
||||
}
|
||||
|
||||
collect_fields = ('metric', 'variant', 'task', 'iter')
|
||||
return collect(
|
||||
sorted(unique_events, key=itemgetter(*collect_fields), reverse=True),
|
||||
collect_fields
|
||||
)
|
||||
|
||||
|
||||
def _get_top_iter_unique_events(events, max_iters):
|
||||
top_unique_events = defaultdict(lambda: [])
|
||||
for e in events:
|
||||
key = e.get("metric", "") + e.get("variant", "")
|
||||
evs = top_unique_events[key]
|
||||
if len(evs) < max_iters:
|
||||
evs.append(e)
|
||||
unique_events = list(itertools.chain.from_iterable(list(top_unique_events.values())))
|
||||
unique_events.sort(key=lambda e: e["iter"], reverse=True)
|
||||
return unique_events
|
||||
433
server/services/models.py
Normal file
433
server/services/models.py
Normal file
@@ -0,0 +1,433 @@
|
||||
from datetime import datetime
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from mongoengine import Q, EmbeddedDocument
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.models import (
|
||||
CreateModelRequest,
|
||||
CreateModelResponse,
|
||||
PublishModelRequest,
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
)
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.fields import SupportedURLField
|
||||
from database.model import validate_id
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from database.utils import (
|
||||
parse_from_call,
|
||||
get_company_or_none_constraint,
|
||||
filter_fields,
|
||||
)
|
||||
from service_repo import APICall, endpoint
|
||||
from timing_context import TimingContext
|
||||
|
||||
log = config.logger(__file__)
|
||||
get_all_query_options = Model.QueryParameterOptions(
|
||||
pattern_fields=("name", "comment"),
|
||||
fields=("ready",),
|
||||
list_fields=("tags", "framework", "uri", "id", "project", "task", "parent"),
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.get_by_id", required_fields=["model"])
|
||||
def get_by_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
res = Model.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query=Q(id=model_id),
|
||||
allow_public=True,
|
||||
)
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
)
|
||||
|
||||
call.result.data = {"model": res[0]}
|
||||
|
||||
|
||||
@endpoint("models.get_by_task_id", required_fields=["task"])
|
||||
def get_by_task_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
task_id = call.data["task"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=task_id, company=call.identity.company)
|
||||
res = Task.get(_only=["output"], **query)
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
if not res.output:
|
||||
raise errors.bad_request.MissingTaskFields(field="output")
|
||||
if not res.output.model:
|
||||
raise errors.bad_request.MissingTaskFields(field="output.model")
|
||||
|
||||
model_id = res.output.model
|
||||
res = Model.objects(
|
||||
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
|
||||
).first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidModelId(
|
||||
"no such public or company model",
|
||||
id=model_id,
|
||||
company=call.identity.company,
|
||||
)
|
||||
call.result.data = {"model": res.to_proper_dict()}
|
||||
|
||||
|
||||
@endpoint("models.get_all_ex", required_fields=[])
|
||||
def get_all_ex(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
models = Model.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
query_options=get_all_query_options,
|
||||
)
|
||||
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
@endpoint("models.get_all", required_fields=[])
|
||||
def get_all(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
models = Model.get_many(
|
||||
company=call.identity.company,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
query_options=get_all_query_options,
|
||||
)
|
||||
|
||||
call.result.data = {"models": models}
|
||||
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
"tags": list,
|
||||
"task": Task,
|
||||
"comment": None,
|
||||
"uri": None,
|
||||
"project": Project,
|
||||
"parent": Model,
|
||||
"framework": None,
|
||||
"design": None,
|
||||
"labels": dict,
|
||||
"ready": None,
|
||||
}
|
||||
|
||||
schemes = list(SupportedURLField.schemes)
|
||||
|
||||
|
||||
def _validate_uri(uri):
|
||||
parsed_uri = urlparse(uri)
|
||||
if parsed_uri.scheme not in schemes:
|
||||
raise errors.bad_request.InvalidModelUri("unsupported scheme", uri=uri)
|
||||
elif not parsed_uri.path:
|
||||
raise errors.bad_request.InvalidModelUri("missing path", uri=uri)
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
|
||||
tags = fields.get("tags")
|
||||
if tags:
|
||||
fields["tags"] = list(set(tags))
|
||||
return fields
|
||||
|
||||
|
||||
@endpoint("models.update_for_task", required_fields=["task"])
|
||||
def update_for_task(call, company_id, _):
|
||||
assert isinstance(call, APICall)
|
||||
task_id = call.data["task"]
|
||||
uri = call.data.get("uri")
|
||||
iteration = call.data.get("iteration")
|
||||
override_model_id = call.data.get("override_model_id")
|
||||
if not (uri or override_model_id) or (uri and override_model_id):
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"exactly one field is required", fields=("uri", "override_model_id")
|
||||
)
|
||||
|
||||
with translate_errors_context():
|
||||
|
||||
query = dict(id=task_id, company=company_id)
|
||||
task = Task.get_for_writing(
|
||||
id=task_id,
|
||||
company=company_id,
|
||||
_only=["output", "execution", "name", "status", "project"],
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(**query)
|
||||
|
||||
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
|
||||
if task.status not in allowed_states:
|
||||
raise errors.bad_request.InvalidTaskStatus(
|
||||
f"model can only be updated for tasks in the {allowed_states} states",
|
||||
**query,
|
||||
)
|
||||
|
||||
if override_model_id:
|
||||
query = dict(company=company_id, id=override_model_id)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
else:
|
||||
if "name" not in call.data:
|
||||
# use task name if name not provided
|
||||
call.data["name"] = task.name
|
||||
|
||||
if "comment" not in call.data:
|
||||
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
|
||||
|
||||
if task.output and task.output.model:
|
||||
# model exists, update
|
||||
res = _update_model(call, model_id=task.output.model).to_struct()
|
||||
res.update({"id": task.output.model, "created": False})
|
||||
call.result.data = res
|
||||
return
|
||||
|
||||
# new model, create
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
|
||||
# create and save model
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
created=datetime.utcnow(),
|
||||
user=call.identity.user,
|
||||
company=company_id,
|
||||
project=task.project,
|
||||
framework=task.execution.framework,
|
||||
parent=task.execution.model,
|
||||
design=task.execution.model_desc,
|
||||
labels=task.execution.model_labels,
|
||||
ready=(task.status == TaskStatus.published),
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
last_iteration_max=iteration,
|
||||
output__model=model.id,
|
||||
)
|
||||
|
||||
call.result.data = {"id": model.id, "created": True}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.create",
|
||||
request_data_model=CreateModelRequest,
|
||||
response_data_model=CreateModelResponse,
|
||||
)
|
||||
def create(call, company, req_model):
|
||||
assert isinstance(call, APICall)
|
||||
assert isinstance(req_model, CreateModelRequest)
|
||||
identity = call.identity
|
||||
|
||||
if req_model.public:
|
||||
company = ""
|
||||
|
||||
with translate_errors_context():
|
||||
|
||||
project = req_model.project
|
||||
if project:
|
||||
validate_id(Project, company=company, project=project)
|
||||
|
||||
uri = req_model.uri
|
||||
if uri:
|
||||
_validate_uri(uri)
|
||||
task = req_model.task
|
||||
req_data = req_model.to_struct()
|
||||
if task:
|
||||
validate_task(call, req_data)
|
||||
|
||||
fields = filter_fields(Model, req_data)
|
||||
# create and save model
|
||||
model = Model(
|
||||
id=database.utils.id(),
|
||||
user=identity.user,
|
||||
company=company,
|
||||
created=datetime.utcnow(),
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
|
||||
def prepare_update_fields(call, fields):
|
||||
fields = fields.copy()
|
||||
if "uri" in fields:
|
||||
_validate_uri(fields["uri"])
|
||||
|
||||
# clear UI cache if URI is provided (model updated)
|
||||
fields["ui_cache"] = fields.pop("ui_cache", {})
|
||||
if "task" in fields:
|
||||
validate_task(call, fields)
|
||||
return fields
|
||||
|
||||
|
||||
def validate_task(call, fields):
|
||||
Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])
|
||||
|
||||
|
||||
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
|
||||
def edit(call):
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
model_id = call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
fields = parse_model_fields(call, create_fields)
|
||||
fields = prepare_update_fields(call, fields)
|
||||
|
||||
for key in fields:
|
||||
field = getattr(model, key, None)
|
||||
value = fields[key]
|
||||
if (
|
||||
field
|
||||
and isinstance(value, dict)
|
||||
and isinstance(field, EmbeddedDocument)
|
||||
):
|
||||
d = field.to_mongo(use_db_field=False).to_dict()
|
||||
d.update(value)
|
||||
fields[key] = d
|
||||
|
||||
iteration = call.data.get("iteration")
|
||||
task_id = model.task or fields.get('task')
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=identity.company,
|
||||
last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
if fields:
|
||||
updated = model.update(upsert=False, **fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
call.result.data_model = UpdateResponse(updated=0)
|
||||
|
||||
|
||||
def _update_model(call, model_id=None):
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
model_id = model_id or call.data["model"]
|
||||
|
||||
with translate_errors_context():
|
||||
# get model by id
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
model = Model.objects(**query).first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
data = prepare_update_fields(call, call.data)
|
||||
|
||||
task_id = data.get("task")
|
||||
iteration = data.get("iteration")
|
||||
if task_id and iteration is not None:
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
company_id=identity.company,
|
||||
last_iteration_max=iteration,
|
||||
)
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(
|
||||
call.identity.company, model.id, data
|
||||
)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.update", required_fields=["model"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call):
|
||||
call.result.data_model = _update_model(call)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"models.set_ready",
|
||||
request_data_model=PublishModelRequest,
|
||||
response_data_model=PublishModelResponse,
|
||||
)
|
||||
def set_ready(call: APICall, company, req_model: PublishModelRequest):
|
||||
updated, published_task_data = TaskBLL.model_set_ready(
|
||||
model_id=req_model.model,
|
||||
company_id=company,
|
||||
publish_task=req_model.publish_task,
|
||||
force_publish_task=req_model.force_publish_task
|
||||
)
|
||||
|
||||
call.result.data_model = PublishModelResponse(
|
||||
updated=updated,
|
||||
published_task=ModelTaskPublishResponse(
|
||||
**published_task_data
|
||||
) if published_task_data else None
|
||||
)
|
||||
|
||||
|
||||
@endpoint("models.delete", required_fields=["model"])
|
||||
def update(call):
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
model_id = call.data["model"]
|
||||
force = call.data.get("force", False)
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=identity.company)
|
||||
model = Model.objects(**query).only("id", "task").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
deleted_model_id = f"__DELETED__{model_id}"
|
||||
|
||||
using_tasks = Task.objects(execution__model=model_id).only("id")
|
||||
if using_tasks:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelInUse(
|
||||
"as execution model, use force=True to delete",
|
||||
num_tasks=len(using_tasks),
|
||||
)
|
||||
# update deleted model id in using tasks
|
||||
using_tasks.update(
|
||||
execution__model=deleted_model_id, upsert=False, multi=True
|
||||
)
|
||||
|
||||
if model.task:
|
||||
task = Task.objects(id=model.task).first()
|
||||
if task and task.status == TaskStatus.published:
|
||||
if not force:
|
||||
raise errors.bad_request.ModelCreatingTaskExists(
|
||||
"and published, use force=True to delete", task=model.task
|
||||
)
|
||||
task.update(
|
||||
output__model=deleted_model_id,
|
||||
output__error=f"model deleted on {datetime.utcnow().isoformat()}",
|
||||
upsert=False,
|
||||
)
|
||||
|
||||
del_count = Model.objects(**query).delete()
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
340
server/services/projects.py
Normal file
340
server/services/projects.py
Normal file
@@ -0,0 +1,340 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
|
||||
import dpath
|
||||
from mongoengine import Q
|
||||
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from bll.task import TaskBLL
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus, TaskVisibility
|
||||
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
|
||||
from service_repo import APICall, endpoint
|
||||
from timing_context import TimingContext
|
||||
|
||||
task_bll = TaskBLL()
|
||||
archived_tasks_cond = {"$in": [TaskVisibility.archived.value, "$tags"]}
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
"description": None,
|
||||
"tags": list,
|
||||
"default_output_destination": None,
|
||||
}
|
||||
|
||||
get_all_query_options = Project.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"), list_fields=("tags", "id")
|
||||
)
|
||||
|
||||
|
||||
@endpoint("projects.get_by_id", required_fields=["project"])
|
||||
def get_by_id(call):
|
||||
assert isinstance(call, APICall)
|
||||
project_id = call.data["project"]
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "projects_by_id"):
|
||||
query = Q(id=project_id) & get_company_or_none_constraint(
|
||||
call.identity.company
|
||||
)
|
||||
res = Project.objects(query).first()
|
||||
if not res:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
res = res.to_proper_dict()
|
||||
|
||||
call.result.data = {"project": res}
|
||||
|
||||
|
||||
def make_projects_get_all_pipelines(project_ids, specific_state=None):
|
||||
archived = TaskVisibility.archived.value
|
||||
status_count_pipeline = [
|
||||
# count tasks per project per status
|
||||
{"$match": {"project": {"$in": project_ids}}},
|
||||
# make sure tags is always an array (required by subsequent $in in archived_tasks_cond)
|
||||
{
|
||||
"$addFields": {
|
||||
"tags": {
|
||||
"$cond": {
|
||||
"if": {"$ne": [{"$type": "$tags"}, "array"]},
|
||||
"then": [],
|
||||
"else": "$tags",
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"project": "$project",
|
||||
"status": "$status",
|
||||
archived: archived_tasks_cond,
|
||||
},
|
||||
"count": {"$sum": 1},
|
||||
}
|
||||
},
|
||||
# for each project, create a list of (status, count, archived)
|
||||
{
|
||||
"$group": {
|
||||
"_id": "$_id.project",
|
||||
"counts": {
|
||||
"$push": {
|
||||
"status": "$_id.status",
|
||||
"count": "$count",
|
||||
archived: "$_id.%s" % archived,
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed and started and completed > started
|
||||
"if": {
|
||||
"$and": [
|
||||
"$started",
|
||||
"$completed",
|
||||
{"$gt": ["$completed", "$started"]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
# then: floor((completed - started) / 1000)
|
||||
"then": {
|
||||
"$floor": {
|
||||
"$divide": [
|
||||
{"$subtract": ["$completed", "$started"]},
|
||||
1000.0,
|
||||
]
|
||||
}
|
||||
},
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in TaskVisibility:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == TaskVisibility.active:
|
||||
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
|
||||
elif state == TaskVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(archived_tasks_cond)
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
{
|
||||
"$match": {
|
||||
"type": {"$in": ["training", "testing", "annotation"]},
|
||||
"project": {"$in": project_ids},
|
||||
}
|
||||
},
|
||||
{
|
||||
# for each project
|
||||
"$group": group_step
|
||||
},
|
||||
]
|
||||
|
||||
return status_count_pipeline, runtime_pipeline
|
||||
|
||||
|
||||
@endpoint("projects.get_all_ex")
|
||||
def get_all_ex(call):
|
||||
assert isinstance(call, APICall)
|
||||
include_stats = call.data.get("include_stats")
|
||||
stats_for_state = call.data.get("stats_for_state", TaskVisibility.active.value)
|
||||
|
||||
if stats_for_state:
|
||||
try:
|
||||
specific_state = TaskVisibility(stats_for_state)
|
||||
except ValueError:
|
||||
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
|
||||
else:
|
||||
specific_state = None
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
res = Project.get_many_with_join(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
if not include_stats:
|
||||
call.result.data = {"projects": res}
|
||||
return
|
||||
|
||||
ids = [project["id"] for project in res]
|
||||
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
|
||||
ids, specific_state=specific_state
|
||||
)
|
||||
|
||||
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
|
||||
|
||||
def set_default_count(entry):
|
||||
return dict(default_counts, **entry)
|
||||
|
||||
status_count = defaultdict(lambda: {})
|
||||
key = itemgetter(TaskVisibility.archived.value)
|
||||
for result in Task.objects.aggregate(*status_count_pipeline):
|
||||
for k, group in groupby(sorted(result["counts"], key=key), key):
|
||||
section = (
|
||||
TaskVisibility.archived if k else TaskVisibility.active
|
||||
).value
|
||||
status_count[result["_id"]][section] = set_default_count(
|
||||
{
|
||||
count_entry["status"]: count_entry["count"]
|
||||
for count_entry in group
|
||||
}
|
||||
)
|
||||
|
||||
runtime = {
|
||||
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
|
||||
for result in Task.objects.aggregate(*runtime_pipeline)
|
||||
}
|
||||
|
||||
def safe_get(obj, path, default=None):
|
||||
try:
|
||||
return dpath.get(obj, path)
|
||||
except KeyError:
|
||||
return default
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
path = "/".join((project_id, section))
|
||||
return {
|
||||
"total_runtime": safe_get(runtime, path, 0),
|
||||
"status_count": safe_get(status_count, path, default_counts),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in TaskVisibility if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
for project in res:
|
||||
project["stats"] = {
|
||||
task_state.value: get_status_counts(project["id"], task_state.value)
|
||||
for task_state in report_for_states
|
||||
}
|
||||
|
||||
call.result.data = {"projects": res}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
def get_all(call):
|
||||
assert isinstance(call, APICall)
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
res = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=call.data,
|
||||
query_options=get_all_query_options,
|
||||
parameters=call.data,
|
||||
allow_public=True,
|
||||
)
|
||||
|
||||
call.result.data = {"projects": res}
|
||||
|
||||
|
||||
@endpoint("projects.create", required_fields=["name", "description"])
|
||||
def create(call):
|
||||
assert isinstance(call, APICall)
|
||||
identity = call.identity
|
||||
|
||||
with translate_errors_context():
|
||||
fields = parse_from_call(call.data, create_fields, Project.get_fields())
|
||||
now = datetime.utcnow()
|
||||
project = Project(
|
||||
id=database.utils.id(),
|
||||
user=identity.user,
|
||||
company=identity.company,
|
||||
created=now,
|
||||
last_update=now,
|
||||
**fields
|
||||
)
|
||||
with TimingContext("mongo", "projects_save"):
|
||||
project.save()
|
||||
call.result.data = {"id": project.id}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
|
||||
)
|
||||
def update(call):
|
||||
"""
|
||||
update
|
||||
|
||||
:summary: Update project information.
|
||||
See `project.create` for parameters.
|
||||
:return: updated - `int` - number of projects updated
|
||||
fields - `[string]` - updated fields
|
||||
"""
|
||||
assert isinstance(call, APICall)
|
||||
project_id = call.data["project"]
|
||||
|
||||
with translate_errors_context():
|
||||
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
fields = parse_from_call(
|
||||
call.data, create_fields, Project.get_fields(), discard_none_values=False
|
||||
)
|
||||
fields["last_update"] = datetime.utcnow()
|
||||
with TimingContext("mongo", "projects_update"):
|
||||
updated = project.update(upsert=False, **fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
|
||||
|
||||
@endpoint("projects.delete", required_fields=["project"])
|
||||
def delete(call):
|
||||
assert isinstance(call, APICall)
|
||||
project_id = call.data["project"]
|
||||
force = call.data.get("force", False)
|
||||
|
||||
with translate_errors_context():
|
||||
project = Project.get_for_writing(company=call.identity.company, id=project_id)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
|
||||
# have the correct project ID.
|
||||
|
||||
# Find the tasks which belong to the project
|
||||
for cls, error in (
|
||||
(Task, errors.bad_request.ProjectHasTasks),
|
||||
(Model, errors.bad_request.ProjectHasModels),
|
||||
):
|
||||
res = cls.objects(
|
||||
project=project_id, tags__nin=[TaskVisibility.archived.value]
|
||||
).only("id")
|
||||
if res and not force:
|
||||
raise error("use force=true to delete", id=project_id)
|
||||
|
||||
updated_count = res.update(project=None)
|
||||
|
||||
project.delete()
|
||||
|
||||
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
|
||||
|
||||
|
||||
@endpoint("projects.get_unique_metric_variants")
|
||||
def get_unique_metric_variants(call, company_id, req_model):
|
||||
project_id = call.data.get("project")
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
company_id, [project_id] if project_id else None
|
||||
)
|
||||
|
||||
call.result.data = {"metrics": metrics}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user