Initial commit

This commit is contained in:
allegroai
2019-06-11 00:24:35 +03:00
parent 6eea80c4a2
commit a6344bad57
138 changed files with 15951 additions and 0 deletions

201
README.md Normal file
View File

@@ -0,0 +1,201 @@
# TRAINS Server
## Magic Version Control & Experiment Manager for AI
## Introduction
The **trains-server** is the infrastructure behind [trains](https://github.com/allegroai/trains).
The server provides:
* UI (single-page webapp) for experiment management and browsing
* REST interface for documenting and logging experiment information, statistics and results
* REST interface for querying experiments history, logs and results
* Locally-hosted fileserver, for storing images and models to be easily accessible from the UI
The server is designed to allow multiple users to collaborate and manage their experiments.
The servers code is freely available [here](https://github.com/allegroai/trains-server).
We've also pre-built a docker image to allow **trains** users to quickly set up their own server.
## System diagram
<pre>
TRAINS-server
+--------------------------------------------------------------------+
| |
| Server Docker Elastic Docker Mongo Docker |
| +-------------------------+ +---------------+ +------------+ |
| | Pythonic Server | | | | | |
| | +-----------------+ | | ElasticSearch | | MongoDB | |
| | | WEB server | | | | | | |
| | | Port 8080 | | | | | | |
| | +--------+--------+ | | | | | |
| | | | | | | | |
| | +--------+--------+ | | | | | |
| | | API server +----------------------------+ | |
| | | Port 8008 +---------+ | | | |
| | +-----------------+ | +-------+-------+ +-----+------+ |
| | | | | |
| | +-----------------+ | +---+----------------+------+ |
| | | File Server +-------+ | Host Storage | |
| | | Port 8081 | | +-----+ | |
| | +-----------------+ | +---------------------------+ |
| +------------+------------+ |
+---------------|----------------------------------------------------+
|HTTP
+--------+
GPU Machine |
+------------------------|-------------------------------------------+
| +------------------|--------------+ |
| | Training | | +---------------------+ |
| | Code +---+------------+ | | trains configuration| |
| | | TRAINS | | | ~/trains.conf | |
| | | +------+ | |
| | +----------------+ | +---------------------+ |
| +---------------------------------+ |
+--------------------------------------------------------------------+
</pre>
## Installation
In order to install and run the pre-built **trains-server**, you must be logged in as a user with sudo privileges.
### Setup
In order to run the pre-packaged **trains-server**, you'll need to install **docker**.
#### Install docker
```bash
sudo apt-get install docker
```
#### Setup docker daemon
In order to run the ElasticSearch docker container, you'll need to change some of the default values in the Docker configuration file.
For systems with an `/etc/sysconfig/docker` file, add the options in quotes to the available arguments in `OPTIONS`:
```bash
OPTIONS="--default-ulimit nofile=1024:65536 --default-ulimit memlock=-1:-1"
```
For systems with an `/etc/docker/daemon.json` file, add the section in curly brackets to `default-ulimits`:
```json
"default-ulimits": {
"nofile": {
"name": "nofile",
"hard": 65536,
"soft": 1024
},
"memlock":
{
"name": "memlock",
"soft": -1,
"hard": -1
}
}
```
Following this configuration change, you will have to restart the docker daemon:
```bash
sudo service docker stop
sudo service docker start
```
#### vm.max_map_count
The `vm.max_map_count` kernel setting must be at least 262144.
The following example was tested with CentOS 7, Ubuntu 16.04, Mint 18.3, Ubuntu 18.04 and Mint 19:
```bash
sudo echo "vm.max_map_count=262144" > /tmp/99-trains.conf
sudo mv /tmp/99-trains.conf /etc/sysctl.d/99-trains.conf
sudo sysctl -w vm.max_map_count=262144
```
For additional information about setting this parameter on other systems, see the [elastic](https://www.elastic.co/guide/en/elasticsearch/reference/current/docker.html#docker-cli-run-prod-mode) documentation.
#### Choose a data folder
You will need to choose a directory on your system in which all data maintained by **trains-server** will be stored (among others, this includes database, uploaded files and logs).
The following instructions assume the directory is `/opt/trains`.
Issue the following commands:
```bash
sudo mkdir -p /opt/trains/data/elastic && sudo chown -R 1000:1000 /opt/trains
```
### Launching docker images
To launch the docker images, issue the following commands:
```bash
sudo docker run -d --restart="always" --name="trains-elastic" -e "ES_JAVA_OPTS=-Xms2g -Xmx2g" -e "bootstrap.memory_lock=true" -e "cluster.name=trains" -e "discovery.zen.minimum_master_nodes=1" -e "node.name=trains" -e "script.inline=true" -e "script.update=true" -e "thread_pool.bulk.queue_size=2000" -e "thread_pool.search.queue_size=10000" -e "xpack.security.enabled=false" -e "xpack.monitoring.enabled=false" -e "cluster.routing.allocation.node_initial_primaries_recoveries=500" -e "node.ingest=true" -e "http.compression_level=7" -e "reindex.remote.whitelist=*.*" -e "script.painless.regex.enabled=true" --network="host" -v /opt/trains/data/elastic:/usr/share/elasticsearch/data docker.elastic.co/elasticsearch/elasticsearch:5.6.16
```
```bash
sudo docker run -d --restart="always" --name="trains-mongo" -v /opt/trains/data/mongo/db:/data/db -v /opt/trains/data/mongo/configdb:/data/configdb --network="host" mongo:3.6.5
```
```bash
sudo docker run -d --restart="always" --name="trains-fileserver" --network="host" -v /opt/trains/logs:/var/log/trains -v /opt/trains/data/fileserver:/mnt/fileserver allegroai/trains:latest fileserver
```
```bash
sudo docker run -d --restart="always" --name="trains-apiserver" --network="host" -v /opt/trains/logs:/var/log/trains allegroai/trains:latest apiserver
```
```bash
sudo docker run -d --restart="always" --name="trains-webserver" --network="host" -v /opt/trains/logs:/var/log/trains allegroai/trains:latest webserver
```
Once the **trains-server** dockers are up, the following are available:
* API server on port `8008`
* Web server on port `8080`
* File server on port `8081`
## Upgrade
We are constantly updating and adding stuff.
When we release a new version, well include a new pre-built docker image.
Once a new release is out, you can simply:
1. Shut down and remove your docker instances. Each instance can be shut down and removed using the following commands:
```bash
sudo docker stop <docker-name>
sudo docker rm -v <docker-name>
```
The docker names are (see [Launching docker images](#Launching-docker-images)):
* `trains-elastic`
* `trains-mongo`
* `trains-fileserver`
* `trains-apiserver`
* `trains-webserver`
2. Back up your data folder (recommended!). A simple way to do that is using this command:
```bash
sudo tar czvf ~/trains_backup.tgz /opt/trains/data
```
Which will back up all data to an archive in your home folder. Restoring such a backup can be done using these commands:
```bash
sudo rm -R /opt/trains/data
sudo tar -xzf ~/trains_backup.tgz -C /opt/trains/data
```
3. Launch the newly released docker image (see [Launching docker images](#Launching-docker-images))
## License
[Server Side Public License v1.0](https://github.com/mongodb/mongo/blob/master/LICENSE-Community.txt)
**trains-server** relies *heavily* on both [MongoDB](https://github.com/mongodb/mongo) and [ElasticSearch](https://github.com/elastic/elasticsearch).
With the recent changes in both MongoDB's and ElasticSearch's OSS license, we feel it is our job as a community to support the projects we love and cherish.
We feel the cause for the license change in both cases is more than just, and chose [SSPL](https://www.mongodb.com/licensing/server-side-public-license) because it is the more restrictive of the two.
This is our way to say - we support you guys!

58
fileserver/fileserver.py Normal file
View File

@@ -0,0 +1,58 @@
""" A Simple file server for uploading and downloading files """
import json
import logging.config
import os
from argparse import ArgumentParser
from pathlib import Path
from flask import Flask, request, send_from_directory, safe_join
from pyhocon import ConfigFactory
logging.config.dictConfig(ConfigFactory.parse_file("logging.conf"))
app = Flask(__name__)
@app.route("/", methods=["POST"])
def upload():
results = []
for filename, file in request.files.items():
if not filename:
continue
file_path = filename.lstrip(os.sep)
target = Path(safe_join(app.config["UPLOAD_FOLDER"], file_path))
target.parent.mkdir(parents=True, exist_ok=True)
file.save(str(target))
results.append(file_path)
return (json.dumps(results), 200)
@app.route("/<path:path>", methods=["GET"])
def download(path):
return send_from_directory(app.config["UPLOAD_FOLDER"], path)
def main():
parser = ArgumentParser(description=__doc__)
parser.add_argument(
"--port", "-p", type=int, default=8081, help="Port (default %(default)d)"
)
parser.add_argument(
"--ip", "-i", type=str, default="0.0.0.0", help="Address (default %(default)s)"
)
parser.add_argument("--debug", action="store_true", default=False)
parser.add_argument(
"--upload-folder",
"-u",
default="/mnt/fileserver",
help="Upload folder (default %(default)s)",
)
args = parser.parse_args()
app.config["UPLOAD_FOLDER"] = args.upload_folder
app.run(debug=args.debug, host=args.ip, port=args.port, threaded=True)
if __name__ == "__main__":
main()

38
fileserver/logging.conf Normal file
View File

@@ -0,0 +1,38 @@
{
version: 1
disable_existing_loggers: false
formatters: {
default: {
format: "[%(asctime)s] [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
}
}
handlers {
console {
formatter: default
class: "logging.StreamHandler"
}
text_file: {
formatter: default,
backupCount: 3
maxBytes: 10240000,
class: "logging.handlers.RotatingFileHandler",
filename: "/var/log/trains/fileserver.log"
}
}
root {
handlers: [console, text_file]
level: INFO
}
loggers {
urllib3 {
handlers: [console, text_file]
level: WARN
propagate: false
}
werkzeug {
handlers: [console, text_file]
level: WARN
propagate: false
}
}
}

View File

@@ -0,0 +1 @@
Flask

View File

@@ -0,0 +1,116 @@
import pathlib
from . import autogen
from .apierror import APIError
""" Error codes """
_error_codes = {
(400, 'bad_request'): {
1: ('not_supported', 'endpoint is not supported'),
2: ('request_path_has_invalid_version', 'request path has invalid version'),
5: ('invalid_headers', 'invalid headers'),
6: ('impersonation_error', 'impersonation error'),
10: ('invalid_id', 'invalid object id'),
11: ('missing_required_fields', 'missing required fields'),
12: ('validation_error', 'validation error'),
13: ('fields_not_allowed_for_role', 'fields not allowed for role'),
14: ('invalid fields', 'fields not defined for object'),
15: ('fields_conflict', 'conflicting fields'),
16: ('fields_value_error', 'invalid value for fields'),
17: ('batch_contains_no_items', 'batch request contains no items'),
18: ('batch_validation_error', 'batch request validation error'),
19: ('invalid_lucene_syntax', 'malformed lucene query'),
20: ('fields_type_error', 'invalid type for fields'),
21: ('invalid_regex_error', 'malformed regular expression'),
22: ('invalid_email_address', 'malformed email address'),
23: ('invalid_domain_name', 'malformed domain name'),
24: ('not_public_object', 'object is not public'),
# Tasks
100: ('task_error', 'general task error'),
101: ('invalid_task_id', 'invalid task id'),
102: ('task_validation_error', 'task validation error'),
110: ('invalid_task_status', 'invalid task status'),
111: ('task_not_started', 'task not started (invalid task status)'),
112: ('task_in_progress', 'task in progress (invalid task status)'),
113: ('task_published', 'task published (invalid task status)'),
114: ('task_status_unknown', 'task unknown (invalid task status)'),
120: ('invalid_task_execution_progress', 'invalid task execution progress'),
121: ('failed_changing_task_status', 'failed changing task status. probably someone changed it before you'),
122: ('missing_task_fields', 'task is missing expected fields'),
123: ('task_cannot_be_deleted', 'task cannot be deleted'),
125: ('task_has_jobs_running', "task has jobs that haven't completed yet"),
126: ('invalid_task_type', "invalid task type for this operations"),
127: ('invalid_task_input', 'invalid task output'),
128: ('invalid_task_output', 'invalid task output'),
129: ('task_publish_in_progress', 'Task publish in progress'),
130: ('task_not_found', 'task not found'),
# Models
200: ('model_error', 'general task error'),
201: ('invalid_model_id', 'invalid model id'),
202: ('model_not_ready', 'model is not ready'),
203: ('model_is_ready', 'model is ready'),
204: ('invalid_model_uri', 'invalid model URI'),
205: ('model_in_use', 'model is used by tasks'),
206: ('model_creating_task_exists', 'task that created this model exists'),
# Users
300: ('invalid_user', 'invalid user'),
301: ('invalid_user_id', 'invalid user id'),
302: ('user_id_exists', 'user id already exists'),
305: ('invalid_preferences_update', 'Malformed key and/or value'),
# Projects
401: ('invalid_project_id', 'invalid project id'),
402: ('project_has_tasks', 'project has associated tasks'),
403: ('project_not_found', 'project not found'),
405: ('project_has_models', 'project has associated models'),
# Database
800: ('data_validation_error', 'data validation error'),
801: ('expected_unique_data', 'value combination already exists'),
},
(401, 'unauthorized'): {
1: ('not_authorized', 'unauthorized (not authorized for endpoint)'),
2: ('entity_not_allowed', 'unauthorized (entity not allowed)'),
10: ('bad_auth_type', 'unauthorized (bad authentication header type)'),
20: ('no_credentials', 'unauthorized (missing credentials)'),
21: ('bad_credentials', 'unauthorized (malformed credentials)'),
22: ('invalid_credentials', 'unauthorized (invalid credentials)'),
30: ('invalid_token', 'invalid token'),
31: ('blocked_token', 'token is blocked')
},
(403, 'forbidden'): {
10: ('routing_error', 'forbidden (routing error)'),
11: ('missing_routing_header', 'forbidden (missing routing header)'),
12: ('blocked_internal_endpoint', 'forbidden (blocked internal endpoint)'),
20: ('role_not_allowed', 'forbidden (not allowed for role)'),
21: ('no_write_permission', 'forbidden (modification not allowed)'),
},
(500, 'server_error'): {
0: ('general_error', 'general server error'),
1: ('internal_error', 'internal server error'),
2: ('config_error', 'configuration error'),
3: ('build_info_error', 'build info unavailable or corrupted'),
10: ('transaction_error', 'a transaction call has returned with an error'),
# Database-related issues
100: ('data_error', 'general data error'),
101: ('inconsistent_data', 'inconsistent data encountered in document'),
102: ('database_unavailable', 'database is temporarily unavailable'),
# Index-related issues
201: ('missing_index', 'missing internal index'),
9999: ('not_implemented', 'action is not yet implemented'),
}
}
autogen.generate(pathlib.Path(__file__).parent, _error_codes)

View File

@@ -0,0 +1,21 @@
class APIError(Exception):
def __init__(self, msg, code=500, subcode=0, **_):
super(APIError, self).__init__()
self._msg = msg
self._code = code
self._subcode = subcode
@property
def msg(self):
return self._msg
@property
def code(self):
return self._code
@property
def subcode(self):
return self._subcode
def __str__(self):
return self.msg

View File

@@ -0,0 +1,4 @@
def generate(path, error_codes):
from .generator import Generator
from pathlib import Path
Generator(Path(path) / 'errors', format_pep8=False).make_errors(error_codes)

View File

@@ -0,0 +1,6 @@
if __name__ == '__main__':
from pathlib import Path
from apierrors import _error_codes
from apierrors.autogen import generate
generate(Path(__file__).parent.parent, _error_codes)

View File

@@ -0,0 +1,85 @@
import re
import json
import jinja2
import hashlib
from pathlib import Path
env = jinja2.Environment(
loader=jinja2.FileSystemLoader(str(Path(__file__).parent)),
autoescape=jinja2.select_autoescape(disabled_extensions=('py',), default_for_string=False),
trim_blocks=True,
lstrip_blocks=True)
def env_filter(name=None):
return lambda func: env.filters.setdefault(name or func.__name__, func)
@env_filter()
def cls_name(name):
delims = list(map(re.escape, (' ', '_')))
parts = re.split('|'.join(delims), name)
return ''.join(x.capitalize() for x in parts)
class Generator(object):
_base_class_name = 'BaseError'
_base_class_module = 'apierrors.base'
def __init__(self, path, format_pep8=True, use_md5=True):
self._use_md5 = use_md5
self._format_pep8 = format_pep8
self._path = Path(path)
self._path.mkdir(parents=True, exist_ok=True)
def _make_init_file(self, path):
(self._path / path / '__init__.py').write_bytes('')
def _do_render(self, file, template, context):
with file.open('w') as f:
result = template.render(
base_class_name=self._base_class_name,
base_class_module=self._base_class_module,
**context)
if self._format_pep8:
result = autopep8.fix_code(result, options={'aggressive': 1, 'verbose': 0, 'max_line_length': 120})
f.write(result)
def _make_section(self, name, code, subcodes):
self._do_render(
file=(self._path / name).with_suffix('.py'),
template=env.get_template('templates/section.jinja2'),
context=dict(code=code, subcodes=list(subcodes.items()),))
def _make_init(self, sections):
self._do_render(
file=(self._path / '__init__.py'),
template=env.get_template('templates/init.jinja2'),
context=dict(sections=sections,))
def _key_to_str(self, data):
if isinstance(data, dict):
return {str(k): self._key_to_str(v) for k, v in data.items()}
return data
def _calc_digest(self, data):
data = json.dumps(self._key_to_str(data), sort_keys=True)
return hashlib.md5(data.encode('utf8')).hexdigest()
def make_errors(self, errors):
digest = None
digest_file = self._path / 'digest.md5'
if self._use_md5:
digest = self._calc_digest(errors)
if digest_file.is_file():
if digest_file.read_text() == digest:
return
self._make_init(errors)
for (code, section_name), subcodes in errors.items():
self._make_section(section_name, code, subcodes)
if self._use_md5:
digest_file.write_text(digest)

View File

@@ -0,0 +1,6 @@
{% macro error_class(name, msg, code, subcode=0) %}
class {{ name }}({{ base_class_name }}):
_default_code = {{ code }}
_default_subcode = {{ subcode }}
_default_msg = "{{ msg|capitalize }}"
{% endmacro -%}

View File

@@ -0,0 +1,14 @@
{% from 'templates/error.jinja2' import error_class with context %}
{% if sections %}
from {{ base_class_module }} import {{ base_class_name }}
{% endif %}
{% for _, name in sections %}
from . import {{ name }}
{% endfor %}
{% for code, name in sections %}
{{ error_class(name|cls_name, name|replace('_', ' '), code) }}
{% endfor %}

View File

@@ -0,0 +1,9 @@
{% from 'templates/error.jinja2' import error_class with context %}
{% if subcodes %}
from {{ base_class_module }} import {{ base_class_name }}
{% endif %}
{% for subcode, (name, msg) in subcodes %}
{{ error_class(name|cls_name, msg, code, subcode) -}}
{% endfor %}

38
server/apierrors/base.py Normal file
View File

@@ -0,0 +1,38 @@
import six
from boltons.typeutils import classproperty
from typing import Tuple
from .apierror import APIError
class BaseError(APIError):
_default_code = 500
_default_subcode = 0
_default_msg = ""
def __init__(self, extra_msg=None, replacement_msg=None, **kwargs):
message = replacement_msg or self._default_msg
if extra_msg:
message += f" ({extra_msg})"
if kwargs:
kwargs_msg = ", ".join(
f"{k}={self._format_kwarg(v)}" for k, v in kwargs.items()
)
message += f": {kwargs_msg}"
params = kwargs.copy()
params.update(
code=self._default_code, subcode=self._default_subcode, msg=message
)
super(BaseError, self).__init__(**params)
@staticmethod
def _format_kwarg(value):
if isinstance(value, (tuple, list)):
return f'({", ".join(str(v) for v in value)})'
elif isinstance(value, six.string_types):
return value
return str(value)
@classproperty
def codes(self) -> Tuple[int, int]:
return self._default_code, self._default_subcode

View File

@@ -0,0 +1,162 @@
from __future__ import absolute_import
from enum import Enum
from typing import Union, Type, Iterable
import jsonmodels.errors
import jsonmodels.validators
import six
import validators
from jsonmodels import fields
from jsonmodels.fields import _LazyType
from jsonmodels.models import Base as ModelBase
from jsonmodels.validators import Enum as EnumValidator
from luqum.parser import parser, ParseError
from apierrors import errors
def make_default(field_cls, default_value):
class _FieldWithDefault(field_cls):
def get_default_value(self):
return default_value
return _FieldWithDefault
class ListField(fields.ListField):
def _cast_value(self, value):
try:
return super(ListField, self)._cast_value(value)
except TypeError:
return value
def validate_single_value(self, item):
super(ListField, self).validate_single_value(item)
if isinstance(item, ModelBase):
item.validate()
class DictField(fields.BaseField):
types = (dict,)
def __init__(self, value_types=None, *args, **kwargs):
self.value_types = self._assign_types(value_types)
super(DictField, self).__init__(*args, **kwargs)
def get_default_value(self):
default = super(DictField, self).get_default_value()
if default is None and not self.required:
return {}
return default
@staticmethod
def _assign_types(value_types):
if value_types:
try:
value_types = tuple(value_types)
except TypeError:
value_types = (value_types,)
else:
value_types = tuple()
return tuple(
_LazyType(type_)
if isinstance(type_, six.string_types)
else type_
for type_ in value_types
)
def validate(self, value):
super(DictField, self).validate(value)
if not self.value_types:
return
for item in value.values():
self.validate_single_value(item)
def validate_single_value(self, item):
if not self.value_types:
return
if not isinstance(item, self.value_types):
raise jsonmodels.errors.ValidationError(
"All items must be instances "
'of "{types}", and not "{type}".'.format(
types=", ".join([t.__name__ for t in self.value_types]),
type=type(item).__name__,
)
)
class IntField(fields.IntField):
def parse_value(self, value):
try:
return super(IntField, self).parse_value(value)
except (ValueError, TypeError):
return value
def validate_lucene_query(value):
if value == '':
return
try:
parser.parse(value)
except ParseError as e:
raise errors.bad_request.InvalidLuceneSyntax(error=e)
class LuceneQueryField(fields.StringField):
def validate(self, value):
super(LuceneQueryField, self).validate(value)
if value is None:
return
validate_lucene_query(value)
class NullableEnumValidator(EnumValidator):
"""Validator for enums that allows a None value."""
def validate(self, value):
if value is not None:
super(NullableEnumValidator, self).validate(value)
class EnumField(fields.StringField):
def __init__(
self,
values_or_type: Union[Iterable, Type[Enum]],
*args,
required=False,
default=None,
**kwargs
):
choices = list(map(self.parse_value, values_or_type))
validator_cls = EnumValidator if required else NullableEnumValidator
kwargs.setdefault("validators", []).append(validator_cls(*choices))
super().__init__(
default=self.parse_value(default), required=required, *args, **kwargs
)
def parse_value(self, value):
if isinstance(value, Enum):
return str(value.value)
return super().parse_value(value)
class EmailField(fields.StringField):
def validate(self, value):
super().validate(value)
if value is None:
return
if validators.email(value) is not True:
raise errors.bad_request.InvalidEmailAddress()
class DomainField(fields.StringField):
def validate(self, value):
super().validate(value)
if value is None:
return
if validators.domain(value) is not True:
raise errors.bad_request.InvalidDomainName()

116
server/apimodels/auth.py Normal file
View File

@@ -0,0 +1,116 @@
from jsonmodels.fields import IntField, StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Max, Enum
from apimodels import ListField, EnumField
from config import config
from database.model.auth import Role
from database.utils import get_options
class GetTokenRequest(Base):
""" User requests a token """
expiration_sec = IntField(
validators=Max(config.get("apiserver.auth.max_expiration_sec")), nullable=True
)
""" Expiration time for token in seconds. """
class GetTaskTokenRequest(GetTokenRequest):
""" User requests a task token """
task = StringField(required=True)
class GetTokenForUserRequest(GetTokenRequest):
""" System requests a token for a user """
user = StringField(required=True)
company = StringField()
class GetTaskTokenForUserRequest(GetTokenForUserRequest):
""" System requests a token for a user, for a specific task """
task = StringField(required=True)
class GetTokenResponse(Base):
token = StringField(required=True)
class ValidateTokenRequest(Base):
token = StringField(required=True)
class ValidateUserRequest(Base):
email = StringField(required=True)
class ValidateResponse(Base):
valid = BoolField(required=True)
msg = StringField()
user = StringField()
company = StringField()
class CreateUserRequest(Base):
name = StringField(required=True)
company = StringField(required=True)
role = StringField(
validators=Enum(*(set(get_options(Role)))),
default=Role.user,
)
email = StringField(required=True)
family_name = StringField()
given_name = StringField()
avatar = StringField()
class CreateUserResponse(Base):
id = StringField(required=True)
class Credentials(Base):
access_key = StringField(required=True)
secret_key = StringField(required=True)
class CredentialsResponse(Credentials):
secret_key = StringField()
class CreateCredentialsResponse(Base):
credentials = EmbeddedField(Credentials)
class GetCredentialsResponse(Base):
credentials = ListField(CredentialsResponse)
class RevokeCredentialsRequest(Base):
access_key = StringField(required=True)
class RevokeCredentialsResponse(Base):
revoked = IntField(required=True)
class AddUserRequest(CreateUserRequest):
company = StringField()
secret_key = StringField()
class AddUserResponse(CreateUserResponse):
secret = StringField()
class DeleteUserRequest(Base):
user = StringField(required=True)
company = StringField()
class EditUserReq(Base):
user = StringField(required=True)
role = EnumField(Role.get_company_roles())

60
server/apimodels/base.py Normal file
View File

@@ -0,0 +1,60 @@
from jsonmodels import models, fields
from mongoengine.base import BaseDocument
from apimodels import DictField
class MongoengineFieldsDict(DictField):
"""
DictField representing mongoengine field names/value mapping.
Used to convert mongoengine-style field/subfield notation to user-presentable syntax, including handling update
operators.
"""
mongoengine_update_operators = (
'inc',
'dec',
'push',
'push_all',
'pop',
'pull',
'pull_all',
'add_to_set',
)
@staticmethod
def _normalize_mongo_value(value):
if isinstance(value, BaseDocument):
return value.to_mongo()
return value
@classmethod
def _normalize_mongo_field_path(cls, path, value):
parts = path.split('__')
if len(parts) > 1:
if parts[0] == 'set':
parts = parts[1:]
elif parts[0] == 'unset':
parts = parts[1:]
value = None
elif parts[0] in cls.mongoengine_update_operators:
return None, None
return '.'.join(parts), cls._normalize_mongo_value(value)
def parse_value(self, value):
value = super(MongoengineFieldsDict, self).parse_value(value)
return {
k: v
for k, v in (self._normalize_mongo_field_path(*p) for p in value.items())
if k is not None
}
class UpdateResponse(models.Base):
updated = fields.IntField(required=True)
fields = MongoengineFieldsDict()
class PagedRequest(models.Base):
page = fields.IntField()
page_size = fields.IntField()

View File

@@ -0,0 +1,43 @@
from jsonmodels import models, fields
from six import string_types
from apimodels import ListField, DictField
from apimodels.base import UpdateResponse
from apimodels.tasks import PublishResponse as TaskPublishResponse
class CreateModelRequest(models.Base):
name = fields.StringField(required=True)
uri = fields.StringField(required=True)
labels = DictField(value_types=string_types+(int,), required=True)
tags = ListField(items_types=string_types)
comment = fields.StringField()
public = fields.BoolField(default=False)
project = fields.StringField()
parent = fields.StringField()
framework = fields.StringField()
design = DictField()
ready = fields.BoolField(default=True)
ui_cache = DictField()
task = fields.StringField()
class CreateModelResponse(models.Base):
id = fields.StringField(required=True)
created = fields.BoolField(required=True)
class PublishModelRequest(models.Base):
model = fields.StringField(required=True)
force_publish_task = fields.BoolField(default=False)
publish_task = fields.BoolField(default=True)
class ModelTaskPublishResponse(models.Base):
id = fields.StringField(required=True)
data = fields.EmbeddedField(TaskPublishResponse)
class PublishModelResponse(UpdateResponse):
published_task = fields.EmbeddedField(ModelTaskPublishResponse)
updated = fields.IntField()

57
server/apimodels/tasks.py Normal file
View File

@@ -0,0 +1,57 @@
import six
from jsonmodels import models
from jsonmodels.fields import StringField, BoolField, IntField
from jsonmodels.validators import Enum
from apimodels import DictField, ListField
from apimodels.base import UpdateResponse
from database.model.task.task import TaskType
from database.utils import get_options
class StartedResponse(UpdateResponse):
started = IntField()
class ResetResponse(UpdateResponse):
deleted_indices = ListField(items_types=six.string_types)
frames = DictField()
events = DictField()
model_deleted = IntField()
class TaskRequest(models.Base):
task = StringField(required=True)
class UpdateRequest(TaskRequest):
status_reason = StringField(default="")
status_message = StringField(default="")
force = BoolField(default=False)
class DeleteRequest(UpdateRequest):
move_to_trash = BoolField(default=True)
class SetRequirementsRequest(TaskRequest):
requirements = DictField(required=True)
class PublishRequest(UpdateRequest):
publish_model = BoolField(default=True)
class PublishResponse(UpdateResponse):
pass
class TaskData(models.Base):
"""
This is a partial description of task can be updated incrementally
"""
class CreateRequest(TaskData):
name = StringField(required=True)
type = StringField(required=True, validators=Enum(*get_options(TaskType)))

17
server/apimodels/users.py Normal file
View File

@@ -0,0 +1,17 @@
from jsonmodels.fields import StringField
from jsonmodels.models import Base
from apimodels import DictField
class CreateRequest(Base):
id = StringField(required=True)
name = StringField(required=True)
company = StringField(required=True)
family_name = StringField()
given_name = StringField()
avatar = StringField()
class SetPreferencesRequest(Base):
preferences = DictField(required=True)

0
server/bll/__init__.py Normal file
View File

167
server/bll/auth/__init__.py Normal file
View File

@@ -0,0 +1,167 @@
from datetime import datetime
import database
from apierrors import errors
from apimodels.auth import (
GetTokenResponse,
CreateUserRequest,
Credentials as CredModel,
)
from apimodels.users import CreateRequest as Users_CreateRequest
from bll.user import UserBLL
from config import config
from database.errors import translate_errors_context
from database.model.auth import User, Role, Credentials
from database.model.company import Company
from service_repo import APICall
from service_repo.auth import (
Identity,
Token,
get_client_id,
get_secret_key,
)
log = config.logger("AuthBLL")
class AuthBLL:
@staticmethod
def get_token_for_user(
user_id: str,
company_id: str = None,
expiration_sec: int = None,
entities: dict = None,
) -> GetTokenResponse:
with translate_errors_context():
query = dict(id=user_id)
if company_id:
query.update(company=company_id)
user = User.objects(**query).first()
if not user:
raise errors.bad_request.InvalidUserId(**query)
company_id = company_id or user.company
company = Company.objects(id=company_id).only("id", "name").first()
if not company:
raise errors.bad_request.InvalidId(
"invalid company associated with user", company=company
)
identity = Identity(
user=user_id,
company=company_id,
role=user.role,
user_name=user.name,
company_name=company.name,
)
token = Token.create_encoded_token(
identity=identity,
entities=entities,
expiration_sec=expiration_sec,
)
return GetTokenResponse(token=token.decode("ascii"))
@staticmethod
def create_user(request: CreateUserRequest, call: APICall = None) -> str:
"""
Create a new user in both the auth database and the backend database
:param request: New user details
:param call: API call that triggered this call. If not None, new backend user creation
will be performed using a new call in the same transaction.
:return: The new user's ID
"""
with translate_errors_context():
if not Company.objects(id=request.company).only("id"):
raise errors.bad_request.InvalidId(company=request.company)
user = User(
id=database.utils.id(),
name=request.name,
company=request.company,
role=request.role or Role.user,
email=request.email,
created=datetime.utcnow(),
)
user.save()
users_create_request = Users_CreateRequest(
id=user.id,
name=request.name,
company=request.company,
family_name=request.family_name,
given_name=request.given_name,
avatar=request.avatar,
)
try:
UserBLL.create(users_create_request)
except Exception as ex:
user.delete()
raise errors.server_error.GeneralError(
"failed adding new user", ex=str(ex)
)
return user.id
@staticmethod
def delete_user(
identity: Identity, user_id: str, company_id: str = None, call: APICall = None
):
"""
Delete an existing user from both the auth database and the backend database
:param identity: Calling user identity
:param user_id: ID of user to delete
:param company_id: Company of user to delete
:param call: API call that triggered this call. If not None, backend user deletion
will be performed using a new call in the same transaction.
"""
if user_id == identity.user:
raise errors.bad_request.FieldsValueError(
"cannot delete yourself", user=user_id
)
if not company_id:
company_id = identity.company
if (
identity.role not in Role.get_system_roles()
and company_id != identity.company
):
raise errors.bad_request.FieldsNotAllowedForRole(
"must be empty or your own company", role=identity.role, field="company"
)
with translate_errors_context():
query = dict(id=user_id, company=company_id)
res = User.objects(**query).delete()
if not res:
raise errors.bad_request.InvalidUserId(**query)
try:
UserBLL.delete(user_id)
except Exception as ex:
log.error(f"Exception calling users.delete: {str(ex)}")
@classmethod
def create_credentials(
cls, user_id: str, company_id: str, role: str = None
) -> CredModel:
with translate_errors_context():
query = dict(id=user_id, company=company_id)
user = User.objects(**query).first()
if not user:
raise errors.bad_request.InvalidUserId(**query)
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key)
)
user.save()
return cred

View File

@@ -0,0 +1 @@
from .event_bll import EventBLL

View File

@@ -0,0 +1,700 @@
from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
import attr
import six
from elasticsearch import helpers
from enum import Enum
from mongoengine import Q
from nested_dict import nested_dict
import database.utils as dbutils
import es_factory
from apierrors import errors
from bll.task import TaskBLL
from database.errors import translate_errors_context
from database.model.task.task import Task
from database.model.task.metrics import MetricEvent
from timing_context import TimingContext
class EventType(Enum):
metrics_scalar = "training_stats_scalar"
metrics_vector = "training_stats_vector"
metrics_image = "training_debug_image"
metrics_plot = "plot"
task_log = "log"
# noinspection PyTypeChecker
EVENT_TYPES = set(map(attrgetter("value"), EventType))
@attr.s
class TaskEventsResult(object):
events = attr.ib(type=list, default=attr.Factory(list))
total_events = attr.ib(type=int, default=0)
next_scroll_id = attr.ib(type=str, default=None)
class EventBLL(object):
id_fields = ["task", "iter", "metric", "variant", "key"]
def __init__(self, events_es=None):
self.es = events_es if events_es is not None else es_factory.connect("events")
def add_events(self, company_id, events, worker):
actions = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_events = nested_dict(
3, dict
) # task_id -> metric_hash -> variant_hash -> MetricEvent
for event in events:
# remove spaces from event type
if "type" not in event:
raise errors.BadRequest("Event must have a 'type' field", event=event)
event_type = event["type"].replace(" ", "_")
if event_type not in EVENT_TYPES:
raise errors.BadRequest(
"Invalid event type {}".format(event_type),
event=event,
types=EVENT_TYPES,
)
event["type"] = event_type
# @timestamp indicates the time the event is written, not when it happened
event["@timestamp"] = es_factory.get_es_timestamp_str()
# for backward bomba-tavili-tea
if "ts" in event:
event["timestamp"] = event.pop("ts")
# set timestamp and worker if not sent
if "timestamp" not in event:
event["timestamp"] = es_factory.get_timestamp_millis()
if "worker" not in event:
event["worker"] = worker
# force iter to be a long int
iter = event.get("iter")
if iter is not None:
iter = int(iter)
event["iter"] = iter
# used to have "values" to indicate array. no need anymore
if "values" in event:
event["value"] = event["values"]
del event["values"]
index_name = EventBLL.get_index_name(company_id, event_type)
es_action = {
"_op_type": "index", # overwrite if exists with same ID
"_index": index_name,
"_type": "event",
"_source": event,
}
# for "log" events, don't assing custom _id - whatever is sent, is written (not overwritten)
if event_type != "log":
es_action["_id"] = self._get_event_id(event)
else:
es_action["_id"] = dbutils.id()
task_id = event.get("task")
if task_id is not None:
es_action["_routing"] = task_id
task_ids.add(task_id)
if iter is not None:
task_iteration[task_id] = max(iter, task_iteration[task_id])
if event_type == EventType.metrics_scalar.value:
self._update_last_metric_event_for_task(
task_last_events=task_last_events, task_id=task_id, event=event
)
else:
es_action["_routing"] = task_id
actions.append(es_action)
if task_ids:
# verify task_ids
with translate_errors_context(), TimingContext("mongo", "task_by_ids"):
res = Task.objects(id__in=task_ids, company=company_id).only("id")
if len(res) < len(task_ids):
invalid_task_ids = tuple(set(task_ids) - set(r.id for r in res))
raise errors.bad_request.InvalidTaskId(
company=company_id, ids=invalid_task_ids
)
errors_in_bulk = []
added = 0
chunk_size = 500
with translate_errors_context(), TimingContext("es", "events_add_batch"):
# TODO: replace it with helpers.parallel_bulk in the future once the parallel pool leak is fixed
with closing(
helpers.streaming_bulk(
self.es,
actions,
chunk_size=chunk_size,
# thread_count=8,
refresh=True,
)
) as it:
for success, info in it:
if success:
added += chunk_size
else:
errors_in_bulk.append(info)
last_metrics = {
t.id: t.to_proper_dict().get("last_metrics", {})
for t in Task.objects(id__in=task_ids, company=company_id).only(
"last_metrics"
)
}
remaining_tasks = set()
now = datetime.utcnow()
for task_id in task_ids:
# Update related tasks. For reasons of performance, we prefer to update all of them and not only those
# who's events were successful
updated = self._update_task(
company_id=company_id,
task_id=task_id,
now=now,
iter=task_iteration.get(task_id),
last_events=task_last_events.get(task_id),
last_metrics=last_metrics.get(task_id),
)
if not updated:
remaining_tasks.add(task_id)
continue
if remaining_tasks:
TaskBLL.set_last_update(remaining_tasks, company_id, last_update=now)
# Compensate for always adding chunk_size on success (last chunk is probably smaller)
added = min(added, len(actions))
return added, errors_in_bulk
def _update_last_metric_event_for_task(self, task_last_events, task_id, event):
"""
Update task_last_events structure for the provided task_id with the provided event details if this event is more
recent than the currently stored event for its metric/variant combination.
task_last_events contains [hashed_metric_name -> hashed_variant_name -> event]. Keys are hashed to avoid mongodb
key conflicts due to invalid characters and/or long field names.
"""
metric = event.get("metric")
variant = event.get("variant")
if not (metric and variant):
return
metric_hash = dbutils.hash_field_name(metric)
variant_hash = dbutils.hash_field_name(variant)
last_events = task_last_events[task_id]
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
if timestamp is None or timestamp < event["timestamp"]:
last_events[metric_hash][variant_hash] = event
def _update_task(
self, company_id, task_id, now, iter=None, last_events=None, last_metrics=None
):
"""
Update task information in DB with aggregated results after handling event(s) related to this task.
This updates the task with the highest iteration value encountered during the last events update, as well
as the latest metric/variant scalar values reported (according to the report timestamp) and the task's last
update time.
"""
fields = {}
if iter is not None:
fields["last_iteration"] = iter
if last_events:
def get_metric_event(ev):
me = MetricEvent.from_dict(**ev)
if "timestamp" in ev:
me.timestamp = datetime.utcfromtimestamp(ev["timestamp"] / 1000)
return me
new_last_metrics = nested_dict(2, MetricEvent)
new_last_metrics.update(last_metrics)
for metric_hash, variants in last_events.items():
for variant_hash, event in variants.items():
new_last_metrics[metric_hash][variant_hash] = get_metric_event(
event
)
fields["last_metrics"] = new_last_metrics.to_dict()
if not fields:
return False
return TaskBLL.update_statistics(task_id, company_id, last_update=now, **fields)
def _get_event_id(self, event):
id_values = (str(event[field]) for field in self.id_fields if field in event)
return "-".join(id_values)
def scroll_task_events(
self,
company_id,
task_id,
order,
event_type=None,
batch_size=10000,
scroll_id=None,
):
if scroll_id:
with translate_errors_context(), TimingContext("es", "task_log_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
size = min(batch_size, 10000)
if event_type is None:
event_type = "*"
es_index = EventBLL.get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return [], None, 0
es_req = {
"size": size,
"sort": {"timestamp": {"order": order}},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
with translate_errors_context(), TimingContext("es", "scroll_task_events"):
es_res = self.es.search(index=es_index, body=es_req, scroll="1h")
events = [hit["_source"] for hit in es_res["hits"]["hits"]]
next_scroll_id = es_res["_scroll_id"]
total_events = es_res["hits"]["total"]
return events, next_scroll_id, total_events
def get_task_events(
self,
company_id,
task_id,
event_type=None,
metric=None,
variant=None,
last_iter_count=None,
sort=None,
size=500,
scroll_id=None,
):
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if event_type is None:
event_type = "*"
es_index = EventBLL.get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return TaskEventsResult()
query = {"bool": defaultdict(list)}
if metric or variant:
must = query["bool"]["must"]
if metric:
must.append({"term": {"metric": metric}})
if variant:
must.append({"term": {"variant": variant}})
if last_iter_count is None:
must = query["bool"]["must"]
must.append({"terms": {"task": task_ids}})
else:
should = query["bool"]["should"]
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
es_index, task_id, event_type, last_iter_count
)
if not last_iters:
continue
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"iter": last_iters}},
]
}
}
)
if not should:
return TaskEventsResult()
if sort is None:
sort = [{"timestamp": {"order": "asc"}}]
es_req = {"sort": sort, "size": min(size, 10000), "query": query}
routing = ",".join(task_ids)
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.search(
index=es_index,
body=es_req,
ignore=404,
routing=routing,
scroll="1h",
)
events = [doc["_source"] for doc in es_res.get("hits", {}).get("hits", [])]
next_scroll_id = es_res["_scroll_id"]
total_events = es_res["hits"]["total"]
return TaskEventsResult(
events=events, next_scroll_id=next_scroll_id, total_events=total_events
)
def get_metrics_and_variants(self, company_id, task_id, event_type):
es_index = EventBLL.get_index_name(company_id, event_type)
if not self.es.indices.exists(es_index):
return {}
es_req = {
"size": 0,
"aggs": {
"metrics": {
"terms": {"field": "metric", "size": 200},
"aggs": {"variants": {"terms": {"field": "variant", "size": 200}}},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
metrics = {}
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
metric = metric_bucket["key"]
metrics[metric] = [
b["key"] for b in metric_bucket["variants"].get("buckets")
]
return metrics
def get_task_latest_scalar_values(self, company_id, task_id):
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 1000,
"order": {"_term": "asc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": 1000,
"order": {"_term": "asc"},
},
"aggs": {
"last_value": {
"top_hits": {
"docvalue_fields": ["value"],
"_source": "value",
"size": 1,
"sort": [{"iter": {"order": "desc"}}],
}
},
"last_timestamp": {"max": {"field": "@timestamp"}},
"last_10_value": {
"top_hits": {
"docvalue_fields": ["value"],
"_source": "value",
"size": 10,
"sort": [{"iter": {"order": "desc"}}],
}
},
},
}
},
}
},
"_source": {"excludes": []},
}
with translate_errors_context(), TimingContext(
"es", "events_get_metrics_and_variants"
):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
metrics = []
max_timestamp = 0
for metric_bucket in es_res["aggregations"]["metrics"].get("buckets"):
metric_summary = dict(name=metric_bucket["key"], variants=[])
for variant_bucket in metric_bucket["variants"].get("buckets"):
variant_name = variant_bucket["key"]
last_value = variant_bucket["last_value"]["hits"]["hits"][0]["fields"][
"value"
][0]
last_10_value = variant_bucket["last_10_value"]["hits"]["hits"][0][
"fields"
]["value"][0]
timestamp = variant_bucket["last_timestamp"]["value"]
max_timestamp = max(timestamp, max_timestamp)
metric_summary["variants"].append(
dict(
name=variant_name,
last_value=last_value,
last_10_value=last_10_value,
)
)
metrics.append(metric_summary)
return metrics, max_timestamp
def compare_scalar_metrics_average_per_iter(
self, company_id, task_ids, allow_public=True
):
assert isinstance(task_ids, list)
task_name_by_id = {}
with translate_errors_context():
task_objs = Task.get_many(
company=company_id,
query=Q(id__in=task_ids),
allow_public=allow_public,
override_projection=("id", "name"),
return_dicts=False,
)
if len(task_objs) < len(task_ids):
invalid = tuple(set(task_ids) - set(r.id for r in task_objs))
raise errors.bad_request.InvalidTaskId(company=company_id, ids=invalid)
task_name_by_id = {t.id: t.name for t in task_objs}
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
es_req = {
"size": 0,
"_source": {"excludes": []},
"query": {"terms": {"task": task_ids}},
"aggs": {
"iters": {
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
"aggs": {
"metric_and_variant": {
"terms": {
"script": "doc['metric'].value +'/'+ doc['variant'].value",
"size": 10000,
},
"aggs": {
"tasks": {
"terms": {"field": "task"},
"aggs": {"avg_val": {"avg": {"field": "value"}}},
}
},
}
},
}
},
}
with translate_errors_context(), TimingContext("es", "task_stats_comparison"):
es_res = self.es.search(index=es_index, body=es_req)
if "aggregations" not in es_res:
return
metrics = {}
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
iteration = int(iter_bucket["key"])
for metric_bucket in iter_bucket["metric_and_variant"]["buckets"]:
metric_name = metric_bucket["key"]
if metrics.get(metric_name) is None:
metrics[metric_name] = {}
metric_data = metrics[metric_name]
for task_bucket in metric_bucket["tasks"]["buckets"]:
task_id = task_bucket["key"]
value = task_bucket["avg_val"]["value"]
if metric_data.get(task_id) is None:
metric_data[task_id] = {
"x": [],
"y": [],
"name": task_name_by_id[
task_id
], # todo: lookup task name from id
}
metric_data[task_id]["x"].append(iteration)
metric_data[task_id]["y"].append(value)
return metrics
def get_scalar_metrics_average_per_iter(self, company_id, task_id):
es_index = EventBLL.get_index_name(company_id, "training_stats_scalar")
if not self.es.indices.exists(es_index):
return {}
es_req = {
"size": 0,
"_source": {"excludes": []},
"query": {"term": {"task": task_id}},
"aggs": {
"iters": {
"histogram": {"field": "iter", "interval": 1, "min_doc_count": 1},
"aggs": {
"metrics": {
"terms": {
"field": "metric",
"size": 200,
"order": {"_term": "desc"},
},
"aggs": {
"variants": {
"terms": {
"field": "variant",
"size": 500,
"order": {"_term": "desc"},
},
"aggs": {"avg_val": {"avg": {"field": "value"}}},
}
},
}
},
}
},
"version": True,
}
with translate_errors_context(), TimingContext("es", "task_stats_scalar"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
metrics = {}
if "aggregations" in es_res:
for iter_bucket in es_res["aggregations"]["iters"]["buckets"]:
iteration = int(iter_bucket["key"])
for metric_bucket in iter_bucket["metrics"]["buckets"]:
metric_name = metric_bucket["key"]
if metrics.get(metric_name) is None:
metrics[metric_name] = {}
metric_data = metrics[metric_name]
for variant_bucket in metric_bucket["variants"]["buckets"]:
variant = variant_bucket["key"]
value = variant_bucket["avg_val"]["value"]
if metric_data.get(variant) is None:
metric_data[variant] = {"x": [], "y": [], "name": variant}
metric_data[variant]["x"].append(iteration)
metric_data[variant]["y"].append(value)
return metrics
def get_vector_metrics_per_iter(self, company_id, task_id, metric, variant):
es_index = EventBLL.get_index_name(company_id, "training_stats_vector")
if not self.es.indices.exists(es_index):
return [], []
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"metric": metric}},
{"term": {"variant": variant}},
]
}
},
"_source": ["iter", "value"],
"sort": ["iter"],
}
with translate_errors_context(), TimingContext("es", "task_stats_vector"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
vectors = []
iterations = []
for hit in es_res["hits"]["hits"]:
vectors.append(hit["_source"]["value"])
iterations.append(hit["_source"]["iter"])
return iterations, vectors
def get_last_iters(self, es_index, task_id, event_type, iters):
if not self.es.indices.exists(es_index):
return []
es_req: dict = {
"size": 0,
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_term": "desc"},
}
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
}
if event_type:
es_req["query"]["bool"]["must"].append({"term": {"type": event_type}})
with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = self.es.search(index=es_index, body=es_req, routing=task_id)
if "aggregations" not in es_res:
return []
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
def delete_task_events(self, company_id, task_id):
es_index = EventBLL.get_index_name(company_id, "*")
es_req = {"query": {"term": {"task": task_id}}}
with translate_errors_context(), TimingContext("es", "delete_task_events"):
es_res = self.es.delete_by_query(
index=es_index, body=es_req, routing=task_id, refresh=True
)
return es_res.get("deleted", 0)
@staticmethod
def get_index_name(company_id, event_type):
event_type = event_type.lower().replace(" ", "_")
return "events-%s-%s" % (event_type, company_id)

