Compare commits

65 Commits
1.0.2 ... 1.2.0

Author SHA1 Message Date
allegroai
fc4fd9e61c Version bump to v1.2.0 2022-02-14 15:26:27 +02:00
allegroai
8908c7dcf9 Update driver requirements
Refactor ES initialization
2022-02-13 20:27:12 +02:00
allegroai
b9996e2c1a Protect against multiple connects to the update server from different processes
Code cleanup
2022-02-13 20:12:12 +02:00
allegroai
afdc56f37c Use task active duration for worker task running time 2022-02-13 20:01:47 +02:00
allegroai
a25cd5dae8 Fix version conflicts when deleting task events cause an error 2022-02-13 20:01:25 +02:00
allegroai
447adb9090 Add support for credentials label
Support no_scroll in events.get_task_plots
Support better project stats
Fix Redis required on mongodb initialization
Update tests
2022-02-13 19:59:58 +02:00
allegroai
92fd98d5ad Add support for lists and nested fields in URL args and form 2022-02-13 19:52:05 +02:00
allegroai
c4001b4037 Add Redis cluster support
Fix for lru_cache usage
2022-02-13 19:48:26 +02:00
allegroai
970a32287a Add Redis password support 2022-02-13 19:37:52 +02:00
allegroai
17cd48dada Add support for override cookie domains
Support for community invitation alarms
Remove duplicate property
Add query optimizations
2022-02-13 19:35:35 +02:00
allegroai
ea3b6e955f Optimize nested_get() 2022-02-13 19:32:22 +02:00
allegroai
843450bb9b Fix add_or_update_artifacts should always be allowed on in_progress tasks
Fix delete_artifacts should always be allowed on in_progress tasks
Fix query code
2022-02-13 19:31:54 +02:00
allegroai
e149af58b1 Support for additional mata data in api call response 2022-02-13 19:30:36 +02:00
allegroai
604a38035b Add organization.update_company_name
Fix unit-tests
2022-02-13 19:29:46 +02:00
allegroai
cae38a365b Fix base query building
Fix schema
Improve events.scalar_metrics_iter_raw implementation
2022-02-13 19:28:23 +02:00
allegroai
e334246b46 Add support for project stats with children flag 2022-02-13 19:26:47 +02:00
allegroai
36e013b40c Add support for events.scalar_metrics_iter_raw 2022-02-13 19:26:03 +02:00
allegroai
f20cd6536e Add scroll support to *.get_* 2022-02-13 19:23:29 +02:00
allegroai
446bd35006 Refactor debug images response, model ORM 2022-02-13 19:21:07 +02:00
allegroai
a377a7e315 Support status_message and status_reason in tasks.delete 2022-02-13 19:20:31 +02:00
allegroai
3d046ac282 Fix project should not be merged into itself 2022-02-13 19:18:08 +02:00
allegroai
a08fa9a0e1 Add missing API Errors 2022-02-13 19:16:58 +02:00
allegroai
5856ed2836 Update Model.last_update on changes to tags and system tags 2022-02-13 19:15:37 +02:00
allegroai
d295355d99 Better logger name if called from __init__.py 2022-02-13 19:15:10 +02:00
pollfly
77350f6119 Fix link (#104) 2022-01-27 12:15:55 +02:00
Niels ten Boom
bc2c2ebbfd Add connection string functionality for MongoDB access (#102) 2022-01-08 12:07:59 +02:00
allegroai
1502e02a1a Update ES version to 7.16.2 2021-12-22 13:53:34 +02:00
allegroai
d0e2313a24 Update README regarding CVE-2021-45046 2021-12-15 15:51:18 +02:00
allegroai
d8ba1a8ea7 Fix README 2021-12-14 15:52:53 +02:00
allegroai
ca7937fc4e Fix README 2021-12-14 15:50:30 +02:00
allegroai
df89bcceef Update README with a note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:48:54 +02:00
allegroai
cfccbe05c1 Add precautionary mitigation for Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31 2021-12-14 15:15:11 +02:00
Théo Mathieu
e352a6a1e7 Fix elasticsearch authentication when initializing (#98) 2021-12-05 09:55:06 +02:00
Théo Mathieu
8a3d992aaf Support MongoDB SRV endpoints (#96) 2021-12-02 10:07:33 +02:00
allegroai
c37f3d8d5b Fix set() not supported in ConfigTree()
Add user/pass config support
2021-11-15 18:33:49 +02:00
allegroai
a96870e092 Add admonition in case only username or password were provided 2021-11-15 15:19:07 +02:00
allegroai
6bf1032237 Rename back to docker-compose.yml 2021-11-15 15:13:09 +02:00
Weixiao Huang
3d816c747d Add ES http_auth credentials support (#93)
Also update ES and MongoDB versions and fix nginx configuration bug

Co-authored-by: huangweixiao <huangweixiao@megvii.com>
2021-11-15 15:01:27 +02:00
Jake Henning
3f2b96266b Merge pull request #91 from valeriano-manassero/fix-dockerfile-chmod
Fix chmod for file copy in Dockerfile
2021-10-19 11:35:00 +03:00
Valeriano Manassero
22b16d12eb fix chmod for file copy 2021-10-19 09:06:51 +02:00
allegroai
c55b6f30df Add Dockerfile 2021-10-18 16:52:17 +03:00
allegroai
b7045d3d28 Fix docker-compose escaping 2021-10-18 16:49:51 +03:00
Jake Henning
e31a404885 Remove README mentions of demo server (#90) 2021-10-10 16:32:51 +03:00
Revital
643588b71a edit README mention of demo server 2021-10-10 11:27:44 +03:00
Jake Henning
a64c4d264d Merge pull request #82 from IgorKasianenko/IgorKasianenko-patch-1
Fix typo TRAINS > CLEARML for env variables in README
2021-08-12 11:47:20 +03:00
Igor Kasianenko
567780e188 Fix typo TRAINS > CLEARML for env variables 2021-08-11 16:21:02 +03:00
allegroai
1bc8529d83 Version bump 2021-08-05 16:46:29 +03:00
allegroai
6b480d7e87 Fix file server GET response for gzipped data-files contains Content-Encoding: gz header, causing clients to automatically decompress the file 2021-08-05 16:46:25 +03:00
allegroai
083fd315e9 Fix server error when running with non-migrated v0.16 ElasticSearch data 2021-08-05 16:46:05 +03:00
Jake Henning
ef20e76174 Update README with artifact.io badge 2021-07-27 19:53:41 +03:00
Jake Henning
8c8910808e Merge pull request #80 from pollfly/master
Fix README links
2021-07-27 12:58:45 +03:00
Revital
f6ad379310 link to clear.ml docs in readme, add image 2021-07-27 12:54:41 +03:00
allegroai
c5d6ce3e65 Version bump 2021-07-25 14:40:57 +03:00
allegroai
694dbc31c4 Fix incorrect ES query (merge issue) 2021-07-25 14:40:49 +03:00
allegroai
6488dc54e6 Better handling of stack trace report on 500 error 2021-07-25 14:39:59 +03:00
allegroai
158da9b480 Allow setting status_message in tasks.update
Optimizations and refactoring
2021-07-25 14:35:36 +03:00
allegroai
ec2e071ab7 Fix mongoengine cannot handle field name with leading or trailing "_" when used in fields query within get_all endpoints 2021-07-25 14:34:04 +03:00
allegroai
465e270342 Fix queued task is not dequeued on tasks.stop 2021-07-25 14:32:09 +03:00
allegroai
6705aff56f Allow requesting plots and iter_histograms for all variants 2021-07-25 14:30:38 +03:00
allegroai
9069cfe1da Support querying task events per specific metrics and variants 2021-07-25 14:29:41 +03:00
allegroai
677bb3ba6d Add force parameter to tasks.enqueue 2021-07-25 14:27:46 +03:00
allegroai
cb253cff9e Don't use special characters in secrets 2021-07-25 14:26:49 +03:00
allegroai
39ceb5ac5c Fix pre-populate logic to avoid overriding existing users 2021-07-25 14:26:31 +03:00
allegroai
d4edeaaf1b Add projects.validate_delete 2021-07-25 14:17:29 +03:00
allegroai
56aea1ffb8 Fix filtering on hyperparams (https://github.com/allegroai/clearml/issues/385, https://clearml.slack.com/archives/CTK20V944/p1626600582284700) 2021-07-25 13:55:09 +03:00
87 changed files with 3869 additions and 2151 deletions

1
.gitignore vendored
View File

@@ -12,7 +12,6 @@ test-reports
.pytest_cache
venv
*.noseids
build
*.egg-info
.cache
.mypy_cache

View File

@@ -8,28 +8,43 @@
[![GitHub license](https://img.shields.io/badge/license-SSPL-green.svg)](https://img.shields.io/badge/license-SSPL-green.svg)
[![Python versions](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
[![GitHub version](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/allegroai)](https://artifacthub.io/packages/search?repo=allegroai)
</div>
---
<div align="center">
**v0.16 Upgrade Notice**
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
</div>
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
due to Elasticsearchs usage of the Java Security Manager.
Follow [this procedure](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_es7_migration.html) to migrate existing data.
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
(which in any case **does not permit access to data within the Elasticsearch cluster**),
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
---
### ClearML Server
## ClearML Server
#### *Formerly known as Trains Server*
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
It allows multiple users to collaborate and manage their experiments.
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
The **ClearML Server** contains the following components:
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
## System design
![Alt Text](https://allegro.ai/clearml/docs/_images/ClearML_Server_Diagram.png)
![Alt Text](docs/ClearML_Server_Diagram.png)
The **ClearML Server** has two supported configurations:
- Single IP (domain) with the following open ports
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
Launch The **ClearML Server** in any of the following formats:
- Pre-built [AWS EC2 AMI](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_aws_ec2_ami.html)
- Pre-built [GCP Custom Image](hhttps://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_gcp.html)
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
- Pre-built Docker Image
- [Linux](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
- [macOS](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
- [Windows 10](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_win.html)
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
- Kubernetes
- [Kubernetes Helm](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes_helm.html)
- Manual [Kubernetes installation](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes.html)
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
## Connecting ClearML to your ClearML Server
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
To have the **ClearML** client use your **ClearML Server** instead:
In order to set up the **ClearML** client to work with your **ClearML Server**:
- Run the `clearml-init` command for an interactive setup.
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
## Restarting ClearML Server
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
```
1. Configure the ClearML-Agent Services (not supported on Windows installation).
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
```bash
export TRAINS_HOST_IP=server_host_ip_here
export TRAINS_AGENT_GIT_USER=git_username_here
export TRAINS_AGENT_GIT_PASS=git_password_here
export CLEARML_HOST_IP=server_host_ip_here
export CLEARML_AGENT_GIT_USER=git_username_here
export CLEARML_AGENT_GIT_PASS=git_password_here
```
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
@@ -205,12 +219,12 @@ To upgrade your existing **ClearML Server** deployment:
docker-compose -f docker-compose.yml up
```
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
## Community & Support
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).

View File

@@ -71,6 +71,8 @@
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
# Queues
701: ["invalid_queue_id", "invalid queue id"]

View File

@@ -75,6 +75,7 @@ class CreateUserResponse(Base):
class Credentials(Base):
access_key = StringField(required=True)
secret_key = StringField(required=True)
label = StringField()
class CredentialsResponse(Credentials):
@@ -82,6 +83,10 @@ class CredentialsResponse(Credentials):
last_used = DateTimeField(default=None)
class CreateCredentialsRequest(Base):
label = StringField()
class CreateCredentialsResponse(Base):
credentials = EmbeddedField(Credentials)

View File

@@ -2,7 +2,7 @@ from enum import auto
from typing import Sequence, Optional
from jsonmodels import validators
from jsonmodels.fields import StringField, BoolField
from jsonmodels.fields import StringField, BoolField, EmbeddedField
from jsonmodels.models import Base
from jsonmodels.validators import Length, Min, Max
@@ -14,12 +14,18 @@ from apiserver.utilities.stringenum import StringEnum
class HistogramRequestBase(Base):
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
class MetricVariants(Base):
metric: str = StringField(required=True)
variants: Sequence[str] = ListField(items_types=str)
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
task: str = StringField(required=True)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
@@ -39,6 +45,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
class TaskMetric(Base):
task: str = StringField(required=True)
metric: str = StringField(default=None)
variants: Sequence[str] = ListField(items_types=str)
class DebugImagesRequest(Base):
@@ -59,8 +66,8 @@ class TaskMetricVariant(Base):
class GetDebugImageSampleRequest(TaskMetricVariant):
iteration: Optional[int] = IntField()
scroll_id: Optional[str] = StringField()
refresh: bool = BoolField(default=False)
scroll_id: Optional[str] = StringField()
class NextDebugImageSampleRequest(Base):
@@ -74,14 +81,34 @@ class LogOrderEnum(StringEnum):
desc = auto()
class LogEventsRequest(Base):
class TaskEventsRequestBase(Base):
task: str = StringField(required=True)
batch_size: int = IntField(default=500)
class TaskEventsRequest(TaskEventsRequestBase):
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
scroll_id: str = StringField()
count_total: bool = BoolField(default=True)
class LogEventsRequest(TaskEventsRequestBase):
batch_size: int = IntField(default=5000)
navigate_earlier: bool = BoolField(default=True)
from_timestamp: Optional[int] = IntField()
order: Optional[str] = ActualEnumField(LogOrderEnum)
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
batch_size: int = IntField()
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
count_total: bool = BoolField(default=False)
scroll_id: str = StringField()
class IterationEvents(Base):
iter: int = IntField()
events: Sequence[dict] = ListField(items_types=dict)
@@ -102,3 +129,11 @@ class TaskMetricsRequest(Base):
items_types=str, validators=[Length(minimum_value=1)]
)
event_type: EventType = ActualEnumField(EventType, required=True)
class TaskPlotsRequest(Base):
task: str = StringField(required=True)
iters: int = IntField(default=1)
scroll_id: str = StringField()
no_scroll: bool = BoolField(default=False)
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)

View File

@@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
include_subprojects = fields.BoolField(default=True)
class GetHyperParamRequest(ProjectOrNoneRequest):
class GetParamsRequest(ProjectOrNoneRequest):
page = fields.IntField(default=0)
page_size = fields.IntField(default=500)
@@ -53,6 +53,7 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
class ProjectsGetRequest(models.Base):
include_stats = fields.BoolField(default=False)
stats_with_children = fields.BoolField(default=True)
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
non_public = fields.BoolField(default=False)
active_users = fields.ListField(str)

View File

@@ -2,7 +2,11 @@ from datetime import datetime
from apiserver import database
from apiserver.apierrors import errors
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
from apiserver.apimodels.auth import (
GetTokenResponse,
CreateUserRequest,
Credentials as CredModel,
)
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
from apiserver.bll.user import UserBLL
from apiserver.config_repo import config
@@ -57,6 +61,7 @@ class AuthBLL:
api_version=str(ServiceRepo.max_endpoint_version()),
server_version=str(get_version()),
server_build=str(get_build_number()),
feature_set="basic",
)
return GetTokenResponse(token=token.decode("ascii"))
@@ -144,7 +149,7 @@ class AuthBLL:
@classmethod
def create_credentials(
cls, user_id: str, company_id: str, role: str = None
cls, user_id: str, company_id: str, role: str = None, label: str = None,
) -> CredModel:
with translate_errors_context():
@@ -153,7 +158,9 @@ class AuthBLL:
if not user:
raise errors.bad_request.InvalidUserId(**query)
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
cred = CredModel(
access_key=get_client_id(), secret_key=get_secret_key(), label=label
)
user.credentials.append(
Credentials(key=cred.access_key, secret=cred.secret_key)
)

View File

@@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime
from functools import partial
from operator import itemgetter
from typing import Sequence, Tuple, Optional, Mapping, Set
from typing import Sequence, Tuple, Optional, Mapping
import attr
import dpath
@@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
get_metric_variants_condition,
)
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.database.errors import translate_errors_context
@@ -74,7 +75,7 @@ class DebugImagesIterator:
def get_task_events(
self,
company_id: str,
task_metrics: Mapping[str, Set[str]],
task_metrics: Mapping[str, dict],
iter_count: int,
navigate_earlier: bool = True,
refresh: bool = False,
@@ -118,7 +119,7 @@ class DebugImagesIterator:
self,
company_id,
state: DebugImageEventsScrollState,
task_metrics: Mapping[str, Set[str]],
task_metrics: Mapping[str, dict],
):
"""
Determine the metrics for which new debug image events were added
@@ -158,11 +159,11 @@ class DebugImagesIterator:
task_metrics_to_recalc = {}
for task, metrics_times in update_times.items():
old_metric_states = task_metric_states[task]
metrics_to_recalc = set(
m
metrics_to_recalc = {
m: task_metrics[task].get(m)
for m, t in metrics_times.items()
if m not in old_metric_states or old_metric_states[m].timestamp < t
)
}
if metrics_to_recalc:
task_metrics_to_recalc[task] = metrics_to_recalc
@@ -196,7 +197,7 @@ class DebugImagesIterator:
]
def _init_task_states(
self, company_id: str, task_metrics: Mapping[str, Set[str]]
self, company_id: str, task_metrics: Mapping[str, dict]
) -> Sequence[TaskScrollState]:
"""
Returned initialized metric scroll stated for the requested task metrics
@@ -213,7 +214,7 @@ class DebugImagesIterator:
]
def _init_metric_states_for_task(
self, task_metrics: Tuple[str, Set[str]], company_id: str
self, task_metrics: Tuple[str, dict], company_id: str
) -> Sequence[MetricState]:
"""
Return metric scroll states for the task filled with the variant states
@@ -222,10 +223,11 @@ class DebugImagesIterator:
task, metrics = task_metrics
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
if metrics:
must.append({"terms": {"metric": list(metrics)}})
must.append(get_metric_variants_condition(metrics))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"query": {"bool": {"must": must}},
"query": query,
"aggs": {
"metrics": {
"terms": {

View File

@@ -6,9 +6,8 @@ from collections import defaultdict
from contextlib import closing
from datetime import datetime
from operator import attrgetter
from typing import Sequence, Set, Tuple, Optional, Dict
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
import six
from elasticsearch import helpers
from elasticsearch.helpers import BulkIndexError
from mongoengine import Q
@@ -22,14 +21,16 @@ from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
delete_company_events,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
from apiserver.bll.util import parallel_chunked_decorator
from apiserver.database import utils as dbutils
from apiserver.es_factory import es_factory
from apiserver.apierrors import errors
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
from apiserver.bll.event.event_metrics import EventMetrics
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
@@ -43,8 +44,8 @@ from apiserver.utilities.json import loads
# noinspection PyTypeChecker
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
MAX_LONG = 2**63 - 1
MIN_LONG = -2**63
MAX_LONG = 2 ** 63 - 1
MIN_LONG = -(2 ** 63)
class PlotFields:
@@ -72,7 +73,7 @@ class EventBLL(object):
self.redis = redis or redman.connection("apiserver")
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
self.log_events_iterator = LogEventsIterator(es=self.es)
self.events_iterator = EventsIterator(es=self.es)
@property
def metrics(self) -> EventMetrics:
@@ -94,7 +95,7 @@ class EventBLL(object):
def add_events(
self, company_id, events, worker, allow_locked_tasks=False
) -> Tuple[int, int, dict]:
actions = []
actions: List[dict] = []
task_ids = set()
task_iteration = defaultdict(lambda: 0)
task_last_scalar_events = nested_dict(
@@ -197,7 +198,6 @@ class EventBLL(object):
actions.append(es_action)
action: Dict[dict]
plot_actions = [
action["_source"]
for action in actions
@@ -260,7 +260,8 @@ class EventBLL(object):
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
if invalid_iterations_count:
raise BulkIndexError(
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
f"{invalid_iterations_count} document(s) failed to index.",
[invalid_iteration_error],
)
if not added:
@@ -466,10 +467,16 @@ class EventBLL(object):
task_id: str,
num_last_iterations: int,
event_type: EventType,
metric_variants: MetricVariants = None,
):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req: dict = {
"size": 0,
"aggs": {
@@ -499,7 +506,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@@ -527,6 +534,8 @@ class EventBLL(object):
sort=None,
size: int = 500,
scroll_id: str = None,
no_scroll: bool = False,
metric_variants: MetricVariants = None,
):
if scroll_id == self.empty_scroll:
return TaskEventsResult()
@@ -555,6 +564,8 @@ class EventBLL(object):
if last_iterations_per_plot is None:
must.append({"terms": {"task": tasks}})
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
else:
should = []
for i, task_id in enumerate(tasks):
@@ -563,6 +574,7 @@ class EventBLL(object):
task_id=task_id,
num_last_iterations=last_iterations_per_plot,
event_type=event_type,
metric_variants=metric_variants,
)
if not last_iters:
continue
@@ -600,7 +612,7 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@@ -669,19 +681,20 @@ class EventBLL(object):
sort=None,
size=500,
scroll_id=None,
):
no_scroll=False,
) -> TaskEventsResult:
if scroll_id == self.empty_scroll:
return [], scroll_id, 0
return TaskEventsResult()
if scroll_id:
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
else:
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return TaskEventsResult()
task_ids = [task_id] if isinstance(task_id, str) else task_id
must = []
if metric:
must.append({"term": {"metric": metric}})
@@ -691,26 +704,24 @@ class EventBLL(object):
if last_iter_count is None:
must.append({"terms": {"task": task_ids}})
else:
should = []
for i, task_id in enumerate(task_ids):
last_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_id,
iters=last_iter_count,
)
if not last_iters:
continue
should.append(
{
"bool": {
"must": [
{"term": {"task": task_id}},
{"terms": {"iter": last_iters}},
]
}
tasks_iters = self.get_last_iters(
company_id=company_id,
event_type=event_type,
task_id=task_ids,
iters=last_iter_count,
)
should = [
{
"bool": {
"must": [
{"term": {"task": task}},
{"terms": {"iter": last_iters}},
]
}
)
}
for task, last_iters in tasks_iters.items()
if last_iters
]
if not should:
return TaskEventsResult()
must.append({"bool": {"should": should}})
@@ -731,7 +742,7 @@ class EventBLL(object):
event_type=event_type,
body=es_req,
ignore=404,
scroll="1h",
**({} if no_scroll else {"scroll": "1h"}),
)
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
@@ -748,6 +759,7 @@ class EventBLL(object):
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
es_req = {
"size": 0,
"aggs": {
@@ -768,7 +780,7 @@ class EventBLL(object):
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": query,
}
with translate_errors_context(), TimingContext(
@@ -787,21 +799,24 @@ class EventBLL(object):
return metrics
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
def get_task_latest_scalar_values(
self, company_id, task_id
) -> Tuple[Sequence[dict], int]:
event_type = EventType.metrics_scalar
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return {}
return [], 0
query = {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
}
es_req = {
"size": 0,
"query": {
"bool": {
"must": [
{"query_string": {"query": "value:>0"}},
{"term": {"task": task_id}},
]
}
},
"query": query,
"aggs": {
"metrics": {
"terms": {
@@ -905,34 +920,47 @@ class EventBLL(object):
return iterations, vectors
def get_last_iters(
self, company_id: str, event_type: EventType, task_id: str, iters: int
):
self,
company_id: str,
event_type: EventType,
task_id: Union[str, Sequence[str]],
iters: int,
) -> Mapping[str, Sequence]:
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
return []
return {}
task_ids = [task_id] if isinstance(task_id, str) else task_id
es_req: dict = {
"size": 0,
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
"tasks": {
"terms": {"field": "task"},
"aggs": {
"iters": {
"terms": {
"field": "iter",
"size": iters,
"order": {"_key": "desc"},
}
}
},
}
},
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
}
with translate_errors_context(), TimingContext("es", "task_last_iter"):
es_res = search_company_events(
self.es, company_id=company_id, event_type=event_type, body=es_req
self.es, company_id=company_id, event_type=event_type, body=es_req,
)
if "aggregations" not in es_res:
return []
return {}
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
return {
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
for tb in es_res["aggregations"]["tasks"]["buckets"]
}
def delete_task_events(self, company_id, task_id, allow_locked=False):
with translate_errors_context():
@@ -965,7 +993,9 @@ class EventBLL(object):
so it should be checked by the calling code
"""
es_req = {"query": {"terms": {"task": task_ids}}}
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
with translate_errors_context(), TimingContext(
"es", "delete_multi_tasks_events"
):
es_res = delete_company_events(
es=self.es,
company_id=company_id,

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Union, Sequence
from typing import Union, Sequence, Mapping
from boltons.typeutils import classproperty
from elasticsearch import Elasticsearch
@@ -16,6 +16,9 @@ class EventType(Enum):
all = "*"
MetricVariants = Mapping[str, Sequence[str]]
class EventSettings:
@classproperty
def max_workers(self):
@@ -63,4 +66,31 @@ def delete_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.delete_by_query(index=es_index, body=body, **kwargs)
return es.delete_by_query(
index=es_index, body=body, conflicts="proceed", **kwargs
)
def count_company_events(
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
) -> dict:
es_index = get_index_name(company_id, event_type.value)
return es.count(index=es_index, body=body, **kwargs)
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
conditions = [
{
"bool": {
"must": [
{"term": {"metric": metric}},
{"terms": {"variant": variants}},
]
}
}
if variants
else {"term": {"metric": metric}}
for metric, variants in metric_variants.items()
]
return {"bool": {"should": conditions}}

View File

@@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
EventSettings,
search_company_events,
check_empty_data,
MetricVariants,
get_metric_variants_condition,
)
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
from apiserver.config_repo import config
@@ -34,7 +36,12 @@ class EventMetrics:
self.es = es
def get_scalar_metrics_average_per_iter(
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
self,
company_id: str,
task_id: str,
samples: int,
key: ScalarKeyEnum,
metric_variants: MetricVariants = None,
) -> dict:
"""
Get scalar metric histogram per metric and variant
@@ -46,7 +53,12 @@ class EventMetrics:
return {}
return self._get_scalar_average_per_iter_core(
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
task_id=task_id,
company_id=company_id,
event_type=event_type,
samples=samples,
key=ScalarKey.resolve(key),
metric_variants=metric_variants,
)
def _get_scalar_average_per_iter_core(
@@ -57,6 +69,7 @@ class EventMetrics:
samples: int,
key: ScalarKey,
run_parallel: bool = True,
metric_variants: MetricVariants = None,
) -> dict:
intervals = self._get_task_metric_intervals(
company_id=company_id,
@@ -64,6 +77,7 @@ class EventMetrics:
task_id=task_id,
samples=samples,
field=key.field,
metric_variants=metric_variants,
)
if not intervals:
return {}
@@ -197,6 +211,7 @@ class EventMetrics:
task_id: str,
samples: int,
field: str = "iter",
metric_variants: MetricVariants = None,
) -> Sequence[MetricInterval]:
"""
Calculate interval per task metric variant so that the resulting
@@ -204,9 +219,14 @@ class EventMetrics:
Return the list og metric variant intervals as the following tuple:
(metric, variant, interval, samples)
"""
must = [{"term": {"task": task_id}}]
if metric_variants:
must.append(get_metric_variants_condition(metric_variants))
query = {"bool": {"must": must}}
es_req = {
"size": 0,
"query": {"term": {"task": task_id}},
"query": query,
"aggs": {
"metrics": {
"terms": {

View File

@@ -0,0 +1,205 @@
from typing import Optional, Tuple, Sequence, Any
import attr
import jsonmodels.models
import jwt
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
MetricVariants,
get_metric_variants_condition,
count_company_events,
)
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class EventsIterator:
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_key_value: Optional[Any] = None,
metric_variants: MetricVariants = None,
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
**kwargs,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, event_type):
return TaskEventsResult()
from_key_value = kwargs.pop("from_timestamp", from_key_value)
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
event_type=event_type,
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=ScalarKey.resolve(key),
)
return res
def count_task_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
metric_variants: MetricVariants = None,
) -> int:
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
es_req = {
"query": query,
}
with translate_errors_context(), TimingContext("es", "count_task_events"):
es_result = count_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
return es_result["count"]
def _get_events(
self,
event_type: EventType,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
key: ScalarKey,
from_key_value: Optional[Any],
metric_variants: MetricVariants = None,
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If from_key_field is not set then start either from latest or earliest.
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
so that events with this value will not be lost between the calls.
"""
query, must = self._get_initial_query_and_must(task_id, metric_variants)
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": query,
"sort": {key.field: "desc" if navigate_earlier else "asc"},
}
if from_key_value:
es_req["search_after"] = [from_key_value]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": must + [{"term": {key.field: events[-1][key.field]}}]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=event_type,
body=es_req,
routing=task_id,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)
@staticmethod
def _get_initial_query_and_must(
task_id: str, metric_variants: MetricVariants = None
) -> Tuple[dict, list]:
if not metric_variants:
must = [{"term": {"task": task_id}}]
query = {"term": {"task": task_id}}
else:
must = [
{"term": {"task": task_id}},
get_metric_variants_condition(metric_variants),
]
query = {"bool": {"must": must}}
return query, must
class Scroll(jsonmodels.models.Base):
def get_scroll_id(self) -> str:
return jwt.encode(
self.to_struct(),
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
).decode()
@classmethod
def from_scroll_id(cls, scroll_id: str):
try:
return cls(
**jwt.decode(
scroll_id,
key=config.get(
"services.events.events_retrieval.scroll_id_key", "1234567890"
),
)
)
except jwt.PyJWTError:
raise ValueError("Invalid Scroll ID")

View File

@@ -1,127 +0,0 @@
from typing import Optional, Tuple, Sequence
import attr
from elasticsearch import Elasticsearch
from apiserver.bll.event.event_common import (
check_empty_data,
search_company_events,
EventType,
)
from apiserver.database.errors import translate_errors_context
from apiserver.timing_context import TimingContext
@attr.s(auto_attribs=True)
class TaskEventsResult:
total_events: int = 0
next_scroll_id: str = None
events: list = attr.Factory(list)
class LogEventsIterator:
EVENT_TYPE = EventType.task_log
def __init__(self, es: Elasticsearch):
self.es = es
def get_task_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool = True,
from_timestamp: Optional[int] = None,
) -> TaskEventsResult:
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
return TaskEventsResult()
res = TaskEventsResult()
res.events, res.total_events = self._get_events(
company_id=company_id,
task_id=task_id,
batch_size=batch_size,
navigate_earlier=navigate_earlier,
from_timestamp=from_timestamp,
)
return res
def _get_events(
self,
company_id: str,
task_id: str,
batch_size: int,
navigate_earlier: bool,
from_timestamp: Optional[int],
) -> Tuple[Sequence[dict], int]:
"""
Return up to 'batch size' events starting from the previous timestamp either in the
direction of earlier events (navigate_earlier=True) or in the direction of later events.
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
For the last timestamp all the events are brought (even if the resulting size
exceeds batch_size) so that this timestamp events will not be lost between the calls.
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
"""
# retrieve the next batch of events
es_req = {
"size": batch_size,
"query": {"term": {"task": task_id}},
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
}
if from_timestamp:
es_req["search_after"] = [from_timestamp]
with translate_errors_context(), TimingContext("es", "get_task_events"):
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
hits = es_result["hits"]["hits"]
hits_total = es_result["hits"]["total"]["value"]
if not hits:
return [], hits_total
events = [hit["_source"] for hit in hits]
# retrieve the events that match the last event timestamp
# but did not make it into the previous call due to batch_size limitation
es_req = {
"size": 10000,
"query": {
"bool": {
"must": [
{"term": {"task": task_id}},
{"term": {"timestamp": events[-1]["timestamp"]}},
]
}
},
}
es_result = search_company_events(
self.es,
company_id=company_id,
event_type=self.EVENT_TYPE,
body=es_req,
)
last_second_hits = es_result["hits"]["hits"]
if not last_second_hits or len(last_second_hits) < 2:
# if only one element is returned for the last timestamp
# then it is already present in the events
return events, hits_total
already_present_ids = set(hit["_id"] for hit in hits)
last_second_events = [
hit["_source"]
for hit in last_second_hits
if hit["_id"] not in already_present_ids
]
# return the list merged from original query results +
# leftovers from the last timestamp
return (
[*events, *last_second_events],
hits_total,
)

View File

@@ -4,6 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
from abc import ABC, abstractmethod
from enum import auto
from typing import Any
from apiserver.utilities import extract_properties_to_lists
from apiserver.utilities.stringenum import StringEnum
from apiserver.config_repo import config
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
"""
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
def cast_value(self, value: Any) -> Any:
"""Cast value to appropriate type"""
return value
class TimestampKey(ScalarKey):
"""
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class IterKey(ScalarKey):
"""
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
}
}
def cast_value(self, value: Any) -> int:
return int(value)
class ISOTimeKey(ScalarKey):
"""

View File

@@ -1,2 +1,3 @@
from .project_bll import ProjectBLL
from .project_queries import ProjectQueries
from .sub_projects import _ids_with_children as project_ids_with_children

View File

@@ -1,6 +1,6 @@
import itertools
from collections import defaultdict
from datetime import datetime
from datetime import datetime, timedelta
from functools import reduce
from itertools import groupby
from operator import itemgetter
@@ -57,10 +57,14 @@ class ProjectBLL:
with TimingContext("mongo", "move_project"):
if source_id == destination_id:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
parent=source_id
source=source_id
)
source = Project.get(company, source_id)
destination = Project.get(company, destination_id)
if source_id in destination.path:
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
source=source_id, destination=destination_id
)
children = _get_sub_projects(
[source.id], _only=("id", "name", "parent", "path")
@@ -140,7 +144,14 @@ class ProjectBLL:
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
location=new_parent.name if new_parent else ""
)
if (
new_parent
and project_id == new_parent.id
or project_id in new_parent.path
):
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
project=project_id, parent=new_parent.id
)
moved = _reposition_project_with_children(
project, children=children, parent=new_parent
)
@@ -295,6 +306,7 @@ class ProjectBLL:
return project
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
@classmethod
def make_projects_get_all_pipelines(
@@ -356,6 +368,26 @@ class ProjectBLL:
},
]
def completed_after_subquery(additional_cond, time_thresh: datetime):
return {
# the sum of
"$sum": {
# for each task
"$cond": {
# if completed after the time_thresh
"if": {
"$and": [
"$completed",
{"$gt": ["$completed", time_thresh]},
additional_cond,
]
},
"then": 1,
"else": 0,
}
}
}
def runtime_subquery(additional_cond):
return {
# the sum of
@@ -386,16 +418,19 @@ class ProjectBLL:
}
group_step = {"_id": "$project"}
for state in EntityVisibility:
time_thresh = datetime.utcnow() - timedelta(hours=24)
for state in cls.visibility_states:
if specific_state and state != specific_state:
continue
if state == EntityVisibility.active:
group_step[state.value] = runtime_subquery(
{"$not": cls.archived_tasks_cond}
)
elif state == EntityVisibility.archived:
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
cond = (
cls.archived_tasks_cond
if state == EntityVisibility.archived
else {"$not": cls.archived_tasks_cond}
)
group_step[state.value] = runtime_subquery(cond)
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
cond, time_thresh=time_thresh
)
runtime_pipeline = [
# only count run time for these types of tasks
@@ -445,11 +480,16 @@ class ProjectBLL:
company: str,
project_ids: Sequence[str],
specific_state: Optional[EntityVisibility] = None,
include_children: bool = True,
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
if not project_ids:
return {}, {}
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
child_projects = (
_get_sub_projects(project_ids, _only=("id", "name"))
if include_children
else {}
)
project_ids_with_children = set(project_ids) | {
c.id for c in itertools.chain.from_iterable(child_projects.values())
}
@@ -483,8 +523,8 @@ class ProjectBLL:
) -> Dict[str, dict]:
return {
section: {
status: nested_get(a, (section, status), 0)
+ nested_get(b, (section, status), 0)
status: nested_get(a, (section, status), default=0)
+ nested_get(b, (section, status), default=0)
for status in set(a.get(section, {})) | set(b.get(section, {}))
}
for section in set(a) | set(b)
@@ -518,15 +558,24 @@ class ProjectBLL:
)
def get_status_counts(project_id, section):
project_runtime = runtime.get(project_id, {})
project_section_statuses = nested_get(
status_count, (project_id, section), default=default_counts
)
return {
"total_runtime": nested_get(runtime, (project_id, section), 0),
"status_count": nested_get(
status_count, (project_id, section), default_counts
"status_count": project_section_statuses,
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
"total_tasks": sum(project_section_statuses.values()),
"total_runtime": project_runtime.get(section, 0),
"completed_tasks": project_runtime.get(
f"{section}_recently_completed", 0
),
}
report_for_states = [
s for s in EntityVisibility if not specific_state or specific_state == s
s
for s in cls.visibility_states
if not specific_state or specific_state == s
]
stats = {
@@ -554,7 +603,7 @@ class ProjectBLL:
user_ids: Optional[Sequence[str]] = None,
) -> Set[str]:
"""
Get the set of user ids that created tasks/models/dataviews in the given projects
Get the set of user ids that created tasks/models in the given projects
If project_ids is empty then all projects are examined
If user_ids are passed then only subset of these users is returned
"""
@@ -676,8 +725,8 @@ class ProjectBLL:
@classmethod
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
"""
Returns the amount of task/dataviews/models per requested project
Use separate aggregation calls on Task/Dataview/Model instead of lookup
Returns the amount of task/models per requested project
Use separate aggregation calls on Task/Model instead of lookup
aggregation on projects in order not to hit memory limits on large tasks
"""
if not project_ids:

View File

@@ -30,6 +30,28 @@ class DeleteProjectResult:
urls: TaskUrls = None
def validate_project_delete(company: str, project_id: str):
project = Project.get_for_writing(
company=company, id=project_id, _only=("id", "path")
)
if not project:
raise errors.bad_request.InvalidProjectId(id=project_id)
project_ids = _ids_with_children([project_id])
ret = {}
for cls in (Task, Model):
ret[f"{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
).count()
for cls in (Task, Model):
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
project__in=project_ids,
system_tags__nin=[EntityVisibility.archived.value],
).count()
return ret
def delete_project(
company: str, project_id: str, force: bool, delete_contents: bool
) -> Tuple[DeleteProjectResult, Set[str]]:

View File

@@ -0,0 +1,241 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from typing import (
Sequence,
Optional,
Tuple,
)
from redis import StrictRedis
from apiserver.config_repo import config
from apiserver.database.model.task.task import Task
from apiserver.redis_manager import redman
from apiserver.utilities.dicts import nested_get
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .sub_projects import _ids_with_children
log = config.logger(__file__)
class ProjectQueries:
def __init__(self, redis=None):
self.redis: StrictRedis = redis or redman.connection("apiserver")
@staticmethod
def _get_project_constraint(
project_ids: Sequence[str], include_subprojects: bool
) -> dict:
if include_subprojects:
if project_ids is None:
return {}
project_ids = _ids_with_children(project_ids)
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
@staticmethod
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
if allow_public:
return {"company": {"$in": [None, "", company_id]}}
return {"company": company_id}
@classmethod
def get_aggregated_project_parameters(
cls,
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
"hyperparams": {"$exists": True, "$gt": {}},
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "section"))
),
"name": ParameterKeyEscaper.unescape(
nested_get(r, ("_id", "name"))
),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
company_constraint = self._get_company_constraint(company_id, allow_public)
project_constraint = self._get_project_constraint(
project_ids, include_subprojects
)
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def get_unique_metric_variants(
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
):
pipeline = [
{
"$match": {
**cls._get_company_constraint(company_id),
**cls._get_project_constraint(project_ids, include_subprojects),
}
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]

View File

@@ -126,14 +126,27 @@ class QueueBLL(object):
)
queue.delete()
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_all(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""Get all the queues according to the query"""
with translate_errors_context():
return Queue.get_many(
company=company_id, parameters=query_dict, query_dict=query_dict
company=company_id,
parameters=query_dict,
query_dict=query_dict,
ret_params=ret_params,
)
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
def get_queue_infos(
self,
company_id: str,
query_dict: dict,
ret_params: dict = None,
) -> Sequence[dict]:
"""
Get infos on all the company queues, including queue tasks and workers
"""
@@ -143,6 +156,7 @@ class QueueBLL(object):
company=company_id,
query_dict=query_dict,
override_projection=projection,
ret_params=ret_params,
)
queue_workers = defaultdict(list)

View File

@@ -49,6 +49,21 @@ class RedisCacheManager(Generic[T]):
def _get_redis_key(self, state_id):
return f"{self.state_class}/{state_id}"
def get_or_create_state_core(
self,
state_id=None,
init_state: Callable[[T], None] = _do_nothing,
validate_state: Callable[[T], None] = _do_nothing,
) -> T:
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
return state
@contextmanager
def get_or_create_state(
self,
@@ -66,12 +81,9 @@ class RedisCacheManager(Generic[T]):
:param validate_state: user callback to validate the state if retrieved from cache
Should throw an exception if the state is not valid. If not passed then no validation is done
"""
state = self.get_state(state_id) if state_id else None
if state:
validate_state(state)
else:
state = self.state_class(id=database.utils.id())
init_state(state)
state = self.get_or_create_state_core(
state_id=state_id, init_state=init_state, validate_state=validate_state
)
try:
yield state

View File

@@ -1,11 +1,10 @@
import itertools
from typing import Sequence, Tuple
from typing import Sequence, Tuple, Optional
import dpath
from apiserver.apierrors import errors
from apiserver.database.model.task.task import Task
from apiserver.tools import safe_get
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
tf_define_section = "TF_DEFINE"
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
"""
Return parameter section and name. The section is either TF_DEFINE or the default one
"""
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
return removed
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
"""
Remove the legacy params from the data dict and return the number of removed params
If the path not found then return 0
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
return []
if with_sections:
return itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
return list(
itertools.chain.from_iterable(
_get_legacy_params(section_data) for section_data in data.values()
)
)
return [
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
Escape all the section and param names for hyper params and configuration to make it mongo sage
"""
for old_params_field, new_params_field, default_section in (
("execution/parameters", "hyperparams", hyperparams_default_section),
("execution/model_desc", "configuration", None),
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
(("execution", "model_desc"), "configuration", None),
):
legacy_params = safe_get(fields, old_params_field)
legacy_params = nested_get(fields, old_params_field)
if legacy_params is None:
continue
if (
not safe_get(fields, new_params_field)
not fields.get(new_params_field)
and previous_task
and previous_task[new_params_field]
):
@@ -117,11 +118,11 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
if section is not None:
new_param["section"] = section
dpath.new(fields, new_path, new_param)
dpath.delete(fields, old_params_field)
nested_set(fields, new_path, new_param)
nested_delete(fields, old_params_field)
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
params = fields.get(param_field)
if params:
escaped_params = {
ParameterKeyEscaper.escape(key): {
@@ -131,7 +132,7 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
else value
for key, value in params.items()
}
dpath.set(fields, param_field, escaped_params)
fields[param_field] = escaped_params
def params_unprepare_from_saved(fields, copy_to_legacy=False):
@@ -140,7 +141,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
"""
for param_field in ("hyperparams", "configuration"):
params = safe_get(fields, param_field)
params = fields.get(param_field)
if params:
unescaped_params = {
ParameterKeyEscaper.unescape(key): {
@@ -150,18 +151,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
else value
for key, value in params.items()
}
dpath.set(fields, param_field, unescaped_params)
fields[param_field] = unescaped_params
if copy_to_legacy:
for new_params_field, old_params_field, use_sections in (
(f"hyperparams", "execution/parameters", True),
(f"configuration", "execution/model_desc", False),
("hyperparams", ("execution", "parameters"), True),
("configuration", ("execution", "model_desc"), False),
):
legacy_params = _get_legacy_params(
safe_get(fields, new_params_field), with_sections=use_sections
fields.get(new_params_field), with_sections=use_sections
)
if legacy_params:
dpath.new(
nested_set(
fields,
old_params_field,
{_get_full_param_name(p): p["value"] for p in legacy_params},
@@ -174,7 +175,7 @@ def _process_path(path: str):
Need to unescape and apply a full mongo escaping
"""
parts = path.split(".")
if len(parts) < 2 or len(parts) > 3:
if len(parts) < 2 or len(parts) > 4:
raise errors.bad_request.ValidationError("invalid task field", path=path)
return ".".join(
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
@@ -184,7 +185,7 @@ def _process_path(path: str):
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
for old_prefix, new_prefix in (
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
("execution.model_desc", f"configuration"),
("execution.model_desc", "configuration"),
("execution.docker_cmd", "container")
):
path: str

View File

@@ -1,9 +1,6 @@
import json
from collections import OrderedDict
from datetime import datetime, timedelta
from datetime import datetime
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
import dpath
import six
from mongoengine import Q
from redis import StrictRedis
@@ -14,7 +11,7 @@ from apiserver.apierrors import errors
from apiserver.apimodels.tasks import TaskInputModel
from apiserver.bll.queue import QueueBLL
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL, project_ids_with_children
from apiserver.bll.project import ProjectBLL
from apiserver.config_repo import config
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.model import Model
@@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
from apiserver.service_repo import APICall
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
from apiserver.timing_context import TimingContext
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
from .artifacts import artifacts_prepare_for_save
from .param_utils import params_prepare_for_save
from .utils import (
@@ -350,54 +346,6 @@ class TaskBLL:
if validate_models:
cls.validate_input_models(task)
@staticmethod
def get_unique_metric_variants(
company_id, project_ids: Sequence[str], include_subprojects: bool
):
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
pipeline = [
{
"$match": dict(
company={"$in": [None, "", company_id]}, **project_constraint,
)
},
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
{"$unwind": "$metrics"},
{
"$project": {
"metric": "$metrics.k",
"variants": {"$objectToArray": "$metrics.v"},
}
},
{"$unwind": "$variants"},
{
"$group": {
"_id": {
"metric": "$variants.v.metric",
"variant": "$variants.v.variant",
},
"metrics": {
"$addToSet": {
"metric": "$variants.v.metric",
"metric_hash": "$metric",
"variant": "$variants.v.variant",
"variant_hash": "$variants.k",
}
},
}
},
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
]
with translate_errors_context():
result = Task.aggregate(pipeline)
return [r["metrics"][0] for r in result]
@staticmethod
def set_last_update(
task_ids: Collection[str],
@@ -494,173 +442,6 @@ class TaskBLL:
**extra_updates,
)
@staticmethod
def get_aggregated_project_parameters(
company_id,
project_ids: Sequence[str],
include_subprojects: bool,
page: int = 0,
page_size: int = 500,
) -> Tuple[int, int, Sequence[dict]]:
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
page = max(0, page)
page_size = max(1, page_size)
pipeline = [
{
"$match": {
"company": {"$in": [None, "", company_id]},
"hyperparams": {"$exists": True, "$gt": {}},
**project_constraint,
}
},
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
{"$unwind": "$sections"},
{
"$project": {
"section": "$sections.k",
"names": {"$objectToArray": "$sections.v"},
}
},
{"$unwind": "$names"},
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
{"$skip": page * page_size},
{"$limit": page_size},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT"},
}
},
]
result = next(Task.aggregate(pipeline), None)
total = 0
remaining = 0
results = []
if result:
total = int(result.get("total", -1))
results = [
{
"section": ParameterKeyEscaper.unescape(
dpath.get(r, "_id/section")
),
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
}
for r in result.get("results", [])
]
remaining = max(0, total - (len(results) + page * page_size))
return total, remaining, results
HyperParamValues = Tuple[int, Sequence[str]]
def _get_cached_hyperparam_values(
self, key: str, last_update: datetime
) -> Optional[HyperParamValues]:
allowed_delta = timedelta(
seconds=config.get(
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
)
)
try:
cached = self.redis.get(key)
if not cached:
return
data = json.loads(cached)
cached_last_update = datetime.fromtimestamp(data["last_update"])
if (last_update - cached_last_update) < allowed_delta:
return data["total"], data["values"]
except Exception as ex:
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
def get_hyperparam_distinct_values(
self,
company_id: str,
project_ids: Sequence[str],
section: str,
name: str,
include_subprojects: bool,
allow_public: bool = True,
) -> HyperParamValues:
if allow_public:
company_constraint = {"company": {"$in": [None, "", company_id]}}
else:
company_constraint = {"company": company_id}
if project_ids:
if include_subprojects:
project_ids = project_ids_with_children(project_ids)
project_constraint = {"project": {"$in": project_ids}}
else:
project_constraint = {}
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
last_updated_task = (
Task.objects(
**company_constraint,
**project_constraint,
**{f"{key_path.replace('.', '__')}__exists": True},
)
.only("last_update")
.order_by("-last_update")
.limit(1)
.first()
)
if not last_updated_task:
return 0, []
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
last_update = last_updated_task.last_update or datetime.utcnow()
cached_res = self._get_cached_hyperparam_values(
key=redis_key, last_update=last_update
)
if cached_res:
return cached_res
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
pipeline = [
{
"$match": {
**company_constraint,
**project_constraint,
key_path: {"$exists": True},
}
},
{"$project": {"value": f"${key_path}.value"}},
{"$group": {"_id": "$value"}},
{"$sort": {"_id": 1}},
{"$limit": max_values},
{
"$group": {
"_id": 1,
"total": {"$sum": 1},
"results": {"$push": "$$ROOT._id"},
}
},
]
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
if not result:
return 0, []
total = int(result.get("total", 0))
values = result.get("results", [])
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
self.redis.setex(redis_key, ttl, json.dumps(cached))
return total, values
@classmethod
def dequeue_and_change_status(
cls, task: Task, company_id: str, status_message: str, status_reason: str,

View File

@@ -130,14 +130,14 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
if not metrics:
return set()
task_metrics = {task: set(metrics)}
task_metrics = {task: {m: [] for m in metrics}}
scroll_id = None
urls = set()
while True:
res = event_bll.debug_images_iterator.get_task_events(
company_id=company,
task_metrics=task_metrics,
iter_count=100,
iter_count=10,
state_id=scroll_id,
)
if not res.metric_events or not any(

View File

@@ -109,6 +109,7 @@ def enqueue_task(
status_message: str,
status_reason: str,
validate: bool = False,
force: bool = False,
) -> Tuple[int, dict]:
if not queue_id:
# try to get default queue
@@ -128,6 +129,7 @@ def enqueue_task(
status_reason=status_reason,
status_message=status_message,
allow_same_state_transition=False,
force=force,
).execute(enqueue_status=task.status)
try:
@@ -160,6 +162,8 @@ def delete_task(
force: bool,
return_file_urls: bool,
delete_output_models: bool,
status_message: str,
status_reason: str,
) -> Tuple[int, Task, CleanupResult]:
task = TaskBLL.get_task_with_access(
task_id, company_id=company_id, requires_write_access=True
@@ -177,6 +181,17 @@ def delete_task(
current=task.status,
)
try:
TaskBLL.dequeue_and_change_status(
task,
company_id=company_id,
status_message=status_message,
status_reason=status_reason,
)
except APIError:
# dequeue may fail if the task was not enqueued
pass
cleanup_res = cleanup_task(
task,
force=force,
@@ -352,6 +367,7 @@ def stop_task(
"system_tags",
"last_worker",
"last_update",
"execution.queue",
),
requires_write_access=True,
)
@@ -365,7 +381,21 @@ def stop_task(
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
)
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
is_queued = task.status == TaskStatus.queued
set_stopped = (
is_queued
or TaskSystemTags.development in task.system_tags
or not is_run_by_worker(task)
)
if set_stopped:
if is_queued:
try:
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
except APIError:
# dequeue may fail if the task was not enqueued
pass
new_status = TaskStatus.stopped
status_message = f"Stopped by {user_name}"
else:

View File

@@ -258,7 +258,7 @@ class WorkerBLL:
tasks_info = {
task.id: task
for task in Task.objects(id__in=task_ids).only(
"name", "started", "last_iteration"
"name", "started", "last_iteration", "active_duration"
)
}
@@ -283,11 +283,7 @@ class WorkerBLL:
if helper.task_id:
task = tasks_info.get(helper.task_id, None)
if task:
worker.task.running_time = (
int((datetime.utcnow() - task.started).total_seconds() * 1000)
if task.started
else 0
)
worker.task.running_time = (task.active_duration or 0) * 1000
worker.task.last_iteration = task.last_iteration
update_queue_entries(worker.queue)

View File

@@ -79,6 +79,8 @@ class BasicConfig:
def logger(self, name: str) -> logging.Logger:
if Path(name).is_file():
name = Path(name).stem
if name == "__init__" and Path(name).parent.stem:
name = Path(name).parent.stem
path = ".".join((self.prefix, name))
return logging.getLogger(path)

View File

@@ -3,7 +3,7 @@
debug: false # Debug mode
pretty_json: false # prettify json response
return_stack: true # return stack trace on error
log_calls: true # Log API Calls
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
# valid values are:
@@ -79,6 +79,11 @@
max_age: 99999999999
}
# provide a cookie domain override per company
# cookies_domain_override {
# <company-id>: <domain>
# }
# # A list of fixed users
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
# fixed_users {

View File

@@ -0,0 +1,4 @@
max_page_size: 500
# expiration time in seconds for the redis scroll states in get_many family of apis
scroll_state_expiration_seconds: 600

View File

@@ -17,6 +17,10 @@ events_retrieval {
# the max amount of variants to aggregate on
max_variants_count: 100
max_raw_scalars_size: 200000
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
}
# if set then plot str will be checked for the valid json on plot add

View File

@@ -28,6 +28,8 @@ OVERRIDE_PORT_ENV_KEY = (
"MONGODB_SERVICE_PORT",
)
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
class DatabaseEntry(models.Base):
host = StringField(required=True)
@@ -47,13 +49,17 @@ class DatabaseFactory:
missing = []
log.info("Initializing database connections")
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
if override_port:
log.info(f"Using override mongodb port {override_port}")
if override_connection_string:
log.info(f"Using override mongodb connection string {override_connection_string}")
else:
if override_hostname:
log.info(f"Using override mongodb host {override_hostname}")
if override_port:
log.info(f"Using override mongodb port {override_port}")
for key, alias in get_items(Database).items():
if key not in db_entries:
@@ -62,11 +68,13 @@ class DatabaseFactory:
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
if override_connection_string:
entry.host = override_connection_string
else:
if override_hostname:
entry.host = furl(entry.host).set(host=override_hostname).url
if override_port:
entry.host = furl(entry.host).set(port=override_port).url
try:
entry.validate()

View File

@@ -48,6 +48,7 @@ class Credentials(EmbeddedDocument):
meta = {"strict": False}
key = StringField(required=True)
secret = StringField(required=True)
label = StringField()
last_used = DateTimeField()

View File

@@ -1,26 +1,41 @@
import re
from collections import namedtuple
from functools import reduce
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
from functools import reduce, partial
from typing import (
Collection,
Sequence,
Union,
Optional,
Type,
Tuple,
Mapping,
Any,
Callable,
Dict,
List,
)
from boltons.iterutils import first, bucketize, partition
from boltons.iterutils import first, partition
from dateutil.parser import parse as parse_datetime
from mongoengine import Q, Document, ListField, StringField
from mongoengine import Q, Document, ListField, StringField, IntField
from pymongo.command_cursor import CommandCursor
from apiserver.apierrors import errors
from apiserver.apierrors.base import BaseError
from apiserver.bll.redis_cache_manager import RedisCacheManager
from apiserver.config_repo import config
from apiserver.database import Database
from apiserver.database.errors import MakeGetAllQueryError
from apiserver.database.projection import project_dict, ProjectionHelper
from apiserver.database.props import PropsMixin
from apiserver.database.query import RegexQ, RegexWrapper
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
from apiserver.database.utils import (
get_company_or_none_constraint,
get_fields_choices,
field_does_not_exist,
field_exists,
)
from apiserver.redis_manager import redman
log = config.logger("dbmodel")
@@ -70,6 +85,9 @@ class GetMixin(PropsMixin):
_ordering_key = "order_by"
_search_text_key = "search_text"
_start_key = "start"
_size_key = "size"
_multi_field_param_sep = "__"
_multi_field_param_prefix = {
("_any_", "_or_"): lambda a, b: a | b,
@@ -103,45 +121,106 @@ class GetMixin(PropsMixin):
class ListFieldBucketHelper:
op_prefix = "__$"
legacy_exclude_prefix = "-"
_legacy_exclude_prefix = "-"
_legacy_exclude_mongo_op = "nin"
_default = "in"
default_mongo_op = "in"
_ops = {
# op -> (mongo_op, sticky)
"not": ("nin", False),
"nop": (default_mongo_op, False),
"all": ("all", True),
"and": ("all", True),
"any": (default_mongo_op, True),
"or": (default_mongo_op, True),
}
_next = _default
_sticky = False
def __init__(self, legacy=False):
self._legacy = legacy
self._current_op = None
self._sticky = False
self._support_legacy = legacy
self.allow_empty = False
def key(self, v):
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
op = (
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
)
if translate:
tup = self._ops.get(op, None)
return tup[0] if tup else None
return op
def _key(self, v) -> Optional[Union[str, bool]]:
if v is None:
self._next = self._default
return self._default
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
self._next = self._default
return self._ops["not"][0]
elif v.startswith(self.op_prefix):
self._next, self._sticky = self._ops.get(
v[len(self.op_prefix) :], (self._default, self._sticky)
)
self.allow_empty = True
return None
next_ = self._next
if not self._sticky:
self._next = self._default
return next_
op = self._get_op(v)
if op is not None:
# operator - set state and return None
self._current_op, self._sticky = self._ops.get(
op, (self.default_mongo_op, self._sticky)
)
return None
elif self._current_op:
current_op = self._current_op
if not self._sticky:
self._current_op = None
return current_op
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
self._current_op = None
return False
def value_transform(self, v):
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
return v[len(self.legacy_exclude_prefix) :]
return v
return self.default_mongo_op
def get_global_op(self, data: Sequence[str]) -> int:
op_to_res = {
"in": Q.OR,
"all": Q.AND,
}
data = (x for x in data if x is not None)
first_op = (
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
)
return op_to_res.get(first_op, self.default_mongo_op)
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
actions = {}
for val in data:
key = self._key(val)
if key is None:
continue
elif self._support_legacy and key is False:
key = self._legacy_exclude_mongo_op
val = val[len(self._legacy_exclude_prefix) :]
actions.setdefault(key, []).append(val)
return actions
get_all_query_options = QueryParameterOptions()
class GetManyScrollState(ProperDictMixin, Document):
meta = {"db_alias": Database.backend, "strict": False}
id = StringField(primary_key=True)
position = IntField(default=0)
_cache_manager = None
@classmethod
def get_cache_manager(cls):
if not cls._cache_manager:
cls._cache_manager = RedisCacheManager(
state_class=cls.GetManyScrollState,
redis=redman.connection("apiserver"),
expiration_interval=config.get(
"services._mongo.scroll_state_expiration_seconds", 600
),
)
return cls._cache_manager
@classmethod
def get(
cls: Union["GetMixin", Document],
@@ -240,7 +319,9 @@ class GetMixin(PropsMixin):
Prepare a query object based on the provided query dictionary and various fields.
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
used instead of Q. Otherwise we can and up with some combination that is not processed according to
RegexQ rules
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
Supported parameters:
@@ -273,10 +354,13 @@ class GetMixin(PropsMixin):
).items():
query &= cls.get_range_field_query(field, data)
for field in opts.fields or []:
data = parameters.pop(field, None)
if data is not None:
dict_query[field] = data
for field, data in cls._pop_matching_params(
patterns=opts.fields or [], parameters=parameters
).items():
if "._" in field or "_." in field:
query &= RegexQ(__raw__={field: data})
else:
dict_query[field.replace(".", "__")] = data
for field in opts.datetime_fields or []:
data = parameters.pop(field, None)
@@ -308,22 +392,31 @@ class GetMixin(PropsMixin):
break
if any("._" in f for f in data.fields):
q = reduce(
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
lambda a, x: func(
a,
RegexQ(
__raw__={
x: {"$regex": data.pattern, "$options": "i"}
}
),
),
data.fields,
Q()
RegexQ(),
)
else:
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
sep_fields = [f.replace(".", "__") for f in data.fields]
q = reduce(
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
lambda a, x: func(a, RegexQ(**{x: regex})),
sep_fields,
RegexQ(),
)
query = query & q
return query & RegexQ(**dict_query)
@classmethod
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Return a range query for the provided field. The data should contain min and max values
Both intervals are included. For open range queries either min or max can be None
@@ -347,14 +440,14 @@ class GetMixin(PropsMixin):
if max_val is not None:
query[f"{mongoengine_field}__lte"] = max_val
q = Q(**query)
q = RegexQ(**query)
if min_val is None:
q |= Q(**{mongoengine_field: None})
q |= RegexQ(**{mongoengine_field: None})
return q
@classmethod
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
"""
Get a proper mongoengine Q object that represents an "or" query for the provided values
with respect to the given list field, with support for "none of empty" in case a None value
@@ -366,30 +459,31 @@ class GetMixin(PropsMixin):
"""
if not isinstance(data, (list, tuple)):
data = [data]
# raise MakeGetAllQueryError("expected list", field)
# TODO: backwards compatibility only for older API versions
helper = cls.ListFieldBucketHelper(legacy=True)
actions = bucketize(
data, key=helper.key, value_transform=helper.value_transform
)
global_op = helper.get_global_op(data)
actions = helper.get_actions(data)
allow_empty = None in actions.get("in", {})
mongoengine_field = field.replace(".", "__")
q = RegexQ()
for action in filter(None, actions):
q &= RegexQ(
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
)
queries = [
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
for action in filter(None, actions)
]
if not allow_empty:
if not queries:
q = RegexQ()
else:
q = RegexQCombination(operation=global_op, children=queries)
if not helper.allow_empty:
return q
return (
q
| Q(**{f"{mongoengine_field}__exists": False})
| Q(**{mongoengine_field: []})
| RegexQ(**{f"{mongoengine_field}__exists": False})
| RegexQ(**{mongoengine_field: []})
| RegexQ(**{mongoengine_field: None})
)
@classmethod
@@ -417,27 +511,41 @@ class GetMixin(PropsMixin):
return order_by
@classmethod
def validate_paging(
cls, parameters=None, default_page=None, default_page_size=None
):
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
if parameters is None:
parameters = {}
default_page = parameters.get("page", default_page)
if default_page is None:
return None, None
default_page_size = parameters.get("page_size", default_page_size)
if not default_page_size:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
elif default_page < 0:
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
"""
Validate and extract paging info from from the provided dictionary. Supports default values.
If page is specified then it should be non-negative, if page size is specified then it should be positive
If page size is specified and page is not then 0 page is assumed
If page is specified then page size should be specified too
"""
parameters = parameters or {}
start = parameters.get(cls._start_key)
if start is not None:
return start, cls.validate_scroll_size(parameters)
max_page_size = config.get("services._mongo.max_page_size", 500)
page = parameters.get("page", default_page)
if page is not None and page < 0:
raise errors.bad_request.ValidationError("page must be >=0", field="page")
elif default_page_size < 1:
page_size = parameters.get("page_size", default_page_size or max_page_size)
if page_size is not None and page_size < 1:
raise errors.bad_request.ValidationError(
"page_size must be >0", field="page_size"
)
return default_page, default_page_size
if page_size is not None:
page = page or 0
page_size = min(page_size, max_page_size)
return page * page_size, page_size
if page is not None:
raise errors.bad_request.MissingRequiredFields(
"page_size is required when page is requested", field="page_size"
)
return None, None
@classmethod
def get_projection(cls, parameters, override_projection=None, **__):
@@ -481,6 +589,57 @@ class GetMixin(PropsMixin):
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
@classmethod
def validate_scroll_size(cls, query_dict: dict) -> int:
size = query_dict.get(cls._size_key)
if not size or not isinstance(size, int) or size < 1:
raise errors.bad_request.ValidationError(
"Integer size parameter greater than 1 should be provided when working with scroll"
)
return size
@classmethod
def get_data_with_scroll_and_filter_support(
cls,
query_dict: dict,
data_getter: Callable[[], Sequence[dict]],
ret_params: dict,
) -> Sequence[dict]:
"""
Retrieves the data by calling the provided data_getter api
If scroll parameters are specified then put the query_dict 'start' parameter to the last
scroll position and continue retrievals from that position
If refresh_scroll is requested then bring once more the data from the beginning
till the current scroll position
In the end the scroll position is updated and accumulated frames are returned
"""
query_dict = query_dict or {}
state: Optional[cls.GetManyScrollState] = None
if "scroll_id" in query_dict:
size = cls.validate_scroll_size(query_dict)
state = cls.get_cache_manager().get_or_create_state_core(
query_dict.get("scroll_id")
)
if query_dict.get("refresh_scroll"):
query_dict[cls._size_key] = max(state.position, size)
state.position = 0
query_dict[cls._start_key] = state.position
data = data_getter()
if cls._start_key in query_dict:
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
def update_state(returned_len: int):
if not state:
return
state.position = query_dict[cls._start_key]
cls.get_cache_manager().set_state(state)
if ret_params is not None:
ret_params["scroll_id"] = state.id
update_state(len(data))
return data
@classmethod
def get_many_with_join(
cls,
@@ -491,6 +650,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection=None,
expand_reference_ids=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query with support for joining referenced documents according to the
@@ -526,6 +686,7 @@ class GetMixin(PropsMixin):
query=query,
query_options=query_options,
allow_public=allow_public,
ret_params=ret_params,
)
def projection_func(doc_type, projection, ids):
@@ -556,6 +717,7 @@ class GetMixin(PropsMixin):
allow_public=False,
override_projection: Collection[str] = None,
return_dicts=True,
ret_params: dict = None,
):
"""
Fetch all documents matching a provided query. Supported several built-in options
@@ -601,12 +763,16 @@ class GetMixin(PropsMixin):
_query = (q & query) if query else q
if return_dicts:
return cls._get_many_override_none_ordering(
data_getter = partial(
cls._get_many_override_none_ordering,
query=_query,
parameters=parameters,
override_projection=override_projection,
override_collation=override_collation,
)
return cls.get_data_with_scroll_and_filter_support(
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
)
return cls._get_many_no_company(
query=_query,
@@ -658,7 +824,7 @@ class GetMixin(PropsMixin):
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
if order_by and not override_collation:
override_collation = cls._get_collation_override(order_by[0])
page, page_size = cls.validate_paging(parameters=parameters)
start, size = cls.validate_paging(parameters=parameters)
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
@@ -679,9 +845,9 @@ class GetMixin(PropsMixin):
if exclude:
qs = qs.exclude(*exclude)
if page is not None and page_size:
if start is not None and size:
# add paging
qs = qs.skip(page * page_size).limit(page_size)
qs = qs.skip(start).limit(size)
return qs
@@ -742,7 +908,10 @@ class GetMixin(PropsMixin):
parameters = parameters or {}
search_text = parameters.get(cls._search_text_key)
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
page, page_size = cls.validate_paging(parameters=parameters)
start, size = cls.validate_paging(parameters=parameters)
if size is not None and size <= 0:
return []
include, exclude = cls.split_projection(
cls.get_projection(parameters, override_projection)
)
@@ -774,25 +943,28 @@ class GetMixin(PropsMixin):
if exclude:
query_sets = [qs.exclude(*exclude) for qs in query_sets]
if page is None or not page_size:
if start is None or not size:
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
# add paging
ret = []
start = page * page_size
for qs in query_sets:
qs_size = qs.count()
if qs_size < start:
start -= qs_size
continue
last_set = len(query_sets) - 1
for i, qs in enumerate(query_sets):
last_size = len(ret)
ret.extend(
obj.to_proper_dict(only=include)
for obj in qs.skip(start).limit(page_size)
for obj in (qs.skip(start) if start else qs).limit(size)
)
if len(ret) >= page_size:
added = len(ret) - last_size
if added > 0:
start = 0
size = max(0, size - added)
elif i != last_set:
start -= min(start, qs.count())
if size <= 0:
break
start = 0
page_size -= len(ret)
return ret

View File

@@ -1,7 +1,6 @@
from typing import Sequence
from mongoengine import (
Document,
StringField,
DateTimeField,
BooleanField,
@@ -14,17 +13,15 @@ from apiserver.database.fields import (
SafeDictField,
SafeSortedListField,
)
from apiserver.database.model import DbModelMixin
from apiserver.database.model import AttributedDocument
from apiserver.database.model.base import GetMixin
from apiserver.database.model.metadata import MetadataItem
from apiserver.database.model.model_labels import ModelLabels
from apiserver.database.model.company import Company
from apiserver.database.model.project import Project
from apiserver.database.model.task.task import Task
from apiserver.database.model.user import User
class Model(DbModelMixin, Document):
class Model(AttributedDocument):
meta = {
"db_alias": Database.backend,
"strict": strict,
@@ -73,8 +70,6 @@ class Model(DbModelMixin, Document):
id = StringField(primary_key=True)
name = StrippedStringField(user_set_allowed=True, min_length=3)
parent = StringField(reference_field="Model", required=False)
user = StringField(required=True, reference_field=User)
company = StringField(required=True, reference_field=Company)
project = StringField(reference_field=Project, user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)
task = StringField(reference_field=Task)

View File

@@ -11,6 +11,7 @@ class Project(AttributedDocument):
get_all_query_options = GetMixin.QueryParameterOptions(
pattern_fields=("name", "description"),
list_fields=("tags", "system_tags", "id", "parent", "path"),
range_fields=("last_update",),
)
meta = {

View File

@@ -219,6 +219,7 @@ class Task(AttributedDocument):
"status",
"project",
"parent",
"hyperparams.*",
),
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
datetime_fields=("status_changed", "last_update"),
@@ -233,7 +234,7 @@ class Task(AttributedDocument):
type = StringField(required=True, choices=get_options(TaskType))
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
status_reason = StringField()
status_message = StringField()
status_message = StringField(user_set_allowed=True)
status_changed = DateTimeField()
comment = StringField(user_set_allowed=True)
created = DateTimeField(required=True, user_set_allowed=True)

View File

@@ -4,7 +4,12 @@ from threading import Lock
from typing import Sequence
import six
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
from mongoengine import (
EmbeddedDocumentField,
EmbeddedDocumentListField,
EmbeddedDocument,
Document,
)
from mongoengine.base import get_document
from apiserver.database.fields import (
@@ -25,6 +30,13 @@ class PropsMixin(object):
__cached_dpath_computed_fields_lock = Lock()
__cached_dpath_computed_fields = None
_document_classes = {}
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if issubclass(cls, (Document, EmbeddedDocument)):
PropsMixin._document_classes[cls._class_name] = cls
@classmethod
def get_fields(cls):
if cls.__cached_fields is None:
@@ -57,8 +69,14 @@ class PropsMixin(object):
def resolve_doc(v):
if not isinstance(v, six.string_types):
return v
if v == 'self':
if v == "self":
return cls_.owner_document
doc_cls = PropsMixin._document_classes.get(v)
if doc_cls:
return doc_cls
return get_document(v)
fields = {k: resolve_doc(v) for k, v in res.items()}
@@ -72,7 +90,7 @@ class PropsMixin(object):
).document_type
fields.update(
{
'.'.join((field, subfield)): doc
".".join((field, subfield)): doc
for subfield, doc in PropsMixin._get_fields_with_attr(
embedded_doc_cls, attr
).items()
@@ -80,10 +98,10 @@ class PropsMixin(object):
)
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
return fields
@@ -94,7 +112,7 @@ class PropsMixin(object):
for depth, part in enumerate(parts):
if current_cls is None:
raise ValueError(
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
)
try:
field_name, field = next(
@@ -103,7 +121,7 @@ class PropsMixin(object):
if k == part
)
except StopIteration:
raise ValueError('Invalid field path %s' % parts[:depth])
raise ValueError("Invalid field path %s" % parts[:depth])
translated_parts.append(part)
@@ -119,7 +137,7 @@ class PropsMixin(object):
),
):
current_cls = field.field.document_type
translated_parts.append('*')
translated_parts.append("*")
else:
current_cls = None
@@ -128,7 +146,7 @@ class PropsMixin(object):
@classmethod
def get_reference_fields(cls):
if cls.__cached_reference_fields is None:
fields = cls._get_fields_with_attr(cls, 'reference_field')
fields = cls._get_fields_with_attr(cls, "reference_field")
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_reference_fields
@@ -143,12 +161,12 @@ class PropsMixin(object):
@classmethod
def get_exclude_fields(cls):
if cls.__cached_exclude_fields is None:
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
return cls.__cached_exclude_fields
@classmethod
def get_dpath_translated_path(cls, path, separator='.'):
def get_dpath_translated_path(cls, path, separator="."):
if cls.__cached_dpath_computed_fields is None:
cls.__cached_dpath_computed_fields = {}
if path not in cls.__cached_dpath_computed_fields:

View File

@@ -5,7 +5,7 @@ Apply elasticsearch mappings to given hosts.
import argparse
import json
from pathlib import Path
from typing import Optional, Sequence
from typing import Optional, Sequence, Tuple
from elasticsearch import Elasticsearch
@@ -13,7 +13,7 @@ HERE = Path(__file__).resolve().parent
def apply_mappings_to_cluster(
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
):
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
@@ -30,7 +30,7 @@ def apply_mappings_to_cluster(
else:
files = p.glob("**/*.json")
es = Elasticsearch(hosts=hosts, **(es_args or {}))
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
return [_send_template(f) for f in files]

View File

@@ -82,7 +82,11 @@ def check_elastic_empty() -> bool:
es_logger.addFilter(log_filter)
for retry in range(max_retries):
try:
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
es = Elasticsearch(
hosts=cluster_conf.get("hosts", None),
http_auth=es_factory.get_credentials("events", cluster_conf),
**cluster_conf.get("args", {})
)
return not es.indices.get_template(name="events*")
except exceptions.NotFoundError as ex:
log.error(ex)
@@ -109,5 +113,7 @@ def init_es_data():
log.info(f"Applying mappings to ES host: {hosts_config}")
args = cluster_conf.get("args", {})
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
http_auth = es_factory.get_credentials(name)
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
log.info(res)

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "events-*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "queue_metrics_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,7 +1,8 @@
{
"index_patterns": "worker_stats_*",
"settings": {
"number_of_shards": 1
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"_source": {

View File

@@ -1,9 +1,10 @@
from datetime import datetime
from functools import lru_cache
from os import getenv
from typing import Tuple
from typing import Tuple, Optional
from boltons.iterutils import first
from elasticsearch import Elasticsearch, Transport
from elasticsearch import Elasticsearch
from apiserver.config_repo import config
@@ -21,6 +22,10 @@ OVERRIDE_PORT_ENV_KEY = (
"ELASTIC_SERVICE_PORT",
)
OVERRIDE_USERNAME_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_USERNAME",)
OVERRIDE_PASSWORD_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_PASSWORD",)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
log.info(f"Using override elastic host {OVERRIDE_HOST}")
@@ -29,6 +34,14 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override elastic port {OVERRIDE_PORT}")
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
if OVERRIDE_USERNAME:
log.info(f"Using override elastic username {OVERRIDE_USERNAME}")
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
if OVERRIDE_PASSWORD:
log.info("Using override elastic password ********")
_instances = {}
@@ -48,6 +61,10 @@ class InvalidClusterConfiguration(Exception):
pass
class MissingPasswordForElasticUser(Exception):
pass
class ESFactory:
@classmethod
def connect(cls, cluster_name):
@@ -65,22 +82,45 @@ class ESFactory:
if not hosts:
raise InvalidClusterConfiguration(cluster_name)
http_auth = cls.get_credentials(cluster_name)
args = cluster_config.get("args", {})
_instances[cluster_name] = Elasticsearch(
hosts=hosts, transport_class=Transport, **args
hosts=hosts, http_auth=http_auth, **args
)
return _instances[cluster_name]
@classmethod
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
if not cluster_config.get("secure", True):
return None
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
if not elastic_user:
return None
elastic_password = OVERRIDE_PASSWORD or config.get(
"secure.elastic.password", None
)
if not elastic_password:
raise MissingPasswordForElasticUser(
f"cluster={cluster_name}, username={elastic_user}"
)
return elastic_user, elastic_password
@classmethod
def get_all_cluster_names(cls):
return list(config.get("hosts.elastic"))
@classmethod
def get_override(cls, cluster_name: str) -> Tuple[str, str]:
def get_override_host(cls, cluster_name: str) -> Tuple[str, str]:
return OVERRIDE_HOST, OVERRIDE_PORT
@classmethod
@lru_cache()
def get_cluster_config(cls, cluster_name):
"""
Returns cluster config for the specified cluster path
@@ -97,7 +137,7 @@ class ESFactory:
for entry in cluster_config.get("hosts", []):
entry[key] = value
host, port = cls.get_override(cluster_name)
host, port = cls.get_override_host(cluster_name)
if host:
set_host_prop("host", host)

View File

@@ -298,8 +298,9 @@ class PrePopulate:
if company_id is None:
company_id = ""
# Always use a public user for pre-populated data
cls.user_cls(id=user_id, name=user_name, company="").save()
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
if not existing_user:
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
cls._import(zfile, company_id, user_id, metadata)

View File

@@ -1,10 +1,8 @@
import threading
from os import getenv
from time import sleep
from boltons.iterutils import first
from redis import StrictRedis
from redis.sentinel import Sentinel, SentinelConnectionPool
from rediscluster import RedisCluster
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
from apiserver.config_repo import config
@@ -21,6 +19,11 @@ OVERRIDE_PORT_ENV_KEY = (
"TRAINS_REDIS_SERVICE_PORT",
"REDIS_SERVICE_PORT",
)
OVERRIDE_PASSWORD_ENV_KEY = (
"CLEARML_REDIS_SERVICE_PASSWORD",
"TRAINS_REDIS_SERVICE_PASSWORD",
"REDIS_SERVICE_PASSWORD",
)
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
if OVERRIDE_HOST:
@@ -30,99 +33,7 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
if OVERRIDE_PORT:
log.info(f"Using override redis port {OVERRIDE_PORT}")
class MyPubSubWorkerThread(threading.Thread):
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
super(MyPubSubWorkerThread, self).__init__()
self.daemon = daemon
self.sentinel = sentinel
self.on_new_master = on_new_master
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.msg_sleep_time = msg_sleep_time
self._running = False
self.pubsub = None
def subscribe(self):
if self.pubsub:
try:
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
except Exception:
pass
finally:
self.pubsub = None
subscriptions = {"+switch-master": self.on_new_master}
while not self.pubsub or not self.pubsub.subscribed:
try:
self.pubsub = self.sentinel.pubsub()
self.pubsub.subscribe(**subscriptions)
except Exception as ex:
log.warn(
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
)
sleep(3)
log.info(f"Subscribed to sentinel {self.sentinel_host}")
def run(self):
if self._running:
return
self._running = True
self.subscribe()
while self.pubsub.subscribed:
try:
self.pubsub.get_message(
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
)
except Exception as ex:
log.warn(
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
)
self.subscribe()
self.pubsub.close()
self._running = False
def stop(self):
# stopping simply unsubscribes from all channels and patterns.
# the unsubscribe responses that are generated will short circuit
# the loop in run(), calling pubsub.close() to clean up the connection
self.pubsub.unsubscribe()
self.pubsub.punsubscribe()
# todo,future - multi master clusters?
class RedisCluster(object):
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
self.service_name = service_name
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
self.master = None
self.master_host_port = None
self.reconfigure()
self.sentinel_threads = {}
self.listen()
def reconfigure(self):
try:
self.master_host_port = self.sentinel.discover_master(self.service_name)
self.master = self.sentinel.master_for(self.service_name)
log.info(f"Reconfigured master to {self.master_host_port}")
except Exception as ex:
log.error(f"Error while reconfiguring. {ex.args[0]}")
def listen(self):
def on_new_master(workerThread):
self.reconfigure()
for sentinel in self.sentinel.sentinels:
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
)
self.sentinel_threads[sentinel_host].start()
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
class RedisManager(object):
@@ -131,6 +42,9 @@ class RedisManager(object):
for alias, alias_config in redis_config_dict.items():
alias_config = alias_config.as_plain_ordered_dict()
alias_config["password"] = config.get(
f"secure.redis.{alias}.password", None
)
is_cluster = alias_config.get("cluster", False)
@@ -142,34 +56,19 @@ class RedisManager(object):
if port:
alias_config["port"] = port
db = alias_config.get("db", 0)
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
if password:
alias_config["password"] = password
sentinels = alias_config.get("sentinels", None)
service_name = alias_config.get("service_name", None)
if not is_cluster and sentinels:
raise ConfigError(
"Redis configuration is invalid. mixed regular and cluster mode",
alias=alias,
)
if is_cluster and (not sentinels or not service_name):
raise ConfigError(
"Redis configuration is invalid. missing sentinels or service_name",
alias=alias,
)
if not is_cluster and (not port or not host):
if not port or not host:
raise ConfigError(
"Redis configuration is invalid. missing port or host", alias=alias
)
if is_cluster:
# todo support all redis connection args via sentinel's connection_kwargs
del alias_config["sentinels"]
del alias_config["cluster"]
del alias_config["service_name"]
self.aliases[alias] = RedisCluster(
sentinels, service_name, **alias_config
)
del alias_config["db"]
self.aliases[alias] = RedisCluster(**alias_config)
else:
self.aliases[alias] = StrictRedis(**alias_config)
@@ -177,27 +76,21 @@ class RedisManager(object):
obj = self.aliases.get(alias)
if not obj:
raise GeneralError(f"Invalid Redis alias {alias}")
if isinstance(obj, RedisCluster):
obj.master.get("health")
return obj.master
else:
obj.get("health")
return obj
obj.get("health")
return obj
def host(self, alias):
r = self.connection(alias)
pool = r.connection_pool
if isinstance(pool, SentinelConnectionPool):
connections = pool.connection_kwargs[
"connection_pool"
]._available_connections
if isinstance(r, RedisCluster):
connections = first(r.connection_pool._available_connections.values())
else:
connections = pool._available_connections
connections = r.connection_pool._available_connections
if len(connections) > 0:
return connections[0].host
else:
if not connections:
return None
return connections[0].host
redman = RedisManager(config.get("hosts.redis"))

View File

@@ -3,7 +3,7 @@ bcrypt>=3.1.4
boltons>=19.1.0
boto3==1.14.13
dpath>=1.4.2,<2.0
elasticsearch>=7.0.0,<8.0.0
elasticsearch==7.13.3
fastjsonschema>=2.8
flask-compress>=1.4.0
flask-cors>=3.0.5
@@ -16,18 +16,19 @@ jinja2==2.11.3
jsonmodels>=2.3
jsonschema>=2.6.0
luqum>=0.10.0
mongoengine==0.19.1
mongoengine==0.23.1
nested_dict>=1.61
packaging==20.3
psutil>=5.6.5
pyhocon>=0.3.35
pyjwt<2.0.0
pymongo==3.10.1
pymongo[srv]==3.12.0
python-rapidjson>=0.6.3
redis>=2.10.5
redis==3.5.3
redis-py-cluster>=2.1.3
related>=0.7.2
requests>=2.13.0
semantic_version>=2.8.3,<3
six
tqdm
validators>=0.12.4
validators>=0.12.4

View File

@@ -26,6 +26,10 @@ credentials {
type: string
description: Credentials secret key
}
label {
type: string
description: Optional credentials label
}
}
}
batch_operation {

View File

@@ -15,6 +15,10 @@ _definitions {
type: string
description: ""
}
label {
type: string
description: Optional credentials label
}
last_used {
type: string
description: ""

File diff suppressed because it is too large Load Diff

View File

@@ -199,6 +199,29 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@@ -302,6 +325,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of models to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_frameworks {
"2.8" {

View File

@@ -152,6 +152,11 @@ _definitions {
type: string
format: "date-time"
}
last_update {
description: "Last update time"
type: string
format: "date-time"
}
tags {
type: array
description: "User-defined tags"
@@ -379,7 +384,7 @@ get_all {
items { type: string }
}
page {
description: "Page number, returns a specific page out of the resulting list of dataviews"
description: "Page number, returns a specific page out of the resulting list of projects"
type: integer
minimum: 0
}
@@ -430,6 +435,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all_ex {
internal: true
@@ -469,7 +497,7 @@ get_all_ex {
default: false
}
check_own_contents {
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks and models are counted"
type: boolean
default: false
}
@@ -488,6 +516,72 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of projects to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
"2.16": ${get_all_ex."2.15"} {
request.properties.stats_with_children {
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
type: boolean
default: true
}
response {
properties {
stats {
properties {
active.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
archived.properties {
total_tasks {
description: "Number of tasks"
type: integer
}
completed_tasks {
description: "Number of tasks completed in the last 24 hours"
type: integer
}
running_tasks {
description: "Number of running tasks"
type: integer
}
}
}
}
}
}
}
}
update {
"2.1" {
@@ -504,10 +598,6 @@ update {
description: "Project name. Unique within the company."
type: string
}
description {
description: "Project description. "
type: string
}
description {
description: "Project description"
type: string
@@ -594,7 +684,7 @@ merge {
type: object
properties {
moved_entities {
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
description: "The number of tasks and models moved from the merged project into the destination"
type: integer
}
moved_projects {
@@ -605,6 +695,42 @@ merge {
}
}
}
validate_delete {
"2.14" {
description: "Validates that the project existis and can be deleted"
request {
type: object
required: [ project ]
properties {
project {
description: "Project ID"
type: string
}
}
}
response {
type: object
properties {
tasks {
description: "The total number of tasks under the project and all its children"
type: integer
}
non_archived_tasks {
description: "The total number of non-archived tasks under the project and all its children"
type: integer
}
models {
description: "The total number of models under the project and all its children"
type: integer
}
non_archived_models {
description: "The total number of non-archived models under the project and all its children"
type: integer
}
}
}
}
}
delete {
"2.1" {
description: "Deletes a project"
@@ -613,7 +739,7 @@ delete {
required: [ project ]
properties {
project {
description: "Project id"
description: "Project ID"
type: string
}
force {
@@ -803,7 +929,6 @@ get_hyper_parameters {
}
}
}
get_task_tags {
"2.8" {
description: "Get user and system tags used for the tasks under the specified projects"

View File

@@ -115,6 +115,29 @@ get_by_id {
get_all_ex {
internal: true
"2.4": ${get_all."2.4"}
"2.15": ${get_all_ex."2.4"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of queues to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.4" {
@@ -178,6 +201,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.4"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of queues to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_default {
"2.4" {

View File

@@ -534,7 +534,7 @@ _definitions {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
models {
description: "Task models"
@@ -685,6 +685,29 @@ get_all_ex {
}
}
}
"2.15": ${get_all_ex."2.13"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all_ex"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
}
}
}
get_all {
"2.1" {
@@ -799,6 +822,29 @@ get_all {
}
}
}
"2.15": ${get_all."2.1"} {
request {
properties {
scroll_id {
type: string
description: "Scroll ID returned from the previos calls to get_all"
}
refresh_scroll {
type: boolean
description: "If set then all the data received with this scroll will be requeried"
}
size {
type: integer
minimum: 1
description: "The number of tasks to retrieve"
}
}
}
response.properties.scroll_id {
type: string
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
}
}
}
get_types {
"2.8" {
@@ -935,7 +981,7 @@ clone {
new_task_container {
description: "The docker container properties for the new task. If not provided then taken from the original task"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@@ -1113,7 +1159,7 @@ create {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@@ -1202,7 +1248,7 @@ validate {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
}
}
@@ -1364,7 +1410,7 @@ edit {
container {
description: "Docker container parameters"
type: object
additionalProperties { type: string }
additionalProperties { type: [string, null] }
}
runtime {
description: "Task runtime mapping"

View File

@@ -4,7 +4,7 @@ from hashlib import md5
from flask import Flask
from flask_compress import Compress
from flask_cors import CORS
from semantic_version import Version
from packaging.version import Version
from apiserver.database import db
from apiserver.bll.statistics.stats_reporter import StatisticsReporter

View File

@@ -1,15 +1,17 @@
from functools import partial
from flask import request, Response, redirect
from werkzeug.datastructures import ImmutableMultiDict
from werkzeug.exceptions import BadRequest
from apiserver.apierrors import APIError
from apiserver.apierrors.base import BaseError
from apiserver.config_repo import config
from apiserver.service_repo import ServiceRepo, APICall
from apiserver.service_repo.auth import AuthType
from apiserver.service_repo.auth import AuthType, Token
from apiserver.service_repo.errors import PathParsingError
from apiserver.utilities import json
from apiserver.utilities.dicts import nested_set
log = config.logger(__file__)
@@ -29,7 +31,7 @@ class RequestHandlers:
try:
call = self._create_api_call(request)
load_data_callback = partial(self._load_call_data, req=request)
content, content_type = ServiceRepo.handle_call(
content, content_type, company = ServiceRepo.handle_call(
call, load_data_callback=load_data_callback
)
@@ -51,20 +53,49 @@ class RequestHandlers:
if call.result.cookies:
for key, value in call.result.cookies.items():
kwargs = config.get("apiserver.auth.cookies")
kwargs = config.get("apiserver.auth.cookies").copy()
if value is None:
kwargs = kwargs.copy()
# Removing a cookie
kwargs["max_age"] = 0
kwargs["expires"] = 0
response.set_cookie(key, "", **kwargs)
else:
response.set_cookie(key, value, **kwargs)
value = ""
elif not company:
# Setting a cookie, let's try to figure out the company
# noinspection PyBroadException
try:
company = Token.decode_identity(value).company
except Exception:
pass
if company:
try:
# use no default value to allow setting a null domain as well
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
except KeyError:
pass
response.set_cookie(key, value, **kwargs)
return response
except Exception as ex:
log.exception(f"Failed processing request {request.url}: {ex}")
return f"Failed processing request {request.url}", 500
@staticmethod
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
def convert_value(v: str):
if v.replace(".", "", 1).isdigit():
return float(v) if "." in v else int(v)
if v in ("true", "True", "TRUE"):
return True
if v in ("false", "False", "FALSE"):
return False
return v
for k, v in md.lists():
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
nested_set(body, k.rstrip("[]").split("."), v)
def _update_call_data(self, call, req):
""" Use request payload/form to fill call data or batched data """
if req.content_type == "application/json-lines":
@@ -82,23 +113,12 @@ class RequestHandlers:
req.on_json_loading_failed(msg)
call.batched_data = items
else:
json_body = req.get_json(force=True, silent=False) if req.data else None
# merge form and args
form = req.form.copy()
form.update(req.args)
form = form.to_dict()
# convert string numbers to floats
for key in form:
if form[key].replace(".", "", 1).isdigit():
if "." in form[key]:
form[key] = float(form[key])
else:
form[key] = int(form[key])
elif form[key].lower() == "true":
form[key] = True
elif form[key].lower() == "false":
form[key] = False
call.data = json_body or form or {}
body = (req.get_json(force=True, silent=False) if req.data else None) or {}
if req.args:
self._apply_multi_dict(body, req.args)
if req.form:
self._apply_multi_dict(body, req.form)
call.data = body
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
call = call or APICall(

View File

@@ -310,6 +310,12 @@ class APICall(DataContainer):
_transaction_headers = _get_headers("Trx")
""" Transaction ID """
_redacted_headers = {
HEADER_AUTHORIZATION: " ",
"Cookie": "=",
}
""" Headers whose value should be redacted. Maps header name to partition char """
@property
def HEADER_TRANSACTION(self):
return self._transaction_headers[0]
@@ -584,17 +590,26 @@ class APICall(DataContainer):
def json_flags(self):
return self._json_flags
@property
def extra_meta_fields(self):
return {}
def mark_end(self):
self._end_ts = time.time()
self._duration = int((self._end_ts - self._start_ts) * 1000)
def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
def get_response(self, include_stack: bool = None) -> Tuple[Union[dict, str], str]:
"""
Get the response for this call.
:param include_stack: If True, stack trace stored in this call's result should
be included in the response (default is False)
be included in the response (default follows configuration)
:return: Response data (encoded according to self.content_type) and the data's content type
"""
include_stack = (
include_stack
if include_stack is not None
else config.get("apiserver.return_stack_to_caller", False)
)
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
"""
@@ -629,6 +644,7 @@ class APICall(DataContainer):
"result_msg": self.result.msg,
"error_stack": self.result.traceback if include_stack else None,
"error_data": self.result.error_data,
**self.extra_meta_fields,
},
"data": self.result.data,
}
@@ -663,3 +679,15 @@ class APICall(DataContainer):
error_data=error_data,
cookies=self._result.cookies,
)
def get_redacted_headers(self):
headers = self.headers.copy()
if not self.requires_authorization or self.auth:
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully
# this will allow us to debug authorization errors).
for header, sep in self._redacted_headers.items():
if header in headers:
prefix, _, redact = headers[header].partition(sep)
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
return headers

View File

@@ -12,6 +12,9 @@ from .payload import Payload
token_secret = config.get('secure.auth.token_secret')
log = config.logger(__file__)
class Token(Payload):
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
@@ -94,3 +97,14 @@ class Token(Payload):
token.exp = now + timedelta(seconds=expiration_sec)
return token.encode(**extra_payload)
@classmethod
def decode_identity(cls, encoded_token):
# noinspection PyBroadException
try:
from ..auth import Identity
decoded = cls.decode(encoded_token, verify=False)
return Identity.from_dict(decoded.get("identity", {}))
except Exception as ex:
log.error(f"Failed parsing identity from encoded token: {ex}")

View File

@@ -1,9 +1,12 @@
import random
import string
sys_random = random.SystemRandom()
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
def get_random_string(
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
) -> str:
"""
Returns a securely generated random string.
@@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
Taken from the django.utils.crypto module.
"""
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
def get_client_id(length=20):
def get_client_id(length: int = 20) -> str:
"""
Create a random secret key.
Taken from the Django project.
"""
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
chars = string.ascii_uppercase + string.digits
return get_random_string(length, chars)
def get_secret_key(length=50):
def get_secret_key(length: int = 50) -> str:
"""
Create a random secret key.
@@ -33,5 +36,5 @@ def get_secret_key(length=50):
NOTE: asterisk is not supported due to issues with environment variables containing
asterisks (in case the secret key is stored in an environment variable)
"""
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&(-_=+)'
chars = string.ascii_letters + string.digits
return get_random_string(length, chars)

View File

@@ -10,6 +10,7 @@ from apiserver.apierrors import APIError, errors
from apiserver.config_repo import config
from apiserver.utilities.partial_version import PartialVersion
from .apicall import APICall
from .auth import Identity
from .endpoint import Endpoint
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
from .util import parse_return_stack_on_code
@@ -37,7 +38,7 @@ class ServiceRepo(object):
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
maximum """
_max_version = PartialVersion("2.13")
_max_version = PartialVersion("2.16")
""" Maximum version number (the highest min_version value across all endpoints) """
_endpoint_exp = (
@@ -233,19 +234,27 @@ class ServiceRepo(object):
return subcode in subcode_list
@classmethod
def _get_company(
def _get_identity(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]:
) -> Optional[Identity]:
authorize = endpoint and endpoint.authorize
if ignore_error or not authorize:
try:
return call.identity.company
return call.identity
except Exception:
return None
return call.identity.company
return call.identity
@classmethod
def _get_company(
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
) -> Optional[str]:
identity = cls._get_identity(call, endpoint=endpoint, ignore_error=ignore_error)
return None if identity is None else identity.company
@classmethod
def handle_call(cls, call: APICall, load_data_callback: Callable = None):
company = None
try:
if call.failed:
raise CallFailedError()
@@ -316,4 +325,4 @@ class ServiceRepo(object):
else:
log.error(console_msg)
return content, content_type
return content, content_type, company

View File

@@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
CredentialsResponse,
RevokeCredentialsRequest,
EditUserReq,
CreateCredentialsRequest,
)
from apiserver.apimodels.base import UpdateResponse
from apiserver.bll.auth import AuthBLL
@@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
""" Generates a token based on a requested user and company. INTERNAL. """
if call.identity.role not in Role.get_system_roles():
if call.identity.role != Role.admin and call.identity.user != request.user:
raise errors.bad_request.InvalidUserId("cannot generate token for another user")
raise errors.bad_request.InvalidUserId(
"cannot generate token for another user"
)
if call.identity.company != request.company:
raise errors.bad_request.InvalidId("cannot generate token in another company")
raise errors.bad_request.InvalidId(
"cannot generate token in another company"
)
call.result.data_model = AuthBLL.get_token_for_user(
user_id=request.user,
@@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
)
def create_user(call: APICall, _, request: CreateUserRequest):
""" Create a user from. INTERNAL. """
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company:
if (
call.identity.role not in Role.get_system_roles()
and request.company != call.identity.company
):
raise errors.bad_request.InvalidId("cannot create user in another company")
user_id = AuthBLL.create_user(request=request, call=call)
@@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
def create_credentials(call: APICall, _, __):
def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
if _is_protected_user(call.identity.user):
raise errors.bad_request.InvalidUserId("protected identity")
@@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
user_id=call.identity.user,
company_id=call.identity.company,
role=call.identity.role,
label=request.label,
)
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
@@ -151,7 +160,9 @@ def get_credentials(call: APICall, _, __):
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
call.result.data_model = GetCredentialsResponse(
credentials=[
CredentialsResponse(access_key=c.key, last_used=c.last_used)
CredentialsResponse(
access_key=c.key, last_used=c.last_used, label=c.label
)
for c in user.credentials
]
)

View File

@@ -1,8 +1,12 @@
import itertools
import math
from collections import defaultdict
from operator import itemgetter
from typing import Sequence, Optional
import attr
import jsonmodels.fields
from boltons.iterutils import bucketize
from apiserver.apierrors import errors
from apiserver.apimodels.events import (
@@ -17,12 +21,19 @@ from apiserver.apimodels.events import (
LogOrderEnum,
GetDebugImageSampleRequest,
NextDebugImageSampleRequest,
MetricVariants as ApiMetrics,
TaskPlotsRequest,
TaskEventsRequest,
ScalarMetricsIterRawRequest,
)
from apiserver.bll.event import EventBLL
from apiserver.bll.event.event_common import EventType
from apiserver.bll.event.event_common import EventType, MetricVariants
from apiserver.bll.event.events_iterator import Scroll
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
from apiserver.bll.task import TaskBLL
from apiserver.config_repo import config
from apiserver.service_repo import APICall, endpoint
from apiserver.utilities import json
from apiserver.utilities import json, extract_properties_to_lists
task_bll = TaskBLL()
event_bll = EventBLL()
@@ -36,7 +47,6 @@ def add(call: APICall, company_id, _):
company_id, [data], call.worker, allow_locked_tasks=allow_locked
)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = 1
@endpoint("events.add_batch")
@@ -47,7 +57,6 @@ def add_batch(call: APICall, company_id, _):
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
call.kpis["events"] = len(events)
@endpoint("events.get_task_log", required_fields=["task"])
@@ -110,7 +119,8 @@ def get_task_log(call, company_id, request: LogEventsRequest):
company_id, task_id, allow_public=True, only=("company", "company_origin")
)[0]
res = event_bll.log_events_iterator.get_task_events(
res = event_bll.events_iterator.get_task_events(
event_type=EventType.task_log,
company_id=task.get_index_company(),
task_id=task_id,
batch_size=request.batch_size,
@@ -255,31 +265,94 @@ def vector_metrics_iter_histogram(call, company_id, _):
)
@endpoint("events.get_task_events", required_fields=["task"])
def get_task_events(call, company_id, _):
task_id = call.data["task"]
batch_size = call.data.get("batch_size", 500)
event_type = call.data.get("event_type")
scroll_id = call.data.get("scroll_id")
order = call.data.get("order") or "asc"
class GetTaskEventsScroll(Scroll):
from_key_value = jsonmodels.fields.StringField()
total = jsonmodels.fields.IntField()
request: TaskEventsRequest = jsonmodels.fields.EmbeddedField(TaskEventsRequest)
def make_response(
total: int, returned: int = 0, scroll_id: str = None, **kwargs
) -> dict:
return {
"returned": returned,
"total": total,
"scroll_id": scroll_id,
**kwargs,
}
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
def get_task_events(call, company_id, request: TaskEventsRequest):
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
company_id, task_id, allow_public=True, only=("company",),
)[0]
result = event_bll.get_task_events(
task.get_index_company(),
task_id,
sort=[{"timestamp": {"order": order}}],
event_type=EventType(event_type) if event_type else EventType.all,
scroll_id=scroll_id,
size=batch_size,
key = ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key)
if not request.scroll_id:
from_key_value = None if (request.order == LogOrderEnum.desc) else 0
total = None
else:
try:
scroll = GetTaskEventsScroll.from_scroll_id(request.scroll_id)
except ValueError:
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
if scroll.from_key_value is None:
return make_response(
scroll_id=request.scroll_id, total=scroll.total, events=[]
)
from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
request = scroll.request
navigate_earlier = request.order == LogOrderEnum.desc
metric_variants = _get_metric_variants_from_request(request.metrics)
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=request.event_type,
company_id=task.company,
task_id=task_id,
metric_variants=metric_variants,
)
batch_size = min(
request.batch_size,
int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
),
)
call.result.data = dict(
events=result.events,
returned=len(result.events),
total=result.total_events,
scroll_id=result.next_scroll_id,
res = event_bll.events_iterator.get_task_events(
event_type=request.event_type,
company_id=task.company,
task_id=task_id,
batch_size=batch_size,
key=ScalarKeyEnum.iter,
navigate_earlier=navigate_earlier,
from_key_value=from_key_value,
metric_variants=metric_variants,
)
scroll = GetTaskEventsScroll(
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
total=total,
request=request,
)
return make_response(
returned=len(res.events),
total=total,
scroll_id=scroll.get_scroll_id(),
events=res.events,
)
@@ -288,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
task_id = call.data["task"]
metric = call.data["metric"]
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -299,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
sort=[{"iter": {"order": "desc"}}],
metric=metric,
scroll_id=scroll_id,
no_scroll=no_scroll,
)
call.result.data = dict(
@@ -321,7 +396,7 @@ def get_task_latest_scalar_values(call, company_id, _):
)
last_iters = event_bll.get_last_iters(
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
)
).get(task_id)
call.result.data = dict(
metrics=metrics,
last_iter=last_iters[0] if last_iters else 0,
@@ -421,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
task_ids = call.data["tasks"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
no_scroll = call.data.get("no_scroll", False)
tasks = task_bll.assert_exists(
company_id=call.identity.company,
@@ -442,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
sort=[{"iter": {"order": "desc"}}],
last_iter_count=iters,
scroll_id=scroll_id,
no_scroll=no_scroll,
)
tasks = {t.id: t.name for t in tasks}
@@ -494,11 +571,22 @@ def get_task_plots_v1_7(call, company_id, _):
)
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
def get_task_plots(call, company_id, _):
task_id = call.data["task"]
iters = call.data.get("iters", 1)
scroll_id = call.data.get("scroll_id")
def _get_metric_variants_from_request(
req_metrics: Sequence[ApiMetrics],
) -> Optional[MetricVariants]:
if not req_metrics:
return None
return {m.metric: m.variants for m in req_metrics}
@endpoint(
"events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest
)
def get_task_plots(call, company_id, request: TaskPlotsRequest):
task_id = request.task
iters = request.iters
scroll_id = request.scroll_id
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company", "company_origin")
@@ -509,6 +597,8 @@ def get_task_plots(call, company_id, _):
sort=[{"iter": {"order": "desc"}}],
last_iterations_per_plot=iters,
scroll_id=scroll_id,
no_scroll=request.no_scroll,
metric_variants=_get_metric_variants_from_request(request.metrics),
)
return_events = result.events
@@ -594,9 +684,9 @@ def get_debug_images_v1_8(call, company_id, _):
response_data_model=DebugImageResponse,
)
def get_debug_images(call, company_id, request: DebugImagesRequest):
task_metrics = defaultdict(set)
task_metrics = defaultdict(dict)
for tm in request.metrics:
task_metrics[tm.task].add(tm.metric)
task_metrics[tm.task][tm.metric] = tm.variants
for metrics in task_metrics.values():
if None in metrics:
metrics.clear()
@@ -734,13 +824,115 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
def _get_top_iter_unique_events(events, max_iters):
top_unique_events = defaultdict(lambda: [])
for e in events:
key = e.get("metric", "") + e.get("variant", "")
for ev in events:
key = ev.get("metric", "") + ev.get("variant", "")
evs = top_unique_events[key]
if len(evs) < max_iters:
evs.append(e)
evs.append(ev)
unique_events = list(
itertools.chain.from_iterable(list(top_unique_events.values()))
)
unique_events.sort(key=lambda e: e["iter"], reverse=True)
return unique_events
class ScalarMetricsIterRawScroll(Scroll):
from_key_value = jsonmodels.fields.StringField()
total = jsonmodels.fields.IntField()
request: ScalarMetricsIterRawRequest = jsonmodels.fields.EmbeddedField(
ScalarMetricsIterRawRequest
)
@endpoint("events.scalar_metrics_iter_raw", min_version="2.16")
def scalar_metrics_iter_raw(
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
):
key = request.key or ScalarKeyEnum.iter
scalar_key = ScalarKey.resolve(key)
if request.batch_size and request.batch_size < 0:
raise errors.bad_request.ValidationError(
"batch_size should be non negative number"
)
if not request.scroll_id:
from_key_value = None
total = None
request.batch_size = request.batch_size or 10_000
else:
try:
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
except ValueError:
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
if scroll.from_key_value is None:
return make_response(
scroll_id=request.scroll_id, total=scroll.total, variants={}
)
from_key_value = scalar_key.cast_value(scroll.from_key_value)
total = scroll.total
request.batch_size = request.batch_size or scroll.request.batch_size
task_id = request.task
task = task_bll.assert_exists(
company_id, task_id, allow_public=True, only=("company",),
)[0]
metric_variants = _get_metric_variants_from_request([request.metric])
if request.count_total and total is None:
total = event_bll.events_iterator.count_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
metric_variants=metric_variants,
)
batch_size = min(
request.batch_size,
int(
config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000)
),
)
events = []
for iteration in range(0, math.ceil(batch_size / 10_000)):
res = event_bll.events_iterator.get_task_events(
event_type=EventType.metrics_scalar,
company_id=task.company,
task_id=task_id,
batch_size=min(batch_size, 10_000),
navigate_earlier=False,
from_key_value=from_key_value,
metric_variants=metric_variants,
key=key,
)
if not res.events:
break
events.extend(res.events)
from_key_value = str(events[-1][scalar_key.field])
key = str(key)
variants = {
variant: extract_properties_to_lists(
["value", scalar_key.field], events, target_keys=["y", key]
)
for variant, events in bucketize(events, key=itemgetter("variant")).items()
}
call.kpis["events"] = len(events)
scroll = ScalarMetricsIterRawScroll(
from_key_value=str(events[-1][scalar_key.field]) if events else None,
total=total,
request=request,
)
return make_response(
returned=len(events),
total=total,
scroll_id=scroll.get_scroll_id(),
variants=variants,
)

View File

@@ -124,11 +124,15 @@ def get_all_ex(call: APICall, company_id, _):
with translate_errors_context():
_process_include_subprojects(call.data)
with TimingContext("mongo", "models_get_all_ex"):
ret_params = {}
models = Model.get_many_with_join(
company=company_id, query_dict=call.data, allow_public=True
company=company_id,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_by_id_ex", required_fields=["id"])
@@ -148,14 +152,16 @@ def get_all(call: APICall, company_id, _):
conform_tag_fields(call, call.data)
with translate_errors_context():
with TimingContext("mongo", "models_get_all"):
ret_params = {}
models = Model.get_many(
company=company_id,
parameters=call.data,
query_dict=call.data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, models)
call.result.data = {"models": models}
call.result.data = {"models": models, **ret_params}
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
@@ -183,7 +189,7 @@ create_fields = {
"metadata": list,
}
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata")
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata", "system_tags", "tags")
def parse_model_fields(call, valid_fields):

View File

@@ -7,7 +7,7 @@ from apiserver.apierrors import errors
from apiserver.apierrors.errors.bad_request import InvalidProjectId
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
from apiserver.apimodels.projects import (
GetHyperParamRequest,
GetParamsRequest,
ProjectTagsRequest,
ProjectTaskParentsRequest,
ProjectHyperparamValuesRequest,
@@ -16,11 +16,14 @@ from apiserver.apimodels.projects import (
MoveRequest,
MergeRequest,
ProjectOrNoneRequest,
ProjectRequest,
)
from apiserver.bll.organization import OrgBLL, Tags
from apiserver.bll.project import ProjectBLL
from apiserver.bll.project.project_cleanup import delete_project
from apiserver.bll.task import TaskBLL
from apiserver.bll.project import ProjectBLL, ProjectQueries
from apiserver.bll.project.project_cleanup import (
delete_project,
validate_project_delete,
)
from apiserver.database.errors import translate_errors_context
from apiserver.database.model.project import Project
from apiserver.database.utils import (
@@ -37,8 +40,8 @@ from apiserver.services.utils import (
from apiserver.timing_context import TimingContext
org_bll = OrgBLL()
task_bll = TaskBLL()
project_bll = ProjectBLL()
project_queries = ProjectQueries()
create_fields = {
"name": None,
@@ -107,8 +110,12 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
_adjust_search_parameters(data, shallow_search=request.shallow_search)
ret_params = {}
projects = Project.get_many_with_join(
company=company_id, query_dict=data, allow_public=allow_public,
company=company_id,
query_dict=data,
allow_public=allow_public,
ret_params=ret_params,
)
if request.check_own_contents and requested_ids:
@@ -124,7 +131,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
conform_output_tags(call, projects)
if not request.include_stats:
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
return
project_ids = {project["id"] for project in projects}
@@ -132,13 +139,14 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
company=company_id,
project_ids=list(project_ids),
specific_state=request.stats_for_state,
include_children=request.stats_with_children,
)
for project in projects:
project["stats"] = stats[project["id"]]
project["sub_projects"] = children[project["id"]]
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
@endpoint("projects.get_all")
@@ -147,15 +155,17 @@ def get_all(call: APICall):
data = call.data
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
ret_params = {}
projects = Project.get_many(
company=call.identity.company,
query_dict=data,
parameters=data,
allow_public=True,
ret_params=ret_params,
)
conform_output_tags(call, projects)
call.result.data = {"projects": projects}
call.result.data = {"projects": projects, **ret_params}
@endpoint(
@@ -230,6 +240,13 @@ def merge(call: APICall, company: str, request: MergeRequest):
}
@endpoint("projects.validate_delete")
def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
call.result.data = validate_project_delete(
company=company_id, project_id=request.project
)
@endpoint("projects.delete", request_data_model=DeleteRequest)
def delete(call: APICall, company_id: str, request: DeleteRequest):
res, affected_projects = delete_project(
@@ -249,7 +266,7 @@ def get_unique_metric_variants(
call: APICall, company_id: str, request: ProjectOrNoneRequest
):
metrics = task_bll.get_unique_metric_variants(
metrics = project_queries.get_unique_metric_variants(
company_id,
[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
@@ -261,11 +278,11 @@ def get_unique_metric_variants(
@endpoint(
"projects.get_hyper_parameters",
min_version="2.9",
request_data_model=GetHyperParamRequest,
request_data_model=GetParamsRequest,
)
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
company_id,
project_ids=[request.project] if request.project else None,
include_subprojects=request.include_subprojects,
@@ -288,7 +305,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
def get_hyperparam_values(
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
):
total, values = task_bll.get_hyperparam_distinct_values(
total, values = project_queries.get_hyperparam_distinct_values(
company_id,
project_ids=request.projects,
section=request.section,

View File

@@ -48,21 +48,29 @@ def get_by_id(call: APICall):
@endpoint("queues.get_all_ex", min_version="2.4")
def get_all_ex(call: APICall):
conform_tag_fields(call, call.data)
ret_params = {}
queues = queue_bll.get_queue_infos(
company_id=call.identity.company, query_dict=call.data
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.get_all", min_version="2.4")
def get_all(call: APICall):
conform_tag_fields(call, call.data)
queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data)
ret_params = {}
queues = queue_bll.get_all(
company_id=call.identity.company,
query_dict=call.data,
ret_params=ret_params,
)
conform_output_tags(call, queues)
call.result.data = {"queues": queues}
call.result.data = {"queues": queues, **ret_params}
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)

View File

@@ -4,7 +4,6 @@ from functools import partial
from typing import Sequence, Union, Tuple
import attr
import dpath
from mongoengine import EmbeddedDocument, Q
from mongoengine.queryset.transform import COMPARISON_OPERATORS
from pymongo import UpdateOne
@@ -220,14 +219,17 @@ def get_all_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
with TimingContext("mongo", "task_get_all_ex"):
_process_include_subprojects(call_data)
ret_params = {}
tasks = Task.get_many_with_join(
company=company_id,
query_dict=call_data,
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
@@ -236,14 +238,13 @@ def get_by_id_ex(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
with TimingContext("mongo", "task_get_by_id_ex"):
tasks = Task.get_many_with_join(
company=company_id, query_dict=call_data, allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
@endpoint("tasks.get_all", required_fields=[])
@@ -252,16 +253,17 @@ def get_all(call: APICall, company_id, _):
call_data = escape_execution_parameters(call)
with translate_errors_context():
with TimingContext("mongo", "task_get_all"):
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks}
with TimingContext("mongo", "task_get_all"):
ret_params = {}
tasks = Task.get_many(
company=company_id,
parameters=call_data,
query_dict=call_data,
allow_public=True,
ret_params=ret_params,
)
unprepare_from_saved(call, tasks)
call.result.data = {"tasks": tasks, **ret_params}
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
@@ -403,15 +405,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
escape_dict_field(fields, path)
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
for field in task_script_stripped_fields:
try:
path = f"script/{field}"
value = dpath.get(fields, path)
script = fields.get("script")
if script:
for field in task_script_stripped_fields:
value = script.get(field)
if isinstance(value, str):
value = value.strip()
dpath.set(fields, path, value)
except KeyError:
pass
script[field] = value.strip()
return fields
@@ -546,10 +545,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
}
def prepare_update_fields(call: APICall, task, call_data):
def prepare_update_fields(call: APICall, call_data):
valid_fields = deepcopy(Task.user_set_allowed())
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
update_fields["output__error"] = None
update_fields.update(
status=None, status_reason=None, status_message=None, output__error=None
)
t_fields = task_fields
t_fields.add("output__error")
fields = parse_from_call(call_data, update_fields, t_fields)
@@ -569,7 +570,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
if not task:
raise errors.bad_request.InvalidTaskId(id=task_id)
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
if not partial_update_dict:
return UpdateResponse(updated=0)
@@ -642,7 +643,7 @@ def update_batch(call: APICall, company_id, _):
updated_projects = set()
for id, data in items.items():
task = tasks[id]
fields, valid_fields = prepare_update_fields(call, task, data)
fields, valid_fields = prepare_update_fields(call, data)
partial_update_dict = Task.get_safe_update_dict(fields)
if not partial_update_dict:
continue
@@ -744,8 +745,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
)
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
call.result.data = {
"params": [{"task": task, **data} for task, data in tasks_params.items()]
@@ -754,39 +754,36 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
force=request.force,
)
}
call.result.data = {
"updated": HyperParams.edit_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
replace_hyperparams=request.replace_hyperparams,
force=request.force,
)
}
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
)
}
call.result.data = {
"deleted": HyperParams.delete_params(
company_id,
task_id=request.task,
hyperparams=request.hyperparams,
force=request.force,
)
}
@endpoint(
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
)
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
with translate_errors_context():
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
tasks_params = HyperParams.get_configurations(
company_id, task_ids=request.tasks, names=request.names
)
call.result.data = {
"configurations": [
@@ -801,10 +798,9 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
def get_configuration_names(
call: APICall, company_id, request: GetConfigurationNamesRequest
):
with translate_errors_context():
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
tasks_params = HyperParams.get_configuration_names(
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
)
call.result.data = {
"configurations": [
@@ -815,31 +811,29 @@ def get_configuration_names(
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
with translate_errors_context():
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
force=request.force,
)
}
call.result.data = {
"updated": HyperParams.edit_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
replace_configuration=request.replace_configuration,
force=request.force,
)
}
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
def delete_configuration(
call: APICall, company_id, request: DeleteConfigurationRequest
):
with translate_errors_context():
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
)
}
call.result.data = {
"deleted": HyperParams.delete_configuration(
company_id,
task_id=request.task,
configuration=request.configuration,
force=request.force,
)
}
@endpoint(
@@ -854,6 +848,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
queue_id=request.queue,
status_message=request.status_message,
status_reason=request.status_reason,
force=request.force,
)
call.result.data_model = EnqueueResponse(queued=queued, **res)
@@ -1061,6 +1056,8 @@ def delete(call: APICall, company_id, request: DeleteRequest):
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
)
if deleted:
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
@@ -1077,6 +1074,8 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
force=request.force,
return_file_urls=request.return_file_urls,
delete_output_models=request.delete_output_models,
status_message=request.status_message,
status_reason=request.status_reason,
),
ids=request.ids,
)
@@ -1169,15 +1168,14 @@ def ping(_, company_id, request: PingRequest):
def add_or_update_artifacts(
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
):
with translate_errors_context():
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=request.force,
)
}
call.result.data = {
"updated": Artifacts.add_or_update_artifacts(
company_id=company_id,
task_id=request.task,
artifacts=request.artifacts,
force=True,
)
}
@endpoint(
@@ -1186,31 +1184,28 @@ def add_or_update_artifacts(
request_data_model=DeleteArtifactsRequest,
)
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
with translate_errors_context():
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=request.force,
)
}
call.result.data = {
"deleted": Artifacts.delete_artifacts(
company_id=company_id,
task_id=request.task,
artifact_ids=request.artifacts,
force=True,
)
}
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
)
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
def make_public(call: APICall, company_id, request: MakePublicRequest):
with translate_errors_context():
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)
call.result.data = Task.set_public(
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
)
@endpoint("tasks.move", request_data_model=MoveRequest)

View File

@@ -71,7 +71,7 @@ class TestService(TestCase, TestServiceInterface):
delete_params=delete_params,
)
def setUp(self, version="1.7"):
def setUp(self, version="999.0"):
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
self._deferred = []
self._version = parse(version)

View File

@@ -38,7 +38,7 @@ class TestEntityOrdering(TestService):
self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20)
field_vals = []
page_size = 2
page_size = 4
num_pages = 5
for page in range(num_pages):
paged_tasks = self._get_page_tasks(

View File

@@ -2,9 +2,6 @@ from apiserver.tests.automated import TestService
class TestOrganization(TestService):
def setUp(self, version="2.12"):
super().setUp(version=version)
def test_get_user_companies(self):
company = self.api.organization.get_user_companies().companies[0]
self.assertEqual(len(company.owners), company.allocated)

View File

@@ -0,0 +1,80 @@
import math
from apiserver.tests.automated import TestService
class TestPagingAndScrolling(TestService):
name_prefix = f"Test paging "
def setUp(self, **kwargs):
super().setUp(**kwargs)
self.task_ids = self._create_tasks()
def _create_tasks(self):
tasks = [
self._temp_task(
name=f"{self.name_prefix}{i}",
hyperparams={
"test": {
"param": {
"section": "test",
"name": "param",
"type": "str",
"value": str(i),
}
}
},
)
for i in range(18)
]
return tasks
def test_paging(self):
page_size = 10
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
tasks = self._get_tasks(page=page, page_size=page_size,).tasks
self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
def test_scrolling(self):
page_size = 10
scroll_id = None
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
start = page * page_size
expected_size = min(page_size, len(self.task_ids) - start)
res = self._get_tasks(size=page_size, scroll_id=scroll_id,)
self.assertTrue(res.scroll_id)
scroll_id = res.scroll_id
tasks = res.tasks
self.assertEqual(len(tasks), expected_size)
for i, t in enumerate(tasks):
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
# no more data in this scroll
tasks = self._get_tasks(size=page_size, scroll_id=scroll_id,).tasks
self.assertFalse(tasks)
# refresh brings all
tasks = self._get_tasks(
size=page_size, scroll_id=scroll_id, refresh_scroll=True,
).tasks
self.assertEqual([t.id for t in tasks], self.task_ids)
def _get_tasks(self, **page_params):
return self.api.tasks.get_all_ex(
name="^Test paging ",
order_by=["hyperparams.test.param.value"],
**page_params,
)
def _temp_task(self, name, **kwargs):
return self.create_temp(
"tasks",
name=name,
comment="Test task",
type="testing",
input=dict(view=dict()),
**kwargs,
)

View File

@@ -0,0 +1,54 @@
from apiserver.apierrors import errors
from apiserver.database.model import EntityVisibility
from apiserver.tests.automated import TestService
from apiserver.database.utils import id as db_id
class TestProjectsDelete(TestService):
def setUp(self, version="2.14"):
super().setUp(version=version)
def new_task(self, **kwargs):
return self.create_temp(
"tasks", type="testing", name=db_id(), input=dict(view=dict()), **kwargs
)
def new_model(self, **kwargs):
return self.create_temp("models", uri="file:///a/b", name=db_id(), labels={}, **kwargs)
def new_project(self, **kwargs):
return self.create_temp("projects", name=db_id(), description="", **kwargs)
def test_delete_fails_with_active_task(self):
project = self.new_project()
self.new_task(project=project)
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.tasks, 1)
self.assertEqual(res.non_archived_tasks, 1)
with self.api.raises(errors.bad_request.ProjectHasTasks):
self.api.projects.delete(project=project)
def test_delete_with_archived_task(self):
project = self.new_project()
self.new_task(project=project, system_tags=[EntityVisibility.archived.value])
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.tasks, 1)
self.assertEqual(res.non_archived_tasks, 0)
self.api.projects.delete(project=project)
def test_delete_fails_with_active_model(self):
project = self.new_project()
self.new_model(project=project)
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.models, 1)
self.assertEqual(res.non_archived_models, 1)
with self.api.raises(errors.bad_request.ProjectHasModels):
self.api.projects.delete(project=project)
def test_delete_with_archived_model(self):
project = self.new_project()
self.new_model(project=project, system_tags=[EntityVisibility.archived.value])
res = self.api.projects.validate_delete(project=project)
self.assertEqual(res.models, 1)
self.assertEqual(res.non_archived_models, 0)
self.api.projects.delete(project=project)

View File

@@ -6,9 +6,6 @@ from apiserver.tests.automated import TestService
class TestQueueAndModelMetadata(TestService):
def setUp(self, version="2.13"):
super().setUp(version=version)
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
def test_queue_metas(self):
@@ -72,3 +69,12 @@ class TestQueueAndModelMetadata(TestService):
return self.create_temp(
"models", uri="file://test", name=name, labels={}, **kwargs
)
def temp_project(self, **kwargs) -> str:
self.update_missing(
kwargs,
name="Test models meta",
description="test",
delete_params=dict(force=True),
)
return self.create_temp("projects", **kwargs)

View File

@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestSubProjects(TestService):
def setUp(self, **kwargs):
super().setUp(version="2.13")
def test_project_aggregations(self):
"""This test requires user with user_auth_only... credentials in db"""
user2_client = APIClient(
@@ -203,6 +200,9 @@ class TestSubProjects(TestService):
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
self.assertEqual(
{sp.name for sp in res1.sub_projects},
{
@@ -215,6 +215,9 @@ class TestSubProjects(TestService):
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
self.assertEqual(res2.sub_projects, [])
def _run_tasks(self, *tasks):

View File

@@ -198,6 +198,9 @@ class TestTags(TestService):
def assertProjectStats(self, project: AttrDict):
self.assertEqual(set(project.stats.keys()), {"active"})
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
self.assertEqual(project.stats.active.completed_tasks, 1)
self.assertEqual(project.stats.active.total_tasks, 1)
self.assertEqual(project.stats.active.running_tasks, 0)
for status, count in project.stats.active.status_count.items():
self.assertEqual(count, 1 if status == "stopped" else 0)

View File

@@ -82,14 +82,10 @@ class TestTasksArtifacts(TestService):
# test edit running task
self.api.tasks.started(task=task)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts, res)
with self.api.raises(InvalidTaskStatus):
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)

View File

@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
class TestTaskEvents(TestService):
def setUp(self, version="2.9"):
super().setUp(version=version)
def _temp_task(self, name="test task events"):
task_input = dict(
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
@@ -257,6 +254,45 @@ class TestTaskEvents(TestService):
self.assertEqual(len(task_data["x"]), iterations)
self.assertEqual(len(task_data["y"]), iterations)
def test_task_metric_raw(self):
metric = "Metric1"
variant = "Variant1"
iter_count = 100
task = self._temp_task()
events = [
{
**self._create_task_event("training_stats_scalar", task, iteration),
"metric": metric,
"variant": variant,
"value": iteration,
}
for iteration in range(iter_count)
]
self.send_batch(events)
batch_size = 15
metric_param = {"metric": metric, "variants": [variant]}
res = self.api.events.scalar_metrics_iter_raw(
task=task, batch_size=batch_size, metric=metric_param, count_total=True
)
self.assertEqual(res.total, len(events))
self.assertTrue(res.scroll_id)
res_iters = []
res_ys = []
calls = 0
while res.returned or calls > 10:
calls += 1
res_iters.extend(res.variants[variant]["iter"])
res_ys.extend(res.variants[variant]["y"])
scroll_id = res.scroll_id
res = self.api.events.scalar_metrics_iter_raw(
task=task, metric=metric_param, scroll_id=scroll_id
)
self.assertEqual(calls, len(events) // batch_size + 1)
self.assertEqual(res_iters, [ev["iter"] for ev in events])
self.assertEqual(res_ys, [ev["value"] for ev in events])
def test_task_metric_value_intervals(self):
metric = "Metric1"
variant = "Variant1"

View File

@@ -1,4 +1,5 @@
import os
from datetime import timedelta, datetime
from threading import Thread
from time import sleep
from typing import Optional
@@ -10,6 +11,7 @@ from semantic_version import Version
from apiserver.config_repo import config
from apiserver.config.info import get_version
from apiserver.database.model.settings import Settings
from apiserver.redis_manager import redman
from apiserver.utilities.threads_manager import ThreadsManager
log = config.logger(__name__)
@@ -17,6 +19,8 @@ log = config.logger(__name__)
class CheckUpdatesThread(Thread):
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
_lock_name = "check_updates"
_redis = redman.connection("apiserver")
@attr.s(auto_attribs=True)
class _VersionResponse:
@@ -29,6 +33,19 @@ class CheckUpdatesThread(Thread):
target=self._check_updates, daemon=True
)
@property
def update_interval(self):
return timedelta(
seconds=max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24,
)
),
60 * 5,
)
)
def start(self) -> None:
if not self._enabled:
log.info("Checking for updates is disabled")
@@ -37,12 +54,13 @@ class CheckUpdatesThread(Thread):
@property
def component_name(self) -> str:
return config.get("apiserver.check_for_updates.component_name", "clearml-server")
return config.get(
"apiserver.check_for_updates.component_name", "clearml-server"
)
def _check_new_version_available(self) -> Optional[_VersionResponse]:
url = config.get(
"apiserver.check_for_updates.url",
"https://updates.clear.ml/updates",
"apiserver.check_for_updates.url", "https://updates.clear.ml/updates",
)
uid = Settings.get_by_key("server.uuid")
@@ -81,34 +99,31 @@ class CheckUpdatesThread(Thread):
)
def _check_updates(self):
update_interval_sec = max(
float(
config.get(
"apiserver.check_for_updates.check_interval_sec",
60 * 60 * 24,
)
),
60 * 5,
)
while not ThreadsManager.terminating:
# noinspection PyBroadException
try:
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception:
log.exception("Failed obtaining updates")
if self._redis.set(
self._lock_name,
value=datetime.utcnow().isoformat(),
ex=self.update_interval - timedelta(seconds=60),
nx=True,
):
response = self._check_new_version_available()
if response:
if response.patch_upgrade:
log.info(
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
)
else:
log.info(
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
f" is recommended!"
)
except Exception as ex:
log.exception("Failed obtaining updates: " + str(ex))
sleep(update_interval_sec)
sleep(self.update_interval.total_seconds())
check_updates_thread = CheckUpdatesThread()

View File

@@ -10,6 +10,7 @@ def extract_properties_to_lists(
key_names: Sequence[str],
data: Sequence[dict],
extract_func: Optional[Callable[[dict], Tuple]] = None,
target_keys: Optional[Sequence[str]] = None,
) -> dict:
"""
Given a list of dictionaries and names of dictionary keys
@@ -20,9 +21,10 @@ def extract_properties_to_lists(
:param extract_func: the optional callable that extracts properties
from a dictionary and put them in a tuple in the order corresponding to
key_names. If not specified then properties are extracted according to key_names
:param target_keys: optional alternative keys to use in the target dictionary. must be equal in length to key_names.
"""
if not data:
return {k: [] for k in key_names}
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
return dict(zip(key_names, map(list, value_sequences)))
return dict(zip((target_keys or key_names), map(list, value_sequences)))

View File

@@ -37,15 +37,12 @@ def deep_merge(source: dict, override: dict) -> dict:
def nested_get(
dictionary: Mapping,
path: Union[Sequence[str], str],
path: Sequence[str],
default: Optional[Union[Any, Callable]] = None,
) -> Any:
if isinstance(path, str):
path = [path]
node = dictionary
for key in path:
if key not in node:
if not node or key not in node:
if callable(default):
return default()
return default

View File

@@ -1 +1 @@
__version__ = "1.0.2"
__version__ = "1.2.0"

35
docker/build/Dockerfile Normal file
View File

@@ -0,0 +1,35 @@
FROM centos/nodejs-12-centos7 AS webapp
USER root
WORKDIR /opt
RUN git clone https://github.com/allegroai/clearml-web.git
RUN mv clearml-web /opt/open-webapp
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
FROM centos:7 AS staging_image
COPY --chmod=744 docker/build/internal_files/entrypoint.sh /opt/clearml/
COPY fileserver /opt/clearml/fileserver/
COPY apiserver /opt/clearml/apiserver/
FROM centos:7
COPY --from=staging_image /opt/clearml/ /opt/clearml/
COPY --chmod=744 docker/build/internal_files/final_image_preparation.sh /tmp/internal_files/
COPY docker/build/internal_files/clearml.conf.template /tmp/internal_files/
RUN /bin/bash -c '/tmp/internal_files/final_image_preparation.sh'
COPY --from=webapp /opt/open-webapp/build /usr/share/nginx/html
EXPOSE 8080
EXPOSE 8008
EXPOSE 8081
ARG VERSION
ARG BUILD
ENV CLEARML_SERVER_VERSION=${VERSION}
ENV CLEARML_SERVER_BUILD=${BUILD}
WORKDIR /opt/clearml/
ENTRYPOINT ["/opt/clearml/entrypoint.sh"]

View File

@@ -0,0 +1,9 @@
#!/usr/bin/env bash
set -x
set -e
cd /opt/open-webapp/
npm ci --unsafe-perm node-sass
cd /opt/open-webapp/
npm run build

View File

@@ -0,0 +1,97 @@
# For more information on configuration, see:
# * Official English Documentation: http://nginx.org/en/docs/
# * Official Russian Documentation: http://nginx.org/ru/docs/
user nginx;
worker_processes auto;
error_log /var/log/nginx/error.log;
pid /run/nginx.pid;
# Load dynamic modules. See /usr/share/doc/nginx/README.dynamic.
include /usr/share/nginx/modules/*.conf;
events {
worker_connections 1024;
}
http {
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
'$status $body_bytes_sent "$http_referer" '
'"$http_user_agent" "$http_x_forwarded_for"';
access_log /var/log/nginx/access.log main;
sendfile on;
tcp_nopush on;
tcp_nodelay on;
keepalive_timeout 65;
types_hash_max_size 2048;
include /etc/nginx/mime.types;
default_type application/octet-stream;
# Load modular configuration files from the /etc/nginx/conf.d directory.
# See http://nginx.org/en/docs/ngx_core_module.html#include
# for more information.
include /etc/nginx/conf.d/*.conf;
server {
listen 80 default_server;
listen [::]:80 default_server;
server_name _;
root /usr/share/nginx/html;
proxy_http_version 1.1;
# comppression
gzip on;
gzip_comp_level 9;
gzip_http_version 1.0;
gzip_min_length 512;
gzip_proxied expired no-cache no-store private auth;
gzip_types text/plain
text/css
application/json
application/javascript
application/x-javascript
text/xml application/xml
application/xml+rss
text/javascript
application/x-font-ttf
font/woff2
image/svg+xml
image/x-icon;
# Load configuration files for the default server block.
include /etc/nginx/default.d/*.conf;
location / {
try_files $uri$args $uri$args/ $uri index.html /index.html;
}
location /version.json {
add_header Cache-Control 'no-cache';
}
location /api {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass ${NGINX_APISERVER_ADDR};
rewrite /api/(.*) /$1 break;
}
location /files {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header Host $host;
proxy_pass ${NGINX_FILESERVER_ADDR};
rewrite /files/(.*) /$1 break;
}
error_page 404 /404.html;
location = /40x.html {
}
error_page 500 502 503 504 /50x.html;
location = /50x.html {
}
}
}

View File

@@ -0,0 +1,67 @@
#!/usr/bin/env bash
set -e
mkdir -p /var/log/clearml
SERVER_TYPE=$1
if (( $# < 1 )) ; then
echo "The server type was not stated. It should be either apiserver, webserver or fileserver."
sleep 60
exit 1
elif [[ ${SERVER_TYPE} == "apiserver" ]]; then
cd /opt/clearml/
python3 -m apiserver.apierrors_generator
if [[ -n $CLEARML_USE_GUNICORN ]]; then
MAX_REQUESTS=
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS ]]; then
MAX_REQUESTS="--max-requests $CLEARML_GUNICORN_MAX_REQUESTS"
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS_JITTER ]]; then
MAX_REQUESTS="$MAX_REQUESTS --max-requests-jitter $CLEARML_GUNICORN_MAX_REQUESTS_JITTER"
fi
fi
export GUNICORN_CMD_ARGS=${CLEARML_GUNICORN_CMD_ARGS}
# Note: don't be tempted to "fix" $MAX_REQUESTS with "$MAX_REQUESTS" as this produces an empty arg which fucks up gunicorn
gunicorn \
-w "${CLEARML_GUNICORN_WORKERS:-8}" \
-t "${CLEARML_GUNICORN_TIMEOUT:-600}" --bind="${CLEARML_GUNICORN_BIND:-0.0.0.0:8008}" \
$MAX_REQUESTS apiserver.server:app
else
python3 -m apiserver.server
fi
elif [[ ${SERVER_TYPE} == "webserver" ]]; then
if [[ "${USER_KEY}" != "" ]] || [[ "${USER_SECRET}" != "" ]] || [[ "${COMPANY_ID}" != "" ]]; then
cat << EOF > /usr/share/nginx/html/credentials.json
{
"userKey": "${USER_KEY}",
"userSecret": "${USER_SECRET}",
"companyID": "${COMPANY_ID}"
}
EOF
fi
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
envsubst '${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/nginx.conf
#start the server
/usr/sbin/nginx -g "daemon off;"
elif [[ ${SERVER_TYPE} == "fileserver" ]]; then
cd /opt/clearml/fileserver/
if [ "$FILESERVER_USE_GUNICORN" = true ] ; then
gunicorn -t 600 --bind=0.0.0.0:8081 fileserver:app
else
python3 fileserver.py
fi
else
echo "Server type ${SERVER_TYPE} is invalid. Please choose either apiserver, webserver or fileserver."
fi

View File

@@ -0,0 +1,18 @@
#!/usr/bin/env bash
set -o errexit
set -o nounset
set -o pipefail
yum update -y
yum install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
yum install -y python36 python36-pip nginx gcc python3-devel gettext
yum -y upgrade
python3 -m pip install -r /opt/clearml/fileserver/requirements.txt
python3 -m pip install -r /opt/clearml/apiserver/requirements.txt
mkdir -p /opt/clearml/log
mkdir -p /opt/clearml/config
ln -s /dev/stdout /var/log/nginx/access.log
ln -s /dev/stderr /var/log/nginx/error.log
mv /etc/nginx/nginx.conf /etc/nginx/nginx.conf.orig
mv /tmp/internal_files/clearml.conf.template /etc/nginx/clearml.conf.template
yum clean all

View File

@@ -19,6 +19,7 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
@@ -38,7 +39,8 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
@@ -60,7 +62,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
restart: unless-stopped
volumes:
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -87,7 +89,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:3.6.5
image: mongo:3.6.23
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:

View File

@@ -19,6 +19,7 @@ services:
environment:
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
CLEARML_ELASTIC_SERVICE_PORT: 9200
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
CLEARML_MONGODB_SERVICE_HOST: mongo
CLEARML_MONGODB_SERVICE_PORT: 27017
CLEARML_REDIS_SERVICE_HOST: redis
@@ -38,7 +39,8 @@ services:
- backend
container_name: clearml-elastic
environment:
ES_JAVA_OPTS: -Xms2g -Xmx2g
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
bootstrap.memory_lock: "true"
cluster.name: clearml
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
@@ -60,7 +62,7 @@ services:
nofile:
soft: 65536
hard: 65536
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
restart: unless-stopped
volumes:
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
@@ -86,7 +88,7 @@ services:
networks:
- backend
container_name: clearml-mongo
image: mongo:3.6.5
image: mongo:3.6.23
restart: unless-stopped
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
volumes:
@@ -121,7 +123,9 @@ services:
- backend
container_name: clearml-agent-services
image: allegroai/clearml-agent-services:latest
restart: unless-stopped
deploy:
restart_policy:
condition: on-failure
privileged: true
environment:
CLEARML_HOST_IP: ${CLEARML_HOST_IP}
@@ -132,7 +136,7 @@ services:
CLEARML_API_SECRET_KEY: ${CLEARML_API_SECRET_KEY:-}
CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:-">=0.17.0"}
CLEARML_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
@@ -142,6 +146,7 @@ services:
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
CLEARML_WORKER_ID: "clearml-services"
CLEARML_AGENT_DOCKER_HOST_MOUNT: "/opt/clearml/agent:/root/.clearml"
SHUTDOWN_IF_NO_ACCESS_KEY: 1
volumes:
- /var/run/docker.sock:/var/run/docker.sock
- /opt/clearml/agent:/root/.clearml

Binary file not shown.

After

Width:  |  Height:  |  Size: 155 KiB

View File

@@ -1,5 +1,6 @@
""" A Simple file server for uploading and downloading files """
import json
import mimetypes
import os
from argparse import ArgumentParser
from pathlib import Path
@@ -48,8 +49,12 @@ def upload():
@app.route("/<path:path>", methods=["GET"])
def download(path):
as_attachment = "download" in request.args
_, encoding = mimetypes.guess_type(os.path.basename(path))
mimetype = "application/octet-stream" if encoding == "gzip" else None
response = send_from_directory(
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment, mimetype=mimetype
)
if config.get("fileserver.download.disable_browser_caching", False):
headers = response.headers