mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ce202cc99 | ||
|
|
d09528bc26 | ||
|
|
42d2a41dbe | ||
|
|
82be1840b0 | ||
|
|
27352c5cb6 | ||
|
|
1ea6408d41 | ||
|
|
5e095af3aa | ||
|
|
ab3dceed92 | ||
|
|
3bf5126d84 | ||
|
|
ab2ab7b23a | ||
|
|
c9184d125b | ||
|
|
0c0fdb72b9 | ||
|
|
86378053d4 | ||
|
|
b1cbba0cf1 | ||
|
|
f31526042d | ||
|
|
3f8d5bc346 |
35
README.md
35
README.md
@@ -1,6 +1,6 @@
|
||||
# Trains Server
|
||||
|
||||
## Auto-Magical Experiment Manager & Version Control for AI
|
||||
## Auto-Magical Experiment Manager & Version Control for AI - ε Devops Included!
|
||||
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
@@ -9,6 +9,8 @@
|
||||
|
||||
### Help improve Trains by filling our 2-min [user survey](https://allegro.ai/lp/trains-user-survey/)
|
||||
|
||||
## :rocket: Trains-Agent Services is now included, for more information see [services](https://github.com/allegroai/trains-server#services)
|
||||
|
||||
## Introduction
|
||||
|
||||
The **trains-server** is the backend service infrastructure for [Trains](https://github.com/allegroai/trains).
|
||||
@@ -98,6 +100,26 @@ you can [use](https://github.com/allegroai/trains#using-trains) **Trains** in yo
|
||||
for example http://localhost:8080.
|
||||
For more information about the Trains client, see [**Trains**](https://github.com/allegroai/trains).
|
||||
|
||||
## Trains-Agent Services <a name="services"></a>
|
||||
|
||||
As of version 0.15 of **trains-server**, dockerized deployment includes a **Trains-Agent Services** container running as
|
||||
part of the docker container collection.
|
||||
|
||||
Trains-Agent Services is an extension of Trains-Agent that provides the ability to launch long-lasting jobs
|
||||
that previously had to be executed on local / dedicated machines. It allows a single agent to
|
||||
launch multiple dockers (Tasks) for different use cases. To name a few use cases, auto-scaler service (spinning instances
|
||||
when the need arises and the budget allows), Controllers (Implementing pipelines and more sophisticated DevOps logic),
|
||||
Optimizer (such as Hyper-parameter Optimization or sweeping), and Application (such as interactive Bokeh apps for
|
||||
increased data transparency)
|
||||
|
||||
Trains-Agent Services container will spin **any** task enqueued into the dedicated `services` queue.
|
||||
Every task launched by Trains-Agent Services will be registered as a new node in the system,
|
||||
providing tracking and transparency capabilities.
|
||||
You can also run the Trains-Agent Services manually, see details in [trains-agent services mode](https://github.com/allegroai/trains-agent#trains-agent-services-mode-)
|
||||
|
||||
**Note**: It is the user's responsibility to make sure the proper tasks are pushed into the `services` queue.
|
||||
Do not enqueue training / inference tasks into the `services` queue, as it will put unnecessary load on the server.
|
||||
|
||||
## Advanced Functionality
|
||||
|
||||
**trains-server** provides a few additional useful features, which can be manually enabled:
|
||||
@@ -152,6 +174,17 @@ To upgrade your existing **trains-server** deployment:
|
||||
curl https://raw.githubusercontent.com/allegroai/trains-server/master/docker-compose.yml -o docker-compose.yml
|
||||
```
|
||||
|
||||
1. Configure the Trains-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, Trains-Agent Services will use the external
|
||||
public address of the **trains-server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
the Trains-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **trains-server** build
|
||||
```bash
|
||||
docker-compose -f docker-compose.yml pull
|
||||
|
||||
@@ -15,6 +15,8 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
|
||||
depends_on:
|
||||
- redis
|
||||
- mongo
|
||||
|
||||
@@ -22,6 +22,7 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-win10}
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
ports:
|
||||
@@ -75,6 +76,8 @@ services:
|
||||
volumes:
|
||||
- c:/opt/trains/logs:/var/log/trains
|
||||
- c:/opt/trains/data/fileserver:/mnt/fileserver
|
||||
- c:/opt/trains/config:/opt/trains/config
|
||||
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
@@ -86,7 +89,8 @@ services:
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
- mongodata:/data
|
||||
- c:/opt/trains/data/mongo/db:/data/db
|
||||
- c:/opt/trains/data/mongo/configdb:/data/configdb
|
||||
ports:
|
||||
- "27017:27017"
|
||||
|
||||
@@ -117,6 +121,3 @@ services:
|
||||
networks:
|
||||
backend:
|
||||
driver: bridge
|
||||
|
||||
volumes:
|
||||
mongodata:
|
||||
|
||||
@@ -22,6 +22,7 @@ services:
|
||||
TRAINS_MONGODB_SERVICE_PORT: 27017
|
||||
TRAINS_REDIS_SERVICE_HOST: redis
|
||||
TRAINS_REDIS_SERVICE_PORT: 6379
|
||||
TRAINS_SERVER_DEPLOYMENT_TYPE: ${TRAINS_SERVER_DEPLOYMENT_TYPE:-linux}
|
||||
TRAINS__apiserver__mongo__pre_populate__enabled: "true"
|
||||
TRAINS__apiserver__mongo__pre_populate__zip_file: "/opt/trains/db-pre-populate/export.zip"
|
||||
ports:
|
||||
@@ -75,6 +76,7 @@ services:
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
- /opt/trains/data/fileserver:/mnt/fileserver
|
||||
- /opt/trains/config:/opt/trains/config
|
||||
ports:
|
||||
- "8081:8081"
|
||||
|
||||
@@ -108,14 +110,14 @@ services:
|
||||
container_name: trains-webserver
|
||||
image: allegroai/trains:latest
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/trains/logs:/var/log/trains
|
||||
depends_on:
|
||||
- apiserver
|
||||
ports:
|
||||
- "8080:80"
|
||||
|
||||
agent-services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: trains-agent-services
|
||||
image: allegroai/trains-agent-services:latest
|
||||
restart: unless-stopped
|
||||
@@ -123,7 +125,7 @@ services:
|
||||
environment:
|
||||
TRAINS_HOST_IP: ${TRAINS_HOST_IP}
|
||||
TRAINS_WEB_HOST: ${TRAINS_WEB_HOST:-}
|
||||
TRAINS_API_HOST: ${TRAINS_API_HOST:-}
|
||||
TRAINS_API_HOST: http://apiserver:8008
|
||||
TRAINS_FILES_HOST: ${TRAINS_FILES_HOST:-}
|
||||
TRAINS_API_ACCESS_KEY: ${TRAINS_API_ACCESS_KEY:-}
|
||||
TRAINS_API_SECRET_KEY: ${TRAINS_API_SECRET_KEY:-}
|
||||
|
||||
@@ -50,45 +50,64 @@ To upgrade the AMI:
|
||||
|
||||
The following sections contain lists of AMI Image IDs, per region, for each released **trains-server** version.
|
||||
|
||||
### Latest version AMI - v0.15.0 (auto update)<a name="autoupdate"></a>
|
||||
### Latest version AMI - v0.15.1 (auto update)<a name="autoupdate"></a>
|
||||
|
||||
For easier upgrades, the following AMIs automatically update to the latest release every reboot:
|
||||
|
||||
* **eu-north-1** : ami-0a05eb5b384a84609
|
||||
* **ap-south-1** : ami-00f190b50e60b1eb5
|
||||
* **eu-west-3** : ami-044fad585e1d1798e
|
||||
* **eu-west-2** : ami-04ab930416a4af8c5
|
||||
* **eu-west-1** : ami-00c022f333417e78e
|
||||
* **ap-northeast-2** : ami-0c436e94f461a9a22
|
||||
* **ap-northeast-1** : ami-018e761ad0009d5d4
|
||||
* **sa-east-1** : ami-0b6c0e8e93b6ebbdd
|
||||
* **ca-central-1** : ami-0cf12aab70c14237d
|
||||
* **ap-southeast-1** : ami-0fe7840b9bde05581
|
||||
* **ap-southeast-2** : ami-00f230e86e1afda91
|
||||
* **eu-central-1** : ami-0635d13b79f76e04f
|
||||
* **us-east-2** : ami-0b323078d0206db0e
|
||||
* **us-west-1** : ami-07fdc1d461906f957
|
||||
* **us-west-2** : ami-0a5cac167c3ebdedb
|
||||
* **us-east-1** : ami-0d03956bea3aa5a44
|
||||
* **eu-north-1** : ami-0f63429f8e5d57315
|
||||
* **ap-south-1** : ami-058a2a70b7fb8ec87
|
||||
* **eu-west-3** : ami-0fc9f9e8e986f39c4
|
||||
* **eu-west-2** : ami-0b0bc1ff2f0239bd9
|
||||
* **eu-west-1** : ami-0056ec5d22b0fac91
|
||||
* **ap-northeast-2** : ami-0898c9aa7f580fec7
|
||||
* **ap-northeast-1** : ami-011036ddcc9398871
|
||||
* **sa-east-1** : ami-04feeded12192438c
|
||||
* **ca-central-1** : ami-02c717776c9e75025
|
||||
* **ap-southeast-1** : ami-05b5866e7029bb9f1
|
||||
* **ap-southeast-2** : ami-0384bd2b69467fff8
|
||||
* **eu-central-1** : ami-01f15be85297d6f06
|
||||
* **us-east-2** : ami-094070ca8aa110180
|
||||
* **us-west-1** : ami-0d08ec5bc29eddb29
|
||||
* **us-west-2** : ami-04715cceedaf6eae7
|
||||
* **us-east-1** : ami-071dbaa1847585c4c
|
||||
|
||||
### v0.15.1 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0bb36c4dbe61f8c46
|
||||
* **ap-south-1** : ami-0ac93ff85a5c770f9
|
||||
* **eu-west-3** : ami-015ebfa846b8de5bb
|
||||
* **eu-west-2** : ami-082aacd59408713d9
|
||||
* **eu-west-1** : ami-066aad8c6b9b9991b
|
||||
* **ap-northeast-2** : ami-0cb47f1c8591c799d
|
||||
* **ap-northeast-1** : ami-005131d3037da9d2a
|
||||
* **sa-east-1** : ami-0f7fdc4e19c8444a3
|
||||
* **ca-central-1** : ami-07c234dad3ece2d78
|
||||
* **ap-southeast-1** : ami-0d8e0475d7d4897e4
|
||||
* **ap-southeast-2** : ami-053e3f25dee0424b9
|
||||
* **eu-central-1** : ami-00d25558c5242708e
|
||||
* **us-east-2** : ami-0bd45f800dfbde456
|
||||
* **us-west-1** : ami-05e79bf1704721148
|
||||
* **us-west-2** : ami-037c328649048409b
|
||||
* **us-east-1** : ami-0a3cafe46bf085200
|
||||
|
||||
### v0.15.0 (static update)
|
||||
|
||||
* **eu-north-1** : ami-0475a5068d615769b
|
||||
* **ap-south-1** : ami-00c7e642badaa2ebf
|
||||
* **eu-west-3** : ami-0655f769c28843e25
|
||||
* **eu-west-2** : ami-04d82f48f09e2b846
|
||||
* **eu-west-1** : ami-07a2aab2dc7b4ec5f
|
||||
* **ap-northeast-2** : ami-0257ab220a8bc7a52
|
||||
* **ap-northeast-1** : ami-0c4900af758b91dde
|
||||
* **sa-east-1** : ami-021f758a4a21d5725
|
||||
* **ca-central-1** : ami-0ce9703b3b47cfe70
|
||||
* **ap-southeast-1** : ami-0b38689fdb8f71b74
|
||||
* **ap-southeast-2** : ami-0c2b3a171e7ae4b00
|
||||
* **eu-central-1** : ami-0fdd3420d6e6b4a1f
|
||||
* **us-east-2** : ami-0288e9654da36ed1c
|
||||
* **us-west-1** : ami-0f1d6ee0b73fe9ca2
|
||||
* **us-west-2** : ami-025f0c5bfeacbf390
|
||||
* **us-east-1** : ami-0b17b0bfa8b91f805
|
||||
* **eu-north-1** : ami-0bef15c03eab64c0c
|
||||
* **ap-south-1** : ami-06ac6248e583e2cd2
|
||||
* **eu-west-3** : ami-0541d86ef47a5714e
|
||||
* **eu-west-2** : ami-01381ef4c4ed22482
|
||||
* **eu-west-1** : ami-064626a0dd38b21f1
|
||||
* **ap-northeast-2** : ami-0a2490a7a3a8aa675
|
||||
* **ap-northeast-1** : ami-063f1de819a2524b8
|
||||
* **sa-east-1** : ami-07980486741b94987
|
||||
* **ca-central-1** : ami-0ced3b8b21ded839e
|
||||
* **ap-southeast-1** : ami-0c493c5093fde8741
|
||||
* **ap-southeast-2** : ami-0320a727eccb8dc6c
|
||||
* **eu-central-1** : ami-0aa85cfc78674c526
|
||||
* **us-east-2** : ami-01791485051e1880c
|
||||
* **us-west-1** : ami-0d8eade4d5888ea73
|
||||
* **us-west-2** : ami-02ceaef72cdf60f7e
|
||||
* **us-east-1** : ami-0fc3f9d1d0eba1d62
|
||||
|
||||
### v0.14.2 (static update)
|
||||
|
||||
|
||||
@@ -53,6 +53,11 @@ To upgrade **trains-server** on an existing GCP instance based on one of these C
|
||||
|
||||
The following sections contain lists of Custom Image URLs (exported in different formats) for each released **trains-server** version.
|
||||
|
||||
### Latest version image (v0.14.1)
|
||||
### Latest version image
|
||||
|
||||
- https://storage.googleapis.com/allegro-files/trains-server/trains-server.tar.gz
|
||||
|
||||
### All released images
|
||||
|
||||
- v0.15.0 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-15-0.tar.gz
|
||||
- v0.14.1 - https://storage.googleapis.com/allegro-files/trains-server/trains-server-0-14-1.tar.gz
|
||||
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -16,6 +17,9 @@ DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
|
||||
|
||||
class BasicConfig:
|
||||
NotSet = object()
|
||||
@@ -46,7 +50,23 @@ class BasicConfig:
|
||||
path = ".".join((self.prefix, Path(name).stem))
|
||||
return logging.getLogger(path)
|
||||
|
||||
def _read_env_paths(self, key):
|
||||
@staticmethod
|
||||
def _read_extra_env_config_values():
|
||||
""" Loads extra configuration from environment-injected values """
|
||||
result = ConfigTree()
|
||||
prefix = EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX
|
||||
|
||||
keys = sorted(k for k in os.environ if k.startswith(prefix))
|
||||
for key in keys:
|
||||
path = key[len(prefix) :].replace(EXTRA_CONFIG_VALUES_ENV_KEY_SEP, ".").lower()
|
||||
result = ConfigTree.merge_configs(
|
||||
result, ConfigFactory.parse_string(f"{path}: {os.environ[key]}")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _read_env_paths(key):
|
||||
value = getenv(EXTRA_CONFIG_PATH_ENV_KEY, DEFAULT_EXTRA_CONFIG_PATH)
|
||||
if value is None:
|
||||
return
|
||||
@@ -64,12 +84,17 @@ class BasicConfig:
|
||||
|
||||
def _load(self, verbose=True):
|
||||
extra_config_paths = self._read_env_paths(EXTRA_CONFIG_PATH_ENV_KEY) or []
|
||||
extra_config_values = self._read_extra_env_config_values()
|
||||
configs = [
|
||||
self._read_recursive(path, verbose=verbose)
|
||||
for path in [self.folder] + extra_config_paths
|
||||
]
|
||||
|
||||
self._config = reduce(
|
||||
lambda config, path: ConfigTree.merge_configs(
|
||||
config, self._read_recursive(path, verbose=verbose), copy_trees=True
|
||||
lambda last, config: ConfigTree.merge_configs(
|
||||
last, config, copy_trees=True
|
||||
),
|
||||
[self.folder] + extra_config_paths,
|
||||
configs + [extra_config_values],
|
||||
ConfigTree(),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
download {
|
||||
# Add response headers requesting no caching for served files
|
||||
disable_browser_caching: false
|
||||
|
||||
# Cache timeout to be set for downloaded files
|
||||
cache_timeout_sec: 300
|
||||
}
|
||||
|
||||
cors {
|
||||
|
||||
@@ -17,6 +17,7 @@ CORS(app, **config.get("fileserver.cors"))
|
||||
Compress(app)
|
||||
|
||||
app.config["UPLOAD_FOLDER"] = os.environ.get("TRAINS_UPLOAD_FOLDER") or DEFAULT_UPLOAD_FOLDER
|
||||
app.config["SEND_FILE_MAX_AGE_DEFAULT"] = config.get("fileserver.download.cache_timeout_sec", 5 * 60)
|
||||
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
|
||||
@@ -2,6 +2,7 @@ from jsonmodels import fields, models
|
||||
|
||||
|
||||
class Filter(models.Base):
|
||||
tags = fields.ListField([str])
|
||||
system_tags = fields.ListField([str])
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
from jsonmodels import models, fields
|
||||
|
||||
from apimodels import ListField
|
||||
from apimodels.organization import TagsRequest
|
||||
|
||||
|
||||
class ProjectReq(models.Base):
|
||||
project = fields.StringField()
|
||||
@@ -14,3 +17,7 @@ class GetHyperParamResp(models.Base):
|
||||
parameters = fields.ListField(str)
|
||||
remaining = fields.IntField()
|
||||
total = fields.IntField()
|
||||
|
||||
|
||||
class ProjectTagsRequest(TagsRequest):
|
||||
projects = ListField(str)
|
||||
|
||||
@@ -230,9 +230,25 @@ class EventBLL(object):
|
||||
metric_hash = dbutils.hash_field_name(metric)
|
||||
variant_hash = dbutils.hash_field_name(variant)
|
||||
|
||||
timestamp = last_events[metric_hash][variant_hash].get("timestamp", None)
|
||||
if timestamp is None or timestamp < event["timestamp"]:
|
||||
last_events[metric_hash][variant_hash] = event
|
||||
last_event = last_events[metric_hash][variant_hash]
|
||||
event_iter = event.get("iter", 0)
|
||||
event_timestamp = event.get("timestamp", 0)
|
||||
value = event.get("value")
|
||||
if value is not None and (
|
||||
(event_iter, event_timestamp)
|
||||
>= (
|
||||
last_event.get("iter", event_iter),
|
||||
last_event.get("timestamp", event_timestamp),
|
||||
)
|
||||
):
|
||||
event_data = {
|
||||
k: event[k]
|
||||
for k in ("value", "metric", "variant", "iter", "timestamp")
|
||||
if k in event
|
||||
}
|
||||
event_data["min_value"] = min(value, last_event.get("min_value", value))
|
||||
event_data["max_value"] = max(value, last_event.get("max_value", value))
|
||||
last_events[metric_hash][variant_hash] = event_data
|
||||
|
||||
def _update_last_metric_events_for_task(self, last_events, event):
|
||||
"""
|
||||
@@ -275,7 +291,13 @@ class EventBLL(object):
|
||||
flatten_nested_items(
|
||||
last_scalar_events,
|
||||
nesting=2,
|
||||
include_leaves=["value", "metric", "variant"],
|
||||
include_leaves=[
|
||||
"value",
|
||||
"min_value",
|
||||
"max_value",
|
||||
"metric",
|
||||
"variant",
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from typing import Sequence
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Sequence, Union, Type, Dict
|
||||
|
||||
from mongoengine import Q
|
||||
from redis import Redis
|
||||
|
||||
from config import config
|
||||
from database.model.base import GetMixin
|
||||
@@ -10,40 +14,65 @@ from redis_manager import redman
|
||||
from utilities import json
|
||||
|
||||
log = config.logger(__file__)
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
class _TagsCache:
|
||||
_tags_field = "tags"
|
||||
_system_tags_field = "system_tags"
|
||||
_settings_prefix = "services.organization"
|
||||
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
def __init__(self, db_cls: Union[Type[Model], Type[Task]], redis: Redis):
|
||||
self.db_cls = db_cls
|
||||
self.redis = redis
|
||||
|
||||
@property
|
||||
def _tags_cache_expiration_seconds(self):
|
||||
return config.get(
|
||||
f"{self._settings_prefix}.tags_cache.expiration_seconds", 3600
|
||||
)
|
||||
return config.get(f"{_settings_prefix}.tags_cache.expiration_seconds", 3600)
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_cache_key(company, field: str, filter_: Sequence[str] = None):
|
||||
filter_str = "_".join(filter_) if filter_ else ""
|
||||
return f"{field}_{company}_{filter_str}"
|
||||
|
||||
@staticmethod
|
||||
def _get_tags_from_db(company, field, filter_: Sequence[str] = None) -> set:
|
||||
def _get_tags_from_db(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
) -> set:
|
||||
query = Q(company=company)
|
||||
if filter_:
|
||||
query &= GetMixin.get_list_field_query("system_tags", filter_)
|
||||
for name, vals in filter_.items():
|
||||
if vals:
|
||||
query &= GetMixin.get_list_field_query(name, vals)
|
||||
if project:
|
||||
query &= Q(project=project)
|
||||
|
||||
tags = set()
|
||||
for cls_ in (Task, Model):
|
||||
tags |= set(cls_.objects(query).distinct(field))
|
||||
return tags
|
||||
return self.db_cls.objects(query).distinct(field)
|
||||
|
||||
def _get_tags_cache_key(
|
||||
self,
|
||||
company: str,
|
||||
field: str,
|
||||
project: str = None,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
):
|
||||
"""
|
||||
Project None means 'from all company projects'
|
||||
The key is built in the way that scanning company keys for 'all company projects'
|
||||
will not return the keys related to the particular company projects and vice versa.
|
||||
So that we can have a fine grain control on what redis keys to invalidate
|
||||
"""
|
||||
filter_str = None
|
||||
if filter_:
|
||||
filter_str = "_".join(
|
||||
["filter", *chain.from_iterable([f, *v] for f, v in filter_.items())]
|
||||
)
|
||||
key_parts = [company, project, self.db_cls.__name__, field, filter_str]
|
||||
return "_".join(filter(None, key_parts))
|
||||
|
||||
def get_tags(
|
||||
self, company, include_system: bool = False, filter_: Sequence[str] = None
|
||||
self,
|
||||
company: str,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
project: str = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get tags and optionally system tags for the company
|
||||
@@ -51,35 +80,114 @@ class OrgBLL:
|
||||
The function retrieves both cached values from Redis in one call
|
||||
and re calculates any of them if missing in Redis
|
||||
"""
|
||||
fields = [
|
||||
self._tags_field,
|
||||
*([self._system_tags_field] if include_system else []),
|
||||
fields = [self._tags_field]
|
||||
if include_system:
|
||||
fields.append(self._system_tags_field)
|
||||
redis_keys = [
|
||||
self._get_tags_cache_key(company, field=f, project=project, filter_=filter_)
|
||||
for f in fields
|
||||
]
|
||||
redis_keys = [self._get_tags_cache_key(company, f, filter_) for f in fields]
|
||||
cached = self.redis.mget(redis_keys)
|
||||
ret = {}
|
||||
for field, tag_data, key in zip(fields, cached, redis_keys):
|
||||
if tag_data is not None:
|
||||
tags = json.loads(tag_data)
|
||||
else:
|
||||
tags = list(self._get_tags_from_db(company, field, filter_))
|
||||
tags = list(self._get_tags_from_db(company, field, project, filter_))
|
||||
self.redis.setex(
|
||||
key,
|
||||
time=self._tags_cache_expiration_seconds,
|
||||
value=json.dumps(tags),
|
||||
)
|
||||
ret[field] = tags
|
||||
ret[field] = set(tags)
|
||||
|
||||
return ret
|
||||
|
||||
def update_org_tags(self, company, tags=None, system_tags=None, reset=False):
|
||||
def update_tags(self, company: str, project: str, tags=None, system_tags=None):
|
||||
"""
|
||||
Updates system tags. If reset is set then both tags and system_tags
|
||||
Updates tags. If reset is set then both tags and system_tags
|
||||
are recalculated. Otherwise only those that are not 'None'
|
||||
"""
|
||||
if reset or tags is not None:
|
||||
self.redis.delete(self._get_tags_cache_key(company, self._tags_field))
|
||||
if reset or system_tags is not None:
|
||||
self.redis.delete(
|
||||
self._get_tags_cache_key(company, self._system_tags_field)
|
||||
fields = [
|
||||
field
|
||||
for field, update in (
|
||||
(self._tags_field, tags),
|
||||
(self._system_tags_field, system_tags),
|
||||
)
|
||||
if update is not None
|
||||
]
|
||||
if not fields:
|
||||
return
|
||||
|
||||
self._delete_redis_keys(company, projects=[project], fields=fields)
|
||||
|
||||
def reset_tags(self, company: str, projects: Sequence[str]):
|
||||
self._delete_redis_keys(
|
||||
company,
|
||||
projects=projects,
|
||||
fields=(self._tags_field, self._system_tags_field),
|
||||
)
|
||||
|
||||
def _delete_redis_keys(
|
||||
self, company: str, projects: [Sequence[str]], fields: Sequence[str]
|
||||
):
|
||||
redis_keys = list(
|
||||
chain.from_iterable(
|
||||
self.redis.keys(
|
||||
self._get_tags_cache_key(company, field=f, project=p) + "*"
|
||||
)
|
||||
for f in fields
|
||||
for p in set(projects) | {None}
|
||||
)
|
||||
)
|
||||
if redis_keys:
|
||||
self.redis.delete(*redis_keys)
|
||||
|
||||
|
||||
class Tags(Enum):
|
||||
Task = "task"
|
||||
Model = "model"
|
||||
|
||||
|
||||
class OrgBLL:
|
||||
def __init__(self, redis=None):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self._task_tags = _TagsCache(Task, self.redis)
|
||||
self._model_tags = _TagsCache(Model, self.redis)
|
||||
|
||||
def get_tags(
|
||||
self,
|
||||
company: str,
|
||||
entity: Tags,
|
||||
include_system: bool = False,
|
||||
filter_: Dict[str, Sequence[str]] = None,
|
||||
projects: Sequence[str] = None,
|
||||
) -> dict:
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
if not projects:
|
||||
return tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_
|
||||
)
|
||||
|
||||
ret = defaultdict(set)
|
||||
for project in projects:
|
||||
project_tags = tags_cache.get_tags(
|
||||
company, include_system=include_system, filter_=filter_, project=project
|
||||
)
|
||||
for field, tags in project_tags.items():
|
||||
ret[field] |= tags
|
||||
|
||||
return ret
|
||||
|
||||
def update_tags(
|
||||
self, company: str, entity: Tags, project: str, tags=None, system_tags=None,
|
||||
):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.update_tags(company, project, tags, system_tags)
|
||||
|
||||
def reset_tags(self, company: str, entity: Tags, projects: Sequence[str]):
|
||||
tags_cache = self._get_tags_cache_for_entity(entity)
|
||||
tags_cache.reset_tags(company, projects=projects)
|
||||
|
||||
def _get_tags_cache_for_entity(self, entity: Tags) -> _TagsCache:
|
||||
return self._task_tags if entity == Tags.Task else self._model_tags
|
||||
|
||||
@@ -14,7 +14,7 @@ import database.utils as dbutils
|
||||
import es_factory
|
||||
from apierrors import errors
|
||||
from apimodels.tasks import Artifact as ApiArtifact
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
from database.model.model import Model
|
||||
@@ -229,7 +229,21 @@ class TaskBLL(object):
|
||||
validate_project=validate_references or project,
|
||||
)
|
||||
new_task.save()
|
||||
org_bll.update_org_tags(company_id, tags=tags, system_tags=system_tags)
|
||||
|
||||
if task.project == new_task.project:
|
||||
updated_tags = tags
|
||||
updated_system_tags = system_tags
|
||||
else:
|
||||
updated_tags = new_task.tags
|
||||
updated_system_tags = new_task.system_tags
|
||||
org_bll.update_tags(
|
||||
company_id,
|
||||
Tags.Task,
|
||||
project=new_task.project,
|
||||
tags=updated_tags,
|
||||
system_tags=updated_system_tags,
|
||||
)
|
||||
|
||||
return new_task
|
||||
|
||||
@classmethod
|
||||
@@ -346,10 +360,12 @@ class TaskBLL(object):
|
||||
return "__".join((op, "last_metrics") + path)
|
||||
|
||||
for path, value in last_scalar_values:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
if path[-1] == "value":
|
||||
if path[-1] == "min_value":
|
||||
extra_updates[op_path("min", *path[:-1], "min_value")] = value
|
||||
elif path[-1] == "max_value":
|
||||
extra_updates[op_path("max", *path[:-1], "max_value")] = value
|
||||
else:
|
||||
extra_updates[op_path("set", *path)] = value
|
||||
|
||||
if last_events is not None:
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import logging
|
||||
import os
|
||||
import platform
|
||||
from functools import reduce
|
||||
from os import getenv
|
||||
from os.path import expandvars
|
||||
@@ -15,7 +16,7 @@ from pyparsing import (
|
||||
|
||||
DEFAULT_EXTRA_CONFIG_PATH = "/opt/trains/config"
|
||||
EXTRA_CONFIG_PATH_ENV_KEY = "TRAINS_CONFIG_DIR"
|
||||
EXTRA_CONFIG_PATH_SEP = ":"
|
||||
EXTRA_CONFIG_PATH_SEP = ":" if platform.system() != "Windows" else ';'
|
||||
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_SEP = "__"
|
||||
EXTRA_CONFIG_VALUES_ENV_KEY_PREFIX = f"TRAINS{EXTRA_CONFIG_VALUES_ENV_KEY_SEP}"
|
||||
|
||||
@@ -10,22 +10,25 @@ from service_repo.auth.fixed_user import FixedUser
|
||||
|
||||
|
||||
def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: bool = False):
|
||||
ensure_credentials = {"key", "secret"}.issubset(user_data)
|
||||
if ensure_credentials:
|
||||
user = AuthUser.objects(
|
||||
credentials__match=Credentials(
|
||||
key=user_data["key"], secret=user_data["secret"]
|
||||
)
|
||||
).first()
|
||||
key, secret = user_data.get("key"), user_data.get("secret")
|
||||
if not (key and secret):
|
||||
credentials = None
|
||||
else:
|
||||
creds = Credentials(key=key, secret=secret)
|
||||
|
||||
user = AuthUser.objects(credentials__match=creds).first()
|
||||
if user:
|
||||
if revoke:
|
||||
user.credentials = []
|
||||
user.save()
|
||||
return user.id
|
||||
|
||||
credentials = [] if revoke else [creds]
|
||||
|
||||
user_id = user_data.get("id", f"__{user_data['name']}__")
|
||||
|
||||
log.info(f"Creating user: {user_data['name']}")
|
||||
|
||||
user = AuthUser(
|
||||
id=user_id,
|
||||
name=user_data["name"],
|
||||
@@ -33,9 +36,7 @@ def _ensure_auth_user(user_data: dict, company_id: str, log: Logger, revoke: boo
|
||||
role=user_data["role"],
|
||||
email=user_data["email"],
|
||||
created=datetime.utcnow(),
|
||||
credentials=[Credentials(key=user_data["key"], secret=user_data["secret"])] if not revoke else []
|
||||
if ensure_credentials
|
||||
else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
user.save()
|
||||
@@ -68,12 +69,4 @@ def ensure_fixed_user(user: FixedUser, company_id: str, log: Logger):
|
||||
|
||||
_ensure_auth_user(user_data=data, company_id=company_id, log=log)
|
||||
|
||||
given_name, _, family_name = user.name.partition(" ")
|
||||
|
||||
User(
|
||||
id=user.user_id,
|
||||
company=company_id,
|
||||
name=user.name,
|
||||
given_name=given_name,
|
||||
family_name=family_name,
|
||||
).save()
|
||||
return _ensure_backend_user(user.user_id, company_id, user.name)
|
||||
|
||||
@@ -1,43 +1,48 @@
|
||||
_description: "This service provides organization level operations"
|
||||
|
||||
get_tags {
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
"2.8" {
|
||||
description: "Get all the user and system tags used for the company tasks and models"
|
||||
request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -196,6 +196,52 @@ _definitions {
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_request {
|
||||
type: object
|
||||
properties {
|
||||
include_system {
|
||||
description: "If set to 'true' then the list of the system tags is also returned. The default value is 'false'"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
projects {
|
||||
description: "The list of projects under which the tags are searched. If not passed or empty then all the projects are searched"
|
||||
type: array
|
||||
items { type: string }
|
||||
}
|
||||
filter {
|
||||
description: "Filter on entities to collect tags from"
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of tag values to filter by. Use 'null' value to specify empty tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of system tag values to filter by. Use 'null' value to specify empty system tags. Use '__Snot' value to specify that the following value should be excluded"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
tags_response {
|
||||
type: object
|
||||
properties {
|
||||
tags {
|
||||
description: "The list of unique tag values"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
system_tags {
|
||||
description: "The list of unique system tag values. Returned only if 'include_system' is set to 'true' in the request"
|
||||
type: array
|
||||
items {type: string}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
create {
|
||||
@@ -508,7 +554,7 @@ get_hyper_parameters {
|
||||
parameters {
|
||||
description: "A list of hyper parameter names"
|
||||
type: array
|
||||
items { type: string }
|
||||
items {type: string}
|
||||
}
|
||||
remaining {
|
||||
description: "Remaining results"
|
||||
@@ -522,3 +568,17 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
get_model_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the models under the specified projects"
|
||||
request = ${_definitions.tags_request}
|
||||
response = ${_definitions.tags_response}
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import Q, EmbeddedDocument
|
||||
|
||||
@@ -12,7 +13,7 @@ from apimodels.models import (
|
||||
PublishModelResponse,
|
||||
ModelTaskPublishResponse,
|
||||
)
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from bll.task import TaskBLL
|
||||
from config import config
|
||||
from database.errors import translate_errors_context
|
||||
@@ -128,9 +129,19 @@ def parse_model_fields(call, valid_fields):
|
||||
return fields
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Model,
|
||||
project=project,
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(
|
||||
company, Tags.Model, projects=projects,
|
||||
)
|
||||
|
||||
|
||||
@@ -203,7 +214,7 @@ def update_for_task(call: APICall, company_id, _):
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
TaskBLL.update_statistics(
|
||||
task_id=task_id,
|
||||
@@ -248,7 +259,7 @@ def create(call: APICall, company_id, req_model: CreateModelRequest):
|
||||
**fields,
|
||||
)
|
||||
model.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
_update_cached_tags(company_id, project=model.project, fields=fields)
|
||||
|
||||
call.result.data_model = CreateModelResponse(id=model.id, created=True)
|
||||
|
||||
@@ -327,7 +338,15 @@ def edit(call: APICall, company_id, _):
|
||||
if fields:
|
||||
updated = model.update(upsert=False, **fields)
|
||||
if updated:
|
||||
_update_org_tags(company_id, fields)
|
||||
new_project = fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, model.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=fields
|
||||
)
|
||||
conform_output_tags(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
else:
|
||||
@@ -355,7 +374,13 @@ def _update_model(call: APICall, company_id, model_id=None):
|
||||
|
||||
updated_count, updated_fields = Model.safe_update(company_id, model.id, data)
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
new_project = updated_fields.get("project", model.project)
|
||||
if new_project != model.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, model.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=model.project, fields=updated_fields
|
||||
)
|
||||
conform_output_tags(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
|
||||
@@ -395,7 +420,7 @@ def update(call: APICall, company_id, _):
|
||||
|
||||
with translate_errors_context():
|
||||
query = dict(id=model_id, company=company_id)
|
||||
model = Model.objects(**query).only("id", "task").first()
|
||||
model = Model.objects(**query).only("id", "task", "project").first()
|
||||
if not model:
|
||||
raise errors.bad_request.InvalidModelId(**query)
|
||||
|
||||
@@ -428,5 +453,5 @@ def update(call: APICall, company_id, _):
|
||||
|
||||
del_count = Model.objects(**query).delete()
|
||||
if del_count:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
_reset_cached_tags(company_id, projects=[model.project])
|
||||
call.result.data = dict(deleted=del_count > 0)
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
from collections import defaultdict
|
||||
|
||||
from apimodels.organization import TagsRequest
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from service_repo import endpoint, APICall
|
||||
from services.utils import get_tags_filter_dictionary, get_tags_response
|
||||
|
||||
org_bll = OrgBLL()
|
||||
|
||||
|
||||
@endpoint("organization.get_tags", request_data_model=TagsRequest)
|
||||
def get_tags(call: APICall, company, request: TagsRequest):
|
||||
filter_ = request.filter.system_tags if request.filter else None
|
||||
call.result.data = org_bll.get_tags(
|
||||
company, include_system=request.include_system, filter_=filter_
|
||||
)
|
||||
filter_dict = get_tags_filter_dictionary(request.filter)
|
||||
ret = defaultdict(set)
|
||||
for entity in Tags.Model, Tags.Task:
|
||||
tags = org_bll.get_tags(
|
||||
company, entity, include_system=request.include_system, filter_=filter_dict,
|
||||
)
|
||||
for field, vals in tags.items():
|
||||
ret[field] |= vals
|
||||
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
||||
@@ -9,7 +9,13 @@ from mongoengine import Q
|
||||
import database
|
||||
from apierrors import errors
|
||||
from apimodels.base import UpdateResponse
|
||||
from apimodels.projects import GetHyperParamReq, GetHyperParamResp, ProjectReq
|
||||
from apimodels.projects import (
|
||||
GetHyperParamReq,
|
||||
GetHyperParamResp,
|
||||
ProjectReq,
|
||||
ProjectTagsRequest,
|
||||
)
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from bll.task import TaskBLL
|
||||
from database.errors import translate_errors_context
|
||||
from database.model import EntityVisibility
|
||||
@@ -18,9 +24,15 @@ from database.model.project import Project
|
||||
from database.model.task.task import Task, TaskStatus
|
||||
from database.utils import parse_from_call, get_options, get_company_or_none_constraint
|
||||
from service_repo import APICall, endpoint
|
||||
from services.utils import conform_tag_fields, conform_output_tags
|
||||
from services.utils import (
|
||||
conform_tag_fields,
|
||||
conform_output_tags,
|
||||
get_tags_filter_dictionary,
|
||||
get_tags_response,
|
||||
)
|
||||
from timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
task_bll = TaskBLL()
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
|
||||
@@ -381,3 +393,31 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
"remaining": remaining,
|
||||
"parameters": parameters,
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_task_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
ret = org_bll.get_tags(
|
||||
company,
|
||||
Tags.Task,
|
||||
include_system=request.include_system,
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
||||
|
||||
@endpoint(
|
||||
"projects.get_model_tags", min_version="2.8", request_data_model=ProjectTagsRequest
|
||||
)
|
||||
def get_tags(call: APICall, company, request: ProjectTagsRequest):
|
||||
ret = org_bll.get_tags(
|
||||
company,
|
||||
Tags.Model,
|
||||
include_system=request.include_system,
|
||||
filter_=get_tags_filter_dictionary(request.filter),
|
||||
projects=request.projects,
|
||||
)
|
||||
call.result.data = get_tags_response(ret)
|
||||
|
||||
@@ -33,7 +33,7 @@ from apimodels.tasks import (
|
||||
ResetRequest,
|
||||
)
|
||||
from bll.event import EventBLL
|
||||
from bll.organization import OrgBLL
|
||||
from bll.organization import OrgBLL, Tags
|
||||
from bll.queue import QueueBLL
|
||||
from bll.task import (
|
||||
TaskBLL,
|
||||
@@ -343,9 +343,19 @@ def validate(call: APICall, company_id, req_model: CreateRequest):
|
||||
_validate_and_get_task_from_call(call)
|
||||
|
||||
|
||||
def _update_org_tags(company, fields: dict):
|
||||
org_bll.update_org_tags(
|
||||
company, tags=fields.get("tags"), system_tags=fields.get("system_tags")
|
||||
def _update_cached_tags(company: str, project: str, fields: dict):
|
||||
org_bll.update_tags(
|
||||
company,
|
||||
Tags.Task,
|
||||
project=project,
|
||||
tags=fields.get("tags"),
|
||||
system_tags=fields.get("system_tags"),
|
||||
)
|
||||
|
||||
|
||||
def _reset_cached_tags(company: str, projects: Sequence[str]):
|
||||
org_bll.reset_tags(
|
||||
company, Tags.Task, projects=projects
|
||||
)
|
||||
|
||||
|
||||
@@ -357,7 +367,7 @@ def create(call: APICall, company_id, req_model: CreateRequest):
|
||||
|
||||
with translate_errors_context(), TimingContext("mongo", "save_task"):
|
||||
task.save()
|
||||
_update_org_tags(company_id, fields)
|
||||
_update_cached_tags(company_id, project=task.project, fields=fields)
|
||||
update_project_time(task.project)
|
||||
|
||||
call.result.data_model = IdResponse(id=task.id)
|
||||
@@ -400,7 +410,9 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
task_id = req_model.task
|
||||
|
||||
with translate_errors_context():
|
||||
task = Task.get_for_writing(id=task_id, company=company_id, _only=["id"])
|
||||
task = Task.get_for_writing(
|
||||
id=task_id, company=company_id, _only=["id", "project"]
|
||||
)
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
@@ -416,7 +428,13 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
injected_update=dict(last_update=datetime.utcnow()),
|
||||
)
|
||||
if updated_count:
|
||||
_update_org_tags(company_id, updated_fields)
|
||||
new_project = updated_fields.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
_reset_cached_tags(company_id, projects=[new_project, task.project])
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=task.project, fields=updated_fields
|
||||
)
|
||||
update_project_time(updated_fields.get("project"))
|
||||
unprepare_from_saved(call, updated_fields)
|
||||
return UpdateResponse(updated=updated_count, fields=updated_fields)
|
||||
@@ -470,8 +488,10 @@ def update_batch(call: APICall, company_id, _):
|
||||
now = datetime.utcnow()
|
||||
|
||||
bulk_ops = []
|
||||
updated_projects = set()
|
||||
for id, data in items.items():
|
||||
fields, valid_fields = prepare_update_fields(call, tasks[id], data)
|
||||
task = tasks[id]
|
||||
fields, valid_fields = prepare_update_fields(call, task, data)
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
@@ -481,12 +501,20 @@ def update_batch(call: APICall, company_id, _):
|
||||
)
|
||||
bulk_ops.append(update_op)
|
||||
|
||||
new_project = partial_update_dict.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
updated_projects.update({new_project, task.project})
|
||||
elif any(f in partial_update_dict for f in ("tags", "system_tags")):
|
||||
updated_projects.add(task.project)
|
||||
|
||||
updated = 0
|
||||
if bulk_ops:
|
||||
res = Task._get_collection().bulk_write(bulk_ops)
|
||||
updated = res.modified_count
|
||||
if updated:
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
|
||||
if updated and updated_projects:
|
||||
_reset_cached_tags(company_id, projects=list(updated_projects))
|
||||
|
||||
call.result.data = {"updated": updated}
|
||||
|
||||
|
||||
@@ -542,7 +570,15 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
fixed_fields.update(last_update=now)
|
||||
updated = task.update(upsert=False, **fixed_fields)
|
||||
if updated:
|
||||
_update_org_tags(company_id, fixed_fields)
|
||||
new_project = fixed_fields.get("project", task.project)
|
||||
if new_project != task.project:
|
||||
_reset_cached_tags(
|
||||
company_id, projects=[new_project, task.project]
|
||||
)
|
||||
else:
|
||||
_update_cached_tags(
|
||||
company_id, project=task.project, fields=fixed_fields
|
||||
)
|
||||
update_project_time(fields.get("project"))
|
||||
unprepare_from_saved(call, fields)
|
||||
call.result.data_model = UpdateResponse(updated=updated, fields=fields)
|
||||
@@ -710,12 +746,11 @@ def reset(call: APICall, company_id, request: ResetRequest):
|
||||
|
||||
if request.clear_all:
|
||||
updates.update(
|
||||
set__execution=Execution(),
|
||||
unset__script=1,
|
||||
set__execution=Execution(), unset__script=1,
|
||||
)
|
||||
else:
|
||||
updates.update(unset__execution__queue=1)
|
||||
updates.update(
|
||||
unset__execution__queue=1,
|
||||
__raw__={"$pull": {"execution.artifacts": {"mode": {"$ne": "input"}}}},
|
||||
)
|
||||
|
||||
@@ -909,7 +944,8 @@ def delete(call: APICall, company_id, req_model: DeleteRequest):
|
||||
task.switch_collection(collection_name)
|
||||
|
||||
task.delete()
|
||||
org_bll.update_org_tags(company_id, reset=True)
|
||||
_reset_cached_tags(company_id, projects=[task.project])
|
||||
|
||||
call.result.data = dict(deleted=True, **attr.asdict(result))
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
from typing import Union, Sequence, Tuple
|
||||
|
||||
from apierrors import errors
|
||||
from apimodels.organization import Filter
|
||||
from database.model.base import GetMixin
|
||||
from database.utils import partition_tags
|
||||
from service_repo import APICall
|
||||
from service_repo.base import PartialVersion
|
||||
|
||||
|
||||
def get_tags_filter_dictionary(input_: Filter) -> dict:
|
||||
if not input_:
|
||||
return {}
|
||||
|
||||
return {
|
||||
field: vals
|
||||
for field, vals in (("tags", input_.tags), ("system_tags", input_.system_tags))
|
||||
if vals
|
||||
}
|
||||
|
||||
|
||||
def get_tags_response(ret: dict) -> dict:
|
||||
return {field: sorted(vals) for field, vals in ret.items()}
|
||||
|
||||
|
||||
def conform_output_tags(call: APICall, documents: Union[dict, Sequence[dict]]):
|
||||
"""
|
||||
For old clients both tags and system tags are returned in 'tags' field
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestOrganization(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_tags(self):
|
||||
tag1 = "Orgtest tag1"
|
||||
tag2 = "Orgtest tag2"
|
||||
system_tag = "Orgtest system tag"
|
||||
|
||||
model = self.create_temp(
|
||||
"models", name="test_org", uri="file:///a", tags=[tag1]
|
||||
)
|
||||
task = self.create_temp(
|
||||
"tasks", name="test org", type="training", input=dict(view={}), tags=[tag1]
|
||||
)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
|
||||
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
|
||||
data = self.api.organization.get_tags(include_system=True)
|
||||
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
|
||||
self.assertTrue(system_tag in data.system_tags)
|
||||
|
||||
data = self.api.organization.get_tags(
|
||||
filter={"system_tags": ["__$not", system_tag]}
|
||||
)
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
self.assertFalse(tag2 in data.tags)
|
||||
|
||||
self.api.models.delete(model=model)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
||||
82
server/tests/automated/test_project_tags.py
Normal file
82
server/tests/automated/test_project_tags.py
Normal file
@@ -0,0 +1,82 @@
|
||||
from tests.automated import TestService
|
||||
|
||||
|
||||
class TestProjectTags(TestService):
|
||||
def setUp(self, version="2.8"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_project_tags(self):
|
||||
tags_1 = ["Test tag 1", "Test tag 2"]
|
||||
tags_2 = ["Test tag 3", "Test tag 4"]
|
||||
|
||||
p1 = self.create_temp("projects", name="Test tags1", description="test")
|
||||
task1_1 = self.new_task(project=p1, tags=tags_1[:1])
|
||||
task1_2 = self.new_task(project=p1, tags=tags_1[1:])
|
||||
|
||||
p2 = self.create_temp("projects", name="Test tasks2", description="test")
|
||||
task2 = self.new_task(project=p2, tags=tags_2)
|
||||
|
||||
# test tags per project
|
||||
data = self.api.projects.get_task_tags(projects=[p1])
|
||||
self.assertEqual(set(tags_1), set(data.tags))
|
||||
data = self.api.projects.get_model_tags(projects=[p1])
|
||||
self.assertEqual(set(), set(data.tags))
|
||||
data = self.api.projects.get_task_tags(projects=[p2])
|
||||
self.assertEqual(set(tags_2), set(data.tags))
|
||||
|
||||
# test tags for projects list
|
||||
data = self.api.projects.get_task_tags(projects=[p1, p2])
|
||||
self.assertEqual(set(tags_1) | set(tags_2), set(data.tags))
|
||||
|
||||
# test tags for all projects
|
||||
data = self.api.projects.get_task_tags(projects=[p1, p2])
|
||||
self.assertTrue((set(tags_1) | set(tags_2)).issubset(data.tags))
|
||||
|
||||
# test move to another project
|
||||
self.api.tasks.edit(task=task1_2, project=p2)
|
||||
data = self.api.projects.get_task_tags(projects=[p1])
|
||||
self.assertEqual(set(tags_1[:1]), set(data.tags))
|
||||
data = self.api.projects.get_task_tags(projects=[p2])
|
||||
self.assertEqual(set(tags_1[1:]) | set(tags_2), set(data.tags))
|
||||
|
||||
# test tags update
|
||||
self.api.tasks.delete(task=task1_1, force=True)
|
||||
self.api.tasks.delete(task=task2, force=True)
|
||||
data = self.api.projects.get_task_tags(projects=[p1, p2])
|
||||
self.assertEqual(set(tags_1[1:]), set(data.tags))
|
||||
|
||||
def test_organization_tags(self):
|
||||
tag1 = "Orgtest tag1"
|
||||
tag2 = "Orgtest tag2"
|
||||
system_tag = "Orgtest system tag"
|
||||
|
||||
model = self.new_model(tags=[tag1])
|
||||
task = self.new_task(tags=[tag1])
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
|
||||
self.api.tasks.edit(task=task, tags=[tag2], system_tags=[system_tag])
|
||||
data = self.api.organization.get_tags(include_system=True)
|
||||
self.assertTrue({tag1, tag2}.issubset(set(data.tags)))
|
||||
self.assertTrue(system_tag in data.system_tags)
|
||||
|
||||
data = self.api.organization.get_tags(
|
||||
filter={"system_tags": ["__$not", system_tag]}
|
||||
)
|
||||
self.assertTrue(tag1 in data.tags)
|
||||
self.assertFalse(tag2 in data.tags)
|
||||
|
||||
self.api.models.delete(model=model)
|
||||
data = self.api.organization.get_tags()
|
||||
self.assertFalse(tag1 in data.tags)
|
||||
self.assertTrue(tag2 in data.tags)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
self.update_missing(
|
||||
kwargs, type="testing", name="test project tags", input=dict(view=dict())
|
||||
)
|
||||
return self.create_temp("tasks", **kwargs)
|
||||
|
||||
def new_model(self, **kwargs):
|
||||
self.update_missing(kwargs, name="test project tags", uri="file:///a")
|
||||
return self.create_temp("models", **kwargs)
|
||||
@@ -8,6 +8,8 @@ from functools import partial
|
||||
from statistics import mean
|
||||
from typing import Sequence
|
||||
|
||||
from boltons.iterutils import first
|
||||
|
||||
import es_factory
|
||||
from apierrors.errors.bad_request import EventsNotAdded
|
||||
from tests.automated import TestService
|
||||
@@ -72,6 +74,31 @@ class TestTaskEvents(TestService):
|
||||
),
|
||||
)
|
||||
|
||||
def test_last_scalar_metrics(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
# send 2 batches to check the interaction with already stored db value
|
||||
# each batch contains multiple iterations
|
||||
self.send_batch(events[:50])
|
||||
self.send_batch(events[50:])
|
||||
|
||||
task_data = self.api.tasks.get_by_id(task=task).task
|
||||
metric_data = first(first(task_data.last_metrics.values()).values())
|
||||
self.assertEqual(iter_count - 1, metric_data.value)
|
||||
self.assertEqual(iter_count - 1, metric_data.max_value)
|
||||
self.assertEqual(0, metric_data.min_value)
|
||||
|
||||
def test_task_debug_images(self):
|
||||
task = self._temp_task()
|
||||
metric = "Metric1"
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.15.0"
|
||||
__version__ = "0.15.1"
|
||||
|
||||
Reference in New Issue
Block a user