View File

@@ -0,0 +1,7 @@
from .task_bll import TaskBLL
from .utils import (
ChangeStatusRequest,
update_project_time,
validate_status_change,
split_by,
)

393
server/bll/task/task_bll.py Normal file
View File

@@ -0,0 +1,393 @@
import re
from collections import OrderedDict
from datetime import datetime
from typing import Mapping, Collection
from urllib.parse import urlparse
import six
from mongoengine import Q
from six import string_types
import es_factory
from apierrors import errors
from database.errors import translate_errors_context
from database.fields import OutputDestinationField
from database.model.model import Model
from database.model.project import Project
from database.model.task.metrics import MetricEvent
from database.model.task.output import Output
from database.model.task.task import Task, TaskStatus, TaskStatusMessage, TaskTags
from database.utils import get_company_or_none_constraint, id as create_id
from service_repo import APICall
from timing_context import TimingContext
from .utils import ChangeStatusRequest, validate_status_change
class TaskBLL(object):
def __init__(self, events_es=None):
self.events_es = (
events_es if events_es is not None else es_factory.connect("events")
)
@staticmethod
def get_task_with_access(
task_id, company_id, only=None, allow_public=False, requires_write_access=False
) -> Task:
"""
Gets a task that has a required write access
:except errors.bad_request.InvalidTaskId: if the task is not found
:except errors.forbidden.NoWritePermission: if write_access was required and the task cannot be modified
"""
with translate_errors_context():
query = dict(id=task_id, company=company_id)
with TimingContext("mongo", "task_with_access"):
if requires_write_access:
task = Task.get_for_writing(_only=only, **query)
else:
task = Task.get(_only=only, **query, include_public=allow_public)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
return task
@staticmethod
def get_by_id(
company_id,
task_id,
required_status=None,
required_dataset=None,
only_fields=None,
):
with TimingContext("mongo", "task_by_id_all"):
qs = Task.objects(id=task_id, company=company_id)
if only_fields:
qs = (
qs.only(only_fields)
if isinstance(only_fields, string_types)
else qs.only(*only_fields)
)
qs = qs.only(
"status", "input"
) # make sure all fields we rely on here are also returned
task = qs.first()
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
if required_status and not task.status == required_status:
raise errors.bad_request.InvalidTaskStatus(expected=required_status)
if required_dataset and required_dataset not in (
entry.dataset for entry in task.input.view.entries
):
raise errors.bad_request.InvalidId(
"not in input view", dataset=required_dataset
)
return task
@staticmethod
def assert_exists(company_id, task_ids, only=None, allow_public=False):
task_ids = [task_ids] if isinstance(task_ids, six.string_types) else task_ids
with translate_errors_context(), TimingContext("mongo", "task_exists"):
ids = set(task_ids)
q = Task.get_many(
company=company_id,
query=Q(id__in=ids),
allow_public=allow_public,
return_dicts=False,
)
if only:
res = q.only(*only)
count = len(res)
else:
count = q.count()
res = q.first()
if count != len(ids):
raise errors.bad_request.InvalidTaskId(ids=task_ids)
return res
@staticmethod
def create(call: APICall, fields: dict):
identity = call.identity
now = datetime.utcnow()
return Task(
id=create_id(),
user=identity.user,
company=identity.company,
created=now,
last_update=now,
**fields,
)
@staticmethod
def validate_execution_model(task, allow_only_public=False):
if not task.execution or not task.execution.model:
return
company = None if allow_only_public else task.company
model_id = task.execution.model
model = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(company)
).first()
if not model:
raise errors.bad_request.InvalidModelId(model=model_id)
return model
@classmethod
def validate(cls, task: Task, force=False):
assert isinstance(task, Task)
if task.parent and not Task.get(
company=task.company, id=task.parent, _only=("id",), include_public=True
):
raise errors.bad_request.InvalidTaskId("invalid parent", parent=task.parent)
if task.project:
Project.get_for_writing(company=task.company, id=task.project)
model = cls.validate_execution_model(task)
if model and not force and not model.ready:
raise errors.bad_request.ModelNotReady(
"can't be used in a task", model=model.id
)
if task.execution:
if task.execution.parameters:
cls._validate_execution_parameters(task.execution.parameters)
if task.output and task.output.destination:
parsed_url = urlparse(task.output.destination)
if parsed_url.scheme not in OutputDestinationField.schemes:
raise errors.bad_request.FieldsValueError(
"unsupported scheme for output destination",
dest=task.output.destination,
)
@staticmethod
def _validate_execution_parameters(parameters):
invalid_keys = [k for k in parameters if re.search(r"\s", k)]
if invalid_keys:
raise errors.bad_request.ValidationError(
"execution.parameters keys contain whitespace", keys=invalid_keys
)
@staticmethod
def get_unique_metric_variants(company_id, project_ids=None):
pipeline = [
{
"$match": dict(
company=company_id,
**({"project": {"$in": project_ids}} if project_ids else {}),
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.objects.aggregate(*pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
def set_last_update(
task_ids: Collection[str], company_id: str, last_update: datetime
):
return Task.objects(id__in=task_ids, company=company_id).update(
upsert=False, last_update=last_update
)
@staticmethod
def update_statistics(
task_id: str,
company_id: str,
last_update: datetime = None,
last_iteration: int = None,
last_iteration_max: int = None,
last_metrics: Mapping[str, Mapping[str, MetricEvent]] = None,
**extra_updates,
):
"""
Update task statistics
:param task_id: Task's ID.
:param company_id: Task's company ID.
:param last_update: Last update time. If not provided, defaults to datetime.utcnow().
:param last_iteration: Last reported iteration. Use this to set a value regardless of current
task's last iteration value.
:param last_iteration_max: Last reported iteration. Use this to conditionally set a value only
if the current task's last iteration value is smaller than the provided value.
:param last_metrics: Last reported metrics summary.
:param extra_updates: Extra task updates to include in this update call.
:return:
"""
last_update = last_update or datetime.utcnow()
if last_iteration is not None:
extra_updates.update(last_iteration=last_iteration)
elif last_iteration_max is not None:
extra_updates.update(max__last_iteration=last_iteration_max)
if last_metrics is not None:
extra_updates.update(last_metrics=last_metrics)
return Task.objects(id=task_id, company=company_id).update(
upsert=False, last_update=last_update, **extra_updates
)
@classmethod
def model_set_ready(
cls,
model_id: str,
company_id: str,
publish_task: bool,
force_publish_task: bool = False,
) -> tuple:
with translate_errors_context():
query = dict(id=model_id, company=company_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
elif model.ready:
raise errors.bad_request.ModelIsReady(**query)
published_task_data = {}
if model.task and publish_task:
task = (
Task.objects(id=model.task, company=company_id)
.only("id", "status")
.first()
)
if task and task.status != TaskStatus.published:
published_task_data["data"] = cls.publish_task(
task_id=model.task,
company_id=company_id,
publish_model=False,
force=force_publish_task,
)
published_task_data["id"] = model.task
updated = model.update(upsert=False, ready=True)
return updated, published_task_data
@classmethod
def publish_task(
cls,
task_id: str,
company_id: str,
publish_model: bool,
force: bool,
status_reason: str = "",
status_message: str = "",
) -> dict:
task = cls.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
)
if not force:
validate_status_change(task.status, TaskStatus.published)
previous_task_status = task.status
output = task.output or Output()
publish_failed = False
try:
# set state to publishing
task.status = TaskStatus.publishing
task.save()
# publish task models
if task.output.model and publish_model:
output_model = (
Model.objects(id=task.output.model)
.only("id", "task", "ready")
.first()
)
if output_model and not output_model.ready:
cls.model_set_ready(
model_id=task.output.model,
company_id=company_id,
publish_task=False,
)
# set task status to published, and update (or set) it's new output (view and models)
return ChangeStatusRequest(
task=task,
new_status=TaskStatus.published,
force=force,
status_reason=status_reason,
status_message=status_message,
).execute(published=datetime.utcnow(), output=output)
except Exception as ex:
publish_failed = True
raise ex
finally:
if publish_failed:
task.status = previous_task_status
task.save()
@classmethod
def stop_task(
cls,
task_id: str,
company_id: str,
user_name: str,
status_reason: str,
force: bool,
) -> dict:
"""
Stop a running task. Requires task status 'in_progress' and
execution_progress 'running', or force=True. Development task or
task that has no associated worker is stopped immediately.
For a non-development task with worker only the status message
is set to 'stopping' to allow the worker to stop the task and report by itself
:return: updated task fields
"""
task = TaskBLL.get_task_with_access(
task_id,
company_id=company_id,
only=("status", "project", "tags", "last_update"),
requires_write_access=True,
)
if TaskTags.development in task.tags:
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:
new_status = task.status
status_message = TaskStatusMessage.stopping
return ChangeStatusRequest(
task=task,
new_status=new_status,
status_reason=status_reason,
status_message=status_message,
force=force,
).execute()

151
server/bll/task/utils.py Normal file
View File

@@ -0,0 +1,151 @@
from datetime import datetime
from typing import TypeVar, Callable, Tuple, Sequence
import attr
import six
from apierrors import errors
from database.errors import translate_errors_context
from database.model.project import Project
from database.model.task.task import Task, TaskStatus
from database.utils import get_options
from timing_context import TimingContext
from utilities.attrs import typed_attrs
valid_statuses = get_options(TaskStatus)
@typed_attrs
class ChangeStatusRequest(object):
task = attr.ib(type=Task)
new_status = attr.ib(
type=six.string_types, validator=attr.validators.in_(valid_statuses)
)
status_reason = attr.ib(type=six.string_types, default="")
status_message = attr.ib(type=six.string_types, default="")
force = attr.ib(type=bool, default=False)
allow_same_state_transition = attr.ib(type=bool, default=True)
def execute(self, **kwargs):
current_status = self.task.status
project_id = self.task.project
# Verify new status is allowed from current status (will throw exception if not valid)
self.validate_transition(current_status)
control = dict(upsert=False, multi=False, write_concern=None, full_result=False)
now = datetime.utcnow()
fields = dict(
status=self.new_status,
status_reason=self.status_reason,
status_message=self.status_message,
status_changed=now,
last_update=now,
)
def safe_mongoengine_key(key):
return f"__{key}" if key in control else key
fields.update({safe_mongoengine_key(k): v for k, v in kwargs.items()})
with translate_errors_context(), TimingContext("mongo", "task_status"):
# atomic change of task status by querying the task with the EXPECTED status before modifying it
params = fields.copy()
params.update(control)
updated = Task.objects(id=self.task.id, status=current_status).update(
**params
)
if not updated:
# failed to change status (someone else beat us to it?)
raise errors.bad_request.FailedChangingTaskStatus(
task_id=self.task.id,
current_status=current_status,
new_status=self.new_status,
)
update_project_time(project_id)
return dict(updated=updated, fields=fields)
def validate_transition(self, current_status):
if self.force:
return
if self.new_status != current_status:
validate_status_change(current_status, self.new_status)
elif not self.allow_same_state_transition:
raise errors.bad_request.InvalidTaskStatus(
"Task already in requested status",
current_status=current_status,
new_status=self.new_status,
)
def validate_status_change(current_status, new_status):
assert current_status in valid_statuses
assert new_status in valid_statuses
allowed_statuses = get_possible_status_changes(current_status)
if new_status not in allowed_statuses:
raise errors.bad_request.InvalidTaskStatus(
"Invalid status change",
current_status=current_status,
new_status=new_status,
)
state_machine = {
TaskStatus.created: {TaskStatus.in_progress},
TaskStatus.in_progress: {TaskStatus.stopped, TaskStatus.failed, TaskStatus.created},
TaskStatus.stopped: {
TaskStatus.closed,
TaskStatus.created,
TaskStatus.failed,
TaskStatus.in_progress,
TaskStatus.published,
TaskStatus.publishing,
},
TaskStatus.closed: {
TaskStatus.created,
TaskStatus.failed,
TaskStatus.published,
TaskStatus.publishing,
TaskStatus.stopped,
},
TaskStatus.failed: {TaskStatus.created, TaskStatus.stopped, TaskStatus.published},
TaskStatus.publishing: {TaskStatus.published},
TaskStatus.published: set(),
}
def get_possible_status_changes(current_status):
"""
:param current_status:
:return possible states from current state
"""
possible = state_machine.get(current_status)
assert (
possible is not None
), f"Current status {current_status} not supported by state machine"
return possible
def update_project_time(project_id):
if project_id:
Project.objects(id=project_id).update(last_update=datetime.utcnow())
T = TypeVar("T")
def split_by(
condition: Callable[[T], bool], items: Sequence[T]
) -> Tuple[Sequence[T], Sequence[T]]:
"""
split "items" to two lists by "condition"
"""
applied = zip(map(condition, items), items)
return (
[item for cond, item in applied if cond],
[item for cond, item in applied if not cond],
)

View File

@@ -0,0 +1,23 @@
from apierrors import errors
from apimodels.users import CreateRequest
from database.errors import translate_errors_context
from database.model.user import User
class UserBLL:
@staticmethod
def create(request: CreateRequest):
user_id = request.id
with translate_errors_context("creating user"):
if user_id and User.objects(id=user_id).only("id"):
raise errors.bad_request.UserIdExists(id=user_id)
user = User(**request.to_struct())
user.save(force_insert=True)
@staticmethod
def delete(user_id: str):
with translate_errors_context("deleting user"):
res = User.objects(id=user_id).delete()
if not res:
raise errors.bad_request.InvalidUserId(id=user_id)

View File

@@ -0,0 +1,8 @@
import logging.config
from pathlib import Path
from .basic import BasicConfig
config = BasicConfig(Path(__file__).with_name("default"))
logging.config.dictConfig(config.get("logging"))

89
server/config/basic.py Normal file
View File

@@ -0,0 +1,89 @@
import logging
from pathlib import Path
from pyhocon import ConfigTree, ConfigFactory
from pyparsing import (
ParseFatalException,
ParseException,
RecursiveGrammarException,
ParseSyntaxException,
)
class BasicConfig:
NotSet = object()
def __init__(self, folder):
self.folder = Path(folder)
if not self.folder.is_dir():
raise ValueError("Invalid configuration folder")
self.prefix = "trains"
self._load()
def __getitem__(self, key):
return self._config[key]
def get(self, key, default=NotSet):
value = self._config.get(key, default)
if value is self.NotSet and not default:
raise KeyError(
f"Unable to find value for key '{key}' and default value was not provided."
)
return value
def logger(self, name):
if Path(name).is_file():
name = Path(name).stem
path = ".".join((self.prefix, Path(name).stem))
return logging.getLogger(path)
def _load(self, verbose=True):
self._config = self._read_recursive(self.folder, verbose=verbose)
def _read_recursive(self, conf_root, verbose=True):
conf = ConfigTree()
if not conf_root:
return conf
if not conf_root.is_dir():
if verbose:
if not conf_root.exists():
print(f"No config in {conf_root}")
else:
print(f"Not a directory: {conf_root}")
return conf
if verbose:
print("Loading config from {conf_root}")
for file in conf_root.rglob("*.conf"):
key = ".".join(file.relative_to(conf_root).with_suffix("").parts)
conf.put(key, self._read_single_file(file, verbose=verbose))
return conf
@staticmethod
def _read_single_file(file_path, verbose=True):
if verbose:
print(f"Loading config from file {file_path}")
try:
return ConfigFactory.parse_file(file_path)
except ParseSyntaxException as ex:
msg = f"Failed parsing {file_path} ({ex.__class__.__name__}): (at char {ex.loc}, line:{ex.lineno}, col:{ex.column})"
raise ConfigurationError(msg, file_path=file_path) from ex
except (ParseException, ParseFatalException, RecursiveGrammarException) as ex:
msg = f"Failed parsing {file_path} ({ex.__class__.__name__}): {ex}"
raise ConfigurationError(msg) from ex
except Exception as ex:
print(f"Failed loading {file_path}: {ex}")
raise
class ConfigurationError(Exception):
def __init__(self, msg, file_path=None, *args):
super(ConfigurationError, self).__init__(msg, *args)
self.file_path = file_path

View File

@@ -0,0 +1,51 @@
{
watch: false # Watch for changes (dev only)
debug: false # Debug mode
pretty_json: false # prettify json response
return_stack: true # return stack trace on error
log_calls: true # Log API Calls
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
# valid values are:
# - an integer number, specifying a status code
# - a tuple of (code, subcode or list of subcodes)
return_stack_on_code: [
[500, 0] # raise on internal server error with no subcode
]
listen {
ip : "0.0.0.0"
port: 8008
}
version {
required: false
default: 1.0
}
mongo {
# controls whether FieldDoesNotExist exception will be raised for any extra attribute existing in stored data
# but not declared in a data model
strict: false
}
auth {
# verify user tokens
verify_user_tokens: false
# max token expiration timeout in seconds (1 year)
max_expiration_sec: 31536000
# default token expiration timeout in seconds (30 days)
default_expiration_sec: 2592000
# cookie containing auth token, for requests arriving from a web-browser
session_auth_cookie_name: "trains_token_basic"
}
cors {
origins: "*"
}
default_company: "d1bd92a3b039400cbafc60a7a5b1e52b"
}

View File

@@ -0,0 +1,21 @@
elastic {
events {
hosts: [{host: "127.0.0.1", port: 9200}]
args {
timeout: 60
dead_timeout: 10
max_retries: 5
retry_on_timeout: true
}
index_version: "1"
}
}
mongo {
backend {
host: "mongodb://127.0.0.1:27017/backend"
}
auth {
host: "mongodb://127.0.0.1:27017/auth"
}
}

View File

@@ -0,0 +1,49 @@
{
version: 1
disable_existing_loggers: false
formatters: {
standard: {
format: "[%(asctime)s] [%(process)d] [%(levelname)s] [%(name)s] %(message)s"
}
}
handlers {
console {
formatter: standard
class: "logging.StreamHandler"
}
text_file: {
formatter: standard,
backupCount: 3
maxBytes: 10240000,
class: "logging.handlers.RotatingFileHandler",
filename: "/var/log/trains/apiserver.log"
}
}
root {
handlers: [console, text_file]
level: INFO
}
loggers {
urllib3 {
handlers: [console, text_file]
level: WARN
propagate: false
}
werkzeug {
handlers: [console, text_file]
level: WARN
propagate: false
}
elasticsearch {
handlers: [console, text_file]
level: WARN
propagate: false
}
# disable pep8 auto-gen logging (at least part of it)
RefactoringTool {
handlers: []
level: ERROR
propagate: false
}
}
}

View File

@@ -0,0 +1,29 @@
{
http {
session_secret {
apiserver: "Gx*gB-L2U8!Naqzd#8=7A4&+=In4H(da424H33ZTDQRGF6=FWw"
}
}
auth {
# token sign secret
token_secret: "7E1ua3xP9GT2(cIQOfhjp+gwN6spBeCAmN-XuugYle00I=Wc+u"
}
credentials {
# system credentials as they appear in the auth DB, used for intra-service communications
apiserver {
user_key: "62T8CP7HGBC6647XF9314C2VY67RJO"
user_secret: "FhS8VZv_I4%6Mo$8S1BWc$n$=o1dMYSivuiWU-Vguq7qGOKskG-d+b@tn_Iq"
}
webserver {
user_key: "EYVQ385RW7Y2QQUH88CZ7DWIQ1WUHP"
user_secret: "yfc8KQo*GMXb*9p((qcYC7ByFIpF7I&4VH3BfUYXH%o9vX1ZUZQEEw1Inc)S"
}
tests {
user_key: "EGRTCO8JMSIGI6S39GTP43NFWXDQOW"
user_secret: "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"
}
}
}

View File

@@ -0,0 +1,3 @@
{
es_index_prefix:"events"
}

View File

@@ -0,0 +1,58 @@
from jsonmodels import models
from jsonmodels.errors import ValidationError
from jsonmodels.fields import StringField
from mongoengine import register_connection
from mongoengine.connection import get_connection
from config import config
from .defs import Database
from .utils import get_items
log = config.logger(__file__)
strict = config.get('apiserver.mongo.strict', True)
_entries = []
class DatabaseEntry(models.Base):
host = StringField(required=True)
alias = StringField()
@property
def health_alias(self):
return '__health__' + self.alias
def initialize():
db_entries = config.get('hosts.mongo', {})
missing = []
log.info('Initializing database connections')
for key, alias in get_items(Database).items():
if key not in db_entries:
missing.append(key)
continue
entry = DatabaseEntry(alias=alias, **db_entries.get(key))
try:
entry.validate()
log.info('Registering connection to %(alias)s (%(host)s)' % entry.to_struct())
register_connection(alias=alias, host=entry.host)
_entries.append(entry)
except ValidationError as ex:
raise Exception('Invalid database entry `%s`: %s' % (key, ex.args[0]))
if missing:
raise ValueError('Missing database configuration for %s' % ', '.join(missing))
def get_entries():
return _entries
def get_aliases():
return [entry.alias for entry in get_entries()]
def reconnect():
for entry in get_entries():
get_connection(entry.alias, reconnect=True)

10
server/database/defs.py Normal file
View File

@@ -0,0 +1,10 @@
class Database(object):
""" Database names for our different DB instances """
backend = 'backend-db'
''' Used for all backend objects (tasks, models etc.) '''
auth = 'auth-db'
''' Used for all authentication and permission objects '''

189
server/database/errors.py Normal file
View File

@@ -0,0 +1,189 @@
import re
from contextlib import contextmanager
from functools import wraps
import dpath
from dpath.exceptions import InvalidKeyName
from elasticsearch import ElasticsearchException
from elasticsearch.helpers import BulkIndexError
from jsonmodels.errors import ValidationError as JsonschemaValidationError
from mongoengine.errors import (
ValidationError,
NotUniqueError,
FieldDoesNotExist,
InvalidDocumentError,
LookUpError,
InvalidQueryError,
)
from pymongo.errors import PyMongoError, NotMasterError
from apierrors import errors
class MakeGetAllQueryError(Exception):
def __init__(self, error, field):
super(MakeGetAllQueryError, self).__init__(f"{error}: field={field}")
self.error = error
self.field = field
class ParseCallError(Exception):
def __init__(self, msg, **kwargs):
super(ParseCallError, self).__init__(msg)
self.params = kwargs
def throws_default_error(err_cls):
"""
Used to make functions (Exception, str) -> Optional[str] searching for specialized error messages raise those
messages in ``err_cls``. If the decorated function does not find a suitable error message,
the underlying exception is returned.
:param err_cls: Error class (generated by apierrors)
"""
def decorator(func):
@wraps(func)
def wrapper(self, e, message, **kwargs):
extra_info = func(self, e, message, **kwargs)
raise err_cls(message, err=e, extra_info=extra_info)
return wrapper
return decorator
class ElasticErrorsHandler(object):
@classmethod
@throws_default_error(errors.server_error.DataError)
def bulk_error(cls, e, _, **__):
if not e.errors:
return
# Else try returning a better error string
for _, reason in dpath.search(e.errors[0], "*/error/reason", yielded=True):
return reason
class MongoEngineErrorsHandler(object):
# NotUniqueError
__not_unique_regex = re.compile(
r"collection:\s(?P<collection>[\w.]+)\sindex:\s(?P<index>\w+)\sdup\skey:\s{(?P<values>[^\}]+)\}"
)
__not_unique_value_regex = re.compile(r':\s"(?P<value>[^"]+)"')
__id_index = "_id_"
__index_sep_regex = re.compile(r"_[0-9]+_?")
# FieldDoesNotExist
__not_exist_fields_regex = re.compile(r'"{(?P<fields>.+?)}".+?"(?P<document>.+?)"')
__not_exist_field_regex = re.compile(r"'(?P<field>\w+)'")
@classmethod
def validation_error(cls, e: ValidationError, message, **_):
# Thrown when a document is validated. Documents are validated by default on save and on update
err_dict = e.errors or {e.field_name: e.message}
raise errors.bad_request.DataValidationError(message, **err_dict)
@classmethod
def not_unique_error(cls, e, message, **_):
# Thrown when a save/update violates a unique index constraint
m = cls.__not_unique_regex.search(str(e))
if not m:
raise errors.bad_request.ExpectedUniqueData(message, err=str(e))
values = cls.__not_unique_value_regex.findall(m.group("values"))
index = m.group("index")
if index == cls.__id_index:
fields = "id"
else:
fields = cls.__index_sep_regex.split(index)[:-1]
raise errors.bad_request.ExpectedUniqueData(
message, **dict(zip(fields, values))
)
@classmethod
def field_does_not_exist(cls, e, message, **kwargs):
# Strict mode. Unknown fields encountered in loaded document(s)
field_does_not_exist_cls = kwargs.get(
"field_does_not_exist_cls", errors.server_error.InconsistentData
)
m = cls.__not_exist_fields_regex.search(str(e))
params = {}
if m:
params["document"] = m.group("document")
fields = cls.__not_exist_field_regex.findall(m.group("fields"))
if fields:
if len(fields) > 1:
params["fields"] = "(%s)" % ", ".join(fields)
else:
params["field"] = fields[0]
raise field_does_not_exist_cls(message, **params)
@classmethod
@throws_default_error(errors.server_error.DataError)
def invalid_document_error(cls, e, message, **_):
# Reverse_delete_rule used in reference field
pass
@classmethod
def lookup_error(cls, e, message, **_):
raise errors.bad_request.InvalidFields(
"probably an invalid field name or unsupported nested field",
replacement_msg="Lookup error",
err=str(e),
)
@classmethod
@throws_default_error(errors.bad_request.InvalidRegexError)
def invalid_regex_error(cls, e, _, **__):
if e.args and e.args[0] == "unexpected end of regular expression":
raise errors.bad_request.InvalidRegexError(e.args[0])
@classmethod
@throws_default_error(errors.server_error.InternalError)
def invalid_query_error(cls, e, message, **_):
pass
@contextmanager
def translate_errors_context(message=None, **kwargs):
"""
A context manager that translates MongoEngine's and Elastic thrown errors into our apierrors classes,
with an appropriate message.
"""
try:
if message:
message = "while " + message
yield True
except ValidationError as e:
MongoEngineErrorsHandler.validation_error(e, message, **kwargs)
except NotUniqueError as e:
MongoEngineErrorsHandler.not_unique_error(e, message, **kwargs)
except FieldDoesNotExist as e:
MongoEngineErrorsHandler.field_does_not_exist(e, message, **kwargs)
except InvalidDocumentError as e:
MongoEngineErrorsHandler.invalid_document_error(e, message, **kwargs)
except LookUpError as e:
MongoEngineErrorsHandler.lookup_error(e, message, **kwargs)
except re.error as e:
MongoEngineErrorsHandler.invalid_regex_error(e, message, **kwargs)
except InvalidQueryError as e:
MongoEngineErrorsHandler.invalid_query_error(e, message, **kwargs)
except PyMongoError as e:
raise errors.server_error.InternalError(message, err=str(e))
except NotMasterError as e:
raise errors.server_error.InternalError(message, err=str(e))
except MakeGetAllQueryError as e:
raise errors.bad_request.ValidationError(e.error, field=e.field)
except ParseCallError as e:
raise errors.bad_request.FieldsValueError(e.args[0], **e.params)
except JsonschemaValidationError as e:
if len(e.args) >= 2:
raise errors.bad_request.ValidationError(e.args[0], reason=e.args[1])
raise errors.bad_request.ValidationError(e.args[0])
except BulkIndexError as e:
ElasticErrorsHandler.bulk_error(e, message, **kwargs)
except ElasticsearchException as e:
raise errors.server_error.DataError(e, message, **kwargs)
except InvalidKeyName:
raise errors.server_error.DataError("invalid empty key encountered in data")
except Exception as ex:
raise

237
server/database/fields.py Normal file
View File

@@ -0,0 +1,237 @@
import re
from sys import maxsize
import six
from mongoengine import (
EmbeddedDocumentListField,
ListField,
FloatField,
StringField,
EmbeddedDocumentField,
SortedListField,
MapField,
DictField,
)
class LengthRangeListField(ListField):
def __init__(self, field=None, max_length=maxsize, min_length=0, **kwargs):
self.__min_length = min_length
self.__max_length = max_length
super(LengthRangeListField, self).__init__(field, **kwargs)
def validate(self, value):
min, val, max = self.__min_length, len(value), self.__max_length
if not min <= val <= max:
self.error("Item count %d exceeds range [%d, %d]" % (val, min, max))
super(LengthRangeListField, self).validate(value)
class LengthRangeEmbeddedDocumentListField(LengthRangeListField):
def __init__(self, field=None, *args, **kwargs):
super(LengthRangeEmbeddedDocumentListField, self).__init__(
EmbeddedDocumentField(field), *args, **kwargs
)
class UniqueEmbeddedDocumentListField(EmbeddedDocumentListField):
def __init__(self, document_type, key, **kwargs):
"""
Create a unique embedded document list field for a document type with a unique comparison key func/property
:param document_type: The type of :class:`~mongoengine.EmbeddedDocument` the list will hold.
:param key: A callable to extract a key from each item
"""
if not callable(key):
raise KeyError("key must be callable")
self.__key = key
super(UniqueEmbeddedDocumentListField, self).__init__(document_type)
def validate(self, value):
if len({self.__key(i) for i in value}) != len(value):
self.error("Items with duplicate key exist in the list")
super(UniqueEmbeddedDocumentListField, self).validate(value)
def object_to_key_value_pairs(obj):
if isinstance(obj, dict):
return [(key, object_to_key_value_pairs(value)) for key, value in obj.items()]
if isinstance(obj, list):
return list(map(object_to_key_value_pairs, obj))
return obj
class EmbeddedDocumentSortedListField(EmbeddedDocumentListField):
"""
A sorted list of embedded documents
"""
def to_mongo(self, value, use_db_field=True, fields=None):
value = super(EmbeddedDocumentSortedListField, self).to_mongo(
value, use_db_field, fields
)
return sorted(value, key=object_to_key_value_pairs)
class LengthRangeSortedListField(LengthRangeListField, SortedListField):
pass
class CustomFloatField(FloatField):
def __init__(self, greater_than=None, **kwargs):
self.greater_than = greater_than
super(CustomFloatField, self).__init__(**kwargs)
def validate(self, value):
super(CustomFloatField, self).validate(value)
if self.greater_than is not None and value <= self.greater_than:
self.error("Float value must be greater than %s" % str(self.greater_than))
# TODO: bucket name should be at most 63 characters....
aws_s3_bucket_only_regex = (
r"^s3://"
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
)
aws_s3_url_with_bucket_regex = (
r"^s3://"
r"(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w)" # bucket name
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?))" # domain...
)
non_aws_s3_regex = (
r"^s3://"
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
r"(?::\d+)?" # optional port
r"(?:/(?:(?:\w[A-Z0-9\-]+\w)\.)*(?:\w[A-Z0-9\-]+\w))" # bucket name
)
google_gs_bucket_only_regex = (
r"^gs://"
r"(?:(?:\w[A-Z0-9\-_]+\w)\.)*(?:\w[A-Z0-9\-_]+\w)" # bucket name
)
file_regex = r"^file://"
generic_url_regex = (
r"^%s://" # scheme placeholder
r"(?:(?:[A-Z0-9](?:[A-Z0-9-]{0,61}[A-Z0-9])?\.)+(?:[A-Z]{2,6}\.?|[A-Z0-9-]{2,}(?<!-)\.?)|" # domain...
r"localhost|" # localhost...
r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}|" # ...or ipv4
r"\[?[A-F0-9]*:[A-F0-9:]+\]?)" # ...or ipv6
r"(?::\d+)?" # optional port
)
path_suffix = r"(?:/?|[/?]\S+)$"
file_path_suffix = r"(?:/\S*[^/]+)$"
class _RegexURLField(StringField):
_regex = []
def __init__(self, regex, **kwargs):
super(_RegexURLField, self).__init__(**kwargs)
regex = regex if isinstance(regex, (tuple, list)) else [regex]
self._regex = [
re.compile(e, re.IGNORECASE) if isinstance(e, six.string_types) else e
for e in regex
]
def validate(self, value):
# Check first if the scheme is valid
if not any(regex for regex in self._regex if regex.match(value)):
self.error("Invalid URL: {}".format(value))
return
class OutputDestinationField(_RegexURLField):
""" A field representing task output URL """
schemes = ["s3", "gs", "file"]
_expressions = (
aws_s3_bucket_only_regex + path_suffix,
aws_s3_url_with_bucket_regex + path_suffix,
non_aws_s3_regex + path_suffix,
google_gs_bucket_only_regex + path_suffix,
file_regex + path_suffix,
)
def __init__(self, **kwargs):
super(OutputDestinationField, self).__init__(self._expressions, **kwargs)
class SupportedURLField(_RegexURLField):
""" A field representing a model URL """
schemes = ["s3", "gs", "file", "http", "https"]
_expressions = tuple(
pattern + file_path_suffix
for pattern in (
aws_s3_bucket_only_regex,
aws_s3_url_with_bucket_regex,
non_aws_s3_regex,
google_gs_bucket_only_regex,
file_regex,
(generic_url_regex % "http"),
(generic_url_regex % "https"),
)
)
def __init__(self, **kwargs):
super(SupportedURLField, self).__init__(self._expressions, **kwargs)
class StrippedStringField(StringField):
def __init__(
self, regex=None, max_length=None, min_length=None, strip_chars=None, **kwargs
):
super(StrippedStringField, self).__init__(
regex, max_length, min_length, **kwargs
)
self._strip_chars = strip_chars
def __set__(self, instance, value):
if value is not None:
try:
value = value.strip(self._strip_chars)
except AttributeError:
pass
super(StrippedStringField, self).__set__(instance, value)
def prepare_query_value(self, op, value):
if not isinstance(op, six.string_types):
return value
if value is not None:
value = value.strip(self._strip_chars)
return super(StrippedStringField, self).prepare_query_value(op, value)
def contains_empty_key(d):
"""
Helper function to recursively determine if any key in a
dictionary is empty (based on mongoengine.fields.key_not_string)
"""
for k, v in list(d.items()):
if not k or (isinstance(v, dict) and contains_empty_key(v)):
return True
class SafeMapField(MapField):
def validate(self, value):
super(SafeMapField, self).validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a MapField")
class SafeDictField(DictField):
def validate(self, value):
super(SafeDictField, self).validate(value)
if contains_empty_key(value):
self.error("Empty keys are not allowed in a DictField")

View File

@@ -0,0 +1,56 @@
from mongoengine import Document, StringField
from apierrors import errors
from database.model.base import DbModelMixin, ABSTRACT_FLAG
from database.model.company import Company
from database.model.user import User
class AttributedDocument(DbModelMixin, Document):
"""
Represents objects which are attributed to a company and a user or to "no one".
Company must be required since it can be used as unique field.
"""
meta = ABSTRACT_FLAG
company = StringField(required=True, reference_field=Company)
user = StringField(reference_field=User)
def is_public(self) -> bool:
return bool(self.company)
class PrivateDocument(AttributedDocument):
"""
Represents documents which always belong to a single company
"""
meta = ABSTRACT_FLAG
# can not have an empty string as this is the "public" marker
company = StringField(required=True, reference_field=Company, min_length=1)
user = StringField(reference_field=User, required=True)
def is_public(self) -> bool:
return False
def validate_id(cls, company, **kwargs):
"""
Validate existence of objects with certain IDs. within company.
:param cls: Model class to search in
:param company: Company to search in
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
it will be reported along with the name it was assigned to.
:return:
"""
ids = set(kwargs.values())
objs = list(cls.objects(company=company, id__in=ids).only('id'))
missing = ids - set(x.id for x in objs)
if not missing:
return
id_to_name = {}
for name, obj_id in kwargs.items():
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
'Invalid {} ids'.format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
)

View File

@@ -0,0 +1,72 @@
from mongoengine import (
StringField,
EmbeddedDocument,
EmbeddedDocumentListField,
EmailField,
DateTimeField,
)
from database import Database, strict
from database.model import DbModelMixin
from database.model.base import AuthDocument
from database.utils import get_options
class Entities(object):
company = "company"
task = "task"
user = "user"
model = "model"
class Role(object):
system = "system"
""" Internal system component """
root = "root"
""" Root admin (person) """
admin = "admin"
""" Company administrator """
superuser = "superuser"
""" Company super user """
user = "user"
""" Company user """
annotator = "annotator"
""" Annotator with limited access"""
@classmethod
def get_system_roles(cls) -> set:
return {cls.system, cls.root}
@classmethod
def get_company_roles(cls) -> set:
return set(get_options(cls)) - cls.get_system_roles()
class Credentials(EmbeddedDocument):
key = StringField(required=True)
secret = StringField(required=True)
class User(DbModelMixin, AuthDocument):
meta = {"db_alias": Database.auth, "strict": strict}
id = StringField(primary_key=True)
name = StringField(unique_with="company")
created = DateTimeField()
""" User auth entry creation time """
validated = DateTimeField()
""" Last validation (login) time """
role = StringField(required=True, choices=get_options(Role), default=Role.user)
""" User role """
company = StringField(required=True)
""" Company this user belongs to """
credentials = EmbeddedDocumentListField(Credentials, default=list)
""" Credentials generated for this user """
email = EmailField(unique=True, required=True)
""" Email uniquely identifying the user """

View File

@@ -0,0 +1,529 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document
from six import string_types
from apierrors import errors
from config import config
from database.errors import MakeGetAllQueryError
from database.projection import project_dict, ProjectionHelper
from database.props import PropsMixin
from database.query import RegexQ, RegexWrapper
from database.utils import get_company_or_none_constraint, get_fields_with_attr
log = config.logger("dbmodel")
ACCESS_REGEX = re.compile(r"^(?P<prefix>>=|>|<=|<)?(?P<value>.*)$")
ACCESS_MODIFIER = {">=": "gte", ">": "gt", "<=": "lte", "<": "lt"}
ABSTRACT_FLAG = {"abstract": True}
class AuthDocument(Document):
meta = ABSTRACT_FLAG
class ProperDictMixin(object):
def to_proper_dict(self, strip_private=True, only=None, extra_dict=None) -> dict:
return self.properize_dict(
self.to_mongo(use_db_field=False).to_dict(),
strip_private=strip_private,
only=only,
extra_dict=extra_dict,
)
@classmethod
def properize_dict(
cls, d, strip_private=True, only=None, extra_dict=None, normalize_id=True
):
res = d
if normalize_id and "_id" in res:
res["id"] = res.pop("_id")
if strip_private:
res = {k: v for k, v in res.items() if k[0] != "_"}
if only:
res = project_dict(res, only)
if extra_dict:
res.update(extra_dict)
return res
class GetMixin(PropsMixin):
_text_score = "$text_score"
_ordering_key = "order_by"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
("_any_", "_or_"): lambda a, b: a | b,
("_all_", "_and_"): lambda a, b: a & b,
}
MultiFieldParameters = namedtuple("MultiFieldParameters", "pattern fields")
class QueryParameterOptions(object):
def __init__(
self,
pattern_fields=("name",),
list_fields=("tags", "id"),
datetime_fields=None,
fields=None,
):
"""
:param pattern_fields: Fields for which a "string contains" condition should be generated
:param list_fields: Fields for which a "list contains" condition should be generated
:param datetime_fields: Fields for which datetime condition should be generated (see ACCESS_MODIFIER)
:param fields: Fields which which a simple equality condition should be generated (basically filters out all
other unsupported query fields)
"""
self.fields = fields
self.datetime_fields = datetime_fields
self.list_fields = list_fields
self.pattern_fields = pattern_fields
get_all_query_options = QueryParameterOptions()
@classmethod
def get(
cls, company, id, *, _only=None, include_public=False, **kwargs
) -> "GetMixin":
q = cls.objects(
cls._prepare_perm_query(company, allow_public=include_public)
& Q(id=id, **kwargs)
)
if _only:
q = q.only(*_only)
return q.first()
@classmethod
def prepare_query(
cls,
company: str,
parameters: dict = None,
parameters_options: QueryParameterOptions = None,
allow_public=False,
):
"""
Prepare a query object based on the provided query dictionary and various fields.
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param company: Company ID (required)
:param allow_public: Allow results from public objects
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
:return: mongoengine.Q query object
"""
return cls._prepare_query_no_company(
parameters, parameters_options
) & cls._prepare_perm_query(company, allow_public=allow_public)
@classmethod
def _prepare_query_no_company(
cls, parameters=None, parameters_options=QueryParameterOptions()
):
"""
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
- <field_name>: <value> Will query for items with this value in the field (see QueryParameterOptions for
specific rules on handling values). Only items matching ALL of these conditions will be retrieved.
- <any|all>: {fields: [<field1>, <field2>, ...], pattern: <pattern>} Will query for items where any or all
provided fields match the provided pattern.
:return: mongoengine.Q query object
"""
parameters_options = parameters_options or cls.get_all_query_options
dict_query = {}
query = RegexQ()
if parameters:
parameters = parameters.copy()
opts = parameters_options
for field in opts.pattern_fields:
pattern = parameters.pop(field, None)
if pattern:
dict_query[field] = RegexWrapper(pattern)
for field in tuple(opts.list_fields or ()):
data = parameters.pop(field, None)
if data:
if not isinstance(data, (list, tuple)):
raise MakeGetAllQueryError("expected list", field)
exclude = [t for t in data if t.startswith("-")]
include = list(set(data).difference(exclude))
mongoengine_field = field.replace(".", "__")
if include:
dict_query[f"{mongoengine_field}__in"] = include
if exclude:
dict_query[f"{mongoengine_field}__nin"] = [
t[1:] for t in exclude
]
for field in opts.fields or []:
data = parameters.pop(field, None)
if data is not None:
dict_query[field] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
if data is not None:
if not isinstance(data, list):
data = [data]
for d in data: # type: str
m = ACCESS_REGEX.match(d)
if not m:
continue
try:
value = parse_datetime(m.group("value"))
prefix = m.group("prefix")
modifier = ACCESS_MODIFIER.get(prefix)
f = field if not modifier else "__".join((field, modifier))
dict_query[f] = value
except (ValueError, OverflowError):
pass
for field, value in parameters.items():
for keys, func in cls._multi_field_param_prefix.items():
if field not in keys:
continue
try:
data = cls.MultiFieldParameters(**value)
except Exception:
raise MakeGetAllQueryError("incorrect field format", field)
if not data.fields:
break
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
)
query = query & q
return query & RegexQ(**dict_query)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
if allow_public:
return get_company_or_none_constraint(company)
return Q(company=company)
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
):
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
if parameters is None:
parameters = {}
default_page = parameters.get("page", default_page)
if default_page is None:
return None, None
default_page_size = parameters.get("page_size", default_page_size)
if not default_page_size:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
elif default_page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
elif default_page_size < 1:
raise errors.bad_request.ValidationError(
"page_size must be >0", field="page_size"
)
return default_page, default_page_size
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
""" Extract a projection list from the provided dictionary. Supports an override projection. """
if override_projection is not None:
return override_projection
if not parameters:
return []
return parameters.get("projection") or parameters.get("only_fields", [])
@classmethod
def set_default_ordering(cls, parameters, value):
parameters[cls._ordering_key] = parameters.get(cls._ordering_key) or value
@classmethod
def get_many_with_join(
cls,
company,
query_dict=None,
query_options=None,
query=None,
allow_public=False,
override_projection=None,
expand_reference_ids=True,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
requested projection. See get_many() for more info.
:param expand_reference_ids: If True, reference fields that contain just an ID string are expanded into
a sub-document in the format {_id: <ID>}. Otherwise, field values are left as a string.
"""
if issubclass(cls, AuthDocument):
# Refuse projection (join) for auth documents (auth.User etc.) to avoid inadvertently disclosing
# auth-related secrets and prevent security leaks
log.error(
f"Attempted projection of {cls.__name__} auth document (ignored)",
stack_info=True,
)
return []
override_projection = cls.get_projection(
parameters=query_dict, override_projection=override_projection
)
helper = ProjectionHelper(
doc_cls=cls,
projection=override_projection,
expand_reference_ids=expand_reference_ids,
)
# Make the main query
results = cls.get_many(
override_projection=helper.doc_projection,
company=company,
parameters=query_dict,
query_dict=query_dict,
query=query,
query_options=query_options,
allow_public=allow_public,
)
def projection_func(doc_type, projection, ids):
return doc_type.get_many_with_join(
company=company,
override_projection=projection,
query=Q(id__in=ids),
expand_reference_ids=expand_reference_ids,
allow_public=allow_public,
)
return helper.project(results, projection_func)
@classmethod
def get_many(
cls,
company,
parameters: dict = None,
query_dict: dict = None,
query_options: QueryParameterOptions = None,
query: Q = None,
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
(aside from those provided by the parameters):
- Ordering: using query field `order_by` which can contain a string or a list of strings corresponding to
field names. Using field names not defined in the document will cause an error.
- Paging: using query fields page and page_size. page must be larger than or equal to 0, page_size must be
larger than 0 and is required when specifying a page.
- Text search: using query field `search_text`. If used, text score can be used in the ordering, using the
`@text_score` keyword. A text index must be defined on the document type, otherwise an error will
be raised.
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param company: Company ID (required)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param query_dict: If provided, passed to prepare_query() along with all of the relevant arguments to produce
a query. The resulting query is AND'ed with the `query` parameter (if provided).
:param query_options: query parameters options (see ParametersOptions)
:param query: Optional query object (mongoengine.Q)
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
:param allow_public: If True, objects marked as public (no associated company) are also queried.
:return: A list of objects matching the query.
"""
if query_dict is not None:
q = cls.prepare_query(
parameters=query_dict,
company=company,
parameters_options=query_options,
allow_public=allow_public,
)
else:
q = cls._prepare_perm_query(company, allow_public=allow_public)
_query = (q & query) if query else q
return cls._get_many_no_company(
query=_query,
parameters=parameters,
override_projection=override_projection,
return_dicts=return_dicts,
)
@classmethod
def _get_many_no_company(
cls, query, parameters=None, override_projection=None, return_dicts=True
):
"""
Fetch all documents matching a provided query.
This is a company-less version for internal uses. We assume the caller has either added any necessary
constraints to the query or that no constraints are required.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows returning data across companies.
:param query: Query object (mongoengine.Q)
:param return_dicts: Return a list of dictionaries. If True, a list of dicts is returned (if projection was
requested, each contains only the requested projection).
If False, a QuerySet object is returned (lazy evaluated)
:param parameters: Parameters dict from which paging ordering and searching parameters are extracted.
:param override_projection: A list of projection fields overriding any projection specified in the `param_dict`
argument
"""
parameters = parameters or {}
if not query:
raise ValueError("query or call_data must be provided")
page, page_size = cls.validate_paging(parameters=parameters)
order_by = parameters.get(cls._ordering_key)
if order_by:
order_by = order_by if isinstance(order_by, list) else [order_by]
order_by = [cls._text_score if x == "@text_score" else x for x in order_by]
search_text = parameters.get("search_text")
only = cls.get_projection(parameters, override_projection)
if not search_text and order_by and cls._text_score in order_by:
raise errors.bad_request.FieldsValueError(
"text score cannot be used in order_by when search text is not used"
)
qs = cls.objects(query)
if search_text:
qs = qs.search_text(search_text)
if order_by:
# add ordering
qs = (
qs.order_by(order_by)
if isinstance(order_by, string_types)
else qs.order_by(*order_by)
)
if only:
# add projection
qs = qs.only(*only)
else:
exclude = set(cls.get_exclude_fields()).difference(only)
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
if return_dicts:
return [obj.to_proper_dict(only=only) for obj in qs]
return qs
@classmethod
def get_for_writing(
cls, *args, _only: Collection[str] = None, **kwargs
) -> "GetMixin":
if _only and "company" not in _only:
_only = list(set(_only) | {"company"})
result = cls.get(*args, _only=_only, include_public=True, **kwargs)
if result and not result.company:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={(result.id,)}"
)
return result
@classmethod
def get_many_for_writing(cls, company, *args, **kwargs):
result = cls.get_many(
company=company,
*args,
**dict(return_dicts=False, **kwargs),
allow_public=True,
)
forbidden_objects = {obj.id for obj in result if not obj.company}
if forbidden_objects:
object_name = cls.__name__.lower()
raise errors.forbidden.NoWritePermission(
f"cannot modify public {object_name}(s), ids={tuple(forbidden_objects)}"
)
return result
class UpdateMixin(object):
@classmethod
def user_set_allowed(cls):
res = getattr(cls, "__user_set_allowed_fields", None)
if res is None:
res = cls.__user_set_allowed_fields = dict(
get_fields_with_attr(cls, "user_set_allowed")
)
return res
@classmethod
def get_safe_update_dict(cls, fields):
if not fields:
return {}
valid_fields = cls.user_set_allowed()
fields = [(k, v, fields[k]) for k, v in valid_fields.items() if k in fields]
update_dict = {
field: value
for field, allowed, value in fields
if allowed is None
or (
(value in allowed)
if not isinstance(value, list)
else all(v in allowed for v in value)
)
}
return update_dict
@classmethod
def safe_update(cls, company_id, id, partial_update_dict, injected_update=None):
update_dict = cls.get_safe_update_dict(partial_update_dict)
if not update_dict:
return 0, {}
if injected_update:
update_dict.update(injected_update)
update_count = cls.objects(id=id, company=company_id).update(
upsert=False, **update_dict
)
return update_count, update_dict
class DbModelMixin(GetMixin, ProperDictMixin, UpdateMixin):
""" Provide convenience methods for a subclass of mongoengine.Document """
pass
def validate_id(cls, company, **kwargs):
"""
Validate existence of objects with certain IDs. within company.
:param cls: Model class to search in
:param company: Company to search in
:param kwargs: Mapping of field name to object ID. If any ID does not have a corresponding object,
it will be reported along with the name it was assigned to.
:return:
"""
ids = set(kwargs.values())
objs = list(cls.objects(company=company, id__in=ids).only("id"))
missing = ids - set(x.id for x in objs)
if not missing:
return
id_to_name = {}
for name, obj_id in kwargs.items():
id_to_name.setdefault(obj_id, []).append(name)
raise errors.bad_request.ValidationError(
"Invalid {} ids".format(cls.__name__.lower()),
**{name: obj_id for obj_id in missing for name in id_to_name[obj_id]}
)

View File

@@ -0,0 +1,25 @@
from mongoengine import Document, EmbeddedDocument, EmbeddedDocumentField, StringField, Q
from database import Database, strict
from database.fields import StrippedStringField
from database.model import DbModelMixin
class CompanyDefaults(EmbeddedDocument):
cluster = StringField()
class Company(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
}
id = StringField(primary_key=True)
name = StrippedStringField(unique=True, min_length=3)
defaults = EmbeddedDocumentField(CompanyDefaults)
@classmethod
def _prepare_perm_query(cls, company, allow_public=False):
""" Override default behavior since a 'company' constraint is not supported for this document... """
return Q()

View File

@@ -0,0 +1,56 @@
from mongoengine import Document, StringField, DateTimeField, ListField, BooleanField
from database import Database, strict
from database.fields import SupportedURLField, StrippedStringField, SafeDictField
from database.model import DbModelMixin
from database.model.model_labels import ModelLabels
from database.model.company import Company
from database.model.project import Project
from database.model.task.task import Task
from database.model.user import User
class Model(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
'indexes': [
{
'name': '%s.model.main_text_index' % Database.backend,
'fields': [
'$name',
'$id',
'$comment',
'$parent',
'$task',
'$project',
],
'default_language': 'english',
'weights': {
'name': 10,
'id': 10,
'comment': 10,
'parent': 5,
'task': 3,
'project': 3,
}
}
],
}
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field='Model', required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)
comment = StringField(user_set_allowed=True)
tags = ListField(StringField(required=True), user_set_allowed=True)
uri = SupportedURLField(default='', user_set_allowed=True)
framework = StringField()
design = SafeDictField()
labels = ModelLabels()
ready = BooleanField(required=True)
ui_cache = SafeDictField(default=dict, user_set_allowed=True, exclude_by_default=True)

View File

@@ -0,0 +1,11 @@
from mongoengine import MapField, IntField
class ModelLabels(MapField):
def __init__(self, *args, **kwargs):
super(ModelLabels, self).__init__(field=IntField(), *args, **kwargs)
def validate(self, value):
super(ModelLabels, self).validate(value)
if value and (len(set(value.values())) < len(value)):
self.error("Same label id appears more than once in model labels")

View File

@@ -0,0 +1,39 @@
from mongoengine import StringField, DateTimeField, ListField
from database import Database, strict
from database.fields import OutputDestinationField, StrippedStringField
from database.model import AttributedDocument
from database.model.base import GetMixin
class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"), list_fields=("tags", "id")
)
meta = {
"db_alias": Database.backend,
"strict": strict,
"indexes": [
{
"name": "%s.project.main_text_index" % Database.backend,
"fields": ["$name", "$id", "$description"],
"default_language": "english",
"weights": {"name": 10, "id": 10, "description": 10},
}
],
}
id = StringField(primary_key=True)
name = StrippedStringField(
required=True,
unique_with=AttributedDocument.company.name,
min_length=3,
sparse=True,
)
description = StringField(required=True)
created = DateTimeField(required=True)
tags = ListField(StringField(required=True), default=list)
default_output_destination = OutputDestinationField()
last_update = DateTimeField()

View File

@@ -0,0 +1,14 @@
from mongoengine import EmbeddedDocument, StringField, DateTimeField, LongField, DynamicField
class MetricEvent(EmbeddedDocument):
metric = StringField(required=True, )
variant = StringField(required=True)
type = StringField(required=True)
timestamp = DateTimeField(default=0, required=True)
iter = LongField()
value = DynamicField(required=True)
@classmethod
def from_dict(cls, **kwargs):
return cls(**{k: v for k, v in kwargs.items() if k in cls._fields})

View File

@@ -0,0 +1,16 @@
from mongoengine import EmbeddedDocument, StringField
from database.utils import get_options
from database.fields import OutputDestinationField
class Result(object):
success = 'success'
failure = 'failure'
class Output(EmbeddedDocument):
destination = OutputDestinationField()
model = StringField(reference_field='Model')
error = StringField(user_set_allowed=True)
result = StringField(choices=get_options(Result))

View File

@@ -0,0 +1,132 @@
from enum import Enum
from mongoengine import (
StringField,
EmbeddedDocumentField,
EmbeddedDocument,
DateTimeField,
IntField,
ListField,
)
from database import Database, strict
from database.fields import StrippedStringField, SafeMapField, SafeDictField
from database.model import AttributedDocument
from database.model.model_labels import ModelLabels
from database.model.project import Project
from database.utils import get_options
from .metrics import MetricEvent
from .output import Output
DEFAULT_LAST_ITERATION = 0
class TaskStatus(object):
created = 'created'
in_progress = 'in_progress'
stopped = 'stopped'
publishing = 'publishing'
published = 'published'
closed = 'closed'
failed = 'failed'
unknown = 'unknown'
class TaskStatusMessage(object):
stopping = 'stopping'
class TaskTags(object):
development = 'development'
class Script(EmbeddedDocument):
binary = StringField(default='python')
repository = StringField(required=True)
tag = StringField()
branch = StringField()
version_num = StringField()
entry_point = StringField(required=True)
working_dir = StringField()
requirements = SafeDictField()
class Execution(EmbeddedDocument):
test_split = IntField(default=0)
parameters = SafeDictField(default=dict)
model = StringField(reference_field='Model')
model_desc = SafeMapField(StringField(default=''))
model_labels = ModelLabels()
framework = StringField()
queue = StringField()
''' Queue ID where task was queued '''
class TaskType(object):
training = 'training'
testing = 'testing'
class Task(AttributedDocument):
meta = {
'db_alias': Database.backend,
'strict': strict,
'indexes': [
'created',
'started',
'completed',
{
'name': '%s.task.main_text_index' % Database.backend,
'fields': [
'$name',
'$id',
'$comment',
'$execution.model',
'$output.model',
'$script.repository',
'$script.entry_point',
],
'default_language': 'english',
'weights': {
'name': 10,
'id': 10,
'comment': 10,
'execution.model': 2,
'output.model': 2,
'script.repository': 1,
'script.entry_point': 1,
},
},
],
}
id = StringField(primary_key=True)
name = StrippedStringField(
required=True, user_set_allowed=True, sparse=False, min_length=3
)
type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField()
status_message = StringField()
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
started = DateTimeField()
completed = DateTimeField()
published = DateTimeField()
parent = StringField()
project = StringField(reference_field=Project, user_set_allowed=True)
output = EmbeddedDocumentField(Output, default=Output)
execution: Execution = EmbeddedDocumentField(Execution, default=Execution)
tags = ListField(StringField(required=True), user_set_allowed=True)
script = EmbeddedDocumentField(Script)
last_update = DateTimeField()
last_iteration = IntField(default=DEFAULT_LAST_ITERATION)
last_metrics = SafeMapField(field=SafeMapField(EmbeddedDocumentField(MetricEvent)))
class TaskVisibility(Enum):
active = 'active'
archived = 'archived'

View File

@@ -0,0 +1,21 @@
from mongoengine import Document, StringField
from database import Database, strict
from database.fields import SafeDictField
from database.model import DbModelMixin
from database.model.company import Company
class User(DbModelMixin, Document):
meta = {
'db_alias': Database.backend,
'strict': strict,
}
id = StringField(primary_key=True)
company = StringField(required=True, reference_field=Company)
name = StringField(required=True, user_set_allowed=True)
family_name = StringField(user_set_allowed=True)
given_name = StringField(user_set_allowed=True)
avatar = StringField()
preferences = SafeDictField(default=dict, exclude_by_default=True)

View File

@@ -0,0 +1,269 @@
from concurrent.futures import ThreadPoolExecutor
from itertools import groupby, chain
import dpath
from apierrors import errors
from database.props import PropsMixin
def project_dict(data, projection, separator='.'):
"""
Project partial data from a dictionary into a new dictionary
:param data: Input dictionary
:param projection: List of dictionary paths (each a string with field names separated using a separator)
:param separator: Separator (default is '.')
:return: A new dictionary containing only the projected parts from the original dictionary
"""
assert isinstance(data, dict)
result = {}
def copy_path(path_parts, source, destination):
src, dst = source, destination
try:
for depth, path_part in enumerate(path_parts[:-1]):
src_part = src[path_part]
if isinstance(src_part, dict):
src = src_part
dst = dst.setdefault(path_part, {})
elif isinstance(src_part, (list, tuple)):
if path_part not in dst:
dst[path_part] = [{} for _ in range(len(src_part))]
elif not isinstance(dst[path_part], (list, tuple)):
raise TypeError('Incompatible destination type %s for %s (list expected)'
% (type(dst), separator.join(path_parts[:depth + 1])))
elif not len(dst[path_part]) == len(src_part):
raise ValueError('Destination list length differs from source length for %s'
% separator.join(path_parts[:depth + 1]))
dst[path_part] = [copy_path(path_parts[depth + 1:], s, d)
for s, d in zip(src_part, dst[path_part])]
return destination
else:
raise TypeError('Unsupported projection type %s for %s'
% (type(src), separator.join(path_parts[:depth + 1])))
last_part = path_parts[-1]
dst[last_part] = src[last_part]
except KeyError:
# Projection field not in source, no biggie.
pass
return destination
for projection_path in sorted(projection):
copy_path(
path_parts=projection_path.split(separator),
source=data,
destination=result)
return result
class ProjectionHelper(object):
pool = ThreadPoolExecutor()
@property
def doc_projection(self):
return self._doc_projection
def __init__(self, doc_cls, projection, expand_reference_ids=False):
super(ProjectionHelper, self).__init__()
self._should_expand_reference_ids = expand_reference_ids
self._doc_cls = doc_cls
self._doc_projection = None
self._ref_projection = None
self._parse_projection(projection)
def _collect_projection_fields(self, doc_cls, projection):
"""
Collect projection for the given document into immediate document projection and reference documents projection
:param doc_cls: Document class
:param projection: List of projection fields
:return: A tuple of document projection and reference fields information
"""
doc_projection = set() # Projection fields for this class (used in the main query)
ref_projection_info = [] # Projection information for reference fields (used in join queries)
for field in projection:
for ref_field, ref_field_cls in doc_cls.get_reference_fields().items():
if not field.startswith(ref_field):
# Doesn't start with a reference field
continue
if field == ref_field:
# Field is exactly a reference field. In this case we won't perform any inner projection (for that,
# use '<reference field name>.*')
continue
subfield = field[len(ref_field):]
if not subfield.startswith('.'):
# Starts with something that looks like a reference field, but isn't
continue
ref_projection_info.append((ref_field, ref_field_cls, subfield[1:]))
break
else:
# Not a reference field, just add to the top-level projection
# We strip any trailing '*' since it means nothing for simple fields and for embedded documents
orig_field = field
if field.endswith('.*'):
field = field[:-2]
if not field:
raise errors.bad_request.InvalidFields(field=orig_field, object=doc_cls.__name__)
doc_projection.add(field)
return doc_projection, ref_projection_info
def _parse_projection(self, projection):
"""
Prepare the projection data structures for get_many_with_join().
:param projection: A list of field names that should be returned by the query. Sub-fields can be specified
using '.' (i.e. "parent.name"). A field terminated by '.*' indicated that all of the field's sub-fields
should be returned (only relevant for fields that represent sub-documents or referenced documents)
:type projection: list of strings
:returns A tuple of (class fields projection, reference fields projection)
"""
doc_cls = self._doc_cls
assert issubclass(doc_cls, PropsMixin)
if not projection:
return [], {}
doc_projection, ref_projection_info = self._collect_projection_fields(doc_cls, projection)
def normalize_cls_projection(cls_, fields):
""" Normalize projection for this class and group (expand *, for once) """
if '*' in fields:
return list(fields.difference('*').union(cls_.get_fields()))
return list(fields)
def compute_ref_cls_projection(cls_, group):
""" Compute inner projection for this class and group """
subfields = set([x[2] for x in group if x[2]])
return normalize_cls_projection(cls_, subfields)
def sort_key(proj_info):
return proj_info[:2]
# Aggregate by reference field. We'll leave out '*' from the projected items since
ref_projection = {
ref_field: dict(cls=ref_cls, only=compute_ref_cls_projection(ref_cls, g))
for (ref_field, ref_cls), g in groupby(sorted(ref_projection_info, key=sort_key), sort_key)
}
# Make sure this doesn't contain any reference field we'll join anyway
# (i.e. in case only_fields=[project, project.name])
doc_projection = normalize_cls_projection(doc_cls, doc_projection.difference(ref_projection).union({'id'}))
# Make sure that in case one or more field is a subfield of another field, we only use the the top-level field.
# This is done since in such a case, MongoDB will only use the most restrictive field (most nested field) and
# won't return some of the data we need.
# This way, we make sure to use the most inclusive field that contains all requested subfields.
projection_set = set(doc_projection)
doc_projection = [
field
for field in doc_projection
if not any(field.startswith(f"{other_field}.") for other_field in projection_set - {field})
]
# Make sure we didn't get any invalid projection fields for this class
invalid_fields = [f for f in doc_projection if f.split('.')[0] not in doc_cls.get_fields()]
if invalid_fields:
raise errors.bad_request.InvalidFields(fields=invalid_fields, object=doc_cls.__name__)
if ref_projection:
# Join mode - use both normal projection fields and top-level reference fields
doc_projection = set(doc_projection)
for field in set(ref_projection).difference(doc_projection):
if any(f for f in doc_projection if field.startswith(f)):
continue
doc_projection.add(field)
doc_projection = list(doc_projection)
self._doc_projection = doc_projection
self._ref_projection = ref_projection
@staticmethod
def _search(doc_cls, obj, path, only_values=True):
""" Call dpath.search with yielded=True, collect result values """
norm_path = doc_cls.get_dpath_translated_path(path)
return [v if only_values else (k, v) for k, v in dpath.search(obj, norm_path, separator='.', yielded=True)]
def project(self, results, projection_func):
"""
Perform projection on query results, using the provided projection func.
:param results: A list of results dictionaries on which projection should be performed
:param projection_func: A callable that receives a document type, list of ids and projection and returns query
results. This callable is used in order to perform sub-queries during projection
:return: Modified results (in-place)
"""
cls = self._doc_cls
ref_projection = self._ref_projection
if ref_projection:
# Join mode - get results for each reference fields projection required (this is the join step)
# Note: this is a recursive step, so we support nested reference fields
def do_projection(item):
ref_field_name, data = item
res = {}
ids = list(filter(None, set(chain.from_iterable(self._search(cls, res, ref_field_name)
for res in results))))
if ids:
doc_type = data['cls']
doc_only = list(filter(None, data['only']))
doc_only = list({'id'} | set(doc_only)) if doc_only else None
res = {r['id']: r for r in projection_func(doc_type=doc_type, projection=doc_only, ids=ids)}
data['res'] = res
items = list(ref_projection.items())
if len(ref_projection) == 1:
do_projection(items[0])
else:
for _ in self.pool.map(do_projection, items):
# From ThreadPoolExecutor.map() documentation: If a call raises an exception then that exception
# will be raised when its value is retrieved from the map() iterator
pass
def do_expand_reference_ids(result, skip_fields=None):
ref_fields = cls.get_reference_fields()
if skip_fields:
ref_fields = set(ref_fields) - set(skip_fields)
self._expand_reference_fields(cls, result, ref_fields)
def merge_projection_result(result):
for ref_field_name, data in ref_projection.items():
res = data.get('res')
if not res:
self._expand_reference_fields(cls, result, [ref_field_name])
continue
ref_ids = self._search(cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
obj = res.get(value) or {'id': value}
dpath.new(result, path, obj, separator='.')
# any reference field not projected should be expanded
do_expand_reference_ids(result, skip_fields=list(ref_projection))
update_func = merge_projection_result if ref_projection else \
do_expand_reference_ids if self._should_expand_reference_ids else None
if update_func:
for result in results:
update_func(result)
return results
@classmethod
def _expand_reference_fields(cls, doc_cls, result, fields):
for ref_field_name in fields:
ref_ids = cls._search(doc_cls, result, ref_field_name, only_values=False)
if not ref_ids:
continue
for path, value in ref_ids:
dpath.set(
result,
path,
{'id': value} if value else {},
separator='.')
@classmethod
def expand_reference_ids(cls, doc_cls, result):
cls._expand_reference_fields(doc_cls, result, doc_cls.get_reference_fields())

142
server/database/props.py Normal file
View File

@@ -0,0 +1,142 @@
from collections import OrderedDict
from operator import attrgetter
from threading import Lock
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine.base import get_document
from database.fields import (
LengthRangeEmbeddedDocumentListField,
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
)
from database.utils import get_fields, get_fields_and_attr
class PropsMixin(object):
__cached_fields = None
__cached_reference_fields = None
__cached_exclude_fields = None
__cached_fields_with_instance = None
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
@classmethod
def get_fields(cls):
if cls.__cached_fields is None:
cls.__cached_fields = get_fields(cls)
return cls.__cached_fields
@classmethod
def get_fields_with_instance(cls, doc_cls):
if cls.__cached_fields_with_instance is None:
cls.__cached_fields_with_instance = {}
if doc_cls not in cls.__cached_fields_with_instance:
cls.__cached_fields_with_instance[doc_cls] = get_fields(
doc_cls, return_instance=True
)
return cls.__cached_fields_with_instance[doc_cls]
@staticmethod
def _get_fields_with_attr(cls_, attr):
""" Get all fields with the specified attribute (supports nested fields) """
res = get_fields_and_attr(cls_, attr=attr)
def resolve_doc(v):
if not isinstance(v, six.string_types):
return v
if v == 'self':
return cls_.owner_document
return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()}
def collect_embedded_docs(doc_cls, embedded_doc_field_getter):
for field, embedded_doc_field in get_fields(
cls_, of_type=doc_cls, return_instance=True
):
embedded_doc_cls = embedded_doc_field_getter(
embedded_doc_field
).document_type
fields.update(
{
'.'.join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr
).items()
}
)
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
return fields
@classmethod
def _translate_fields_path(cls, parts):
current_cls = cls
translated_parts = []
for depth, part in enumerate(parts):
if current_cls is None:
raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
)
try:
field_name, field = next(
(k, v)
for k, v in cls.get_fields_with_instance(current_cls)
if k == part
)
except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth])
translated_parts.append(part)
if isinstance(field, EmbeddedDocumentField):
current_cls = field.document_type
elif isinstance(
field,
(
EmbeddedDocumentListField,
LengthRangeEmbeddedDocumentListField,
UniqueEmbeddedDocumentListField,
EmbeddedDocumentSortedListField,
),
):
current_cls = field.field.document_type
translated_parts.append('*')
else:
current_cls = None
return translated_parts
@classmethod
def get_reference_fields(cls):
if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field')
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields
@classmethod
def get_dpath_translated_path(cls, path, separator='.'):
if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields:
with cls.__cached_dpath_computed_fields_lock:
parts = path.split(separator)
translated = cls._translate_fields_path(parts)
result = separator.join(translated)
cls.__cached_dpath_computed_fields[path] = result
return cls.__cached_dpath_computed_fields[path]

63
server/database/query.py Normal file
View File

@@ -0,0 +1,63 @@
import copy
import re
from mongoengine import Q
from mongoengine.queryset.visitor import QueryCompilerVisitor, SimplificationVisitor, QCombination
class RegexWrapper(object):
def __init__(self, pattern, flags=None):
super(RegexWrapper, self).__init__()
self.pattern = pattern
self.flags = flags
@property
def regex(self):
return re.compile(self.pattern, self.flags if self.flags is not None else 0)
class RegexMixin(object):
def to_query(self, document):
query = self.accept(SimplificationVisitor())
query = query.accept(RegexQueryCompilerVisitor(document))
return query
def _combine(self, other, operation):
"""Combine this node with another node into a QCombination
object.
"""
if getattr(other, 'empty', True):
return self
if self.empty:
return other
return RegexQCombination(operation, [self, other])
class RegexQCombination(RegexMixin, QCombination):
pass
class RegexQ(RegexMixin, Q):
pass
class RegexQueryCompilerVisitor(QueryCompilerVisitor):
"""
Improved mongoengine complied queries visitor class that supports compiled regex expressions as part of the query.
We need this class since mongoengine's Q (QNode) class uses copy.deepcopy() as part of the tree simplification
stage, which does not support re.compiled objects (since Python 2.5).
This class allows users to provide regex strings wrapped in QueryRegex instances, which are lazily evaluated to
to re.compile instances just before being visited for compilation (this is done after the simplification stage)
"""
def visit_query(self, query):
query = copy.deepcopy(query)
query.query = self._transform_query(query.query)
return super(RegexQueryCompilerVisitor, self).visit_query(query)
def _transform_query(self, query):
return {k: v.regex if isinstance(v, RegexWrapper) else v for k, v in query.items()}

160
server/database/utils.py Normal file
View File

@@ -0,0 +1,160 @@
import hashlib
from inspect import ismethod, getmembers
from uuid import uuid4
from mongoengine import EmbeddedDocumentField, ListField, Document, Q
from mongoengine.base import BaseField
from .errors import translate_errors_context, ParseCallError
def get_fields(cls, of_type=BaseField, return_instance=False):
""" get field names from a class containing mongoengine fields """
res = []
for cls_ in reversed(cls.mro()):
res.extend([k if not return_instance else (k, v)
for k, v in vars(cls_).items()
if isinstance(v, of_type)])
return res
def get_fields_and_attr(cls, attr):
""" get field names from a class containing mongoengine fields """
res = {}
for cls_ in reversed(cls.mro()):
res.update({k: getattr(v, attr)
for k, v in vars(cls_).items()
if isinstance(v, BaseField) and hasattr(v, attr)})
return res
def _get_field_choices(name, field):
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
obj = field.document_type_obj
n, choices = _get_field_choices(field.name, obj.field)
return '%s__%s' % (name, n), choices
elif issubclass(type(field), ListField):
return name, field.field.choices
return name, field.choices
def get_fields_with_attr(cls, attr, default=False):
fields = []
for field_name, field in cls._fields.items():
if not getattr(field, attr, default):
continue
field_t = type(field)
if issubclass(field_t, EmbeddedDocumentField):
fields.extend((('%s__%s' % (field_name, name), choices)
for name, choices in get_fields_with_attr(field.document_type, attr, default)))
elif issubclass(type(field), ListField):
fields.append((field_name, field.field.choices))
else:
fields.append((field_name, field.choices))
return fields
def get_items(cls):
""" get key/value items from an enum-like class (members represent enumeration key/value) """
res = {
k: v
for k, v in getmembers(cls)
if not (k.startswith("_") or ismethod(v))
}
return res
def get_options(cls):
""" get options from an enum-like class (members represent enumeration key/value) """
return list(get_items(cls).values())
# return a dictionary of items which:
# 1. are in the call_data
# 2. are in the fields dictionary, and their value in the call_data matches the type in fields
# 3. are in the cls_fields
def parse_from_call(call_data, fields, cls_fields, discard_none_values=True):
if not isinstance(fields, dict):
# fields should be key=>type dict
fields = {k: None for k in fields}
fields = {k: v for k, v in fields.items() if k in cls_fields}
res = {}
with translate_errors_context('parsing call data'):
for field, desc in fields.items():
value = call_data.get(field)
if value is None:
if not discard_none_values and field in call_data:
# we'll keep the None value in case the field actually exists in the call data
res[field] = None
continue
if desc:
if callable(desc):
desc(value)
else:
if issubclass(desc, (list, tuple, dict)) and not isinstance(value, desc):
raise ParseCallError('expecting %s' % desc.__name__, field=field)
if issubclass(desc, Document) and not desc.objects(id=value).only('id'):
raise ParseCallError('expecting %s id' % desc.__name__, id=value, field=field)
res[field] = value
return res
def init_cls_from_base(cls, instance):
return cls(**{k: v for k, v in instance.to_mongo(use_db_field=False).to_dict().items() if k[0] != '_'})
def get_company_or_none_constraint(company=None):
return Q(company__in=(company, None, '')) | Q(company__exists=False)
def field_does_not_exist(field: str, empty_value=None, is_list=False) -> Q:
"""
Creates a query object used for finding a field that doesn't exist, or has None or an empty value.
:param field: Field name
:param empty_value: The empty value to test for (None means no specific empty value will be used)
:param is_list: Is this a list (array) field. In this case, instead of testing for an empty value,
the length of the array will be used (len==0 means empty)
:return:
"""
query = (Q(**{f"{field}__exists": False}) |
Q(**{f"{field}__in": {empty_value, None}}))
if is_list:
query |= Q(**{f"{field}__size": 0})
return query
def get_subkey(d, key_path, default=None):
""" Get a key from a nested dictionary. kay_path is a '.' separated string of keys used to traverse
the nested dictionary.
"""
keys = key_path.split('.')
for i, key in enumerate(keys):
if not isinstance(d, dict):
raise KeyError('Expecting a dict (%s)' % ('.'.join(keys[:i]) if i else 'bad input'))
d = d.get(key)
if key is None:
return default
return d
def id():
return str(uuid4()).replace("-", "")
def hash_field_name(s):
""" Hash field name into a unique safe string """
return hashlib.md5(s.encode()).hexdigest()
def merge_dicts(*dicts):
base = {}
for dct in dicts:
base.update(dct)
return base
def filter_fields(cls, fields):
"""From the fields dictionary return only the fields that match cls fields"""
return {key: fields[key] for key in fields if key in get_fields(cls)}

View File

@@ -0,0 +1,47 @@
#!/usr/bin/env python3
"""
Apply elasticsearch mappings to given hosts.
"""
import argparse
import json
import requests
from pathlib import Path
HERE = Path(__file__).parent
def apply_mappings_to_host(host: str):
def _send_mapping(f):
with f.open() as json_data:
data = json.load(json_data)
es_server = host
url = f"{es_server}/_template/{f.stem}"
requests.delete(url)
r = requests.post(
url, headers={"Content-Type": "application/json"}, data=json.dumps(data)
)
return {"mapping": f.stem, "result": r.text}
p = HERE / "mappings"
return [
_send_mapping(f) for f in p.iterdir() if f.is_file() and f.suffix == ".json"
]
def parse_args():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument("hosts", nargs="+")
return parser.parse_args()
def main():
for host in parse_args().hosts:
print(">>>>> Applying mapping to " + host)
res = apply_mappings_to_host(host)
print(res)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,27 @@
{
"template": "events-*",
"settings": {
"number_of_shards": 5
},
"mappings": {
"_default_": {
"_source": {
"enabled": true
},
"_routing": {
"required": true
},
"properties": {
"@timestamp": { "type": "date" },
"task": { "type": "keyword" },
"type": { "type": "keyword" },
"worker": { "type": "keyword" },
"timestamp": { "type": "date" },
"iter": { "type": "long" },
"metric": { "type": "keyword" },
"variant": { "type": "keyword" },
"value": { "type": "float" }
}
}
}
}

View File

@@ -0,0 +1,12 @@
{
"template": "events-log-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"msg": { "type":"text", "index": false },
"level": { "type":"keyword" }
}
}
}
}

View File

@@ -0,0 +1,11 @@
{
"template": "events-plot-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"plot_str": { "type":"text", "index": false }
}
}
}
}

View File

@@ -0,0 +1,12 @@
{
"template": "events-training_debug_image-*",
"order" : 1,
"mappings": {
"_default_": {
"properties": {
"key": { "type": "keyword" },
"url": { "type": "keyword" }
}
}
}
}

View File

@@ -0,0 +1 @@
requests>=2.21.0

85
server/es_factory.py Normal file
View File

@@ -0,0 +1,85 @@
from datetime import datetime
from elasticsearch import Elasticsearch, Transport
from config import config
log = config.logger(__file__)
_instances = {}
class MissingClusterConfiguration(Exception):
"""
Exception when cluster configuration is not found in config files
"""
pass
class InvalidClusterConfiguration(Exception):
"""
Exception when cluster configuration does not contain required properties
"""
pass
def connect(cluster_name):
"""
Returns the es client for the cluster.
Connects to the cluster if did not connect previously
:param cluster_name: Dot separated cluster path in the configuration file
:return: es client
:raises MissingClusterConfiguration: in case no config section is found for the cluster
:raises InvalidClusterConfiguration: in case cluster config section misses needed properties
"""
if cluster_name not in _instances:
cluster_config = _get_cluster_config(cluster_name)
hosts = cluster_config.get('hosts', None)
if not hosts:
raise InvalidClusterConfiguration(cluster_name)
args = cluster_config.get('args', {})
_instances[cluster_name] = Elasticsearch(hosts=hosts, transport_class=Transport, **args)
return _instances[cluster_name]
def _get_cluster_config(cluster_name):
"""
Returns cluster config for the specified cluster path
:param cluster_name: Dot separated cluster path in the configuration file
:return: config section for the cluster
:raises MissingClusterConfiguration: in case no config section is found for the cluster
"""
cluster_key = '.'.join(('hosts.elastic', cluster_name))
cluster_config = config.get(cluster_key, None)
if not cluster_config:
raise MissingClusterConfiguration(cluster_name)
return cluster_config
def connect_all():
clusters = config.get("hosts.elastic").as_plain_ordered_dict()
for name in clusters:
connect(name)
def instances():
return _instances
def timestamp_str_to_millis(ts_str):
epoch = datetime.utcfromtimestamp(0)
current_date = datetime.strptime(ts_str, "%Y-%m-%dT%H:%M:%S.%fZ")
return int((current_date - epoch).total_seconds() * 1000.0)
def get_timestamp_millis():
now = datetime.utcnow()
epoch = datetime.utcfromtimestamp(0)
return int((now - epoch).total_seconds() * 1000.0)
def get_es_timestamp_str():
now = datetime.utcnow()
return now.strftime("%Y-%m-%dT%H:%M:%S") + ".%03d" % (now.microsecond / 1000) + "Z"

82
server/init_data.py Normal file
View File

@@ -0,0 +1,82 @@
from datetime import datetime
from furl import furl
from database.model.auth import User, Credentials
from config import config
from database.model.auth import Role
from database.model.company import Company
from elastic.apply_mappings import apply_mappings_to_host
log = config.logger(__file__)
class MissingElasticConfiguration(Exception):
"""
Exception when cluster configuration is not found in config files
"""
pass
def init_es_data():
hosts_key = "hosts.elastic.events.hosts"
hosts_config = config.get(hosts_key, None)
if not hosts_config:
raise MissingElasticConfiguration(hosts_key)
for conf in hosts_config:
host = furl(scheme="http", host=conf["host"], port=conf["port"]).url
log.info(f"Applying mappings to host: {host}")
res = apply_mappings_to_host(host)
log.info(res)
def _ensure_company():
company_id = config.get("apiserver.default_company")
company = Company.objects(id=company_id).only("id").first()
if company:
return company_id
company_name = "trains"
log.info(f"Creating company: {company_name}")
company = Company(id=company_id, name=company_name)
company.save()
return company_id
def _ensure_user(user_data, company_id):
user = User.objects(
credentials__match=Credentials(key=user_data["key"], secret=user_data["secret"])
).first()
if user:
return user.id
log.info(f"Creating user: {user_data['name']}")
user = User(
id=f"__{user_data['name']}__",
name=user_data["name"],
company=company_id,
role=user_data["role"],
email=user_data["email"],
created=datetime.utcnow(),
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])],
)
user.save()
return user.id
def init_mongo_data():
company_id = _ensure_company()
users = [
{"name": "apiserver", "role": Role.system, "email": "apiserver@example.com"},
{"name": "webserver", "role": Role.system, "email": "webserver@example.com"},
{"name": "tests", "role": Role.user, "email": "tests@example.com"},
]
for user in users:
credentials = config.get(f"secure.credentials.{user['name']}")
user["key"] = credentials.user_key
user["secret"] = credentials.user_secret
_ensure_user(user, company_id)

27
server/requirements.txt Normal file
View File

@@ -0,0 +1,27 @@
six
Flask>=0.12.2
elasticsearch>=5.0.0,<6.0.0
pyhocon>=0.3.35
requests>=2.13.0
arrow>=0.10.0
pymongo==3.6.1 # 3.7 has a bug multiple users logged in
Flask-Cors>=3.0.5
Flask-Compress>=1.4.0
mongoengine==0.16.2
jsonmodels>=2.3
pyjwt>=1.3.0
gunicorn>=19.7.1
Jinja2==2.10
python-rapidjson>=0.6.3
jsonschema>=2.6.0
dpath>=1.4.2
funcsigs==1.0.2
luqum>=0.7.2
typing>=3.6.4
attrs>=19.1.0
nested_dict>=1.61
related>=0.7.2
validators>=0.12.4
fastjsonschema>=2.8
boltons>=19.1.0
semantic_version>=2.6.0,<3

267
server/schema.py Normal file
View File

@@ -0,0 +1,267 @@
"""
Objects representing schema entities
"""
import json
import re
from operator import attrgetter
from pathlib import Path
from typing import Mapping, Sequence
import attr
from boltons.dictutils import subdict
from pyhocon import ConfigFactory
from config import config
from service_repo.base import PartialVersion
HERE = Path(__file__)
log = config.logger(__file__)
ALL_ROLES = "*"
class EndpointSchema:
REQUEST_KEY = "request"
RESPONSE_KEY = "response"
BATCH_REQUEST_KEY = "batch_request"
DEFINITIONS_KEY = "definitions"
def __init__(
self,
service_name: str,
action_name: str,
version: PartialVersion,
schema: dict,
definitions: dict = None,
):
"""
Class for interacting with the schema of a single endpoint
:param service_name: name of containing service
:param action_name: name of action
:param version: endpoint version
:param schema: endpoint schema
:param definitions: service definitions
"""
self.service_name = service_name
self.action_name = action_name
self.full_name = f"{service_name}.{action_name}"
self.version = version
self.definitions = definitions
self.request_schema = None
self.batch_request_schema = None
if self.REQUEST_KEY in schema:
self.request_schema = {
**schema[self.REQUEST_KEY],
self.DEFINITIONS_KEY: self.definitions,
}
elif self.BATCH_REQUEST_KEY in schema:
self.batch_request_schema = {
**schema[self.BATCH_REQUEST_KEY],
self.DEFINITIONS_KEY: self.definitions,
}
else:
raise RuntimeError(
f"endpoint {self.full_name} version {self.version} "
"has no request or batch_request schema",
schema,
)
self.response_schema = {
**schema[self.RESPONSE_KEY],
"definitions": self.definitions,
}
class EndpointVersionsGroup:
endpoints: Sequence[EndpointSchema]
allow_roles: Sequence[str]
internal: bool
authorize: bool
def __repr__(self):
return (
f"{type(self).__name__}<{self.full_name}, "
f"versions={tuple(e.version for e in self.endpoints)}>"
)
def __init__(
self,
service_name: str,
action_name: str,
conf: dict,
definitions: dict = None,
defaults: dict = None,
):
"""
Represents multiple implementations of a single endpoint, discriminated by API version
:param service_name: name of containing service
:param action_name: name of action
:param conf: mapping between minimum version to endpoint schema
:param definitions: service definitions
:param defaults: service defaults
"""
self.service_name = service_name
self.action_name = action_name
self.full_name = f"{service_name}.{action_name}"
self.definitions = definitions or {}
self.defaults = defaults or {}
self.internal = self._pop_attr_with_default(conf, "internal")
self.allow_roles = self._pop_attr_with_default(conf, "allow_roles")
self.authorize = self._pop_attr_with_default(conf, "authorize")
def parse_version(version):
if not re.match(r"^\d+\.\d+$", version):
raise ValueError(
f"Encountered unrecognized key {version!r} in {self.service_name}.{self.action_name}"
)
return PartialVersion(version)
self.endpoints = sorted(
(
EndpointSchema(
service_name=self.service_name,
action_name=self.action_name,
version=parse_version(version),
schema=endpoint_conf,
definitions=self.definitions,
)
for version, endpoint_conf in conf.items()
),
key=attrgetter("version"),
)
def allows(self, role):
return ALL_ROLES in self.allow_roles or role in self.allow_roles
def _pop_attr_with_default(self, conf, attr):
return conf.pop(attr, self.defaults[attr])
def get_for_version(self, min_version: PartialVersion):
"""
Return endpoint schema for version
"""
if not self.endpoints:
raise ValueError(f"endpoint group {self} has no versions")
for endpoint in self.endpoints:
if min_version <= endpoint.version:
return endpoint
raise ValueError(
f"min_version {min_version} is higher than highest version in group {self}"
)
class Service:
endpoint_groups: Mapping[str, EndpointVersionsGroup]
def __init__(self, name: str, conf: dict, api_defaults: dict):
"""
Represents schema of one service
:param name: name of service
:param conf: service configuration, containing endpoint groups and other details
:param api_defaults: API-wide endpoint attributes default values
"""
self.name = name
conf = subdict(conf, drop=("_description", "_references"))
self.defaults = {**api_defaults, **conf.pop("_default", {})}
self.definitions = conf.pop("_definitions", None)
self.endpoint_groups: Mapping[str, EndpointVersionsGroup] = {
endpoint_name: EndpointVersionsGroup(
service_name=self.name,
action_name=endpoint_name,
conf=endpoint_conf,
defaults=self.defaults,
definitions=self.definitions,
)
for endpoint_name, endpoint_conf in conf.items()
}
@attr.s()
class SchemaReader:
root: Path = attr.ib(default=HERE.parent / "schema/services", converter=Path)
cache_path: Path = attr.ib(default=None)
def __attrs_post_init__(self):
if not self.cache_path:
self.cache_path = self.root / "_cache.json"
@staticmethod
def mod_time(path):
"""
return file modification time
"""
return path.stat().st_mtime
@staticmethod
def read_file(path):
return ConfigFactory.parse_file(path).as_plain_ordered_dict()
def get_schema(self):
"""
Parse the API schema to schema object.
Load from config files and write to cache file if possible.
"""
services = [
service
for service in self.root.glob("*.conf")
if not service.name.startswith("_")
]
current_services_names = {path.stem for path in services}
try:
if self.mod_time(self.cache_path) >= max(map(self.mod_time, services)):
log.info("loading schema from cache")
result = json.loads(self.cache_path.read_text())
cached_services_names = set(result.pop("services_names", []))
if cached_services_names == current_services_names:
return Schema(**result)
else:
log.info(
f"found services files changed: "
f"added: {list(current_services_names - cached_services_names)}, "
f"removed: {list(cached_services_names - current_services_names)}"
)
except (IOError, KeyError, TypeError, ValueError, AttributeError) as ex:
log.warning(f"failed loading cache: {ex}")
log.info("regenerating schema cache")
services = {path.stem: self.read_file(path) for path in services}
api_defaults = self.read_file(self.root / "_api_defaults.conf")
try:
self.cache_path.write_text(
json.dumps(
dict(
services_names=list(current_services_names),
services=services,
api_defaults=api_defaults,
)
)
)
except IOError:
log.exception(f"failed cache file to {self.cache_path}")
return Schema(services, api_defaults)
class Schema:
services: Mapping[str, Service]
def __init__(self, services: dict, api_defaults: dict):
"""
Represents the entire API schema
:param services: services schema
:param api_defaults: default values of service configuration
"""
self.api_defaults = api_defaults
self.services = {
name: Service(name, conf, api_defaults=self.api_defaults)
for name, conf in services.items()
}
schema = SchemaReader().get_schema()

View File

View File

@@ -0,0 +1,199 @@
// some definitions
definitions {
// description of a reference
reference {
type: object
additionalProperties: false
required: [ "$ref" ]
properties {
"$ref" { type: string }
}
}
// description of an "additionalProperties" section
additional_properties {
oneOf: [
{ type: object, additionalProperties: true },
{ type: boolean }
]
}
// each endpoint is a mapping of versions to actions
endpoint {
type: object
additionalProperties: false
properties {
// whether endpoint is internal
internal { type: boolean }
// whether endpoint requires authorization
authorize { type: boolean }
// list of roles allowed to access endpoint
allow_roles {
type: array
items { type: string }
}
}
patternProperties {
"^\d\.\d+$" { "$ref": "#/definitions/action" }
}
}
// an action describes request and response
action {
type: object
additionalProperties: false
// must have response
required: [ response, description ]
// must have either request or batch_request
oneOf: [
{ required: [ request ] }
{ required: [ batch_request ] }
]
properties {
method { const: post }
description { type: string }
request {
oneOf: [
{ "$ref": "#/definitions/reference" }
{ "$ref": "#/definitions/request" }
{ "$ref": "#/definitions/multi_request" }
]
}
response {
type: object
oneOf: [
{ "$ref": "#/definitions/response" }
// { "$ref": "#/definitions/reference" }
{
type: object
properties {
type { const: string }
}
additionalProperties: false
}
]
}
batch_request {
type: object
additionalProperties: false
// required: [ action, version, description ]
required: [ action, version ]
properties {
action { type: string }
version { type: number }
// description { type: string }
}
}
}
}
// describes request to server
request {
type: object
additionalProperties: false
required: [
type
]
properties {
// says it's an object
type { const: object }
// required fields
required {
type: array
items { type: string }
}
// request fields,
// an object that can have anything
properties {
type: object
additionalProperties {
type: object
required: [ description ]
properties {
description { type: string }
}
additionalProperties: true
}
}
dependencies { type: object }
// can have an "additionalProperties" section
additionalProperties { "$ref": "#/definitions/additional_properties" }
}
}
multi_request {
type: object
required: [ type ]
additionalProperties: false
oneOf: [
{
required: [ anyOf ]
}
{
required: [ oneOf ]
}
]
properties {
type { const: object }
anyOf {
type: array
items { "$ref": "#/definitions/reference" }
}
oneOf {
type: array
items { "$ref": "#/definitions/reference" }
}
}
}
response {
type: object
additionalProperties: false
required: [
type
]
properties {
// says it's an object
type { const: object }
// nothing is required
// can have anything
properties {
type: object
additionalProperties: true
}
// can have an "additionalProperties" section
additionalProperties { "$ref": "#/definitions/additional_properties" }
}
}
}
// schema starts here!
type: object
required: [ _description ]
properties {
_description { type: string }
// definitions for generator
_definitions {
type: object
additionalProperties {
required: [ type ]
properties {
type { type: string }
}
// can have anything
additionalProperties: true
}
}
_references = ${properties._definitions}
// default values for actions
// can have anything
_default {
additionalProperties: true
}
}
// describing each endpoint
additionalProperties {
type: object
// can be:
oneOf: [
// a reference
{ "$ref": "#/definitions/reference" }
// or a mapping from versions to actions
{ "$ref": "#/definitions/endpoint" }
]
}

303
server/schema/meta/validate.py Executable file
View File

@@ -0,0 +1,303 @@
#!/usr/bin/env python
from __future__ import print_function
import argparse
import json
import os
import sys
import time
from itertools import groupby
from operator import itemgetter
import pyhocon
import six
import yaml
from colors import color
from jsonschema import validate, ValidationError as JSONSchemaValidationError
from jsonschema.validators import validator_for
from pathlib import Path
from pyparsing import ParseBaseException
LINTER_URL = "https://www.jsonschemavalidator.net/"
class LocalStorage(object):
def __init__(self, driver):
self.driver = driver
def __len__(self):
return self.driver.execute_script("return window.localStorage.length;")
def items(self):
return self.driver.execute_script(
"""
var ls = window.localStorage, items = {};
for (var i = 0, k; i < ls.length; ++i)
items[k = ls.key(i)] = ls.getItem(k);
return items;
"""
)
def keys(self):
return self.driver.execute_script(
"""
var ls = window.localStorage, keys = [];
for (var i = 0; i < ls.length; ++i)
keys[i] = ls.key(i);
return keys;
"""
)
def get(self, key):
return self.driver.execute_script(
"return window.localStorage.getItem(arguments[0]);", key
)
def remove(self, key):
self.driver.execute_script("window.localStorage.removeItem(arguments[0]);", key)
def clear(self):
self.driver.execute_script("window.localStorage.clear();")
def __getitem__(self, key):
value = self.get(key)
if value is None:
raise KeyError(key)
return value
def __setitem__(self, key, value):
self.driver.execute_script(
"window.localStorage.setItem(arguments[0], arguments[1]);", key, value
)
def __contains__(self, key):
return key in self.keys()
def __iter__(self):
return iter(self.keys())
def __repr__(self):
return repr(self.items())
class ValidationError(Exception):
def __init__(self, *args):
super(ValidationError, self).__init__(*args)
self.message = self.args[0]
def report(self, schema_file):
message = color(schema_file, fg='red')
if self.message:
message += ": {}".format(self.message)
print(message)
class InvalidFile(ValidationError):
"""
InvalidFile
Wraps other exceptions that occur in file validation
:param message: message to display
"""
def __init__(self, message):
super(InvalidFile, self).__init__(message)
exc_type, _, _ = self.exc_info = sys.exc_info()
if exc_type:
self.message = "{}: {}".format(exc_type.__name__, message)
def raise_original(self):
six.reraise(*self.exc_info)
def load_hocon(name):
"""
load_hocon
load configuration from file
:param name: file path
"""
return pyhocon.ConfigFactory.parse_file(name).as_plain_ordered_dict()
def validate_ascii_only(name):
invalid_char = next(
(
(line_num, column, char)
for line_num, line in enumerate(Path(name).read_text().splitlines())
for column, char in enumerate(line)
if ord(char) not in range(128)
),
None,
)
if invalid_char:
line, column, char = invalid_char
raise ValidationError(
"file contains non-ascii character {!r} in line {} pos {}".format(
char, line, column
)
)
def validate_file(meta, name):
"""
validate_file
validate file according to meta-scheme
:param meta: meta-scheme
:param name: file path
"""
validate_ascii_only(name)
try:
schema = load_hocon(name)
except ParseBaseException as e:
raise InvalidFile(repr(e))
try:
validate(schema, meta)
return schema
except JSONSchemaValidationError as e:
path = "->".join(e.absolute_path)
message = "{}: {}".format(path, e.args[0])
raise InvalidFile(message)
except Exception as e:
raise InvalidFile(str(e))
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("files", nargs="+")
parser.add_argument(
"--linter", "-l", action="store_true", help="open jsonschema linter in browser"
)
parser.add_argument(
"--raise",
"-r",
action="store_true",
dest="raise_",
help="raise first exception encountered and print traceback",
)
parser.add_argument(
"--detect-collisions",
action="store_true",
help="detect objects with the same name in different modules",
)
return parser.parse_args()
def open_linter(driver, meta, schema):
driver.maximize_window()
driver.get(LINTER_URL)
storage = LocalStorage(driver)
storage["jsonText"] = json.dumps(schema, indent=4)
storage["schemaText"] = json.dumps(meta, indent=4)
driver.refresh()
class LazyDriver(object):
def __init__(self):
self._driver = None
try:
from selenium import webdriver, common
except ImportError:
webdriver = None
common = None
self.webdriver = webdriver
self.common = common
def __getattr__(self, item):
return getattr(self.driver, item)
@property
def driver(self):
if self._driver:
return self._driver
if not (self.webdriver and self.common):
print("selenium not installed: linter unavailable")
return None
for driver_type in self.webdriver.Chrome, self.webdriver.Firefox:
try:
self._driver = driver_type()
break
except self.common.exceptions.WebDriverException:
pass
else:
print("No webdriver is found for chrome or firefox")
return self._driver
def wait(self):
if not self._driver:
return
try:
while True:
self._driver.title
time.sleep(0.5)
except self.common.exceptions.WebDriverException:
pass
def remove_description(dct):
dct.pop("description", None)
for value in dct.values():
try:
remove_description(value)
except (TypeError, AttributeError):
pass
def main():
args = parse_args()
meta = load_hocon(os.path.dirname(__file__) + "/meta.conf")
validator_for(meta).check_schema(meta)
driver = LazyDriver()
collisions = {}
for schema_file in args.files:
if Path(schema_file).name.startswith("_"):
continue
try:
schema = validate_file(meta, schema_file)
except InvalidFile as e:
if args.linter and driver.driver:
open_linter(driver, meta, load_hocon(schema_file))
elif args.raise_:
e.raise_original()
e.report(schema_file)
except ValidationError as e:
e.report(schema_file)
else:
for def_name, value in schema.get("_definitions", {}).items():
service_name = str(Path(schema_file).stem)
remove_description(value)
collisions.setdefault(def_name, {})[service_name] = value
warning = color("warning", fg="red")
if args.detect_collisions:
for name, values in collisions.items():
if len(values) <= 1:
continue
groups = [
[service for (service, _) in pairs]
for _, pairs in groupby(values.items(), itemgetter(1))
]
if not groups:
raise RuntimeError("Unknown error")
print(
"{}: collision for {}:\n{}".format(warning, name, yaml.dump(groups)),
end="",
)
driver.wait()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,62 @@
# Writing descriptions
There are two options for writing parameters descriptions. Mixing between the two
will result in output which is not Sphinx friendly.
Whatever you choose, lines are subject to wrapping.
- non-strict whitespace - Break the string however you like.
Newlines and sequences of tabs/spaces are replaced by one space.
Example:
```
get_all {
"1.5" {
description: """This will all appear
as one long
sentence.
Break lines wherever you
like.
"""
}
}
```
Becomes:
```
class GetAllRequest(...):
"""
This will all appear as one long sentence. Break lines wherever you
like.
"""
```
- strict whitespace - Single newlines will be replaced by spaces.
Double newlines become a single newline WITH INDENTATION PRESERVED,
so if uniform indentation is requried for all lines you MUST start new lines
at the first column.
Example:
```
get_all {
"1.5" {
description: """
Some general sentence.
- separate lines must have double newlines between them
- must begin at first column even though the "description" key is indented
- you can use single newlines, the lines will be
joined
-- sub bullet: this line's leading spaces are preserved
"""
}
}
```
Becomes:
```
class GetAllRequest(...):
"""
Some general sentence.
- separate lines must have double newlines between them
- must begin at first column even though the "description" key is indented
- you can use single newlines, the lines will be joined
-- sub bullet: this line's leading spaces are preserved
"""
```

View File

@@ -0,0 +1,3 @@
internal: false
allow_roles: ["*"]
authorize: true

View File

@@ -0,0 +1,13 @@
credentials {
type: object
properties {
access_key {
type: string
description: Credentials access key
}
secret_key {
type: string
description: Credentials secret key
}
}
}

View File

@@ -0,0 +1,320 @@
_description: """This service provides authentication management and authorization
validation for the entire system."""
_default {
internal: true
allow_roles: ["system", "root"]
}
_definitions {
include "_common.conf"
credential_key {
type: object
properties {
access_key {
type: string
description: ""
}
}
}
role {
type: string
enum: [ admin, superuser, user, annotator ]
}
}
login {
internal: false
allow_roles = [ "*" ]
"2.1" {
description: """Get a token based on supplied credentials (key/secret).
Intended for use by users with key/secret credentials that wish to obtain a token
for use with other services."""
request {
type: object
properties {
expiration_sec {
type: integer
description: """Requested token expiration time in seconds.
Not guaranteed, might be overridden by the service"""
}
}
}
response {
type: object
properties {
token {
type: string
description: Token string
}
}
}
}
}
get_token_for_user {
"2.1" {
description: """Get a token for the specified user. Intended for internal use."""
request {
type: object
required: [
user
]
properties {
user {
type: string
description: User ID
}
company {
type: string
description: Company ID
}
expiration_sec {
type: integer
description: """Requested token expiration time in seconds.
Not guaranteed, might be overridden by the service"""
}
}
}
response {
type: object
properties {
token {
type: string
description: ""
}
}
}
}
}
validate_token {
"2.1" {
description: """Validate a token and return user identity if valid.
Intended for internal use. """
request {
type: object
required: [ token ]
properties {
token {
type: string
description: Token string
}
}
}
response {
type: object
properties {
valid {
type: boolean
description: Boolean indicating if the token is valid
}
user {
type: string
description: Associated user ID
}
company {
type: string
description: Associated company ID
}
}
}
}
}
create_user {
"2.1" {
description: """Creates a new user auth entry. Intended for internal use. """
request {
type: object
required: [
name
company
email
]
properties {
name {
type: string
description: User name (makes the auth entry more readable)
}
company {
type: string
description: Associated company ID
}
email {
type: string
description: Email address uniquely identifying the user
}
role {
description: User role
default: user
"$ref": "#/definitions/role"
}
given_name {
type: string
description: Given name
}
family_name {
type: string
description: Family name
}
avatar {
type: string
description: Avatar URL
}
}
}
response {
type: object
properties {
id {
type: string
description: New user ID
}
}
}
}
}
create_credentials {
allow_roles = [ "*" ]
internal: false
"2.1" {
description: """Creates a new set of credentials for the authenticated user.
New key/secret is returned.
Note: Secret will never be returned in any other API call.
If a secret is lost or compromised, the key should be revoked
and a new set of credentials can be created."""
request {
type: object
properties {}
additionalProperties: false
}
response {
type: object
properties {
credentials {
"$ref": "#/definitions/credentials"
description: Created credentials
}
}
}
}
}
get_credentials {
allow_roles = [ "*" ]
internal: false
"2.1" {
description: """Returns all existing credential keys for the authenticated user.
Note: Only credential keys are returned."""
request {
type: object
properties {}
additionalProperties: false
}
response {
type: object
properties {
credentials {
description: "List of credentials, each with an empty secret field."
type: array
items { "$ref": "#/definitions/credential_key" }
}
}
}
}
}
revoke_credentials {
allow_roles = [ "*" ]
internal: false
"2.1" {
description: """Revokes (and deletes) a set (key, secret) of credentials for
the authenticated user."""
request {
type: object
required: [ key_id ]
properties {
access_key {
type: string
description: Credentials key
}
}
}
response {
type: object
properties {
revoked {
description: "Number of credentials revoked"
type: integer
enum: [0, 1]
}
}
}
}
}
delete_user {
allow_roles = [ "system", "root", "admin" ]
internal: false
"2.1" {
description: """Delete a new user manually. Only supported in on-premises deployments. This only removes the user's auth entry so that any references to the deleted user's ID will still have valid user information"""
request {
type: object
required: [ user ]
properties {
user {
type: string
description: User ID
}
}
}
response {
type: object
properties {
deleted {
description: "True if user was successfully deleted, False otherwise"
type: boolean
}
}
}
}
}
edit_user {
internal: false
allow_roles: ["system", "root", "admin"]
"2.1" {
description: """ Edit a users' auth data properties"""
request {
type: object
properties {
user {
description: "User ID"
type: string
}
role {
description: "The new user's role within the company"
type: string
enum: [admin, superuser, user, annotator]
}
}
}
response {
type: object
properties {
updated {
description: "Number of users updated (0 or 1)"
type: number
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}

View File

@@ -0,0 +1,849 @@
{
_description : "Provides an API for running tasks to report events collected by the system."
_definitions {
metrics_scalar_event {
description: "Used for reporting scalar metrics during training task"
type: object
required: [ task, type ]
properties {
timestamp {
description: "Epoch milliseconds UTC, will be set by the server if not set."
type: number
}
type {
description: "training_stats_vector"
const: "training_stats_scalar"
}
task {
description: "Task ID (required)"
type: string
}
iter {
description: "Iteration"
type: integer
}
metric {
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
type: string
}
variant {
description: "E.g. 'class_1', 'total', 'average"
type: string
}
value {
description: ""
type: number
}
}
}
metrics_vector_event {
description: "Used for reporting vector metrics during training task"
type: object
required: [ task ]
properties {
timestamp {
description: "Epoch milliseconds UTC, will be set by the server if not set."
type: number
}
type {
description: "training_stats_vector"
const: "training_stats_vector"
}
task {
description: "Task ID (required)"
type: string
}
iter {
description: "Iteration"
type: integer
}
metric {
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
type: string
}
variant {
description: "E.g. 'class_1', 'total', 'average"
type: string
}
values {
description: "vector of float values"
type: array
items { type: number }
}
}
}
metrics_image_event {
description: "An image or video was dumped to storage for debugging"
type: object
required: [ task, type ]
properties {
timestamp {
description: "Epoch milliseconds UTC, will be set by the server if not set."
type: number
}
type {
description: ""
const: "training_debug_image"
}
task {
description: "Task ID (required)"
type: string
}
iter {
description: "Iteration"
type: integer
}
metric {
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
type: string
}
variant {
description: "E.g. 'class_1', 'total', 'average"
type: string
}
key {
description: "File key"
type: string
}
url {
description: "File URL"
type: string
}
}
}
metrics_plot_event {
description: """ An entire plot (not single datapoint) and it's layout.
Used for plotting ROC curves, confidence matrices, etc. when evaluating the net."""
type: object
required: [ task, type ]
properties {
timestamp {
description: "Epoch milliseconds UTC, will be set by the server if not set."
type: number
}
type {
description: "'plot'"
const: "plot"
}
task {
description: "Task ID (required)"
type: string
}
iter {
description: "Iteration"
type: integer
}
metric {
description: "Metric name, e.g. 'count', 'loss', 'accuracy'"
type: string
}
variant {
description: "E.g. 'class_1', 'total', 'average"
type: string
}
plot_str {
description: """An entire plot (not single datapoint) and it's layout.
Used for plotting ROC curves, confidence matrices, etc. when evaluating the net.
"""
type: string
}
}
}
log_level_enum {
type: string
enum: [
notset
debug
verbose
info
warn
warning
error
fatal
critical
]
}
task_log_event {
description: """A log event associated with a task."""
type: object
required: [ task, type ]
properties {
timestamp {
description: "Epoch milliseconds UTC, will be set by the server if not set."
type: number
}
type {
description: "'log'"
const: "log"
}
task {
description: "Task ID (required)"
type: string
}
level {
description: "Log level."
"$ref": "#/definitions/log_level_enum"
}
worker {
description: "Name of machine running the task."
type: string
}
msg {
description: "Log message."
type: string
}
}
}
}
add {
"2.1" {
description: "Adds a single event"
request {
type: object
anyOf: [
{ "$ref": "#/definitions/metrics_scalar_event" }
{ "$ref": "#/definitions/metrics_vector_event" }
{ "$ref": "#/definitions/metrics_image_event" }
{ "$ref": "#/definitions/metrics_plot_event" }
{ "$ref": "#/definitions/task_log_event" }
]
}
response {
type: object
additionalProperties: true
}
}
}
add_batch {
"2.1" {
description: "Adds a batch of events in a single call."
batch_request: {
action: add
version: 1.5
}
response {
type: object
properties {
added { type: integer }
errors { type: integer }
}
}
}
}
delete_for_task {
"2.1" {
description: "Delete all task event. *This cannot be undone!*"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
}
}
response {
type: object
properties {
deleted {
type: boolean
description: "Number of deleted events"
}
}
}
}
}
debug_images {
"2.1" {
description: "Get all debug images of a task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
task {
type: string
description: "Task ID"
}
images {
type: array
items { type: object }
description: "Images list"
}
returned {
type: integer
description: "Number of results returned"
}
total {
type: number
description: "Total number of results available for this query"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
get_task_log {
"1.5" {
description: "Get all 'log' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
order {
type: string
description: "Timestamp order in which log events will be returned (defaults to ascending)"
enum: [
asc
desc
]
}
scroll_id {
type: string
description: ""
}
batch_size {
type: integer
description: ""
}
}
}
response {
type: object
properties {
events {
type: array
# TODO: items: log event
items { type: object }
}
returned { type: integer }
total { type: integer }
scroll_id { type: string }
}
}
}
"1.7" {
description: "Get all 'log' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
order {
type: string
description: "Timestamp order in which log events will be returned (defaults to ascending)"
enum: [
asc
desc
]
}
from {
type: string
description: "Where will the log entries be taken from (default to the head of the log)"
enum: [
head
tail
]
}
scroll_id {
type: string
description: ""
}
batch_size {
type: integer
description: ""
}
}
}
response {
type: object
properties {
events {
type: array
# TODO: items: log event
items { type: object }
description: "Log items list"
}
returned {
type: integer
description: "Number of results returned"
}
total {
type: number
description: "Total number of results available for this query"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
get_task_events {
"2.1" {
description: "Scroll through task events, sorted by timestamp"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
order {
type:string
description: "'asc' (default) or 'desc'."
enum: [
asc
desc
]
}
scroll_id {
type: string
description: "Pass this value on next call to get next page"
}
batch_size {
type: integer
description: "Number of events to return each time"
}
}
}
response {
type: object
properties {
events {
type: array
items { type: object }
description: "Events list"
}
returned {
type: integer
description: "Number of results returned"
}
total {
type: number
description: "Total number of results available for this query"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
download_task_log {
"2.1" {
description: "Get an attachment containing the task's log"
request {
type: object
required: [
task
]
properties {
task {
description: "Task ID"
type: string
}
line_type {
description: "Line format type"
type: string
enum: [
json
text
]
}
line_format {
type: string
description: "Line string format. Used if the line type is 'text'"
default: "{asctime} {worker} {level} {msg}"
}
}
}
response {
type: string
}
}
}
get_task_plots {
"2.1" {
description: "Get all 'plot' events for this task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
plots {
type: array
items {
type: object
}
description: "Plots list"
}
returned {
type: integer
description: "Number of results returned"
}
total {
type: number
description: "Total number of results available for this query"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
get_multi_task_plots {
"2.1" {
description: "Get 'plot' events for the given tasks"
request {
type: object
required: [
tasks
]
properties {
tasks {
description: "List of task IDs"
type: array
items {
type: string
description: "Task ID"
}
}
iters {
type: integer
description: "Max number of latest iterations for which to return debug images"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
response {
type: object
properties {
plots {
type: object
description: "Plots mapping (keyed by task name)"
}
returned {
type: integer
description: "Number of results returned"
}
total {
type: number
description: "Total number of results available for this query"
}
scroll_id {
type: string
description: "Scroll ID for getting more results"
}
}
}
}
}
get_vector_metrics_and_variants {
"2.1" {
description: ""
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
}
}
response {
type: object
properties {
metrics {
description: ""
type: array
items: { type: object }
# TODO: items: ???
}
}
}
}
}
vector_metrics_iter_histogram {
"2.1" {
description: "Get histogram data of all the scalar metrics and variants in the task"
request {
type: object
required: [
task
metric
variant
]
properties {
task {
type: string
description: "Task ID"
}
metric {
type: string
description: ""
}
variant {
type: string
description: ""
}
}
}
response {
type: object
properties {
images {
type: array
items {
type: object
}
}
}
}
}
}
scalar_metrics_iter_histogram {
"2.1" {
description: "Get histogram data of all the vector metrics and variants in the task"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
}
}
response {
type: object
properties {
images {
type: array
items {
type: object
}
}
}
}
}
}
multi_task_scalar_metrics_iter_histogram {
"2.1" {
description: "Used to compare scalar stats histogram of multiple tasks"
request {
type: object
required: [
tasks
]
properties {
tasks {
description: "List of task Task IDs"
type: array
items {
type: string
description: "List of task Task IDs"
}
}
}
}
response {
type: object
// properties {}
additionalProperties: true
}
}
}
get_task_latest_scalar_values {
"2.1" {
description: "Get the tasks's latest scalar values"
request {
type: object
required: [
task
]
properties {
task {
type: string
description: "Task ID"
}
}
}
response {
type: object
properties {
metrics {
type: array
items {
type: object
properties {
name {
type: string
description: "Metric name"
}
variants {
type: array
items {
type: object
properties {
name {
type: string
description: "Variant name"
}
last_value {
type: number
description: "Last reported value"
}
last_100_value {
type: number
description: "Average of 100 last reported values"
}
}
}
}
}
}
}
}
}
}
}
get_scalar_metrics_and_variants {
"2.1" {
description: get task scalar metrics and variants
request {
type: object
required: [ task ]
properties {
task {
description: task ID
type: string
}
}
}
response {
type: object
properties {
metrics {
type: object
additionalProperties: true
}
}
}
}
}
get_scalar_metric_data {
"2.1" {
description: "get scalar metric data for task"
request {
type: object
properties {
task {
type: string
description: task ID
}
metric {
type: string
description: type of metric
}
}
}
response {
type: object
properties {
events {
description: "task scalar metric events"
type: array
items {
type: object
}
}
returned {
type: integer
description: "amount of events returned"
}
total {
type: integer
description: "amount of events in task"
}
scroll_id {
type: string
description: "Scroll ID of previous call (used for getting more results)"
}
}
}
}
}
}

View File

@@ -0,0 +1,639 @@
{
_description: """This service provides a management interface for models (results of training tasks) stored in the system."""
_definitions {
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
model {
type: object
properties {
id {
description: "Model id"
type: string
}
name {
description: "Model name"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Model creation time"
type: string
format: "date-time"
}
task {
description: "Task ID of task in which the model was created"
type: string
}
parent {
description: "Parent model ID"
type: string
}
project {
description: "Associated project ID"
type: string
}
comment {
description: "Model comment"
type: string
}
tags {
type: array
description: "Tags"
items { type: string }
}
framework {
description: "Framework on which the model is based. Should be identical to the framework of the task which created the model"
type: string
}
design {
description: "Json object representing the model design. Should be identical to the network design of the task which created the model"
type: object
additionalProperties: true
}
labels {
description: "Json object representing the ids of the labels in the model. The keys are the layers' names and the values are the ids."
type: object
additionalProperties { type: integer }
}
uri {
description: "URI for the model, pointing to the destination storage."
type: string
}
ready {
description: "Indication if the model is final and can be used by other tasks"
type: boolean
}
ui_cache {
description: "UI cache for this model"
type: object
additionalProperties: true
}
}
}
}
get_by_id {
"2.1" {
description: "Gets model information"
request {
type: object
required: [ model ]
properties {
model {
description: "Model id"
type: string
}
}
}
response {
type: object
properties {
model {
description: "Model info"
"$ref": "#/definitions/model"
}
}
}
}
}
get_by_task_id {
"2.1" {
description: "Gets model information"
request {
type: object
properties {
task {
description: "Task id"
type: string
}
}
}
response {
type: object
properties {
model {
description: "Model info"
"$ref": "#/definitions/model"
}
}
}
}
}
get_all_ex {
internal: true
"2.1": ${get_all."2.1"}
}
get_all {
"2.1" {
description: "Get all models"
request {
type: object
properties {
name {
description: "Get only models whose name matches this pattern (python regular expression syntax)"
type: string
}
ready {
description: "Indication whether to retrieve only models that are marked ready If not supplied returns both ready and not-ready projects."
type: boolean
}
tags {
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
type: array
items { type: string }
}
only_fields {
description: "List of model field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of models"
type: integer
minimum: 0
}
page_size {
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
type: integer
minimum: 1
}
project {
description: "List of associated project IDs"
type: array
items { type: string }
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
task {
description: "List of associated task IDs"
type: array
items { type: string }
}
id {
description: "List of model IDs"
type: array
items { type: string }
}
search_text {
description: "Free text search query"
type: string
}
framework {
description: "List of frameworks"
type: array
items { type: string }
}
uri {
description: "List of model URIs"
type: array
items { type: string }
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
_any_ {
description: "Multi-field pattern condition (any field matches pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
}
dependencies {
page: [ page_size ]
}
}
response {
type: object
properties {
models: {
description: "Models list"
type: array
items { "$ref": "#/definitions/model" }
}
}
}
}
}
update_for_task {
"2.1" {
description: "Create or update a new model for a task"
request {
type: object
required: [
task
]
properties {
task {
description: "Task id"
type: string
}
uri {
description: "URI for the model"
type: string
}
name {
description: "Model name Unique within the company."
type: string
}
comment {
description: "Model comment"
type: string
}
tags {
description: "Tags list"
type: array
items { type: string }
}
override_model_id {
description: "Override model ID. If provided, this model is updated in the task."
type: string
}
iteration {
description: "Iteration (used to update task statistics)"
type: integer
}
}
}
response {
type: object
properties {
id {
description: "ID of the model"
type: string
}
created {
description: "Was the model created"
type: boolean
}
updated {
description: "Number of models updated (0 or 1)"
type: integer
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
create {
"2.1" {
description: "Create a new model not associated with a task"
request {
type: object
required: [
uri
name
labels
]
properties {
uri {
description: "URI for the model"
type: string
}
name {
description: "Model name Unique within the company."
type: string
}
comment {
description: "Model comment"
type: string
}
tags {
description: "Tags list"
type: array
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
type: string
}
design {
description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model"
type: object
additionalProperties: true
}
labels {
description: "Json object"
type: object
additionalProperties { type: integer }
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
type: boolean
default: false
}
public {
description: "Create a public model Default is false."
type: boolean
default: false
}
project {
description: "Project to which to model belongs"
type: string
}
parent {
description: "Parent model"
type: string
}
task {
description: "Associated task ID"
type: string
}
}
}
response {
type: object
properties {
id {
description: "ID of the model"
type: string
}
created {
description: "Was the model created"
type: boolean
}
}
}
}
}
edit {
"2.1" {
description: "Edit an existing model"
request {
type: object
required: [
model
]
properties {
model {
description: "Model ID"
type: string
}
uri {
description: "URI for the model"
type: string
}
name {
description: "Model name Unique within the company."
type: string
}
comment {
description: "Model comment"
type: string
}
tags {
description: "Tags list"
type: array
items { type: string }
}
framework {
description: "Framework on which the model is based. Case insensitive. Should be identical to the framework of the task which created the model."
type: string
}
design {
description: "Json[d] object representing the model design. Should be identical to the network design of the task which created the model"
type: object
additionalProperties: true
}
labels {
description: "Json object"
type: object
additionalProperties { type: integer }
}
ready {
description: "Indication if the model is final and can be used by other tasks"
type: boolean
}
project {
description: "Project to which to model belongs"
type: string
}
parent {
description: "Parent model"
type: string
}
task {
description: "Associated task ID"
type: string
}
iteration {
description: "Iteration (used to update task statistics)"
type: integer
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
update {
"2.1" {
description: "Update a model"
request {
type: object
required: [ model ]
properties {
model {
description: "Model id"
type: string
}
name {
description: "Model name Unique within the company."
type: string
}
comment {
description: "Model comment"
type: string
}
tags {
description: "Tags list"
type: array
items { type: string }
}
ready {
description: "Indication if the model is final and can be used by other tasks Default is false."
type: boolean
default: false
}
created {
description: "Model creation time (UTC) "
type: string
format: "date-time"
}
ui_cache {
description: "UI cache for this model"
type: object
additionalProperties: true
}
project {
description: "Project to which to model belongs"
type: string
}
task {
description: "Associated task ID"
type: "string"
}
iteration {
description: "Iteration (used to update task statistics if an associated task is reported)"
type: integer
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
set_ready {
"2.1" {
description: "Set the model ready flag to True. If the model is an output model of a task then try to publish the task."
request {
type: object
required: [ model ]
properties {
model {
description: "Model id"
type: string
}
force_publish_task {
description: "Publish the associated task (if exists) even if it is not in the 'stopped' state. Optional, the default value is False."
type: boolean
}
publish_task {
description: "Indicates that the associated task (if exists) should be published. Optional, the default value is True."
type: boolean
}
}
}
response {
type: object
properties {
updated {
description: "Number of models updated (0 or 1)"
type: integer
enum: [0, 1]
}
published_task {
description: "Result of publishing of the model's associated task (if exists). Returned only if the task was published successfully as part of the model publishing."
type: object
properties {
id {
description: "Task id"
type: string
}
data {
description: "Data returned from the task publishing operation."
type: object
properties {
committed_versions_results {
description: "Committed versions results"
type: array
items {
type: object
additionalProperties: true
}
}
updated {
description: "Number of tasks updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
}
}
}
}
delete {
"2.1" {
description: "Delete a model."
request {
required: [
model
]
type: object
properties {
model {
description: "Model ID"
type: string
}
force {
description: """Force. Required if there are tasks that use the model as an execution model, or if the model's creating task is published.
"""
type: boolean
}
}
}
response {
type: object
properties {
deleted {
description: "Indicates whether the model was deleted"
type: boolean
}
}
}
}
}
}

View File

@@ -0,0 +1,463 @@
{
_description: "Provides support for defining Projects containing Tasks, Models and Dataset Versions."
_definitions {
multi_field_pattern_data {
type: object
properties {
pattern {
description: "Pattern string (regex)"
type: string
}
fields {
description: "List of field names"
type: array
items { type: string }
}
}
}
project_tags_enum {
type: string
enum: [ archived, public, default ]
}
project {
type: object
properties {
id {
description: "Project id"
type: string
}
name {
description: "Project name"
type: string
}
description {
description: "Project description"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Creation time"
type: string
format: "date-time"
}
tags {
description: "Tags"
type: array
items { "$ref": "#/definitions/project_tags_enum" }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
last_update {
description: """Last project update time. Reflects the last time the project metadata was changed or a task in this project has changed status"""
type: string
format: "date-time"
}
}
}
stats_status_count {
type: object
properties {
total_runtime {
description: "Total run time of all tasks in project (in seconds)"
type: integer
}
status_count {
description: "Status counts"
type: object
properties {
created {
description: "Number of 'created' tasks in project"
type: integer
}
queued {
description: "Number of 'queued' tasks in project"
type: integer
}
in_progress {
description: "Number of 'in_progress' tasks in project"
type: integer
}
stopped {
description: "Number of 'stopped' tasks in project"
type: integer
}
published {
description: "Number of 'published' tasks in project"
type: integer
}
closed {
description: "Number of 'closed' tasks in project"
type: integer
}
failed {
description: "Number of 'failed' tasks in project"
type: integer
}
unknown {
description: "Number of 'unknown' tasks in project"
type: integer
}
}
}
}
}
stats {
type: object
properties {
active {
description: "Stats for active tasks"
"$ref": "#/definitions/stats_status_count"
}
archived {
description: "Stats for archived tasks"
"$ref": "#/definitions/stats_status_count"
}
}
}
projects_get_all_response_single {
// copy-paste from project definition
type: object
properties {
id {
description: "Project id"
type: string
}
name {
description: "Project name"
type: string
}
description {
description: "Project description"
type: string
}
user {
description: "Associated user id"
type: string
}
company {
description: "Company id"
type: string
}
created {
description: "Creation time"
type: string
format: "date-time"
}
tags {
description: "Tags"
type: array
items { "$ref": "#/definitions/project_tags_enum" }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
// extra properties
stats: {
description: "Additional project stats"
"$ref": "#/definitions/stats"
}
}
}
metric_variant_result {
type: object
properties {
metric {
description: "Metric name"
type: string
}
metric_hash {
description: """Metric name hash. Used instead of the metric name when categorizing
last metrics events in task objects."""
type: string
}
variant {
description: "Variant name"
type: string
}
variant_hash {
description: """Variant name hash. Used instead of the variant name when categorizing
last metrics events in task objects."""
type: string
}
}
}
}
create {
"2.1" {
description: "Create a new project"
request {
type: object
required :[
name
description
]
properties {
name {
description: "Project name Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
tags {
description: "Tags"
type: array
items { "$ref": "#/definitions/project_tags_enum" }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
}
}
response {
type: object
properties {
id {
description: "Project id"
type: string
}
}
}
}
}
get_by_id {
"2.1" {
description: ""
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
}
}
response {
type: object
properties {
project {
description: "Project info"
"$ref": "#/definitions/project"
}
}
}
}
}
get_all {
"2.1" {
description: "Get all the company's projects and all public projects"
request {
type: object
properties {
id {
description: "List of IDs to filter by"
type: array
items { type: string }
}
name {
description: "Get only projects whose name matches this pattern (python regular expression syntax)"
type: string
}
description {
description: "Get only projects whose description matches this pattern (python regular expression syntax)"
type: string
}
tags {
description: "Tags list used to filter results. Prepend '-' to tag name to indicate exclusion"
type: array
items { type: string }
}
order_by {
description: "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of dataviews"
type: integer
minimum: 0
}
page_size {
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
type: integer
minimum: 1
}
search_text {
description: "Free text search query"
type: string
}
only_fields {
description: "List of document's field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
_all_ {
description: "Multi-field pattern condition (all fields match pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
_any_ {
description: "Multi-field pattern condition (any field matches pattern)"
"$ref": "#/definitions/multi_field_pattern_data"
}
}
}
response {
type: object
properties {
projects {
description: "Projects list"
type: array
items { "$ref": "#/definitions/projects_get_all_response_single" }
}
}
}
}
}
get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
request {
properties {
include_stats {
description: "If true, include project statistic in response."
type: boolean
default: false
}
stats_for_state {
description: "Report stats include only statistics for tasks in the specified state. If Null is provided, stats for all task states will be returned."
type: string
enum: [ active, archived ]
default: active
}
}
}
}
}
update {
"2.1" {
description: "Update project information"
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
name {
description: "Project name. Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
description {
description: "Project description"
type: string
}
tags {
description: "Tags list"
type: array
items { type: string }
}
default_output_destination {
description: "The default output destination URL for new tasks under this project"
type: string
}
}
}
response {
type: object
properties {
updated {
description: "Number of projects updated (0 or 1)"
type: integer
enum: [ 0, 1 ]
}
fields {
description: "Updated fields names and values"
type: object
additionalProperties: true
}
}
}
}
}
delete {
"2.1" {
description: "Deletes a project"
request {
type: object
required: [ project ]
properties {
project {
description: "Project id"
type: string
}
force {
description: """If not true, fails if project has tasks.
If true, and project has tasks, they will be unassigned"""
type: boolean
default: false
}
}
}
response {
type: object
properties {
deleted {
description: "Number of projects deleted (0 or 1)"
type: integer
}
disassociated_tasks {
description: "Number of tasks disassociated from the deleted project"
type: integer
}
}
}
}
}
get_unique_metric_variants {
"2.1" {
description: """Get all metric/variant pairs reported for tasks in a specific project.
If no project is specified, metrics/variant paris reported for all tasks will be returned.
If the project does not exist, an empty list will be returned."""
request {
type: object
properties {
project {
description: "Project ID"
type: string
}
}
}
response {
type: object
properties {
metrics {
description: "A list of metric variants reported for tasks in this project"
type: array
items { "$ref": "#/definitions/metric_variant_result" }
}
}
}
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,366 @@
_description: """This service provides a management interface to users information
and new users login restrictions."""
_default {
internal: true
}
_definitions {
user {
type: object
properties {
id {
description: User ID
type: string
}
name {
description: Full name
type: string
}
given_name {
description: Given name
type: string
}
family_name {
description: Family name
type: string
}
avatar {
description: Avatar URL
type: string
}
company {
description: Company ID
type: string
}
# Admin only fields
role {
description: """User's role (admin only)"""
type: string
}
providers {
description: """Providers uses has logged-in with"""
type: object
additionalProperties: true
}
created {
description: User creation date
type: string
format: date-time
}
email {
description: User email
type: string
format: email
}
}
}
get_current_user_response_user_object {
type: object
description: "like user, but returns company object instead of ID"
properties {
id {
type: string
}
name {
type: string
}
given_name {
type: string
}
family_name {
type: string
}
role {
type: string
}
avatar {
type: string
}
company {
type: object
properties {
id {
type: string
}
name {
type: string
}
}
}
preferences {
description: User preferences
type: object
additionalProperties: true
}
}
}
}
get_by_id {
internal: false
"2.1" {
description: Gets user information
request {
type: object
required: [ user ]
properties {
user {
description: User ID
type: string
}
}
}
response {
type: object
properties {
user {
description: User info
"$ref": "#/definitions/user"
}
}
}
}
}
get_current_user {
internal: false
"2.1" {
description: """Gets current user information, based on the authenticated user making the call."""
request {
type: object
additionalProperties: false
}
response {
type: object
properties {
user {
description: "User info"
"$ref": "#/definitions/get_current_user_response_user_object"
}
}
}
}
}
get_all_ex {
internal: true
"2.1": ${get_all."2.1"} {
}
}
get_all {
"2.1" {
description: Get all user objects
request {
type: object
properties {
name {
description: "Get only users whose name matches this pattern (python regular expression syntax)"
type: string
}
id {
description: "List of user IDs used to filter results"
type: array
items { type: string }
}
only_fields {
description: "List of user field names (if applicable, nesting is supported using '.'). If provided, this list defines the query's projection (only these fields will be returned for each result entry)"
type: array
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of users"
type: integer
minimum: 0
}
page_size {
description: "Page size, specifies the number of results returned in each page (last page may contain fewer results)"
type: integer
minimum: 1
}
order_by {
description: "List of field names to order by. Use '-' prefix to specify descending order. Optional, recommended when using page"
type: array
items { type: string }
}
}
}
response {
type: object
properties {
users {
description: User list
type: array
items { "$ref": "#/definitions/user" }
}
}
}
}
}
delete {
internal: true
allow_roles: [ "system", "root" ]
"2.1" {
description: Delete user
description: Delete a user
request {
type: object
required: [ user ]
properties {
user {
description: ID of user to delete
type: string
}
}
}
response {
type: object
additionalProperties: false
}
}
}
create {
allow_roles: [ "system", "root" ]
"2.1" {
description: Create a new user object. Reserved for internal use.
request {
type: object
required: [
company
id
name
]
properties {
id {
description: User ID
type: string
}
company {
description: Company ID
type: string
}
name {
description: Full name
type: string
}
given_name {
description: Given name
type: string
}
family_name {
description: Family name
type: string
}
avatar {
description: Avatar URL
type: string
}
}
}
response {
type: object
properties {}
additionalProperties: false
}
}
}
update {
internal: false
"2.1" {
description: Update a user object
request {
type: object
required: [ user ]
properties {
user {
description: User ID
type: string
}
name {
description: Full name
type: string
}
given_name {
description: Given name
type: string
}
family_name {
description: Family name
type: string
}
avatar {
description: Avatar URL
type: string
}
}
}
response {
type: object
properties {
updated {
description: Number of updated user objects (0 or 1)
type: integer
}
fields {
description: Updated fields names and values
type: object
additionalProperties: true
}
}
}
}
}
get_preferences {
internal: false
"2.1" {
description: Get user preferences
request {
type: object
properties {}
}
response {
type: object
properties {
preferences {
type: object
additionalProperties: true
}
}
}
}
}
set_preferences {
internal: false
"2.1" {
description: Set user preferences
request {
type: object
required: [ preferences ]
properties {
preferences {
description: """Updates to user preferences. A mapping from keys in dot notation to values.
For example, `{"a.b": 0}` will set the key "b" in object "a" to 0."""
type: object
additionalProperties: true
}
}
}
response {
type: object
properties {
updated {
description: Number of updated user objects (0 or 1)
type: integer
}
fields {
description: Updated fields names and values
type: object
additionalProperties: true
}
}
}
}
}

184
server/server.py Normal file
View File

@@ -0,0 +1,184 @@
from argparse import ArgumentParser
from flask import Flask, request, Response
from flask_compress import Compress
from flask_cors import CORS
from werkzeug.exceptions import BadRequest
import database
from apierrors.base import BaseError
from config import config
from service_repo import ServiceRepo, APICall
from service_repo.auth import AuthType
from service_repo.errors import PathParsingError
from timing_context import TimingContext
from utilities import json
from init_data import init_es_data, init_mongo_data
app = Flask(__name__, static_url_path="/static")
CORS(app, supports_credentials=True, **config.get("apiserver.cors"))
Compress(app)
log = config.logger(__file__)
log.info("################ API Server initializing #####################")
app.config["SECRET_KEY"] = config.get("secure.http.session_secret.apiserver")
app.config["JSONIFY_PRETTYPRINT_REGULAR"] = config.get("apiserver.pretty_json")
database.initialize()
init_es_data()
init_mongo_data()
ServiceRepo.load("services")
log.info(f"Exposed Services: {' '.join(ServiceRepo.endpoint_names())}")
@app.before_first_request
def before_app_first_request():
pass
@app.before_request
def before_request():
if request.method == "OPTIONS":
return "", 200
if "/static/" in request.path:
return
try:
call = create_api_call(request)
content, content_type = ServiceRepo.handle_call(call)
headers = None
if call.result.filename:
headers = {
"Content-Disposition": f"attachment; filename={call.result.filename}"
}
return Response(
content, mimetype=content_type, status=call.result.code, headers=headers
)
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
def update_call_data(call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
items = []
for i, line in enumerate(req.data.splitlines()):
try:
event = json.loads(line)
if not isinstance(event, dict):
raise BadRequest(
f"json lines must contain objects, found: {type(event).__name__}"
)
items.append(event)
except ValueError as e:
msg = f"{e} in batch item #{i}"
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
# NOTE: dict() form data to make sure we won't pass along a MultiDict or some other nasty crap
call.data = json_body or form or {}
def _call_or_empty_with_error(call, req, msg, code=500, subcode=0):
call = call or APICall(
"", remote_addr=req.remote_addr, headers=dict(req.headers), files=req.files
)
call.set_error_result(msg=msg, code=code, subcode=subcode)
return call
def create_api_call(req):
call = None
try:
# Parse the request path
endpoint_version, endpoint_name = ServiceRepo.parse_endpoint_path(req.path)
# Resolve authorization: if cookies contain an authorization token, use it as a starting point.
# in any case, request headers always take precedence.
auth_cookie = req.cookies.get(
config.get("apiserver.auth.session_auth_cookie_name")
)
headers = (
{}
if not auth_cookie
else {"Authorization": f"{AuthType.bearer_token} {auth_cookie}"}
)
headers.update(
list(req.headers.items())
) # add (possibly override with) the headers
# Construct call instance
call = APICall(
endpoint_name=endpoint_name,
remote_addr=req.remote_addr,
endpoint_version=endpoint_version,
headers=headers,
files=req.files,
)
# Update call data from request
with TimingContext("preprocess", "update_call_data"):
update_call_data(call, req)
except PathParsingError as ex:
call = _call_or_empty_with_error(call, req, ex.args[0], 400)
call.log_api = False
except BadRequest as ex:
call = _call_or_empty_with_error(call, req, ex.description, 400)
except BaseError as ex:
call = _call_or_empty_with_error(call, req, ex.msg, ex.code, ex.subcode)
except Exception as ex:
log.exception("Error creating call")
call = _call_or_empty_with_error(
call, req, ex.args[0] if ex.args else type(ex).__name__, 500
)
return call
# =================== MAIN =======================
if __name__ == "__main__":
p = ArgumentParser(description=__doc__)
p.add_argument(
"--port", "-p", type=int, default=config.get("apiserver.listen.port")
)
p.add_argument("--ip", "-i", type=str, default=config.get("apiserver.listen.ip"))
p.add_argument(
"--debug", action="store_true", default=config.get("apiserver.debug")
)
p.add_argument(
"--watch", action="store_true", default=config.get("apiserver.watch")
)
args = p.parse_args()
# logging.info("Starting API Server at %s:%s and env '%s'" % (args.ip, args.port, config.env))
app.run(
debug=args.debug,
host=args.ip,
port=args.port,
threaded=True,
use_reloader=args.watch,
)

View File

@@ -0,0 +1,52 @@
from typing import Text, Sequence, Callable, Union
from funcsigs import signature
from jsonmodels import models
from .apicall import APICall, APICallResult
from .endpoint import EndpointFunc, Endpoint
from .service_repo import ServiceRepo
__all__ = ["endpoint"]
LegacyEndpointFunc = Callable[[APICall], None]
def endpoint(
name: Text,
min_version: Text = "1.0",
required_fields: Sequence[Text] = None,
request_data_model: models.Base = None,
response_data_model: models.Base = None,
validate_schema=False,
):
""" Endpoint decorator, used to declare a method as an endpoint handler """
def decorator(f: Union[EndpointFunc, LegacyEndpointFunc]) -> EndpointFunc:
# Backwards compatibility: support endpoints with both old-style signature (call) and new-style signature
# (call, company, request_model)
func = f
if len(signature(f).parameters) == 1:
# old-style
def adapter(call, *_, **__):
return f(call)
func = adapter
ServiceRepo.register(
Endpoint(
name=name,
func=func,
min_version=min_version,
required_fields=required_fields,
request_data_model=request_data_model,
response_data_model=response_data_model,
validate_schema=validate_schema,
)
)
return func
return decorator

View File

@@ -0,0 +1,562 @@
import time
import types
from traceback import format_exc
from typing import Type, Optional
from jsonmodels import models
from six import string_types
import database
import timing_context
from timing_context import TimingContext
from utilities import json
from .auth import Identity
from .auth import Payload as AuthPayload
from .base import PartialVersion
from .errors import CallParsingError
from .schema_validator import SchemaValidator
JSON_CONTENT_TYPE = "application/json"
class DataContainer(object):
""" Data container that supports raw data (dict or a list of batched dicts) and a data model """
def __init__(self, data=None, batched_data=None):
if data and batched_data:
raise ValueError("data and batched data are not supported simultaneously")
self._batched_data = None
self._data = None
self._data_model = None
self._data_model_cls = None
self._schema_validator: SchemaValidator = SchemaValidator(None)
# use setter to properly initialize data
self.data = data
self.batched_data = batched_data
self._raw_data = None
self._content_type = JSON_CONTENT_TYPE
@property
def schema_validator(self):
return self._schema_validator
@schema_validator.setter
def schema_validator(self, value):
self._schema_validator = value
self._update_data_model()
@property
def data(self):
return self._data or {}
@data.setter
def data(self, value):
""" Set the data using a raw dict. If a model cls is defined, validate the raw data """
if value is not None:
assert isinstance(value, dict), "Data should be a dict"
self._data = value
self._update_data_model()
@property
def batched_data(self):
if self._batched_data is not None:
return self._batched_data
elif self.data != {}:
return [self.data]
else:
return []
@batched_data.setter
def batched_data(self, value):
if not value:
return
assert isinstance(value, (tuple, list)), "Batched data should be a list"
self._batched_data = value
self._update_data_model()
@property
def raw_data(self):
return self._raw_data
@raw_data.setter
def raw_data(self, value):
assert isinstance(
value, string_types + (types.GeneratorType,)
), "Raw data must be a string type or generator"
self._raw_data = value
@property
def content_type(self):
return self._content_type
@content_type.setter
def content_type(self, value):
self._content_type = value
def _update_data_model(self):
self.schema_validator.detailed_validate(self._data)
cls = self.data_model_cls
if not cls or (self._data is None and self._batched_data is None):
return
# handle batched items
if self._batched_data:
try:
data_model = [cls(**item) for item in self._batched_data]
except TypeError as ex:
raise CallParsingError(str(ex))
for m in data_model:
m.validate()
else:
try:
data_model = cls(**self.data)
except TypeError as ex:
raise CallParsingError(str(ex))
if not self.schema_validator.enabled:
data_model.validate()
self._data_model = data_model
@property
def data_model(self):
return self._data_model
# @property
# def get_partial_update(self, data_model_class):
# return {k: v for k, v in self.data_model.to_struct().iteritems() if k in self.data}
@property
def data_model_for_partial_update(self):
"""
Return only data model fields that we actually passed by the user
:return:
"""
return {k: v for k, v in self.data_model.to_struct().items() if k in self.data}
@data_model.setter
def data_model(self, value):
""" Set the data using a model instance. NOTE: batched_data is never updated. """
cls = self.data_model_cls
if not cls:
raise ValueError("Data model is not defined")
if isinstance(value, cls):
# instance of the data model class - just take it
self._data_model = value
elif issubclass(cls, type(value)):
# instance of a subclass of the data model class - create the expected class instance and use the instance
# we received to initialize it
self._data_model = cls(**value.to_struct())
else:
raise ValueError(f"Invalid data model (expecting {cls} or super classes)")
self._data = value.to_struct()
@property
def data_model_cls(self) -> Optional[Type[models.Base]]:
return self._data_model_cls
@data_model_cls.setter
def data_model_cls(self, value: Type[models.Base]):
assert issubclass(value, models.Base)
self._data_model_cls = value
self._update_data_model()
class APICallResult(DataContainer):
def __init__(self, data=None, code=200, subcode=0, msg="OK", traceback=""):
super(APICallResult, self).__init__(data)
self._code = code
self._subcode = subcode
self._msg = msg
self._traceback = traceback
self._extra = None
self._filename = None
def get_log_entry(self):
res = dict(
msg=self.msg,
code=self.code,
subcode=self.subcode,
traceback=self._traceback,
extra=self._extra,
)
if self.log_data:
res["data"] = self.data
return res
def copy_from(self, result):
self._code = result.code
self._subcode = result.subcode
self._msg = result.code
self._traceback = result.traceback
self._extra = result.extra_log
@property
def msg(self):
return self._msg
@msg.setter
def msg(self, value):
self._msg = value
@property
def code(self):
return self._code
@code.setter
def code(self, value):
self._code = value
@property
def subcode(self):
return self._subcode
@subcode.setter
def subcode(self, value):
self._subcode = value
@property
def traceback(self):
return self._traceback
@traceback.setter
def traceback(self, value):
self._traceback = value
@property
def extra_log(self):
""" Extra data to be logged into ES """
return self._extra
@extra_log.setter
def extra_log(self, value):
self._extra = value
@property
def filename(self):
return self._filename
@filename.setter
def filename(self, value):
self._filename = value
class MissingIdentity(Exception):
pass
class APICall(DataContainer):
HEADER_AUTHORIZATION = "Authorization"
HEADER_REAL_IP = "X-Real-IP"
""" Standard headers """
_transaction_headers = ("X-Trains-Trx",)
""" Transaction ID """
@property
def HEADER_TRANSACTION(self):
return self._transaction_headers[0]
_worker_headers = ("X-Trains-Worker",)
""" Worker (machine) ID """
@property
def HEADER_WORKER(self):
return self._worker_headers[0]
_impersonate_as_headers = ("X-Trains-Impersonate-As",)
""" Impersonate as someone else (using his identity and permissions) """
@property
def HEADER_IMPERSONATE_AS(self):
return self._impersonate_as_headers[0]
_act_as_headers = ("X-Trains-Act-As",)
""" Act as someone else (using his identity, but with your own role and permissions) """
@property
def HEADER_ACT_AS(self):
return self._act_as_headers[0]
_async_headers = ("X-Trains-Async",)
""" Specifies that this call should be done asynchronously """
@property
def HEADER_ASYNC(self):
return self._async_headers[0]
def __init__(
self,
endpoint_name,
remote_addr=None,
endpoint_version: PartialVersion = PartialVersion("1.0"),
data=None,
batched_data=None,
headers=None,
files=None,
trx=None,
):
super(APICall, self).__init__(data=data, batched_data=batched_data)
timing_context.clear()
self._id = database.utils.id()
self._files = files # currently dic of key to flask's FileStorage)
self._start_ts = time.time()
self._end_ts = 0
self._duration = 0
self._endpoint_name = endpoint_name
self._remote_addr = remote_addr
assert isinstance(endpoint_version, PartialVersion), endpoint_version
self._requested_endpoint_version = endpoint_version
self._actual_endpoint_version = None
self._headers = {}
self._kpis = {}
self._log_api = True
if headers:
self._headers.update(headers)
self._result = APICallResult()
self._auth = None
self._impersonation = None
if trx:
self.set_header(self._transaction_headers, trx)
self._requires_authorization = True
@property
def id(self):
return self._id
@property
def requires_authorization(self):
return self._requires_authorization
@requires_authorization.setter
def requires_authorization(self, value):
self._requires_authorization = value
@property
def log_api(self):
return self._log_api
@log_api.setter
def log_api(self, value):
self._log_api = value
def assign_new_id(self):
self._id = database.utils.id()
def get_header(self, header, default=None):
"""
Get header value
:param header: Header name options (more than on supported, listed by priority)
:param default: Default value if no such headers were found
"""
for option in header if isinstance(header, (tuple, list)) else (header,):
if option in self._headers:
return self._headers[option]
return default
def clear_header(self, header):
"""
Clear header value
:param header: Header name options (more than on supported, all will be cleared)
"""
for value in header if isinstance(header, (tuple, list)) else (header,):
self.headers.pop(value, None)
def set_header(self, header, value):
"""
Set header value
:param header: header name (if a list is provided, first item is used)
:param value: Value to set
:return:
"""
self._headers[
header[0] if isinstance(header, (tuple, list)) else header
] = value
@property
def real_ip(self):
real_ip = self.get_header(self.HEADER_REAL_IP)
return real_ip or self._remote_addr or "untrackable"
@property
def failed(self):
return self.result and self.result.code != 200
@property
def duration(self):
return self._duration
@property
def endpoint_name(self):
return self._endpoint_name
@property
def requested_endpoint_version(self) -> PartialVersion:
return self._requested_endpoint_version
@property
def auth(self):
""" Authenticated payload (Token or Basic) """
return self._auth
@auth.setter
def auth(self, value):
if value:
assert isinstance(value, AuthPayload)
self._auth = value
@property
def impersonation_headers(self):
return {
k: v
for k, v in self._headers.items()
if k in (self._impersonate_as_headers + self._act_as_headers)
}
@property
def impersonate_as(self):
return self.get_header(self._impersonate_as_headers)
@property
def act_as(self):
return self.get_header(self._act_as_headers)
@property
def impersonation(self):
return self._impersonation
@impersonation.setter
def impersonation(self, value):
if value:
assert isinstance(value, AuthPayload)
self._impersonation = value
@property
def identity(self) -> Identity:
if self.impersonation:
if not self.impersonation.identity:
raise Exception("Missing impersonate identity")
return self.impersonation.identity
if self.auth:
if not self.auth.identity:
raise Exception("Missing authorized identity (not authorized?)")
return self.auth.identity
raise MissingIdentity("Missing identity")
@property
def actual_endpoint_version(self):
return self._actual_endpoint_version
@actual_endpoint_version.setter
def actual_endpoint_version(self, value):
self._actual_endpoint_version = value
@property
def headers(self):
return self._headers
@property
def kpis(self):
"""
Key Performance Indicators, holding things like number of returned frames/rois, etc.
:return:
"""
return self._kpis
@property
def trx(self):
return self.get_header(self._transaction_headers, self.id)
@trx.setter
def trx(self, value):
self.set_header(self._transaction_headers, value)
@property
def worker(self):
return self.get_header(self._worker_headers, "<unknown>")
@property
def authorization(self):
""" Call authorization data used to authenticate the call """
return self.get_header(self.HEADER_AUTHORIZATION)
@property
def result(self):
return self._result
@property
def exec_async(self):
return self.get_header(self._async_headers) is not None
@exec_async.setter
def exec_async(self, value):
if value:
self.set_header(self._async_headers, "1")
else:
self.clear_header(self._async_headers)
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
self.stats = timing_context.stats()
def get_response(self):
def make_version_number(version):
"""
Client versions <=2.0 expect expect endpoint versions in float format, otherwise throwing an exception
"""
if version is None:
return None
if self.requested_endpoint_version < PartialVersion("2.1"):
return float(str(version))
return str(version)
if self.result.raw_data and not self.failed:
# endpoint returned raw data and no error was detected, return raw data, no fancy dicts
return self.result.raw_data, self.result.content_type
else:
res = {
"meta": {
"id": self.id,
"trx": self.trx,
"endpoint": {
"name": self.endpoint_name,
"requested_version": make_version_number(
self.requested_endpoint_version
),
"actual_version": make_version_number(
self.actual_endpoint_version
),
},
"result_code": self.result.code,
"result_subcode": self.result.subcode,
"result_msg": self.result.msg,
"error_stack": self.result.traceback,
},
"data": self.result.data,
}
if self.content_type.lower() == JSON_CONTENT_TYPE:
with TimingContext("json", "serialization"):
try:
res = json.dumps(res)
except Exception as ex:
# JSON serialization may fail, probably problem with data so pop it and try again
if not self.result.data:
raise
self.result.data = None
msg = "Error serializing response data: " + str(ex)
self.set_error_result(
code=500, subcode=0, msg=msg, include_stack=False
)
return self.get_response()
return res, self.content_type
def set_error_result(self, msg, code=500, subcode=0, include_stack=False):
tb = format_exc() if include_stack else None
self._result = APICallResult(
data=self._result.data, code=code, subcode=subcode, msg=msg, traceback=tb
)

View File

@@ -0,0 +1,4 @@
from .auth import get_auth_func, authorize_impersonation
from .payload import Token, Basic, AuthType, Payload
from .identity import Identity
from .utils import get_client_id, get_secret_key

View File

@@ -0,0 +1,86 @@
import base64
import jwt
from database.errors import translate_errors_context
from database.model.company import Company
from database.utils import get_options
from database.model.auth import User, Entities, Credentials
from apierrors import errors
from config import config
from timing_context import TimingContext
from .payload import Payload, Token, Basic, AuthType
from .identity import Identity
log = config.logger(__file__)
entity_keys = set(get_options(Entities))
verify_user_tokens = config.get("apiserver.auth.verify_user_tokens", True)
def get_auth_func(auth_type):
if auth_type == AuthType.bearer_token:
return authorize_token
elif auth_type == AuthType.basic:
return authorize_credentials
raise errors.unauthorized.BadAuthType()
def authorize_token(jwt_token, *_, **__):
""" Validate token against service/endpoint and requests data (dicts).
Returns a parsed token object (auth payload)
"""
try:
return Token.from_encoded_token(jwt_token)
except jwt.exceptions.InvalidKeyError as ex:
raise errors.unauthorized.InvalidToken('jwt invalid key error', reason=ex.args[0])
except jwt.InvalidTokenError as ex:
raise errors.unauthorized.InvalidToken('invalid jwt token', reason=ex.args[0])
except ValueError as ex:
log.exception('Failed while processing token: %s' % ex.args[0])
raise errors.unauthorized.InvalidToken('failed processing token', reason=ex.args[0])
def authorize_credentials(auth_data, service, action, call_data_items):
""" Validate credentials against service/action and request data (dicts).
Returns a new basic object (auth payload)
"""
try:
access_key, _, secret_key = base64.b64decode(auth_data.encode()).decode('latin-1').partition(':')
except Exception as e:
log.exception('malformed credentials')
raise errors.unauthorized.BadCredentials(str(e))
with TimingContext("mongo", "user_by_cred"), translate_errors_context('authorizing request'):
user = User.objects(credentials__match=Credentials(key=access_key, secret=secret_key)).first()
if not user:
raise errors.unauthorized.InvalidCredentials('failed to locate provided credentials')
with TimingContext("mongo", "company_by_id"):
company = Company.objects(id=user.company).only('id', 'name').first()
if not company:
raise errors.unauthorized.InvalidCredentials('invalid user company')
identity = Identity(user=user.id, company=user.company, role=user.role,
user_name=user.name, company_name=company.name)
basic = Basic(user_key=access_key, identity=identity)
return basic
def authorize_impersonation(user, identity, service, action, call_data_items):
""" Returns a new basic object (auth payload)"""
if not user:
raise ValueError('missing user')
company = Company.objects(id=user.company).only('id', 'name').first()
if not company:
raise errors.unauthorized.InvalidCredentials('invalid user company')
return Payload(auth_type=None, identity=identity)

View File

@@ -0,0 +1,24 @@
class Dictable(object):
_cached_props = None
@classmethod
def _get_cached_props(cls):
if cls._cached_props is None:
props = set()
for c in cls.mro():
props.update(k for k, v in vars(c).items() if isinstance(v, property) and not k.startswith('_'))
cls._cached_props = list(props)
return cls._cached_props
def to_dict(self, **extra):
props = self._get_cached_props()
d = {k: getattr(self, k) for k in props if getattr(self, k)}
res = {k: (v.to_dict() if isinstance(v, Dictable) else v) for k, v in d.items()}
if extra:
# add the extra items to our result, make sure not to overwrite existing properties (claims etc)
res.update({k: v for k, v in extra.items() if k not in props})
return res
@classmethod
def from_dict(cls, d):
return cls(**d)

View File

@@ -0,0 +1,30 @@
from .dictable import Dictable
class Identity(Dictable):
def __init__(self, user, company, role, user_name=None, company_name=None):
self._user = user
self._company = company
self._role = role
self._user_name = user_name
self._company_name = company_name
@property
def user(self):
return self._user
@property
def company(self):
return self._company
@property
def role(self):
return self._role
@property
def user_name(self):
return self._user_name
@property
def company_name(self):
return self._company_name

View File

@@ -0,0 +1,4 @@
from .auth_type import AuthType
from .payload import Payload
from .basic import Basic
from .token import Token

View File

@@ -0,0 +1,3 @@
class AuthType(object):
basic = 'Basic'
bearer_token = 'Bearer'

View File

@@ -0,0 +1,18 @@
from .payload import Payload
from .auth_type import AuthType
class Basic(Payload):
def __init__(self, user_key, identity=None, entities=None, **_):
super(Basic, self).__init__(
AuthType.basic, identity=identity, entities=entities)
self._user_key = user_key
@property
def user_key(self):
return self._user_key
def get_log_entry(self):
d = super(Basic, self).get_log_entry()
d.update(user_key=self.user_key)
return d

View File

@@ -0,0 +1,52 @@
from apierrors import errors
from ..identity import Identity
from ..dictable import Dictable
class Payload(Dictable):
def __init__(self, auth_type, identity=None, entities=None):
self._auth_type = auth_type
self.identity = identity
self.entities = entities or {}
@property
def auth_type(self):
return self._auth_type
@property
def identity(self):
return self._identity
@identity.setter
def identity(self, value):
if isinstance(value, dict):
value = Identity(**value)
else:
assert isinstance(value, Identity)
self._identity = value
@property
def entities(self):
return self._entities
@entities.setter
def entities(self, value):
self._entities = value
def get_log_entry(self):
return {
"type": self.auth_type,
"identity": self.identity.to_dict(),
"entities": self.entities,
}
def validate_entities(self, **entities):
""" Validate entities. key/value represents entity_name/entity_id(s) """
if not self.entities:
return
for entity_name, entity_id in entities.items():
constraints = self.entities.get(entity_name)
ids = set(entity_id if isinstance(entity_id, (tuple, list)) else (entity_id,))
if constraints and not ids.issubset(constraints):
raise errors.unauthorized.EntityNotAllowed(entity=entity_name)

View File

@@ -0,0 +1,96 @@
import jwt
from datetime import datetime, timedelta
from apierrors import errors
from config import config
from database.model.auth import Role
from .auth_type import AuthType
from .payload import Payload
token_secret = config.get('secure.auth.token_secret')
class Token(Payload):
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
def __init__(self, exp=None, iat=None, nbf=None, env=None, identity=None, entities=None, **_):
super(Token, self).__init__(
AuthType.bearer_token, identity=identity, entities=entities)
self.exp = exp
self.iat = iat
self.nbf = nbf
self._env = env or config.get('env', '<unknown>')
@property
def env(self):
return self._env
@property
def exp(self):
return self._exp
@exp.setter
def exp(self, value):
self._exp = value
@property
def iat(self):
return self._iat
@iat.setter
def iat(self, value):
self._iat = value
@property
def nbf(self):
return self._nbf
@nbf.setter
def nbf(self, value):
self._nbf = value
def get_log_entry(self):
d = super(Token, self).get_log_entry()
d.update(iat=self.iat, exp=self.exp, env=self.env)
return d
def encode(self, **extra_payload):
payload = self.to_dict(**extra_payload)
return jwt.encode(payload, token_secret)
@classmethod
def decode(cls, encoded_token, verify=True):
return jwt.decode(encoded_token, token_secret, verify=verify)
@classmethod
def from_encoded_token(cls, encoded_token, verify=True):
decoded = cls.decode(encoded_token, verify=verify)
try:
token = Token.from_dict(decoded)
assert isinstance(token, Token)
if not token.identity:
raise errors.unauthorized.InvalidToken('token missing identity')
return token
except Exception as e:
raise errors.unauthorized.InvalidToken('failed parsing token, %s' % e.args[0])
@classmethod
def create_encoded_token(cls, identity, expiration_sec=None, entities=None, **extra_payload):
if identity.role not in (Role.system,):
# limit expiration time for all roles but an internal service
expiration_sec = expiration_sec or cls.default_expiration_sec
now = datetime.utcnow()
token = cls(
identity=identity,
entities=entities,
iat=now)
if expiration_sec:
# add 'expiration' claim
token.exp = now + timedelta(seconds=expiration_sec)
return token.encode(**extra_payload)

View File

@@ -0,0 +1,35 @@
import random
sys_random = random.SystemRandom()
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
"""
Returns a securely generated random string.
The default length of 12 with the a-z, A-Z, 0-9 character set returns
a 71-bit value. log_2((26+26+10)^12) =~ 71 bits.
Taken from the django.utils.crypto module.
"""
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
def get_client_id(length=20):
"""
Create a random secret key.
Taken from the Django project.
"""
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
return get_random_string(length, chars)
def get_secret_key(length=50):
"""
Create a random secret key.
Taken from the Django project.
"""
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*(-_=+)'
return get_random_string(length, chars)

View File

@@ -0,0 +1,7 @@
from semantic_version import Version
class PartialVersion(Version):
def __init__(self, version_string: str):
assert isinstance(version_string, str)
super().__init__(version_string, partial=True)

View File

@@ -0,0 +1,118 @@
from typing import Callable, Sequence, Text
from jsonmodels import models
from jsonmodels.errors import FieldNotSupported
from schema import schema
from .apicall import APICall
from .base import PartialVersion
from .schema_validator import SchemaValidator
EndpointFunc = Callable[[APICall, Text, models.Base], None]
class Endpoint(object):
_endpoint_config_cache = {}
"""
Endpoints configuration cache, in the format of {full endpoint name: dict}
"""
def __init__(
self,
name: Text,
func: EndpointFunc,
min_version: Text = "1.0",
required_fields: Sequence[Text] = None,
request_data_model: models.Base = None,
response_data_model: models.Base = None,
validate_schema: bool = False,
):
"""
Endpoint configuration
:param name: full endpoint name
:param func: endpoint implementation
:param min_version: minimum supported version
:param required_fields: required request fields, can not be used with validate_schema
:param request_data_model: request jsonschema model, will be validated if validate_schema=False
:param response_data_model: response jsonschema model, will be validated if validate_schema=False
:param validate_schema: whether request and response schema should be validated
"""
super(Endpoint, self).__init__()
self.name = name
self.min_version = PartialVersion(min_version)
self.func = func
self.required_fields = required_fields
self.request_data_model = request_data_model
self.response_data_model = response_data_model
service, _, endpoint_name = self.name.partition(".")
try:
self.endpoint_group = schema.services[service].endpoint_groups[
endpoint_name
]
except KeyError:
raise RuntimeError(
f"schema for endpoint {service}.{endpoint_name} not found"
)
if validate_schema:
if self.required_fields:
raise ValueError(
f"endpoint {self.name}: can not use 'required_fields' with 'validate_schema'"
)
endpoint = self.endpoint_group.get_for_version(self.min_version)
request_schema = endpoint.request_schema
response_schema = endpoint.response_schema
else:
request_schema = None
response_schema = None
self.request_schema_validator = SchemaValidator(request_schema)
self.response_schema_validator = SchemaValidator(response_schema)
def __repr__(self):
return f"{type(self).__name__}<{self.name}>"
def to_dict(self):
"""
Used by `server.endpoints` endpoint.
Provided endpoints and their schemas on a best-effort basis.
"""
d = {
"min_version": self.min_version,
"required_fields": self.required_fields,
"request_data_model": None,
"response_data_model": None,
}
def safe_to_json_schema(data_model: models.Base):
"""
Provided data_model schema if available
"""
try:
return data_model.to_json_schema()
except (FieldNotSupported, TypeError):
return str(data_model.__name__)
if self.request_data_model:
d["request_data_model"] = safe_to_json_schema(self.request_data_model)
if self.response_data_model:
d["response_data_model"] = safe_to_json_schema(self.response_data_model)
if self.request_schema_validator.enabled:
d["request_schema"] = self.request_schema_validator.schema
if self.response_schema_validator.enabled:
d["response_schema"] = self.response_schema_validator.schema
return d
@property
def authorize(self):
return self.endpoint_group.authorize
@property
def allow_roles(self):
return self.endpoint_group.allow_roles
def allows(self, role):
return self.endpoint_group.allows(role)
@property
def is_internal(self):
return self.endpoint_group.internal

View File

@@ -0,0 +1,19 @@
class PathParsingError(Exception):
def __init__(self, msg):
super(PathParsingError, self).__init__(msg)
class MalformedPathError(PathParsingError):
pass
class InvalidVersionError(PathParsingError):
pass
class CallParsingError(Exception):
pass
class CallFailedError(Exception):
pass

View File

@@ -0,0 +1,72 @@
import sys
from typing import Optional, Callable
import attr
import fastjsonschema
import jsonschema
from boltons.iterutils import remap
from apierrors import errors
from config import config
log = config.logger(__file__)
@attr.s(auto_attribs=True, auto_exc=True)
class FastValidationError(Exception):
error: fastjsonschema.JsonSchemaException
data: dict
class SchemaValidator:
def __init__(self, schema: Optional[dict]):
"""
Utility for different schema validation strategies
:param schema: jsonschema to validate against
"""
self.schema = schema
self.validator: Callable = schema and fastjsonschema.compile(schema)
@property
def enabled(self) -> bool:
return self.schema is not None
def fast_validate(self, data: dict) -> None:
"""
Perform a quick validate with laconic error messages
:param data: data to validate
:raises: fastjsonschema.JsonSchemaException
"""
if self.enabled and data is not None:
data = remap(data, lambda path, key, value: value is not None)
try:
self.validator(data)
except fastjsonschema.JsonSchemaException as e:
raise FastValidationError(e, data) from e
def detailed_validate(self, data: dict) -> None:
"""
Perform a slow validate with detailed error messages
:param data: data to validate
:raises: errors.bad_request.ValidationError
"""
try:
self.fast_validate(data)
except FastValidationError as error:
_, _, traceback = sys.exc_info()
try:
jsonschema.validate(error.data, self.schema)
except jsonschema.exceptions.ValidationError as detailed_error:
raise errors.bad_request.ValidationError(
message=detailed_error.message,
path=list(detailed_error.path),
context=detailed_error.context,
cause=detailed_error.cause,
validator=detailed_error.validator,
validator_value=detailed_error.validator_value,
instance=detailed_error.instance,
parent=detailed_error.parent,
)
else:
log.error("fast validation failed while detailed validation succeeded")
raise error.error.with_traceback(traceback)

View File

@@ -0,0 +1,289 @@
import re
from importlib import import_module
from itertools import chain
from typing import cast, Iterable, List, MutableMapping
import jsonmodels.models
from pathlib import Path
import timing_context
from apierrors import APIError
from apierrors.errors.bad_request import RequestPathHasInvalidVersion
from config import config
from service_repo.base import PartialVersion
from .apicall import APICall
from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
from .util import parse_return_stack_on_code
from .validators import validate_all
log = config.logger(__file__)
class ServiceRepo(object):
_endpoints: MutableMapping[str, List[Endpoint]] = {}
"""
Registered endpoints, in the format of {endpoint_name: Endpoint)}
the list of endpoints is sorted by min_version
"""
_version_required = config.get("apiserver.version.required")
""" If version is required, parsing will fail for endpoint paths that do not contain a valid version """
_max_version = PartialVersion("2.1")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
re.compile(
r"^/?v(?P<endpoint_version>\d+\.?\d+)/(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
)
if config.get("apiserver.version.required")
else re.compile(
r"^/?(v(?P<endpoint_version>\d+\.?\d+)/)?(?P<endpoint_name>[a-zA-Z_]\w+\.[a-zA-Z_]\w+)/?$"
)
)
"""
Endpoint structure expressions. We have two expressions, one with optional version part.
Constraints for the first (strict) expression:
1. May start with a leading '/'
2. Followed by a version number (int or float) preceded by a leading 'v'
3. Followed by a '/'
4. Followed by a service name, which must start with an english letter (lower or upper case) or underscore,
and followed by any number of alphanumeric or underscore characters
5. Followed by a '.'
6. Followed by an action name, which must start with an english letter (lower or upper case) or underscore,
and followed by any number of alphanumeric or underscore characters
7. May end with a leading '/'
The second (optional version) expression does not require steps 2 and 3.
"""
_return_stack = config.get("apiserver.return_stack")
""" return stack trace on error """
_return_stack_on_code = parse_return_stack_on_code(
config.get("apiserver.return_stack_on_code", {})
)
""" if 'return_stack' is true and error contains a return code, return stack trace only for these error codes """
_credentials = config["secure.credentials.apiserver"]
""" Api Server credentials used for intra-service communication """
_token = None
""" Token for internal calls """
@classmethod
def load(cls, root_module="services"):
root_module = Path(root_module)
sub_module = None
for sub_module in root_module.glob("*"):
if (
sub_module.is_file()
and sub_module.suffix == ".py"
and not sub_module.stem == "__init__"
):
import_module(f"{root_module.stem}.{sub_module.stem}")
if sub_module.is_dir():
import_module(f"{root_module.stem}.{sub_module.stem}")
# leave no trace of the 'sub_module' local
del sub_module
cls._max_version = max(
cls._max_version,
max(
ep.min_version
for ep in cast(Iterable[Endpoint], chain(*cls._endpoints.values()))
),
)
@classmethod
def register(cls, endpoint):
assert isinstance(endpoint, Endpoint)
if cls._endpoints.get(endpoint.name):
if any(
ep.min_version == endpoint.min_version
for ep in cls._endpoints[endpoint.name]
):
raise Exception(
f"Trying to register an existing endpoint. name={endpoint.name}, version={endpoint.min_version}"
)
else:
cls._endpoints[endpoint.name].append(endpoint)
else:
cls._endpoints[endpoint.name] = [endpoint]
cls._endpoints[endpoint.name].sort(key=lambda ep: ep.min_version, reverse=True)
@classmethod
def endpoint_names(cls):
return sorted(cls._endpoints.keys())
@classmethod
def endpoints_summary(cls):
return {
"endpoints": {
name: list(map(Endpoint.to_dict, eps))
for name, eps in cls._endpoints.items()
},
"models": {},
}
@classmethod
def max_endpoint_version(cls) -> PartialVersion:
return cls._max_version
@classmethod
def _get_endpoint(cls, name, version):
versions = cls._endpoints.get(name)
if not versions:
return None
try:
return next(ep for ep in versions if ep.min_version <= version)
except StopIteration:
# no appropriate version found
return None
@classmethod
def _resolve_endpoint_from_call(cls, call):
assert isinstance(call, APICall)
endpoint = cls._get_endpoint(
call.endpoint_name, call.requested_endpoint_version
)
if endpoint is None:
call.log_api = False
call.set_error_result(
msg=(
f"Unable to find endpoint for name {call.endpoint_name} "
f"and version {call.requested_endpoint_version}"
),
code=404,
subcode=0,
)
return
assert isinstance(endpoint, Endpoint)
call.actual_endpoint_version: PartialVersion = endpoint.min_version
call.requires_authorization = endpoint.authorize
return endpoint
@classmethod
def parse_endpoint_path(cls, path):
""" Parse endpoint version, service and action from request path. """
m = cls._endpoint_exp.match(path)
if not m:
raise MalformedPathError("Invalid request path %s" % path)
endpoint_name = m.group("endpoint_name")
version = m.group("endpoint_version")
if version is None:
# If endpoint is available, use the max version
version = cls._max_version
else:
try:
version = PartialVersion(version)
except ValueError as e:
raise RequestPathHasInvalidVersion(version=version, reason=e)
if version > cls._max_version:
raise InvalidVersionError(
f"Invalid API version (max. supported version is {cls._max_version})"
)
return version, endpoint_name
@classmethod
def _should_return_stack(cls, code, subcode):
if not cls._return_stack or code not in cls._return_stack_on_code:
return False
if subcode is None:
# Code in dict, but no subcode. We'll allow it.
return True
subcode_list = cls._return_stack_on_code.get(code)
if subcode_list is None:
# if the code is there but we don't have any subcode list, always return stack
return True
return subcode in subcode_list
@classmethod
def _validate_call(cls, call):
endpoint = cls._resolve_endpoint_from_call(call)
if call.failed:
return
validate_all(call, endpoint)
return endpoint
@classmethod
def validate_call(cls, call):
cls._validate_call(call)
@classmethod
def _get_company(cls, call, endpoint=None, ignore_error=False):
authorize = endpoint and endpoint.authorize
if ignore_error or not authorize:
try:
return call.identity.company
except Exception:
return None
return call.identity.company
@classmethod
def handle_call(cls, call):
try:
assert isinstance(call, APICall)
if call.failed:
raise CallFailedError()
endpoint = cls._resolve_endpoint_from_call(call)
if call.failed:
raise CallFailedError()
with timing_context.TimingContext("service_repo", "validate_call"):
validate_all(call, endpoint)
if call.failed:
raise CallFailedError()
# In case call does not require authorization, parsing the identity.company might raise an exception
company = cls._get_company(call, endpoint)
ret = endpoint.func(call, company, call.data_model)
# allow endpoints to return dict or model (instead of setting them explicitly on the call)
if ret is not None:
if isinstance(ret, jsonmodels.models.Base):
call.result.data_model = ret
elif isinstance(ret, dict):
call.result.data = ret
except APIError as ex:
# report stack trace only for gene
include_stack = cls._return_stack and cls._should_return_stack(
ex.code, ex.subcode
)
call.set_error_result(
code=ex.code,
subcode=ex.subcode,
msg=str(ex),
include_stack=include_stack,
)
except CallFailedError:
# Do nothing, let 'finally' wrap up
pass
except Exception as ex:
log.exception(ex)
call.set_error_result(
code=500, subcode=0, msg=str(ex), include_stack=cls._return_stack
)
finally:
content, content_type = call.get_response()
call.mark_end()
console_msg = f"Returned {call.result.code} for {call.endpoint_name} in {call.duration}ms"
if call.result.code < 300:
log.info(console_msg)
else:
console_msg = f"{console_msg}, msg={call.result.msg}"
if call.result.code < 500:
log.warn(console_msg)
else:
log.error(console_msg)
return content, content_type

View File

@@ -0,0 +1,47 @@
import socket
import six
def get_local_addr():
""" Get the local IP address (that isn't localhost) """
_, _, ipaddrlist = socket.gethostbyname_ex(socket.gethostname())
try:
return next(ip for ip in ipaddrlist if ip not in ('127.0.0.1',))
except StopIteration:
raise ValueError('Cannot find non-loopback ip address for this server (received %s)' % ', '.join(ipaddrlist))
def resolve_addr(addr):
""" Resolve address (IP string of host name) into an IP string. """
try:
socket.inet_aton(addr)
return addr
except socket.error:
try:
return socket.gethostbyname(addr)
except socket.error:
pass
def parse_return_stack_on_code(codes):
assert isinstance(codes, list), "return_stack_on_code must be a list"
def parse(e):
if isinstance(e, six.integer_types):
code, subcodes = e, None
elif isinstance(e, (list, tuple)):
code, subcodes = e[:2]
assert isinstance(code, six.integer_types), "return_stack_on_code/code must be int"
if isinstance(subcodes, six.integer_types):
subcodes = [subcodes]
if isinstance(subcodes, (list, tuple)):
assert all(isinstance(x, six.integer_types) for x in subcodes),\
"return_stack_on_code/subcode must be list(int)"
else:
raise ValueError("invalid return_stack_on_code/subcode(s): %s" % subcodes)
else:
raise ValueError("invalid return_stack_on_code/subcode(s): %s" % e)
return code, subcodes
return dict(map(parse, codes))

View File

@@ -0,0 +1,171 @@
import fastjsonschema
import jsonmodels.errors
from apierrors import errors, APIError
from config import config
from database.model import Company
from database.model.auth import Role, User
from service_repo import APICall
from service_repo.apicall import MissingIdentity
from service_repo.endpoint import Endpoint
from .auth import get_auth_func, Identity, authorize_impersonation, Payload
from .errors import CallParsingError
log = config.logger(__file__)
def validate_all(call: APICall, endpoint: Endpoint):
""" Perform all required call/endpoint validation, update call result appropriately """
try:
validate_auth(endpoint, call)
validate_role(endpoint, call)
if validate_impersonation(endpoint, call):
# if impersonating, validate role again
validate_role(endpoint, call)
# todo: remove vaildate_required_fields once all endpoints have json schema
validate_required_fields(endpoint, call)
# set models. models will be validated automatically
call.schema_validator = endpoint.request_schema_validator
if endpoint.request_data_model:
call.data_model_cls = endpoint.request_data_model
call.result.schema_validator = endpoint.response_schema_validator
if endpoint.response_data_model:
call.result.data_model_cls = endpoint.response_data_model
return True
except CallParsingError as ex:
raise errors.bad_request.ValidationError(str(ex))
except jsonmodels.errors.ValidationError as ex:
raise errors.bad_request.ValidationError(
" ".join(map(str.lower, map(str, ex.args)))
)
except fastjsonschema.exceptions.JsonSchemaException as ex:
log.exception(f"{endpoint.name}: fastjsonschema exception")
raise errors.bad_request.ValidationError(ex.args[0])
def validate_role(endpoint, call):
try:
if not endpoint.allows(call.identity.role):
raise errors.forbidden.RoleNotAllowed(role=call.identity.role, allowed=endpoint.allow_roles)
except MissingIdentity:
pass
def validate_auth(endpoint, call):
""" Validate authorization for this endpoint and call.
If authentication has occurred, the call is updated with the authentication results.
"""
if not call.authorization:
# No auth data. Invalid if we need to authorize and valid otherwise
if endpoint.authorize:
raise errors.unauthorized.NoCredentials()
return
# prepare arguments for validation
service, _, action = endpoint.name.partition(".")
# If we have auth data, we'll try to validate anyway (just so we'll have auth-based permissions whenever possible,
# even if endpoint did not require authorization)
try:
auth = call.authorization or ""
auth_type, _, auth_data = auth.partition(" ")
authorize_func = get_auth_func(auth_type)
call.auth = authorize_func(auth_data, service, action, call.batched_data)
except Exception as e:
if endpoint.authorize:
# if endpoint requires authorization, re-raise exception
raise
def validate_impersonation(endpoint, call):
""" Validate impersonation headers and set impersonated identity and authorization data accordingly.
:returns True is impersonating, False otherwise
"""
try:
act_as = call.act_as
impersonate_as = call.impersonate_as
if not impersonate_as and not act_as:
return
elif impersonate_as and act_as:
raise errors.bad_request.InvalidHeaders(
"only one allowed", headers=tuple(call.impersonation_headers.keys())
)
identity = call.auth.identity
# verify this user is allowed to impersonate at all
if identity.role not in Role.get_system_roles() | {Role.admin}:
raise errors.bad_request.ImpersonationError(
"impersonation not allowed", role=identity.role
)
# get the impersonated user's info
user_id = act_as or impersonate_as
if identity.role in [Role.root]:
# only root is allowed to impersonate users in other companies
query = dict(id=user_id)
else:
query = dict(id=user_id, company=identity.company)
user = User.objects(**query).first()
if not user:
raise errors.bad_request.ImpersonationError("unknown user", **query)
company = Company.objects(id=user.company).only("name").first()
if not company:
query.update(company=user.company)
raise errors.bad_request.ImpersonationError("unknown company for user", **query)
# create impersonation payload
if act_as:
# act as a user, using your own role and permissions
call.impersonation = Payload(
auth_type=None,
identity=Identity(
user=user.id,
company=user.company,
role=identity.role,
user_name=f"{identity.user_name} (acting as {user.name})",
company_name=company.name,
),
)
elif impersonate_as:
# impersonate as a user, using his own identity and permissions (required additional validation to verify
# impersonated user is allowed to access the endpoint)
service, _, action = endpoint.name.partition(".")
call.impersonation = authorize_impersonation(
user=user,
identity=Identity(
user=user.id,
company=user.company,
role=user.role,
user_name=f"{user.name} (impersonated by {identity.user_name})",
company_name=company.name,
),
service=service,
action=action,
call_data_items=call.batched_data,
)
else:
return False
return True
except APIError:
raise
except Exception:
raise errors.server_error.InternalError("validating impersonation")
def validate_required_fields(endpoint, call):
if endpoint.required_fields is None:
return
missing = [val for val in endpoint.required_fields if val not in call.data]
if missing:
raise errors.bad_request.MissingRequiredFields(missing=missing)

View File

166
server/services/auth.py Normal file
View File

@@ -0,0 +1,166 @@
from apierrors import errors
from apimodels.auth import (
GetTokenResponse,
GetTokenForUserRequest,
GetTokenRequest,
ValidateTokenRequest,
ValidateResponse,
CreateUserRequest,
CreateUserResponse,
CreateCredentialsResponse,
GetCredentialsResponse,
RevokeCredentialsResponse,
CredentialsResponse,
RevokeCredentialsRequest,
EditUserReq,
)
from apimodels.base import UpdateResponse
from bll.auth import AuthBLL
from config import config
from database.errors import translate_errors_context
from database.model.auth import (
User,
)
from service_repo import APICall, endpoint
from service_repo.auth import Token
log = config.logger(__file__)
@endpoint(
"auth.login",
request_data_model=GetTokenRequest,
response_data_model=GetTokenResponse,
)
def login(call):
""" Generates a token based on the authenticated user (intended for use with credentials) """
assert isinstance(call, APICall)
call.result.data_model = AuthBLL.get_token_for_user(
user_id=call.identity.user,
company_id=call.identity.company,
expiration_sec=call.data_model.expiration_sec,
)
@endpoint(
"auth.get_token_for_user",
request_data_model=GetTokenForUserRequest,
response_data_model=GetTokenResponse,
)
def get_token_for_user(call):
""" Generates a token based on a requested user and company. INTERNAL. """
assert isinstance(call, APICall)
call.result.data_model = AuthBLL.get_token_for_user(
user_id=call.data_model.user,
company_id=call.data_model.company,
expiration_sec=call.data_model.expiration_sec,
)
@endpoint(
"auth.validate_token",
request_data_model=ValidateTokenRequest,
response_data_model=ValidateResponse,
)
def validate_token_endpoint(call):
""" Validate a token and return identity if valid. INTERNAL. """
assert isinstance(call, APICall)
try:
# if invalid, decoding will fail
token = Token.from_encoded_token(call.data_model.token)
call.result.data_model = ValidateResponse(
valid=True, user=token.identity.user, company=token.identity.company
)
except Exception as e:
call.result.data_model = ValidateResponse(valid=False, msg=e.args[0])
@endpoint(
"auth.create_user",
request_data_model=CreateUserRequest,
response_data_model=CreateUserResponse,
)
def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """
user_id = AuthBLL.create_user(request=request, call=call)
call.result.data_model = CreateUserResponse(id=user_id)
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
def create_credentials(call: APICall, _, __):
credentials = AuthBLL.create_credentials(
user_id=call.identity.user,
company_id=call.identity.company,
role=call.identity.role,
)
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@endpoint(
"auth.revoke_credentials",
request_data_model=RevokeCredentialsRequest,
response_data_model=RevokeCredentialsResponse,
)
def revoke_credentials(call):
assert isinstance(call, APICall)
identity = call.identity
access_key = call.data_model.access_key
with translate_errors_context():
query = dict(
id=identity.user, company=identity.company, credentials__key=access_key
)
updated = User.objects(**query).update_one(pull__credentials__key=access_key)
if not updated:
raise errors.bad_request.InvalidUser(
"invalid user or invalid access key", **query
)
call.result.data_model = RevokeCredentialsResponse(revoked=updated)
@endpoint("auth.get_credentials", response_data_model=GetCredentialsResponse)
def get_credentials(call):
""" Validate a user by his email. INTERNAL. """
assert isinstance(call, APICall)
identity = call.identity
with translate_errors_context():
query = dict(id=identity.user, company=identity.company)
user = User.objects(**query).first()
if not user:
raise errors.bad_request.InvalidUserId(**query)
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
call.result.data_model = GetCredentialsResponse(
credentials=[
CredentialsResponse(access_key=c.key) for c in user.credentials
]
)
@endpoint(
"auth.edit_user", request_data_model=EditUserReq, response_data_model=UpdateResponse
)
def update(call, company_id, _):
assert isinstance(call, APICall)
fields = {
k: v
for k, v in call.data_model.to_struct().items()
if k != "user" and v is not None
}
with translate_errors_context():
result = User.objects(company=company_id, id=call.data_model.user).update(
**fields, full_result=True, upsert=False
)
if not result.matched_count:
raise errors.bad_request.InvalidUserId()
call.result.data_model = UpdateResponse(
updated=result.modified_count, fields=fields
)

509
server/services/events.py Normal file
View File

@@ -0,0 +1,509 @@
import itertools
from collections import defaultdict
from operator import itemgetter
import six
from apierrors import errors
from bll.event import EventBLL
from bll.task import TaskBLL
from service_repo import APICall, endpoint
from utilities import json
task_bll = TaskBLL()
event_bll = EventBLL()
@endpoint("events.add")
def add(call, company_id, req_model):
assert isinstance(call, APICall)
added, batch_errors = event_bll.add_events(company_id, [call.data.copy()], call.worker)
call.result.data = dict(
added=added,
errors=len(batch_errors)
)
call.kpis["events"] = 1
@endpoint("events.add_batch")
def add_batch(call, company_id, req_model):
assert isinstance(call, APICall)
events = call.batched_data
if events is None or len(events) == 0:
raise errors.bad_request.BatchContainsNoItems()
added, batch_errors = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(
added=added,
errors=len(batch_errors)
)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
def get_task_log(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
order = call.data.get("order") or "desc"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id, task_id, order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id)
call.result.data = dict(
events=events,
returned=len(events),
total=total_events,
scroll_id=scroll_id,
)
@endpoint("events.get_task_log", min_version="1.7", required_fields=["task"])
def get_task_log_v1_7(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
order = call.data.get("order") or "desc"
from_ = call.data.get("from") or "head"
scroll_id = call.data.get("scroll_id")
batch_size = int(call.data.get("batch_size") or 500)
scroll_order = 'asc' if (from_ == 'head') else 'desc'
events, scroll_id, total_events = event_bll.scroll_task_events(
company_id=company_id,
task_id=task_id,
order=scroll_order,
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id
)
if scroll_order != order:
events = events[::-1]
call.result.data = dict(
events=events,
returned=len(events),
total=total_events,
scroll_id=scroll_id,
)
@endpoint('events.download_task_log', required_fields=['task'])
def download_task_log(call, company_id, req_model):
task_id = call.data['task']
task_bll.assert_exists(company_id, task_id, allow_public=True)
line_type = call.data.get('line_type', 'json').lower()
line_format = str(call.data.get('line_format', '{asctime} {worker} {level} {msg}'))
is_json = (line_type == 'json')
if not is_json:
if not line_format:
raise errors.bad_request.MissingRequiredFields('line_format is required for plain text lines')
# validate line format placeholders
valid_task_log_fields = {'asctime', 'timestamp', 'level', 'worker', 'msg'}
invalid_placeholders = set()
while True:
try:
line_format.format(**dict.fromkeys(valid_task_log_fields | invalid_placeholders))
break
except KeyError as e:
invalid_placeholders.add(e.args[0])
except Exception as e:
raise errors.bad_request.FieldsValueError('invalid line format', error=e.args[0])
if invalid_placeholders:
raise errors.bad_request.FieldsValueError(
'undefined placeholders in line format',
placeholders=invalid_placeholders
)
# make sure line_format has a trailing newline
line_format = line_format.rstrip('\n') + '\n'
def generate():
scroll_id = None
batch_size = 1000
while True:
log_events, scroll_id, _ = event_bll.scroll_task_events(
company_id,
task_id,
order="asc",
event_type="log",
batch_size=batch_size,
scroll_id=scroll_id
)
if not log_events:
break
for ev in log_events:
ev['asctime'] = ev.pop('@timestamp')
if is_json:
ev.pop('type')
ev.pop('task')
yield json.dumps(ev) + '\n'
else:
try:
yield line_format.format(**ev)
except KeyError as ex:
raise errors.bad_request.FieldsValueError(
'undefined placeholders in line format',
placeholders=[str(ex)]
)
if len(log_events) < batch_size:
break
call.result.filename = 'task_%s.log' % task_id
call.result.content_type = 'text/plain'
call.result.raw_data = generate()
@endpoint("events.get_vector_metrics_and_variants", required_fields=["task"])
def get_vector_metrics_and_variants(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_vector")
)
@endpoint("events.get_scalar_metrics_and_variants", required_fields=["task"])
def get_scalar_metrics_and_variants(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
call.result.data = dict(
metrics=event_bll.get_metrics_and_variants(company_id, task_id, "training_stats_scalar")
)
# todo: !!! currently returning 10,000 records. should decide on a better way to control it
@endpoint("events.vector_metrics_iter_histogram", required_fields=["task", "metric", "variant"])
def vector_metrics_iter_histogram(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id, allow_public=True)
metric = call.data["metric"]
variant = call.data["variant"]
iterations, vectors = event_bll.get_vector_metrics_per_iter(company_id, task_id, metric, variant)
call.result.data = dict(
metric=metric,
variant=variant,
vectors=vectors,
iterations=iterations
)
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, req_model):
task_id = call.data["task"]
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
task_bll.assert_exists(company_id, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
sort=[{"timestamp": {"order": order}}],
event_type=event_type,
scroll_id=scroll_id
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_scalar_metric_data", required_fields=["task", "metric"])
def get_scalar_metric_data(call, company_id, req_model):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(company_id, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
event_type="training_stats_scalar",
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_latest_scalar_values", required_fields=["task"])
def get_task_latest_scalar_values(call, company_id, req_model):
task_id = call.data["task"]
task = task_bll.assert_exists(company_id, task_id, allow_public=True)
metrics, last_timestamp = event_bll.get_task_latest_scalar_values(company_id, task_id)
es_index = EventBLL.get_index_name(company_id, "*")
last_iters = event_bll.get_last_iters(es_index, task_id, None, 1)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
name=task.name,
status=task.status,
last_timestamp=last_timestamp
)
# todo: should not repeat iter (x-axis) for each metric/variant, JS client should get raw data and fill gaps if needed
@endpoint("events.scalar_metrics_iter_histogram", required_fields=["task"])
def scalar_metrics_iter_histogram(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
metrics = event_bll.get_scalar_metrics_average_per_iter(company_id, task_id)
call.result.data = metrics
@endpoint("events.multi_task_scalar_metrics_iter_histogram", required_fields=["tasks"])
def multi_task_scalar_metrics_iter_histogram(call, company_id, req_model):
task_ids = call.data["tasks"]
if isinstance(task_ids, six.string_types):
task_ids = [s.strip() for s in task_ids.split(",")]
# Note, bll already validates task ids as it needs their names
call.result.data = dict(
metrics=event_bll.compare_scalar_metrics_average_per_iter(company_id, task_ids, allow_public=True)
)
@endpoint("events.get_multi_task_plots", required_fields=["tasks"])
def get_multi_task_plots_v1_7(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
)
# Get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_multi_task_plots", min_version="1.8", required_fields=["tasks"])
def get_multi_task_plots(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
tasks = task_bll.assert_exists(
company_id=call.identity.company, only=('id', 'name'), task_ids=task_ids, allow_public=True
)
result = event_bll.get_task_events(
company_id, task_ids,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
)
tasks = {t.id: t.name for t in tasks}
return_events = _get_top_iter_unique_events_per_task(result.events, max_iters=iters, tasks=tasks)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_plots", required_fields=["task"])
def get_task_plots_v1_7(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company_id, task_id,
# event_type="plot",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
# scroll_id=scroll_id)
# the following is a hack for Bosch, requested by Moshik
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
event_type="plot",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
)
return_events = result.events
call.result.data = dict(
plots=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.debug_images", required_fields=["task"])
def get_debug_images_v1_7(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
# events, next_scroll_id, total_events = event_bll.get_task_events(
# company_id, task_id,
# event_type="training_debug_image",
# sort=[{"iter": {"order": "desc"}}],
# last_iter_count=iters,
# scroll_id=scroll_id)
# the following is a hack for Bosch, requested by Moshik
# get last 10K events by iteration and group them by unique metric+variant, returning top events for combination
result = event_bll.get_task_events(
company_id, task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
size=10000,
scroll_id=scroll_id
)
return_events = _get_top_iter_unique_events(result.events, max_iters=iters)
call.result.data = dict(
task=task_id,
images=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.debug_images", min_version="1.8", required_fields=["task"])
def get_debug_images(call, company_id, req_model):
task_id = call.data["task"]
iters = call.data.get("iters") or 1
scroll_id = call.data.get("scroll_id")
task_bll.assert_exists(call.identity.company, task_id, allow_public=True)
result = event_bll.get_task_events(
company_id, task_id,
event_type="training_debug_image",
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id
)
return_events = result.events
call.result.data = dict(
task=task_id,
images=return_events,
returned=len(return_events),
total=result.total_events,
scroll_id=result.next_scroll_id,
)
@endpoint("events.delete_for_task", required_fields=["task"])
def delete_for_task(call, company_id, req_model):
task_id = call.data["task"]
task_bll.assert_exists(company_id, task_id)
call.result.data = dict(
deleted=event_bll.delete_task_events(company_id, task_id)
)
def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
key = itemgetter('metric', 'variant', 'task', 'iter')
unique_events = itertools.chain.from_iterable(
itertools.islice(group, max_iters)
for _, group in itertools.groupby(sorted(events, key=key, reverse=True), key=key))
def collect(evs, fields):
if not fields:
evs = list(evs)
return {
'name': tasks.get(evs[0].get('task')),
'plots': evs
}
return {
str(k): collect(group, fields[1:])
for k, group in itertools.groupby(evs, key=itemgetter(fields[0]))
}
collect_fields = ('metric', 'variant', 'task', 'iter')
return collect(
sorted(unique_events, key=itemgetter(*collect_fields), reverse=True),
collect_fields
)
def _get_top_iter_unique_events(events, max_iters):
top_unique_events = defaultdict(lambda: [])
for e in events:
key = e.get("metric", "") + e.get("variant", "")
evs = top_unique_events[key]
if len(evs) < max_iters:
evs.append(e)
unique_events = list(itertools.chain.from_iterable(list(top_unique_events.values())))
unique_events.sort(key=lambda e: e["iter"], reverse=True)
return unique_events

433
server/services/models.py Normal file
View File

@@ -0,0 +1,433 @@
from datetime import datetime
from urllib.parse import urlparse
from mongoengine import Q, EmbeddedDocument
import database
from apierrors import errors
from apimodels.base import UpdateResponse
from apimodels.models import (
CreateModelRequest,
CreateModelResponse,
PublishModelRequest,
PublishModelResponse,
ModelTaskPublishResponse,
)
from bll.task import TaskBLL
from config import config
from database.errors import translate_errors_context
from database.fields import SupportedURLField
from database.model import validate_id
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, TaskStatus
from database.utils import (
parse_from_call,
get_company_or_none_constraint,
filter_fields,
)
from service_repo import APICall, endpoint
from timing_context import TimingContext
log = config.logger(__file__)
get_all_query_options = Model.QueryParameterOptions(
pattern_fields=("name", "comment"),
fields=("ready",),
list_fields=("tags", "framework", "uri", "id", "project", "task", "parent"),
)
@endpoint("models.get_by_id", required_fields=["model"])
def get_by_id(call):
assert isinstance(call, APICall)
model_id = call.data["model"]
with translate_errors_context():
res = Model.get_many(
company=call.identity.company,
query_dict=call.data,
query=Q(id=model_id),
allow_public=True,
)
if not res:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
)
call.result.data = {"model": res[0]}
@endpoint("models.get_by_task_id", required_fields=["task"])
def get_by_task_id(call):
assert isinstance(call, APICall)
task_id = call.data["task"]
with translate_errors_context():
query = dict(id=task_id, company=call.identity.company)
res = Task.get(_only=["output"], **query)
if not res:
raise errors.bad_request.InvalidTaskId(**query)
if not res.output:
raise errors.bad_request.MissingTaskFields(field="output")
if not res.output.model:
raise errors.bad_request.MissingTaskFields(field="output.model")
model_id = res.output.model
res = Model.objects(
Q(id=model_id) & get_company_or_none_constraint(call.identity.company)
).first()
if not res:
raise errors.bad_request.InvalidModelId(
"no such public or company model",
id=model_id,
company=call.identity.company,
)
call.result.data = {"model": res.to_proper_dict()}
@endpoint("models.get_all_ex", required_fields=[])
def get_all_ex(call):
assert isinstance(call, APICall)
with translate_errors_context():
with TimingContext("mongo", "models_get_all_ex"):
models = Model.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
)
call.result.data = {"models": models}
@endpoint("models.get_all", required_fields=[])
def get_all(call):
assert isinstance(call, APICall)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
models = Model.get_many(
company=call.identity.company,
parameters=call.data,
query_dict=call.data,
allow_public=True,
query_options=get_all_query_options,
)
call.result.data = {"models": models}
create_fields = {
"name": None,
"tags": list,
"task": Task,
"comment": None,
"uri": None,
"project": Project,
"parent": Model,
"framework": None,
"design": None,
"labels": dict,
"ready": None,
}
schemes = list(SupportedURLField.schemes)
def _validate_uri(uri):
parsed_uri = urlparse(uri)
if parsed_uri.scheme not in schemes:
raise errors.bad_request.InvalidModelUri("unsupported scheme", uri=uri)
elif not parsed_uri.path:
raise errors.bad_request.InvalidModelUri("missing path", uri=uri)
def parse_model_fields(call, valid_fields):
fields = parse_from_call(call.data, valid_fields, Model.get_fields())
tags = fields.get("tags")
if tags:
fields["tags"] = list(set(tags))
return fields
@endpoint("models.update_for_task", required_fields=["task"])
def update_for_task(call, company_id, _):
assert isinstance(call, APICall)
task_id = call.data["task"]
uri = call.data.get("uri")
iteration = call.data.get("iteration")
override_model_id = call.data.get("override_model_id")
if not (uri or override_model_id) or (uri and override_model_id):
raise errors.bad_request.MissingRequiredFields(
"exactly one field is required", fields=("uri", "override_model_id")
)
with translate_errors_context():
query = dict(id=task_id, company=company_id)
task = Task.get_for_writing(
id=task_id,
company=company_id,
_only=["output", "execution", "name", "status", "project"],
)
if not task:
raise errors.bad_request.InvalidTaskId(**query)
allowed_states = [TaskStatus.created, TaskStatus.in_progress]
if task.status not in allowed_states:
raise errors.bad_request.InvalidTaskStatus(
f"model can only be updated for tasks in the {allowed_states} states",
**query,
)
if override_model_id:
query = dict(company=company_id, id=override_model_id)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
else:
if "name" not in call.data:
# use task name if name not provided
call.data["name"] = task.name
if "comment" not in call.data:
call.data["comment"] = f"Created by task `{task.name}` ({task.id})"
if task.output and task.output.model:
# model exists, update
res = _update_model(call, model_id=task.output.model).to_struct()
res.update({"id": task.output.model, "created": False})
call.result.data = res
return
# new model, create
fields = parse_model_fields(call, create_fields)
# create and save model
model = Model(
id=database.utils.id(),
created=datetime.utcnow(),
user=call.identity.user,
company=company_id,
project=task.project,
framework=task.execution.framework,
parent=task.execution.model,
design=task.execution.model_desc,
labels=task.execution.model_labels,
ready=(task.status == TaskStatus.published),
**fields,
)
model.save()
TaskBLL.update_statistics(
task_id=task_id,
company_id=company_id,
last_iteration_max=iteration,
output__model=model.id,
)
call.result.data = {"id": model.id, "created": True}
@endpoint(
"models.create",
request_data_model=CreateModelRequest,
response_data_model=CreateModelResponse,
)
def create(call, company, req_model):
assert isinstance(call, APICall)
assert isinstance(req_model, CreateModelRequest)
identity = call.identity
if req_model.public:
company = ""
with translate_errors_context():
project = req_model.project
if project:
validate_id(Project, company=company, project=project)
uri = req_model.uri
if uri:
_validate_uri(uri)
task = req_model.task
req_data = req_model.to_struct()
if task:
validate_task(call, req_data)
fields = filter_fields(Model, req_data)
# create and save model
model = Model(
id=database.utils.id(),
user=identity.user,
company=company,
created=datetime.utcnow(),
**fields,
)
model.save()
call.result.data_model = CreateModelResponse(id=model.id, created=True)
def prepare_update_fields(call, fields):
fields = fields.copy()
if "uri" in fields:
_validate_uri(fields["uri"])
# clear UI cache if URI is provided (model updated)
fields["ui_cache"] = fields.pop("ui_cache", {})
if "task" in fields:
validate_task(call, fields)
return fields
def validate_task(call, fields):
Task.get_for_writing(company=call.identity.company, id=fields["task"], _only=["id"])
@endpoint("models.edit", required_fields=["model"], response_data_model=UpdateResponse)
def edit(call):
assert isinstance(call, APICall)
identity = call.identity
model_id = call.data["model"]
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
fields = parse_model_fields(call, create_fields)
fields = prepare_update_fields(call, fields)
for key in fields:
field = getattr(model, key, None)
value = fields[key]
if (
field
and isinstance(value, dict)
and isinstance(field, EmbeddedDocument)
):
d = field.to_mongo(use_db_field=False).to_dict()
d.update(value)
fields[key] = d
iteration = call.data.get("iteration")
task_id = model.task or fields.get('task')
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
)
if fields:
updated = model.update(upsert=False, **fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
else:
call.result.data_model = UpdateResponse(updated=0)
def _update_model(call, model_id=None):
assert isinstance(call, APICall)
identity = call.identity
model_id = model_id or call.data["model"]
with translate_errors_context():
# get model by id
query = dict(id=model_id, company=identity.company)
model = Model.objects(**query).first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
data = prepare_update_fields(call, call.data)
task_id = data.get("task")
iteration = data.get("iteration")
if task_id and iteration is not None:
TaskBLL.update_statistics(
task_id=task_id,
company_id=identity.company,
last_iteration_max=iteration,
)
updated_count, updated_fields = Model.safe_update(
call.identity.company, model.id, data
)
return UpdateResponse(updated=updated_count, fields=updated_fields)
@endpoint(
"models.update", required_fields=["model"], response_data_model=UpdateResponse
)
def update(call):
call.result.data_model = _update_model(call)
@endpoint(
"models.set_ready",
request_data_model=PublishModelRequest,
response_data_model=PublishModelResponse,
)
def set_ready(call: APICall, company, req_model: PublishModelRequest):
updated, published_task_data = TaskBLL.model_set_ready(
model_id=req_model.model,
company_id=company,
publish_task=req_model.publish_task,
force_publish_task=req_model.force_publish_task
)
call.result.data_model = PublishModelResponse(
updated=updated,
published_task=ModelTaskPublishResponse(
**published_task_data
) if published_task_data else None
)
@endpoint("models.delete", required_fields=["model"])
def update(call):
assert isinstance(call, APICall)
identity = call.identity
model_id = call.data["model"]
force = call.data.get("force", False)
with translate_errors_context():
query = dict(id=model_id, company=identity.company)
model = Model.objects(**query).only("id", "task").first()
if not model:
raise errors.bad_request.InvalidModelId(**query)
deleted_model_id = f"__DELETED__{model_id}"
using_tasks = Task.objects(execution__model=model_id).only("id")
if using_tasks:
if not force:
raise errors.bad_request.ModelInUse(
"as execution model, use force=True to delete",
num_tasks=len(using_tasks),
)
# update deleted model id in using tasks
using_tasks.update(
execution__model=deleted_model_id, upsert=False, multi=True
)
if model.task:
task = Task.objects(id=model.task).first()
if task and task.status == TaskStatus.published:
if not force:
raise errors.bad_request.ModelCreatingTaskExists(
"and published, use force=True to delete", task=model.task
)
task.update(
output__model=deleted_model_id,
output__error=f"model deleted on {datetime.utcnow().isoformat()}",
upsert=False,
)
del_count = Model.objects(**query).delete()
call.result.data = dict(deleted=del_count > 0)

340
server/services/projects.py Normal file
View File

@@ -0,0 +1,340 @@
from collections import defaultdict
from datetime import datetime
from itertools import groupby
from operator import itemgetter
import dpath
from mongoengine import Q
import database
from apierrors import errors
from apimodels.base import UpdateResponse
from bll.task import TaskBLL
from database.errors import translate_errors_context
from database.model.model import Model
from database.model.project import Project
from database.model.task.task import Task, TaskStatus, TaskVisibility
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
from service_repo import APICall, endpoint
from timing_context import TimingContext
task_bll = TaskBLL()
archived_tasks_cond = {"$in": [TaskVisibility.archived.value, "$tags"]}
create_fields = {
"name": None,
"description": None,
"tags": list,
"default_output_destination": None,
}
get_all_query_options = Project.QueryParameterOptions(
pattern_fields=("name", "description"), list_fields=("tags", "id")
)
@endpoint("projects.get_by_id", required_fields=["project"])
def get_by_id(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
with translate_errors_context():
with TimingContext("mongo", "projects_by_id"):
query = Q(id=project_id) & get_company_or_none_constraint(
call.identity.company
)
res = Project.objects(query).first()
if not res:
raise errors.bad_request.InvalidProjectId(id=project_id)
res = res.to_proper_dict()
call.result.data = {"project": res}
def make_projects_get_all_pipelines(project_ids, specific_state=None):
archived = TaskVisibility.archived.value
status_count_pipeline = [
# count tasks per project per status
{"$match": {"project": {"$in": project_ids}}},
# make sure tags is always an array (required by subsequent $in in archived_tasks_cond)
{
"$addFields": {
"tags": {
"$cond": {
"if": {"$ne": [{"$type": "$tags"}, "array"]},
"then": [],
"else": "$tags",
}
}
}
},
{
"$group": {
"_id": {
"project": "$project",
"status": "$status",
archived: archived_tasks_cond,
},
"count": {"$sum": 1},
}
},
# for each project, create a list of (status, count, archived)
{
"$group": {
"_id": "$_id.project",
"counts": {
"$push": {
"status": "$_id.status",
"count": "$count",
archived: "$_id.%s" % archived,
}
},
}
},
]
def runtime_subquery(additional_cond):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed and started and completed > started
"if": {
"$and": [
"$started",
"$completed",
{"$gt": ["$completed", "$started"]},
additional_cond,
]
},
# then: floor((completed - started) / 1000)
"then": {
"$floor": {
"$divide": [
{"$subtract": ["$completed", "$started"]},
1000.0,
]
}
},
"else": 0,
}
}
}
group_step = {"_id": "$project"}
for state in TaskVisibility:
if specific_state and state != specific_state:
continue
if state == TaskVisibility.active:
group_step[state.value] = runtime_subquery({"$not": archived_tasks_cond})
elif state == TaskVisibility.archived:
group_step[state.value] = runtime_subquery(archived_tasks_cond)
runtime_pipeline = [
# only count run time for these types of tasks
{
"$match": {
"type": {"$in": ["training", "testing", "annotation"]},
"project": {"$in": project_ids},
}
},
{
# for each project
"$group": group_step
},
]
return status_count_pipeline, runtime_pipeline
@endpoint("projects.get_all_ex")
def get_all_ex(call):
assert isinstance(call, APICall)
include_stats = call.data.get("include_stats")
stats_for_state = call.data.get("stats_for_state", TaskVisibility.active.value)
if stats_for_state:
try:
specific_state = TaskVisibility(stats_for_state)
except ValueError:
raise errors.bad_request.FieldsValueError(stats_for_state=stats_for_state)
else:
specific_state = None
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
res = Project.get_many_with_join(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
allow_public=True,
)
if not include_stats:
call.result.data = {"projects": res}
return
ids = [project["id"] for project in res]
status_count_pipeline, runtime_pipeline = make_projects_get_all_pipelines(
ids, specific_state=specific_state
)
default_counts = dict.fromkeys(get_options(TaskStatus), 0)
def set_default_count(entry):
return dict(default_counts, **entry)
status_count = defaultdict(lambda: {})
key = itemgetter(TaskVisibility.archived.value)
for result in Task.objects.aggregate(*status_count_pipeline):
for k, group in groupby(sorted(result["counts"], key=key), key):
section = (
TaskVisibility.archived if k else TaskVisibility.active
).value
status_count[result["_id"]][section] = set_default_count(
{
count_entry["status"]: count_entry["count"]
for count_entry in group
}
)
runtime = {
result["_id"]: {k: v for k, v in result.items() if k != "_id"}
for result in Task.objects.aggregate(*runtime_pipeline)
}
def safe_get(obj, path, default=None):
try:
return dpath.get(obj, path)
except KeyError:
return default
def get_status_counts(project_id, section):
path = "/".join((project_id, section))
return {
"total_runtime": safe_get(runtime, path, 0),
"status_count": safe_get(status_count, path, default_counts),
}
report_for_states = [
s for s in TaskVisibility if not specific_state or specific_state == s
]
for project in res:
project["stats"] = {
task_state.value: get_status_counts(project["id"], task_state.value)
for task_state in report_for_states
}
call.result.data = {"projects": res}
@endpoint("projects.get_all")
def get_all(call):
assert isinstance(call, APICall)
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
res = Project.get_many(
company=call.identity.company,
query_dict=call.data,
query_options=get_all_query_options,
parameters=call.data,
allow_public=True,
)
call.result.data = {"projects": res}
@endpoint("projects.create", required_fields=["name", "description"])
def create(call):
assert isinstance(call, APICall)
identity = call.identity
with translate_errors_context():
fields = parse_from_call(call.data, create_fields, Project.get_fields())
now = datetime.utcnow()
project = Project(
id=database.utils.id(),
user=identity.user,
company=identity.company,
created=now,
last_update=now,
**fields
)
with TimingContext("mongo", "projects_save"):
project.save()
call.result.data = {"id": project.id}
@endpoint(
"projects.update", required_fields=["project"], response_data_model=UpdateResponse
)
def update(call):
"""
update
:summary: Update project information.
See `project.create` for parameters.
:return: updated - `int` - number of projects updated
fields - `[string]` - updated fields
"""
assert isinstance(call, APICall)
project_id = call.data["project"]
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
fields = parse_from_call(
call.data, create_fields, Project.get_fields(), discard_none_values=False
)
fields["last_update"] = datetime.utcnow()
with TimingContext("mongo", "projects_update"):
updated = project.update(upsert=False, **fields)
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
@endpoint("projects.delete", required_fields=["project"])
def delete(call):
assert isinstance(call, APICall)
project_id = call.data["project"]
force = call.data.get("force", False)
with translate_errors_context():
project = Project.get_for_writing(company=call.identity.company, id=project_id)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
# NOTE: from this point on we'll use the project ID and won't check for company, since we assume we already
# have the correct project ID.
# Find the tasks which belong to the project
for cls, error in (
(Task, errors.bad_request.ProjectHasTasks),
(Model, errors.bad_request.ProjectHasModels),
):
res = cls.objects(
project=project_id, tags__nin=[TaskVisibility.archived.value]
).only("id")
if res and not force:
raise error("use force=true to delete", id=project_id)
updated_count = res.update(project=None)
project.delete()
call.result.data = {"deleted": 1, "disassociated_tasks": updated_count}
@endpoint("projects.get_unique_metric_variants")
def get_unique_metric_variants(call, company_id, req_model):
project_id = call.data.get("project")
metrics = task_bll.get_unique_metric_variants(
company_id, [project_id] if project_id else None
)
call.result.data = {"metrics": metrics}

Some files were not shown because too many files have changed in this diff Show More