mirror of
https://github.com/clearml/clearml-server
synced 2025-06-26 23:15:47 +00:00
Compare commits
65 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc4fd9e61c | ||
|
|
8908c7dcf9 | ||
|
|
b9996e2c1a | ||
|
|
afdc56f37c | ||
|
|
a25cd5dae8 | ||
|
|
447adb9090 | ||
|
|
92fd98d5ad | ||
|
|
c4001b4037 | ||
|
|
970a32287a | ||
|
|
17cd48dada | ||
|
|
ea3b6e955f | ||
|
|
843450bb9b | ||
|
|
e149af58b1 | ||
|
|
604a38035b | ||
|
|
cae38a365b | ||
|
|
e334246b46 | ||
|
|
36e013b40c | ||
|
|
f20cd6536e | ||
|
|
446bd35006 | ||
|
|
a377a7e315 | ||
|
|
3d046ac282 | ||
|
|
a08fa9a0e1 | ||
|
|
5856ed2836 | ||
|
|
d295355d99 | ||
|
|
77350f6119 | ||
|
|
bc2c2ebbfd | ||
|
|
1502e02a1a | ||
|
|
d0e2313a24 | ||
|
|
d8ba1a8ea7 | ||
|
|
ca7937fc4e | ||
|
|
df89bcceef | ||
|
|
cfccbe05c1 | ||
|
|
e352a6a1e7 | ||
|
|
8a3d992aaf | ||
|
|
c37f3d8d5b | ||
|
|
a96870e092 | ||
|
|
6bf1032237 | ||
|
|
3d816c747d | ||
|
|
3f2b96266b | ||
|
|
22b16d12eb | ||
|
|
c55b6f30df | ||
|
|
b7045d3d28 | ||
|
|
e31a404885 | ||
|
|
643588b71a | ||
|
|
a64c4d264d | ||
|
|
567780e188 | ||
|
|
1bc8529d83 | ||
|
|
6b480d7e87 | ||
|
|
083fd315e9 | ||
|
|
ef20e76174 | ||
|
|
8c8910808e | ||
|
|
f6ad379310 | ||
|
|
c5d6ce3e65 | ||
|
|
694dbc31c4 | ||
|
|
6488dc54e6 | ||
|
|
158da9b480 | ||
|
|
ec2e071ab7 | ||
|
|
465e270342 | ||
|
|
6705aff56f | ||
|
|
9069cfe1da | ||
|
|
677bb3ba6d | ||
|
|
cb253cff9e | ||
|
|
39ceb5ac5c | ||
|
|
d4edeaaf1b | ||
|
|
56aea1ffb8 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -12,7 +12,6 @@ test-reports
|
||||
.pytest_cache
|
||||
venv
|
||||
*.noseids
|
||||
build
|
||||
*.egg-info
|
||||
.cache
|
||||
.mypy_cache
|
||||
|
||||
62
README.md
62
README.md
@@ -8,28 +8,43 @@
|
||||
[](https://img.shields.io/badge/license-SSPL-green.svg)
|
||||
[](https://img.shields.io/badge/python-3.6%20%7C%203.7-blue.svg)
|
||||
[](https://img.shields.io/github/release-pre/allegroai/trains-server.svg)
|
||||
[](https://artifacthub.io/packages/search?repo=allegroai)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
<div align="center">
|
||||
|
||||
**v0.16 Upgrade Notice**
|
||||
**Note regarding Apache Log4j2 Remote Code Execution (RCE) Vulnerability - CVE-2021-44228 - ESA-2021-31**
|
||||
|
||||
</div>
|
||||
|
||||
In v0.16, the Elasticsearch subsystem of ClearML Server has been upgraded from version 5.6 to version 7.6. This change necessitates the migration of the database contents to accommodate the change in index structure across the different versions.
|
||||
According to [ElasticSearch's latest report](https://discuss.elastic.co/t/apache-log4j2-remote-code-execution-rce-vulnerability-cve-2021-44228-esa-2021-31/291476),
|
||||
supported versions of Elasticsearch (6.8.9+, 7.8+) used with recent versions of the JDK (JDK9+) **are not susceptible to either remote code execution or information leakage**
|
||||
due to Elasticsearch’s usage of the Java Security Manager.
|
||||
|
||||
Follow [this procedure](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_es7_migration.html) to migrate existing data.
|
||||
**As the latest version of ClearML Server uses Elasticsearch 7.10+ with JDK15, it is not affected by these vulnerabilities.**
|
||||
|
||||
As a precaution, we've upgraded the ES version to 7.16.2 and added the mitigation recommended by ElasticSearch to our latest [docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
While previous Elasticsearch versions (5.6.11+, 6.4.0+ and 7.0.0+) used by older ClearML Server versions are only susceptible to the information leakage vulnerability
|
||||
(which in any case **does not permit access to data within the Elasticsearch cluster**),
|
||||
we still recommend upgrading to the latest version of ClearML Server. Alternatively, you can apply the mitigation as implemented in our latest
|
||||
[docker-compose.yml](https://github.com/allegroai/clearml-server/blob/cfccbe05c158b75e520581f86e9668291da5c70a/docker/docker-compose.yml#L42) file.
|
||||
|
||||
**Update 15 December**: A further vulnerability (CVE-2021-45046) was disclosed on December 14th.
|
||||
ElasticSearch's guidance for Elasticsearch remains unchanged by this new vulnerability, thus **not affecting ClearML Server**.
|
||||
|
||||
**Update 22 December**: To keep with ElasticSearch's recommendations, we've upgraded the ES version to the newly released 7.16.2
|
||||
|
||||
---
|
||||
|
||||
### ClearML Server
|
||||
## ClearML Server
|
||||
#### *Formerly known as Trains Server*
|
||||
|
||||
The **ClearML Server** is the backend service infrastructure for [ClearML](https://github.com/allegroai/clearml).
|
||||
It allows multiple users to collaborate and manage their experiments.
|
||||
By default, **ClearML** is set up to work with the **ClearML** demo server, which is open to anyone and resets periodically.
|
||||
**ClearML** offers a [free hosted service](https://app.clear.ml/), which is maintained by **ClearML** and open to anyone.
|
||||
In order to host your own server, you will need to launch the **ClearML Server** and point **ClearML** to it.
|
||||
|
||||
The **ClearML Server** contains the following components:
|
||||
@@ -45,7 +60,7 @@ You can quickly [deploy](#launching-the-clearml-server) your **ClearML Server**
|
||||
## System design
|
||||
|
||||
|
||||

|
||||

|
||||
|
||||
The **ClearML Server** has two supported configurations:
|
||||
- Single IP (domain) with the following open ports
|
||||
@@ -78,20 +93,19 @@ For example, to see if port `8080` is in use:
|
||||
|
||||
Launch The **ClearML Server** in any of the following formats:
|
||||
|
||||
- Pre-built [AWS EC2 AMI](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_aws_ec2_ami.html)
|
||||
- Pre-built [GCP Custom Image](hhttps://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_gcp.html)
|
||||
- Pre-built [AWS EC2 AMI](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_aws_ec2_ami)
|
||||
- Pre-built [GCP Custom Image](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_gcp)
|
||||
- Pre-built Docker Image
|
||||
- [Linux](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [macOS](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_linux_mac.html)
|
||||
- [Windows 10](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_win.html)
|
||||
- [Linux](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [macOS](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_linux_mac)
|
||||
- [Windows 10](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_win)
|
||||
- Kubernetes
|
||||
- [Kubernetes Helm](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes_helm.html)
|
||||
- Manual [Kubernetes installation](https://allegro.ai/clearml/docs/docs/deploying_clearml/clearml_server_kubernetes.html)
|
||||
- [Kubernetes Helm](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes_helm)
|
||||
- Manual [Kubernetes installation](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_kubernetes)
|
||||
|
||||
## Connecting ClearML to your ClearML Server
|
||||
|
||||
By default, the **ClearML** client is set up to work with the [**ClearML** demo server](https://demoapp.demo.clear.ml/).
|
||||
To have the **ClearML** client use your **ClearML Server** instead:
|
||||
In order to set up the **ClearML** client to work with your **ClearML Server**:
|
||||
- Run the `clearml-init` command for an interactive setup.
|
||||
- Or manually edit `~/clearml.conf` file, making sure the server settings (`api_server`, `web_server`, `file_server`) are configured correctly, for example:
|
||||
|
||||
@@ -138,8 +152,8 @@ Do not enqueue training / inference tasks into the `services` queue, as it will
|
||||
|
||||
The **ClearML Server** provides a few additional useful features, which can be manually enabled:
|
||||
|
||||
* [Web login authentication](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://allegro.ai/clearml/docs/deploying_clearml/clearml_server_config/#task_watchdog)
|
||||
* [Web login authentication](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#web-login-authentication)
|
||||
* [Non-responsive experiments watchdog](https://clear.ml/docs/latest/docs/deploying_clearml/clearml_server_config#non-responsive-task-watchdog)
|
||||
|
||||
## Restarting ClearML Server
|
||||
|
||||
@@ -189,14 +203,14 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
```
|
||||
|
||||
1. Configure the ClearML-Agent Services (not supported on Windows installation).
|
||||
If `TRAINS_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `TRAINS_AGENT_GIT_USER` / `TRAINS_AGENT_GIT_PASS` are not provided,
|
||||
If `CLEARML_HOST_IP` is not provided, ClearML-Agent Services will use the external
|
||||
public address of the **ClearML Server**. If `CLEARML_AGENT_GIT_USER` / `CLEARML_AGENT_GIT_PASS` are not provided,
|
||||
the ClearML-Agent Services will not be able to access any private repositories for running service tasks.
|
||||
|
||||
```bash
|
||||
export TRAINS_HOST_IP=server_host_ip_here
|
||||
export TRAINS_AGENT_GIT_USER=git_username_here
|
||||
export TRAINS_AGENT_GIT_PASS=git_password_here
|
||||
export CLEARML_HOST_IP=server_host_ip_here
|
||||
export CLEARML_AGENT_GIT_USER=git_username_here
|
||||
export CLEARML_AGENT_GIT_PASS=git_password_here
|
||||
```
|
||||
|
||||
1. Spin up the docker containers, it will automatically pull the latest **ClearML Server** build
|
||||
@@ -205,12 +219,12 @@ To upgrade your existing **ClearML Server** deployment:
|
||||
docker-compose -f docker-compose.yml up
|
||||
```
|
||||
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://allegro.ai/clearml/docs/docs/faq/faq.html).**
|
||||
**\* If something went wrong along the way, check our FAQ: [Common Docker Upgrade Errors](https://clear.ml/docs/latest/docs/faq/).**
|
||||
|
||||
|
||||
## Community & Support
|
||||
|
||||
If you have any questions, look to the ClearML [FAQ](https://allegro.ai/clearml/docs/docs/faq/faq.html), or
|
||||
If you have any questions, look to the ClearML [FAQ](https://clear.ml/docs/latest/docs/faq), or
|
||||
tag your questions on [stackoverflow](https://stackoverflow.com/questions/tagged/clearml) with '**clearml**' tag.
|
||||
|
||||
For feature requests or bug reports, please use [GitHub issues](https://github.com/allegroai/clearml-server/issues).
|
||||
|
||||
@@ -71,6 +71,8 @@
|
||||
408: ["cannot_update_project_location", "Cannot update project location. Use projects.move instead"]
|
||||
409: ["project_path_exceeds_max", "Project path exceed the maximum allowed depth"]
|
||||
410: ["project_source_and_destination_are_the_same", "Project has the same source and destination paths"]
|
||||
411: ["project_cannot_be_moved_under_itself", "Project can not be moved under itself in the projects hierarchy"]
|
||||
412: ["project_cannot_be_merged_into_its_child", "Project can not be merged into its own child"]
|
||||
|
||||
# Queues
|
||||
701: ["invalid_queue_id", "invalid queue id"]
|
||||
|
||||
@@ -75,6 +75,7 @@ class CreateUserResponse(Base):
|
||||
class Credentials(Base):
|
||||
access_key = StringField(required=True)
|
||||
secret_key = StringField(required=True)
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CredentialsResponse(Credentials):
|
||||
@@ -82,6 +83,10 @@ class CredentialsResponse(Credentials):
|
||||
last_used = DateTimeField(default=None)
|
||||
|
||||
|
||||
class CreateCredentialsRequest(Base):
|
||||
label = StringField()
|
||||
|
||||
|
||||
class CreateCredentialsResponse(Base):
|
||||
credentials = EmbeddedField(Credentials)
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from enum import auto
|
||||
from typing import Sequence, Optional
|
||||
|
||||
from jsonmodels import validators
|
||||
from jsonmodels.fields import StringField, BoolField
|
||||
from jsonmodels.fields import StringField, BoolField, EmbeddedField
|
||||
from jsonmodels.models import Base
|
||||
from jsonmodels.validators import Length, Min, Max
|
||||
|
||||
@@ -14,12 +14,18 @@ from apiserver.utilities.stringenum import StringEnum
|
||||
|
||||
|
||||
class HistogramRequestBase(Base):
|
||||
samples: int = IntField(default=6000, validators=[Min(1), Max(6000)])
|
||||
samples: int = IntField(default=2000, validators=[Min(1), Max(6000)])
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
|
||||
|
||||
class MetricVariants(Base):
|
||||
metric: str = StringField(required=True)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class ScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
task: str = StringField(required=True)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
|
||||
class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
@@ -39,6 +45,7 @@ class MultiTaskScalarMetricsIterHistogramRequest(HistogramRequestBase):
|
||||
class TaskMetric(Base):
|
||||
task: str = StringField(required=True)
|
||||
metric: str = StringField(default=None)
|
||||
variants: Sequence[str] = ListField(items_types=str)
|
||||
|
||||
|
||||
class DebugImagesRequest(Base):
|
||||
@@ -59,8 +66,8 @@ class TaskMetricVariant(Base):
|
||||
|
||||
class GetDebugImageSampleRequest(TaskMetricVariant):
|
||||
iteration: Optional[int] = IntField()
|
||||
scroll_id: Optional[str] = StringField()
|
||||
refresh: bool = BoolField(default=False)
|
||||
scroll_id: Optional[str] = StringField()
|
||||
|
||||
|
||||
class NextDebugImageSampleRequest(Base):
|
||||
@@ -74,14 +81,34 @@ class LogOrderEnum(StringEnum):
|
||||
desc = auto()
|
||||
|
||||
|
||||
class LogEventsRequest(Base):
|
||||
class TaskEventsRequestBase(Base):
|
||||
task: str = StringField(required=True)
|
||||
batch_size: int = IntField(default=500)
|
||||
|
||||
|
||||
class TaskEventsRequest(TaskEventsRequestBase):
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
event_type: EventType = ActualEnumField(EventType, default=EventType.all)
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum, default=LogOrderEnum.asc)
|
||||
scroll_id: str = StringField()
|
||||
count_total: bool = BoolField(default=True)
|
||||
|
||||
|
||||
class LogEventsRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField(default=5000)
|
||||
navigate_earlier: bool = BoolField(default=True)
|
||||
from_timestamp: Optional[int] = IntField()
|
||||
order: Optional[str] = ActualEnumField(LogOrderEnum)
|
||||
|
||||
|
||||
class ScalarMetricsIterRawRequest(TaskEventsRequestBase):
|
||||
batch_size: int = IntField()
|
||||
key: ScalarKeyEnum = ActualEnumField(ScalarKeyEnum, default=ScalarKeyEnum.iter)
|
||||
metric: MetricVariants = EmbeddedField(MetricVariants, required=True)
|
||||
count_total: bool = BoolField(default=False)
|
||||
scroll_id: str = StringField()
|
||||
|
||||
|
||||
class IterationEvents(Base):
|
||||
iter: int = IntField()
|
||||
events: Sequence[dict] = ListField(items_types=dict)
|
||||
@@ -102,3 +129,11 @@ class TaskMetricsRequest(Base):
|
||||
items_types=str, validators=[Length(minimum_value=1)]
|
||||
)
|
||||
event_type: EventType = ActualEnumField(EventType, required=True)
|
||||
|
||||
|
||||
class TaskPlotsRequest(Base):
|
||||
task: str = StringField(required=True)
|
||||
iters: int = IntField(default=1)
|
||||
scroll_id: str = StringField()
|
||||
no_scroll: bool = BoolField(default=False)
|
||||
metrics: Sequence[MetricVariants] = ListField(items_types=MetricVariants)
|
||||
|
||||
@@ -27,7 +27,7 @@ class ProjectOrNoneRequest(models.Base):
|
||||
include_subprojects = fields.BoolField(default=True)
|
||||
|
||||
|
||||
class GetHyperParamRequest(ProjectOrNoneRequest):
|
||||
class GetParamsRequest(ProjectOrNoneRequest):
|
||||
page = fields.IntField(default=0)
|
||||
page_size = fields.IntField(default=500)
|
||||
|
||||
@@ -53,6 +53,7 @@ class ProjectHyperparamValuesRequest(MultiProjectRequest):
|
||||
|
||||
class ProjectsGetRequest(models.Base):
|
||||
include_stats = fields.BoolField(default=False)
|
||||
stats_with_children = fields.BoolField(default=True)
|
||||
stats_for_state = ActualEnumField(EntityVisibility, default=EntityVisibility.active)
|
||||
non_public = fields.BoolField(default=False)
|
||||
active_users = fields.ListField(str)
|
||||
|
||||
@@ -2,7 +2,11 @@ from datetime import datetime
|
||||
|
||||
from apiserver import database
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.auth import GetTokenResponse, CreateUserRequest, Credentials as CredModel
|
||||
from apiserver.apimodels.auth import (
|
||||
GetTokenResponse,
|
||||
CreateUserRequest,
|
||||
Credentials as CredModel,
|
||||
)
|
||||
from apiserver.apimodels.users import CreateRequest as Users_CreateRequest
|
||||
from apiserver.bll.user import UserBLL
|
||||
from apiserver.config_repo import config
|
||||
@@ -57,6 +61,7 @@ class AuthBLL:
|
||||
api_version=str(ServiceRepo.max_endpoint_version()),
|
||||
server_version=str(get_version()),
|
||||
server_build=str(get_build_number()),
|
||||
feature_set="basic",
|
||||
)
|
||||
|
||||
return GetTokenResponse(token=token.decode("ascii"))
|
||||
@@ -144,7 +149,7 @@ class AuthBLL:
|
||||
|
||||
@classmethod
|
||||
def create_credentials(
|
||||
cls, user_id: str, company_id: str, role: str = None
|
||||
cls, user_id: str, company_id: str, role: str = None, label: str = None,
|
||||
) -> CredModel:
|
||||
|
||||
with translate_errors_context():
|
||||
@@ -153,7 +158,9 @@ class AuthBLL:
|
||||
if not user:
|
||||
raise errors.bad_request.InvalidUserId(**query)
|
||||
|
||||
cred = CredModel(access_key=get_client_id(), secret_key=get_secret_key())
|
||||
cred = CredModel(
|
||||
access_key=get_client_id(), secret_key=get_secret_key(), label=label
|
||||
)
|
||||
user.credentials.append(
|
||||
Credentials(key=cred.access_key, secret=cred.secret_key)
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from concurrent.futures.thread import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Tuple, Optional, Mapping, Set
|
||||
from typing import Sequence, Tuple, Optional, Mapping
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
@@ -18,6 +18,7 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -74,7 +75,7 @@ class DebugImagesIterator:
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
task_metrics: Mapping[str, dict],
|
||||
iter_count: int,
|
||||
navigate_earlier: bool = True,
|
||||
refresh: bool = False,
|
||||
@@ -118,7 +119,7 @@ class DebugImagesIterator:
|
||||
self,
|
||||
company_id,
|
||||
state: DebugImageEventsScrollState,
|
||||
task_metrics: Mapping[str, Set[str]],
|
||||
task_metrics: Mapping[str, dict],
|
||||
):
|
||||
"""
|
||||
Determine the metrics for which new debug image events were added
|
||||
@@ -158,11 +159,11 @@ class DebugImagesIterator:
|
||||
task_metrics_to_recalc = {}
|
||||
for task, metrics_times in update_times.items():
|
||||
old_metric_states = task_metric_states[task]
|
||||
metrics_to_recalc = set(
|
||||
m
|
||||
metrics_to_recalc = {
|
||||
m: task_metrics[task].get(m)
|
||||
for m, t in metrics_times.items()
|
||||
if m not in old_metric_states or old_metric_states[m].timestamp < t
|
||||
)
|
||||
}
|
||||
if metrics_to_recalc:
|
||||
task_metrics_to_recalc[task] = metrics_to_recalc
|
||||
|
||||
@@ -196,7 +197,7 @@ class DebugImagesIterator:
|
||||
]
|
||||
|
||||
def _init_task_states(
|
||||
self, company_id: str, task_metrics: Mapping[str, Set[str]]
|
||||
self, company_id: str, task_metrics: Mapping[str, dict]
|
||||
) -> Sequence[TaskScrollState]:
|
||||
"""
|
||||
Returned initialized metric scroll stated for the requested task metrics
|
||||
@@ -213,7 +214,7 @@ class DebugImagesIterator:
|
||||
]
|
||||
|
||||
def _init_metric_states_for_task(
|
||||
self, task_metrics: Tuple[str, Set[str]], company_id: str
|
||||
self, task_metrics: Tuple[str, dict], company_id: str
|
||||
) -> Sequence[MetricState]:
|
||||
"""
|
||||
Return metric scroll states for the task filled with the variant states
|
||||
@@ -222,10 +223,11 @@ class DebugImagesIterator:
|
||||
task, metrics = task_metrics
|
||||
must = [{"term": {"task": task}}, {"exists": {"field": "url"}}]
|
||||
if metrics:
|
||||
must.append({"terms": {"metric": list(metrics)}})
|
||||
must.append(get_metric_variants_condition(metrics))
|
||||
query = {"bool": {"must": must}}
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"query": {"bool": {"must": must}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
|
||||
@@ -6,9 +6,8 @@ from collections import defaultdict
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from operator import attrgetter
|
||||
from typing import Sequence, Set, Tuple, Optional, Dict
|
||||
from typing import Sequence, Set, Tuple, Optional, List, Mapping, Union
|
||||
|
||||
import six
|
||||
from elasticsearch import helpers
|
||||
from elasticsearch.helpers import BulkIndexError
|
||||
from mongoengine import Q
|
||||
@@ -22,14 +21,16 @@ from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
delete_company_events,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.event.events_iterator import EventsIterator, TaskEventsResult
|
||||
from apiserver.bll.util import parallel_chunked_decorator
|
||||
from apiserver.database import utils as dbutils
|
||||
from apiserver.es_factory import es_factory
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.bll.event.debug_images_iterator import DebugImagesIterator
|
||||
from apiserver.bll.event.event_metrics import EventMetrics
|
||||
from apiserver.bll.event.log_events_iterator import LogEventsIterator, TaskEventsResult
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
@@ -43,8 +44,8 @@ from apiserver.utilities.json import loads
|
||||
# noinspection PyTypeChecker
|
||||
EVENT_TYPES: Set[str] = set(map(attrgetter("value"), EventType))
|
||||
LOCKED_TASK_STATUSES = (TaskStatus.publishing, TaskStatus.published)
|
||||
MAX_LONG = 2**63 - 1
|
||||
MIN_LONG = -2**63
|
||||
MAX_LONG = 2 ** 63 - 1
|
||||
MIN_LONG = -(2 ** 63)
|
||||
|
||||
|
||||
class PlotFields:
|
||||
@@ -72,7 +73,7 @@ class EventBLL(object):
|
||||
self.redis = redis or redman.connection("apiserver")
|
||||
self.debug_images_iterator = DebugImagesIterator(es=self.es, redis=self.redis)
|
||||
self.debug_sample_history = DebugSampleHistory(es=self.es, redis=self.redis)
|
||||
self.log_events_iterator = LogEventsIterator(es=self.es)
|
||||
self.events_iterator = EventsIterator(es=self.es)
|
||||
|
||||
@property
|
||||
def metrics(self) -> EventMetrics:
|
||||
@@ -94,7 +95,7 @@ class EventBLL(object):
|
||||
def add_events(
|
||||
self, company_id, events, worker, allow_locked_tasks=False
|
||||
) -> Tuple[int, int, dict]:
|
||||
actions = []
|
||||
actions: List[dict] = []
|
||||
task_ids = set()
|
||||
task_iteration = defaultdict(lambda: 0)
|
||||
task_last_scalar_events = nested_dict(
|
||||
@@ -197,7 +198,6 @@ class EventBLL(object):
|
||||
|
||||
actions.append(es_action)
|
||||
|
||||
action: Dict[dict]
|
||||
plot_actions = [
|
||||
action["_source"]
|
||||
for action in actions
|
||||
@@ -260,7 +260,8 @@ class EventBLL(object):
|
||||
invalid_iterations_count = errors_per_type.get(invalid_iteration_error)
|
||||
if invalid_iterations_count:
|
||||
raise BulkIndexError(
|
||||
f"{invalid_iterations_count} document(s) failed to index.", [invalid_iteration_error]
|
||||
f"{invalid_iterations_count} document(s) failed to index.",
|
||||
[invalid_iteration_error],
|
||||
)
|
||||
|
||||
if not added:
|
||||
@@ -466,10 +467,16 @@ class EventBLL(object):
|
||||
task_id: str,
|
||||
num_last_iterations: int,
|
||||
event_type: EventType,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@@ -499,7 +506,7 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
@@ -527,6 +534,8 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size: int = 500,
|
||||
scroll_id: str = None,
|
||||
no_scroll: bool = False,
|
||||
metric_variants: MetricVariants = None,
|
||||
):
|
||||
if scroll_id == self.empty_scroll:
|
||||
return TaskEventsResult()
|
||||
@@ -555,6 +564,8 @@ class EventBLL(object):
|
||||
|
||||
if last_iterations_per_plot is None:
|
||||
must.append({"terms": {"task": tasks}})
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(tasks):
|
||||
@@ -563,6 +574,7 @@ class EventBLL(object):
|
||||
task_id=task_id,
|
||||
num_last_iterations=last_iterations_per_plot,
|
||||
event_type=event_type,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
@@ -600,7 +612,7 @@ class EventBLL(object):
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@@ -669,19 +681,20 @@ class EventBLL(object):
|
||||
sort=None,
|
||||
size=500,
|
||||
scroll_id=None,
|
||||
):
|
||||
no_scroll=False,
|
||||
) -> TaskEventsResult:
|
||||
if scroll_id == self.empty_scroll:
|
||||
return [], scroll_id, 0
|
||||
return TaskEventsResult()
|
||||
|
||||
if scroll_id:
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_res = self.es.scroll(scroll_id=scroll_id, scroll="1h")
|
||||
else:
|
||||
task_ids = [task_id] if isinstance(task_id, six.string_types) else task_id
|
||||
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
|
||||
must = []
|
||||
if metric:
|
||||
must.append({"term": {"metric": metric}})
|
||||
@@ -691,26 +704,24 @@ class EventBLL(object):
|
||||
if last_iter_count is None:
|
||||
must.append({"terms": {"task": task_ids}})
|
||||
else:
|
||||
should = []
|
||||
for i, task_id in enumerate(task_ids):
|
||||
last_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_id,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
if not last_iters:
|
||||
continue
|
||||
should.append(
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
tasks_iters = self.get_last_iters(
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
task_id=task_ids,
|
||||
iters=last_iter_count,
|
||||
)
|
||||
should = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task}},
|
||||
{"terms": {"iter": last_iters}},
|
||||
]
|
||||
}
|
||||
)
|
||||
}
|
||||
for task, last_iters in tasks_iters.items()
|
||||
if last_iters
|
||||
]
|
||||
if not should:
|
||||
return TaskEventsResult()
|
||||
must.append({"bool": {"should": should}})
|
||||
@@ -731,7 +742,7 @@ class EventBLL(object):
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
ignore=404,
|
||||
scroll="1h",
|
||||
**({} if no_scroll else {"scroll": "1h"}),
|
||||
)
|
||||
|
||||
events, total_events, next_scroll_id = self._get_events_from_es_res(es_res)
|
||||
@@ -748,6 +759,7 @@ class EventBLL(object):
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
|
||||
query = {"bool": {"must": [{"term": {"task": task_id}}]}}
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
@@ -768,7 +780,7 @@ class EventBLL(object):
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext(
|
||||
@@ -787,21 +799,24 @@ class EventBLL(object):
|
||||
|
||||
return metrics
|
||||
|
||||
def get_task_latest_scalar_values(self, company_id: str, task_id: str):
|
||||
def get_task_latest_scalar_values(
|
||||
self, company_id, task_id
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
event_type = EventType.metrics_scalar
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return {}
|
||||
return [], 0
|
||||
|
||||
query = {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
}
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"query_string": {"query": "value:>0"}},
|
||||
{"term": {"task": task_id}},
|
||||
]
|
||||
}
|
||||
},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
@@ -905,34 +920,47 @@ class EventBLL(object):
|
||||
return iterations, vectors
|
||||
|
||||
def get_last_iters(
|
||||
self, company_id: str, event_type: EventType, task_id: str, iters: int
|
||||
):
|
||||
self,
|
||||
company_id: str,
|
||||
event_type: EventType,
|
||||
task_id: Union[str, Sequence[str]],
|
||||
iters: int,
|
||||
) -> Mapping[str, Sequence]:
|
||||
if check_empty_data(self.es, company_id=company_id, event_type=event_type):
|
||||
return []
|
||||
return {}
|
||||
|
||||
task_ids = [task_id] if isinstance(task_id, str) else task_id
|
||||
es_req: dict = {
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
"tasks": {
|
||||
"terms": {"field": "task"},
|
||||
"aggs": {
|
||||
"iters": {
|
||||
"terms": {
|
||||
"field": "iter",
|
||||
"size": iters,
|
||||
"order": {"_key": "desc"},
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
"query": {"bool": {"must": [{"term": {"task": task_id}}]}},
|
||||
"query": {"bool": {"must": [{"terms": {"task": task_ids}}]}},
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "task_last_iter"):
|
||||
es_res = search_company_events(
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req
|
||||
self.es, company_id=company_id, event_type=event_type, body=es_req,
|
||||
)
|
||||
|
||||
if "aggregations" not in es_res:
|
||||
return []
|
||||
return {}
|
||||
|
||||
return [b["key"] for b in es_res["aggregations"]["iters"]["buckets"]]
|
||||
return {
|
||||
tb["key"]: [ib["key"] for ib in tb["iters"]["buckets"]]
|
||||
for tb in es_res["aggregations"]["tasks"]["buckets"]
|
||||
}
|
||||
|
||||
def delete_task_events(self, company_id, task_id, allow_locked=False):
|
||||
with translate_errors_context():
|
||||
@@ -965,7 +993,9 @@ class EventBLL(object):
|
||||
so it should be checked by the calling code
|
||||
"""
|
||||
es_req = {"query": {"terms": {"task": task_ids}}}
|
||||
with translate_errors_context(), TimingContext("es", "delete_multi_tasks_events"):
|
||||
with translate_errors_context(), TimingContext(
|
||||
"es", "delete_multi_tasks_events"
|
||||
):
|
||||
es_res = delete_company_events(
|
||||
es=self.es,
|
||||
company_id=company_id,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from enum import Enum
|
||||
from typing import Union, Sequence
|
||||
from typing import Union, Sequence, Mapping
|
||||
|
||||
from boltons.typeutils import classproperty
|
||||
from elasticsearch import Elasticsearch
|
||||
@@ -16,6 +16,9 @@ class EventType(Enum):
|
||||
all = "*"
|
||||
|
||||
|
||||
MetricVariants = Mapping[str, Sequence[str]]
|
||||
|
||||
|
||||
class EventSettings:
|
||||
@classproperty
|
||||
def max_workers(self):
|
||||
@@ -63,4 +66,31 @@ def delete_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.delete_by_query(index=es_index, body=body, **kwargs)
|
||||
return es.delete_by_query(
|
||||
index=es_index, body=body, conflicts="proceed", **kwargs
|
||||
)
|
||||
|
||||
|
||||
def count_company_events(
|
||||
es: Elasticsearch, company_id: str, event_type: EventType, body: dict, **kwargs
|
||||
) -> dict:
|
||||
es_index = get_index_name(company_id, event_type.value)
|
||||
return es.count(index=es_index, body=body, **kwargs)
|
||||
|
||||
|
||||
def get_metric_variants_condition(metric_variants: MetricVariants,) -> Sequence:
|
||||
conditions = [
|
||||
{
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"metric": metric}},
|
||||
{"terms": {"variant": variants}},
|
||||
]
|
||||
}
|
||||
}
|
||||
if variants
|
||||
else {"term": {"metric": metric}}
|
||||
for metric, variants in metric_variants.items()
|
||||
]
|
||||
|
||||
return {"bool": {"should": conditions}}
|
||||
|
||||
@@ -15,6 +15,8 @@ from apiserver.bll.event.event_common import (
|
||||
EventSettings,
|
||||
search_company_events,
|
||||
check_empty_data,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKey, ScalarKeyEnum
|
||||
from apiserver.config_repo import config
|
||||
@@ -34,7 +36,12 @@ class EventMetrics:
|
||||
self.es = es
|
||||
|
||||
def get_scalar_metrics_average_per_iter(
|
||||
self, company_id: str, task_id: str, samples: int, key: ScalarKeyEnum
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
samples: int,
|
||||
key: ScalarKeyEnum,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get scalar metric histogram per metric and variant
|
||||
@@ -46,7 +53,12 @@ class EventMetrics:
|
||||
return {}
|
||||
|
||||
return self._get_scalar_average_per_iter_core(
|
||||
task_id, company_id, event_type, samples, ScalarKey.resolve(key)
|
||||
task_id=task_id,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
samples=samples,
|
||||
key=ScalarKey.resolve(key),
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
def _get_scalar_average_per_iter_core(
|
||||
@@ -57,6 +69,7 @@ class EventMetrics:
|
||||
samples: int,
|
||||
key: ScalarKey,
|
||||
run_parallel: bool = True,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> dict:
|
||||
intervals = self._get_task_metric_intervals(
|
||||
company_id=company_id,
|
||||
@@ -64,6 +77,7 @@ class EventMetrics:
|
||||
task_id=task_id,
|
||||
samples=samples,
|
||||
field=key.field,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
if not intervals:
|
||||
return {}
|
||||
@@ -197,6 +211,7 @@ class EventMetrics:
|
||||
task_id: str,
|
||||
samples: int,
|
||||
field: str = "iter",
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Sequence[MetricInterval]:
|
||||
"""
|
||||
Calculate interval per task metric variant so that the resulting
|
||||
@@ -204,9 +219,14 @@ class EventMetrics:
|
||||
Return the list og metric variant intervals as the following tuple:
|
||||
(metric, variant, interval, samples)
|
||||
"""
|
||||
must = [{"term": {"task": task_id}}]
|
||||
if metric_variants:
|
||||
must.append(get_metric_variants_condition(metric_variants))
|
||||
query = {"bool": {"must": must}}
|
||||
|
||||
es_req = {
|
||||
"size": 0,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"query": query,
|
||||
"aggs": {
|
||||
"metrics": {
|
||||
"terms": {
|
||||
|
||||
205
apiserver/bll/event/events_iterator.py
Normal file
205
apiserver/bll/event/events_iterator.py
Normal file
@@ -0,0 +1,205 @@
|
||||
from typing import Optional, Tuple, Sequence, Any
|
||||
|
||||
import attr
|
||||
import jsonmodels.models
|
||||
import jwt
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
MetricVariants,
|
||||
get_metric_variants_condition,
|
||||
count_company_events,
|
||||
)
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class EventsIterator:
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_key_value: Optional[Any] = None,
|
||||
metric_variants: MetricVariants = None,
|
||||
key: ScalarKeyEnum = ScalarKeyEnum.timestamp,
|
||||
**kwargs,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, event_type):
|
||||
return TaskEventsResult()
|
||||
|
||||
from_key_value = kwargs.pop("from_timestamp", from_key_value)
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
event_type=event_type,
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=ScalarKey.resolve(key),
|
||||
)
|
||||
return res
|
||||
|
||||
def count_task_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> int:
|
||||
query, _ = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
es_req = {
|
||||
"query": query,
|
||||
}
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "count_task_events"):
|
||||
es_result = count_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
|
||||
return es_result["count"]
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
event_type: EventType,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
key: ScalarKey,
|
||||
from_key_value: Optional[Any],
|
||||
metric_variants: MetricVariants = None,
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous key-field value (timestamp or iter) either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If from_key_field is not set then start either from latest or earliest.
|
||||
For the last key-field value all the events are brought (even if the resulting size exceeds batch_size)
|
||||
so that events with this value will not be lost between the calls.
|
||||
"""
|
||||
query, must = self._get_initial_query_and_must(task_id, metric_variants)
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": query,
|
||||
"sort": {key.field: "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_key_value:
|
||||
es_req["search_after"] = [from_key_value]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": must + [{"term": {key.field: events[-1][key.field]}}]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=event_type,
|
||||
body=es_req,
|
||||
routing=task_id,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_initial_query_and_must(
|
||||
task_id: str, metric_variants: MetricVariants = None
|
||||
) -> Tuple[dict, list]:
|
||||
if not metric_variants:
|
||||
must = [{"term": {"task": task_id}}]
|
||||
query = {"term": {"task": task_id}}
|
||||
else:
|
||||
must = [
|
||||
{"term": {"task": task_id}},
|
||||
get_metric_variants_condition(metric_variants),
|
||||
]
|
||||
query = {"bool": {"must": must}}
|
||||
return query, must
|
||||
|
||||
|
||||
class Scroll(jsonmodels.models.Base):
|
||||
def get_scroll_id(self) -> str:
|
||||
return jwt.encode(
|
||||
self.to_struct(),
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
).decode()
|
||||
|
||||
@classmethod
|
||||
def from_scroll_id(cls, scroll_id: str):
|
||||
try:
|
||||
return cls(
|
||||
**jwt.decode(
|
||||
scroll_id,
|
||||
key=config.get(
|
||||
"services.events.events_retrieval.scroll_id_key", "1234567890"
|
||||
),
|
||||
)
|
||||
)
|
||||
except jwt.PyJWTError:
|
||||
raise ValueError("Invalid Scroll ID")
|
||||
@@ -1,127 +0,0 @@
|
||||
from typing import Optional, Tuple, Sequence
|
||||
|
||||
import attr
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.bll.event.event_common import (
|
||||
check_empty_data,
|
||||
search_company_events,
|
||||
EventType,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class TaskEventsResult:
|
||||
total_events: int = 0
|
||||
next_scroll_id: str = None
|
||||
events: list = attr.Factory(list)
|
||||
|
||||
|
||||
class LogEventsIterator:
|
||||
EVENT_TYPE = EventType.task_log
|
||||
|
||||
def __init__(self, es: Elasticsearch):
|
||||
self.es = es
|
||||
|
||||
def get_task_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool = True,
|
||||
from_timestamp: Optional[int] = None,
|
||||
) -> TaskEventsResult:
|
||||
if check_empty_data(self.es, company_id, self.EVENT_TYPE):
|
||||
return TaskEventsResult()
|
||||
|
||||
res = TaskEventsResult()
|
||||
res.events, res.total_events = self._get_events(
|
||||
company_id=company_id,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_timestamp=from_timestamp,
|
||||
)
|
||||
return res
|
||||
|
||||
def _get_events(
|
||||
self,
|
||||
company_id: str,
|
||||
task_id: str,
|
||||
batch_size: int,
|
||||
navigate_earlier: bool,
|
||||
from_timestamp: Optional[int],
|
||||
) -> Tuple[Sequence[dict], int]:
|
||||
"""
|
||||
Return up to 'batch size' events starting from the previous timestamp either in the
|
||||
direction of earlier events (navigate_earlier=True) or in the direction of later events.
|
||||
If last_min_timestamp and last_max_timestamp are not set then start either from latest or earliest.
|
||||
For the last timestamp all the events are brought (even if the resulting size
|
||||
exceeds batch_size) so that this timestamp events will not be lost between the calls.
|
||||
In case any events were received update 'last_min_timestamp' and 'last_max_timestamp'
|
||||
"""
|
||||
|
||||
# retrieve the next batch of events
|
||||
es_req = {
|
||||
"size": batch_size,
|
||||
"query": {"term": {"task": task_id}},
|
||||
"sort": {"timestamp": "desc" if navigate_earlier else "asc"},
|
||||
}
|
||||
|
||||
if from_timestamp:
|
||||
es_req["search_after"] = [from_timestamp]
|
||||
|
||||
with translate_errors_context(), TimingContext("es", "get_task_events"):
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
hits = es_result["hits"]["hits"]
|
||||
hits_total = es_result["hits"]["total"]["value"]
|
||||
if not hits:
|
||||
return [], hits_total
|
||||
|
||||
events = [hit["_source"] for hit in hits]
|
||||
|
||||
# retrieve the events that match the last event timestamp
|
||||
# but did not make it into the previous call due to batch_size limitation
|
||||
es_req = {
|
||||
"size": 10000,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{"term": {"task": task_id}},
|
||||
{"term": {"timestamp": events[-1]["timestamp"]}},
|
||||
]
|
||||
}
|
||||
},
|
||||
}
|
||||
es_result = search_company_events(
|
||||
self.es,
|
||||
company_id=company_id,
|
||||
event_type=self.EVENT_TYPE,
|
||||
body=es_req,
|
||||
)
|
||||
last_second_hits = es_result["hits"]["hits"]
|
||||
if not last_second_hits or len(last_second_hits) < 2:
|
||||
# if only one element is returned for the last timestamp
|
||||
# then it is already present in the events
|
||||
return events, hits_total
|
||||
|
||||
already_present_ids = set(hit["_id"] for hit in hits)
|
||||
last_second_events = [
|
||||
hit["_source"]
|
||||
for hit in last_second_hits
|
||||
if hit["_id"] not in already_present_ids
|
||||
]
|
||||
|
||||
# return the list merged from original query results +
|
||||
# leftovers from the last timestamp
|
||||
return (
|
||||
[*events, *last_second_events],
|
||||
hits_total,
|
||||
)
|
||||
@@ -4,6 +4,8 @@ Module for polymorphism over different types of X axes in scalar aggregations
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import auto
|
||||
|
||||
from typing import Any
|
||||
|
||||
from apiserver.utilities import extract_properties_to_lists
|
||||
from apiserver.utilities.stringenum import StringEnum
|
||||
from apiserver.config_repo import config
|
||||
@@ -96,6 +98,10 @@ class ScalarKey(ABC):
|
||||
"""
|
||||
return int(iter_data[self.bucket_key_key]), iter_data["avg_val"]["value"]
|
||||
|
||||
def cast_value(self, value: Any) -> Any:
|
||||
"""Cast value to appropriate type"""
|
||||
return value
|
||||
|
||||
|
||||
class TimestampKey(ScalarKey):
|
||||
"""
|
||||
@@ -117,6 +123,9 @@ class TimestampKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class IterKey(ScalarKey):
|
||||
"""
|
||||
@@ -134,6 +143,9 @@ class IterKey(ScalarKey):
|
||||
}
|
||||
}
|
||||
|
||||
def cast_value(self, value: Any) -> int:
|
||||
return int(value)
|
||||
|
||||
|
||||
class ISOTimeKey(ScalarKey):
|
||||
"""
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
from .project_bll import ProjectBLL
|
||||
from .project_queries import ProjectQueries
|
||||
from .sub_projects import _ids_with_children as project_ids_with_children
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import itertools
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from functools import reduce
|
||||
from itertools import groupby
|
||||
from operator import itemgetter
|
||||
@@ -57,10 +57,14 @@ class ProjectBLL:
|
||||
with TimingContext("mongo", "move_project"):
|
||||
if source_id == destination_id:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
parent=source_id
|
||||
source=source_id
|
||||
)
|
||||
source = Project.get(company, source_id)
|
||||
destination = Project.get(company, destination_id)
|
||||
if source_id in destination.path:
|
||||
raise errors.bad_request.ProjectCannotBeMergedIntoItsChild(
|
||||
source=source_id, destination=destination_id
|
||||
)
|
||||
|
||||
children = _get_sub_projects(
|
||||
[source.id], _only=("id", "name", "parent", "path")
|
||||
@@ -140,7 +144,14 @@ class ProjectBLL:
|
||||
raise errors.bad_request.ProjectSourceAndDestinationAreTheSame(
|
||||
location=new_parent.name if new_parent else ""
|
||||
)
|
||||
|
||||
if (
|
||||
new_parent
|
||||
and project_id == new_parent.id
|
||||
or project_id in new_parent.path
|
||||
):
|
||||
raise errors.bad_request.ProjectCannotBeMovedUnderItself(
|
||||
project=project_id, parent=new_parent.id
|
||||
)
|
||||
moved = _reposition_project_with_children(
|
||||
project, children=children, parent=new_parent
|
||||
)
|
||||
@@ -295,6 +306,7 @@ class ProjectBLL:
|
||||
return project
|
||||
|
||||
archived_tasks_cond = {"$in": [EntityVisibility.archived.value, "$system_tags"]}
|
||||
visibility_states = [EntityVisibility.archived, EntityVisibility.active]
|
||||
|
||||
@classmethod
|
||||
def make_projects_get_all_pipelines(
|
||||
@@ -356,6 +368,26 @@ class ProjectBLL:
|
||||
},
|
||||
]
|
||||
|
||||
def completed_after_subquery(additional_cond, time_thresh: datetime):
|
||||
return {
|
||||
# the sum of
|
||||
"$sum": {
|
||||
# for each task
|
||||
"$cond": {
|
||||
# if completed after the time_thresh
|
||||
"if": {
|
||||
"$and": [
|
||||
"$completed",
|
||||
{"$gt": ["$completed", time_thresh]},
|
||||
additional_cond,
|
||||
]
|
||||
},
|
||||
"then": 1,
|
||||
"else": 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
def runtime_subquery(additional_cond):
|
||||
return {
|
||||
# the sum of
|
||||
@@ -386,16 +418,19 @@ class ProjectBLL:
|
||||
}
|
||||
|
||||
group_step = {"_id": "$project"}
|
||||
|
||||
for state in EntityVisibility:
|
||||
time_thresh = datetime.utcnow() - timedelta(hours=24)
|
||||
for state in cls.visibility_states:
|
||||
if specific_state and state != specific_state:
|
||||
continue
|
||||
if state == EntityVisibility.active:
|
||||
group_step[state.value] = runtime_subquery(
|
||||
{"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
elif state == EntityVisibility.archived:
|
||||
group_step[state.value] = runtime_subquery(cls.archived_tasks_cond)
|
||||
cond = (
|
||||
cls.archived_tasks_cond
|
||||
if state == EntityVisibility.archived
|
||||
else {"$not": cls.archived_tasks_cond}
|
||||
)
|
||||
group_step[state.value] = runtime_subquery(cond)
|
||||
group_step[f"{state.value}_recently_completed"] = completed_after_subquery(
|
||||
cond, time_thresh=time_thresh
|
||||
)
|
||||
|
||||
runtime_pipeline = [
|
||||
# only count run time for these types of tasks
|
||||
@@ -445,11 +480,16 @@ class ProjectBLL:
|
||||
company: str,
|
||||
project_ids: Sequence[str],
|
||||
specific_state: Optional[EntityVisibility] = None,
|
||||
include_children: bool = True,
|
||||
) -> Tuple[Dict[str, dict], Dict[str, dict]]:
|
||||
if not project_ids:
|
||||
return {}, {}
|
||||
|
||||
child_projects = _get_sub_projects(project_ids, _only=("id", "name"))
|
||||
child_projects = (
|
||||
_get_sub_projects(project_ids, _only=("id", "name"))
|
||||
if include_children
|
||||
else {}
|
||||
)
|
||||
project_ids_with_children = set(project_ids) | {
|
||||
c.id for c in itertools.chain.from_iterable(child_projects.values())
|
||||
}
|
||||
@@ -483,8 +523,8 @@ class ProjectBLL:
|
||||
) -> Dict[str, dict]:
|
||||
return {
|
||||
section: {
|
||||
status: nested_get(a, (section, status), 0)
|
||||
+ nested_get(b, (section, status), 0)
|
||||
status: nested_get(a, (section, status), default=0)
|
||||
+ nested_get(b, (section, status), default=0)
|
||||
for status in set(a.get(section, {})) | set(b.get(section, {}))
|
||||
}
|
||||
for section in set(a) | set(b)
|
||||
@@ -518,15 +558,24 @@ class ProjectBLL:
|
||||
)
|
||||
|
||||
def get_status_counts(project_id, section):
|
||||
project_runtime = runtime.get(project_id, {})
|
||||
project_section_statuses = nested_get(
|
||||
status_count, (project_id, section), default=default_counts
|
||||
)
|
||||
return {
|
||||
"total_runtime": nested_get(runtime, (project_id, section), 0),
|
||||
"status_count": nested_get(
|
||||
status_count, (project_id, section), default_counts
|
||||
"status_count": project_section_statuses,
|
||||
"running_tasks": project_section_statuses.get(TaskStatus.in_progress),
|
||||
"total_tasks": sum(project_section_statuses.values()),
|
||||
"total_runtime": project_runtime.get(section, 0),
|
||||
"completed_tasks": project_runtime.get(
|
||||
f"{section}_recently_completed", 0
|
||||
),
|
||||
}
|
||||
|
||||
report_for_states = [
|
||||
s for s in EntityVisibility if not specific_state or specific_state == s
|
||||
s
|
||||
for s in cls.visibility_states
|
||||
if not specific_state or specific_state == s
|
||||
]
|
||||
|
||||
stats = {
|
||||
@@ -554,7 +603,7 @@ class ProjectBLL:
|
||||
user_ids: Optional[Sequence[str]] = None,
|
||||
) -> Set[str]:
|
||||
"""
|
||||
Get the set of user ids that created tasks/models/dataviews in the given projects
|
||||
Get the set of user ids that created tasks/models in the given projects
|
||||
If project_ids is empty then all projects are examined
|
||||
If user_ids are passed then only subset of these users is returned
|
||||
"""
|
||||
@@ -676,8 +725,8 @@ class ProjectBLL:
|
||||
@classmethod
|
||||
def calc_own_contents(cls, company: str, project_ids: Sequence[str]) -> Dict[str, dict]:
|
||||
"""
|
||||
Returns the amount of task/dataviews/models per requested project
|
||||
Use separate aggregation calls on Task/Dataview/Model instead of lookup
|
||||
Returns the amount of task/models per requested project
|
||||
Use separate aggregation calls on Task/Model instead of lookup
|
||||
aggregation on projects in order not to hit memory limits on large tasks
|
||||
"""
|
||||
if not project_ids:
|
||||
|
||||
@@ -30,6 +30,28 @@ class DeleteProjectResult:
|
||||
urls: TaskUrls = None
|
||||
|
||||
|
||||
def validate_project_delete(company: str, project_id: str):
|
||||
project = Project.get_for_writing(
|
||||
company=company, id=project_id, _only=("id", "path")
|
||||
)
|
||||
if not project:
|
||||
raise errors.bad_request.InvalidProjectId(id=project_id)
|
||||
|
||||
project_ids = _ids_with_children([project_id])
|
||||
ret = {}
|
||||
for cls in (Task, Model):
|
||||
ret[f"{cls.__name__.lower()}s"] = cls.objects(
|
||||
project__in=project_ids,
|
||||
).count()
|
||||
for cls in (Task, Model):
|
||||
ret[f"non_archived_{cls.__name__.lower()}s"] = cls.objects(
|
||||
project__in=project_ids,
|
||||
system_tags__nin=[EntityVisibility.archived.value],
|
||||
).count()
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def delete_project(
|
||||
company: str, project_id: str, force: bool, delete_contents: bool
|
||||
) -> Tuple[DeleteProjectResult, Set[str]]:
|
||||
|
||||
241
apiserver/bll/project/project_queries.py
Normal file
241
apiserver/bll/project/project_queries.py
Normal file
@@ -0,0 +1,241 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import (
|
||||
Sequence,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from redis import StrictRedis
|
||||
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.dicts import nested_get
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .sub_projects import _ids_with_children
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class ProjectQueries:
|
||||
def __init__(self, redis=None):
|
||||
self.redis: StrictRedis = redis or redman.connection("apiserver")
|
||||
|
||||
@staticmethod
|
||||
def _get_project_constraint(
|
||||
project_ids: Sequence[str], include_subprojects: bool
|
||||
) -> dict:
|
||||
if include_subprojects:
|
||||
if project_ids is None:
|
||||
return {}
|
||||
project_ids = _ids_with_children(project_ids)
|
||||
|
||||
return {"project": {"$in": project_ids if project_ids is not None else [None]}}
|
||||
|
||||
@staticmethod
|
||||
def _get_company_constraint(company_id: str, allow_public: bool = True) -> dict:
|
||||
if allow_public:
|
||||
return {"company": {"$in": [None, "", company_id]}}
|
||||
|
||||
return {"company": company_id}
|
||||
|
||||
@classmethod
|
||||
def get_aggregated_project_parameters(
|
||||
cls,
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "section"))
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(
|
||||
nested_get(r, ("_id", "name"))
|
||||
),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
company_constraint = self._get_company_constraint(company_id, allow_public)
|
||||
project_constraint = self._get_project_constraint(
|
||||
project_ids, include_subprojects
|
||||
)
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def get_unique_metric_variants(
|
||||
cls, company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**cls._get_company_constraint(company_id),
|
||||
**cls._get_project_constraint(project_ids, include_subprojects),
|
||||
}
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
@@ -126,14 +126,27 @@ class QueueBLL(object):
|
||||
)
|
||||
queue.delete()
|
||||
|
||||
def get_all(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_all(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""Get all the queues according to the query"""
|
||||
with translate_errors_context():
|
||||
return Queue.get_many(
|
||||
company=company_id, parameters=query_dict, query_dict=query_dict
|
||||
company=company_id,
|
||||
parameters=query_dict,
|
||||
query_dict=query_dict,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def get_queue_infos(self, company_id: str, query_dict: dict) -> Sequence[dict]:
|
||||
def get_queue_infos(
|
||||
self,
|
||||
company_id: str,
|
||||
query_dict: dict,
|
||||
ret_params: dict = None,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Get infos on all the company queues, including queue tasks and workers
|
||||
"""
|
||||
@@ -143,6 +156,7 @@ class QueueBLL(object):
|
||||
company=company_id,
|
||||
query_dict=query_dict,
|
||||
override_projection=projection,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
queue_workers = defaultdict(list)
|
||||
|
||||
@@ -49,6 +49,21 @@ class RedisCacheManager(Generic[T]):
|
||||
def _get_redis_key(self, state_id):
|
||||
return f"{self.state_class}/{state_id}"
|
||||
|
||||
def get_or_create_state_core(
|
||||
self,
|
||||
state_id=None,
|
||||
init_state: Callable[[T], None] = _do_nothing,
|
||||
validate_state: Callable[[T], None] = _do_nothing,
|
||||
) -> T:
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
|
||||
return state
|
||||
|
||||
@contextmanager
|
||||
def get_or_create_state(
|
||||
self,
|
||||
@@ -66,12 +81,9 @@ class RedisCacheManager(Generic[T]):
|
||||
:param validate_state: user callback to validate the state if retrieved from cache
|
||||
Should throw an exception if the state is not valid. If not passed then no validation is done
|
||||
"""
|
||||
state = self.get_state(state_id) if state_id else None
|
||||
if state:
|
||||
validate_state(state)
|
||||
else:
|
||||
state = self.state_class(id=database.utils.id())
|
||||
init_state(state)
|
||||
state = self.get_or_create_state_core(
|
||||
state_id=state_id, init_state=init_state, validate_state=validate_state
|
||||
)
|
||||
|
||||
try:
|
||||
yield state
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import itertools
|
||||
from typing import Sequence, Tuple
|
||||
from typing import Sequence, Tuple, Optional
|
||||
|
||||
import dpath
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.tools import safe_get
|
||||
from apiserver.utilities.dicts import nested_get, nested_delete, nested_set
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
|
||||
|
||||
@@ -14,7 +13,7 @@ hyperparams_legacy_type = "legacy"
|
||||
tf_define_section = "TF_DEFINE"
|
||||
|
||||
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[str, str]:
|
||||
def split_param_name(full_name: str, default_section: str) -> Tuple[Optional[str], str]:
|
||||
"""
|
||||
Return parameter section and name. The section is either TF_DEFINE or the default one
|
||||
"""
|
||||
@@ -62,7 +61,7 @@ def _remove_legacy_params(data: dict, with_sections: bool = False) -> int:
|
||||
return removed
|
||||
|
||||
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]:
|
||||
def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[dict]:
|
||||
"""
|
||||
Remove the legacy params from the data dict and return the number of removed params
|
||||
If the path not found then return 0
|
||||
@@ -71,8 +70,10 @@ def _get_legacy_params(data: dict, with_sections: bool = False) -> Sequence[str]
|
||||
return []
|
||||
|
||||
if with_sections:
|
||||
return itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
return list(
|
||||
itertools.chain.from_iterable(
|
||||
_get_legacy_params(section_data) for section_data in data.values()
|
||||
)
|
||||
)
|
||||
|
||||
return [
|
||||
@@ -86,15 +87,15 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
Escape all the section and param names for hyper params and configuration to make it mongo sage
|
||||
"""
|
||||
for old_params_field, new_params_field, default_section in (
|
||||
("execution/parameters", "hyperparams", hyperparams_default_section),
|
||||
("execution/model_desc", "configuration", None),
|
||||
(("execution", "parameters"), "hyperparams", hyperparams_default_section),
|
||||
(("execution", "model_desc"), "configuration", None),
|
||||
):
|
||||
legacy_params = safe_get(fields, old_params_field)
|
||||
legacy_params = nested_get(fields, old_params_field)
|
||||
if legacy_params is None:
|
||||
continue
|
||||
|
||||
if (
|
||||
not safe_get(fields, new_params_field)
|
||||
not fields.get(new_params_field)
|
||||
and previous_task
|
||||
and previous_task[new_params_field]
|
||||
):
|
||||
@@ -117,11 +118,11 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
new_param = dict(name=name, type=hyperparams_legacy_type, value=str(value))
|
||||
if section is not None:
|
||||
new_param["section"] = section
|
||||
dpath.new(fields, new_path, new_param)
|
||||
dpath.delete(fields, old_params_field)
|
||||
nested_set(fields, new_path, new_param)
|
||||
nested_delete(fields, old_params_field)
|
||||
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
params = fields.get(param_field)
|
||||
if params:
|
||||
escaped_params = {
|
||||
ParameterKeyEscaper.escape(key): {
|
||||
@@ -131,7 +132,7 @@ def params_prepare_for_save(fields: dict, previous_task: Task = None):
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, escaped_params)
|
||||
fields[param_field] = escaped_params
|
||||
|
||||
|
||||
def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
@@ -140,7 +141,7 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
If copy_to_legacy is set then copy hyperparams and configuration data to the legacy location for the old clients
|
||||
"""
|
||||
for param_field in ("hyperparams", "configuration"):
|
||||
params = safe_get(fields, param_field)
|
||||
params = fields.get(param_field)
|
||||
if params:
|
||||
unescaped_params = {
|
||||
ParameterKeyEscaper.unescape(key): {
|
||||
@@ -150,18 +151,18 @@ def params_unprepare_from_saved(fields, copy_to_legacy=False):
|
||||
else value
|
||||
for key, value in params.items()
|
||||
}
|
||||
dpath.set(fields, param_field, unescaped_params)
|
||||
fields[param_field] = unescaped_params
|
||||
|
||||
if copy_to_legacy:
|
||||
for new_params_field, old_params_field, use_sections in (
|
||||
(f"hyperparams", "execution/parameters", True),
|
||||
(f"configuration", "execution/model_desc", False),
|
||||
("hyperparams", ("execution", "parameters"), True),
|
||||
("configuration", ("execution", "model_desc"), False),
|
||||
):
|
||||
legacy_params = _get_legacy_params(
|
||||
safe_get(fields, new_params_field), with_sections=use_sections
|
||||
fields.get(new_params_field), with_sections=use_sections
|
||||
)
|
||||
if legacy_params:
|
||||
dpath.new(
|
||||
nested_set(
|
||||
fields,
|
||||
old_params_field,
|
||||
{_get_full_param_name(p): p["value"] for p in legacy_params},
|
||||
@@ -174,7 +175,7 @@ def _process_path(path: str):
|
||||
Need to unescape and apply a full mongo escaping
|
||||
"""
|
||||
parts = path.split(".")
|
||||
if len(parts) < 2 or len(parts) > 3:
|
||||
if len(parts) < 2 or len(parts) > 4:
|
||||
raise errors.bad_request.ValidationError("invalid task field", path=path)
|
||||
return ".".join(
|
||||
ParameterKeyEscaper.escape(ParameterKeyEscaper.unescape(p)) for p in parts
|
||||
@@ -184,7 +185,7 @@ def _process_path(path: str):
|
||||
def escape_paths(paths: Sequence[str]) -> Sequence[str]:
|
||||
for old_prefix, new_prefix in (
|
||||
("execution.parameters", f"hyperparams.{hyperparams_default_section}"),
|
||||
("execution.model_desc", f"configuration"),
|
||||
("execution.model_desc", "configuration"),
|
||||
("execution.docker_cmd", "container")
|
||||
):
|
||||
path: str
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import json
|
||||
from collections import OrderedDict
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import Collection, Sequence, Tuple, Any, Optional, Dict
|
||||
|
||||
import dpath
|
||||
import six
|
||||
from mongoengine import Q
|
||||
from redis import StrictRedis
|
||||
@@ -14,7 +11,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.tasks import TaskInputModel
|
||||
from apiserver.bll.queue import QueueBLL
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL, project_ids_with_children
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.model import Model
|
||||
@@ -39,7 +36,6 @@ from apiserver.redis_manager import redman
|
||||
from apiserver.service_repo import APICall
|
||||
from apiserver.services.utils import validate_tags, escape_dict_field, escape_dict
|
||||
from apiserver.timing_context import TimingContext
|
||||
from apiserver.utilities.parameter_key_escaper import ParameterKeyEscaper
|
||||
from .artifacts import artifacts_prepare_for_save
|
||||
from .param_utils import params_prepare_for_save
|
||||
from .utils import (
|
||||
@@ -350,54 +346,6 @@ class TaskBLL:
|
||||
if validate_models:
|
||||
cls.validate_input_models(task)
|
||||
|
||||
@staticmethod
|
||||
def get_unique_metric_variants(
|
||||
company_id, project_ids: Sequence[str], include_subprojects: bool
|
||||
):
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
pipeline = [
|
||||
{
|
||||
"$match": dict(
|
||||
company={"$in": [None, "", company_id]}, **project_constraint,
|
||||
)
|
||||
},
|
||||
{"$project": {"metrics": {"$objectToArray": "$last_metrics"}}},
|
||||
{"$unwind": "$metrics"},
|
||||
{
|
||||
"$project": {
|
||||
"metric": "$metrics.k",
|
||||
"variants": {"$objectToArray": "$metrics.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$variants"},
|
||||
{
|
||||
"$group": {
|
||||
"_id": {
|
||||
"metric": "$variants.v.metric",
|
||||
"variant": "$variants.v.variant",
|
||||
},
|
||||
"metrics": {
|
||||
"$addToSet": {
|
||||
"metric": "$variants.v.metric",
|
||||
"metric_hash": "$metric",
|
||||
"variant": "$variants.v.variant",
|
||||
"variant_hash": "$variants.k",
|
||||
}
|
||||
},
|
||||
}
|
||||
},
|
||||
{"$sort": OrderedDict({"_id.metric": 1, "_id.variant": 1})},
|
||||
]
|
||||
|
||||
with translate_errors_context():
|
||||
result = Task.aggregate(pipeline)
|
||||
return [r["metrics"][0] for r in result]
|
||||
|
||||
@staticmethod
|
||||
def set_last_update(
|
||||
task_ids: Collection[str],
|
||||
@@ -494,173 +442,6 @@ class TaskBLL:
|
||||
**extra_updates,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids: Sequence[str],
|
||||
include_subprojects: bool,
|
||||
page: int = 0,
|
||||
page_size: int = 500,
|
||||
) -> Tuple[int, int, Sequence[dict]]:
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
page = max(0, page)
|
||||
page_size = max(1, page_size)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
"company": {"$in": [None, "", company_id]},
|
||||
"hyperparams": {"$exists": True, "$gt": {}},
|
||||
**project_constraint,
|
||||
}
|
||||
},
|
||||
{"$project": {"sections": {"$objectToArray": "$hyperparams"}}},
|
||||
{"$unwind": "$sections"},
|
||||
{
|
||||
"$project": {
|
||||
"section": "$sections.k",
|
||||
"names": {"$objectToArray": "$sections.v"},
|
||||
}
|
||||
},
|
||||
{"$unwind": "$names"},
|
||||
{"$group": {"_id": {"section": "$section", "name": "$names.k"}}},
|
||||
{"$sort": OrderedDict({"_id.section": 1, "_id.name": 1})},
|
||||
{"$skip": page * page_size},
|
||||
{"$limit": page_size},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline), None)
|
||||
|
||||
total = 0
|
||||
remaining = 0
|
||||
results = []
|
||||
|
||||
if result:
|
||||
total = int(result.get("total", -1))
|
||||
results = [
|
||||
{
|
||||
"section": ParameterKeyEscaper.unescape(
|
||||
dpath.get(r, "_id/section")
|
||||
),
|
||||
"name": ParameterKeyEscaper.unescape(dpath.get(r, "_id/name")),
|
||||
}
|
||||
for r in result.get("results", [])
|
||||
]
|
||||
remaining = max(0, total - (len(results) + page * page_size))
|
||||
|
||||
return total, remaining, results
|
||||
|
||||
HyperParamValues = Tuple[int, Sequence[str]]
|
||||
|
||||
def _get_cached_hyperparam_values(
|
||||
self, key: str, last_update: datetime
|
||||
) -> Optional[HyperParamValues]:
|
||||
allowed_delta = timedelta(
|
||||
seconds=config.get(
|
||||
"services.tasks.hyperparam_values.cache_allowed_outdate_sec", 60
|
||||
)
|
||||
)
|
||||
try:
|
||||
cached = self.redis.get(key)
|
||||
if not cached:
|
||||
return
|
||||
|
||||
data = json.loads(cached)
|
||||
cached_last_update = datetime.fromtimestamp(data["last_update"])
|
||||
if (last_update - cached_last_update) < allowed_delta:
|
||||
return data["total"], data["values"]
|
||||
except Exception as ex:
|
||||
log.error(f"Error retrieving hyperparam cached values: {str(ex)}")
|
||||
|
||||
def get_hyperparam_distinct_values(
|
||||
self,
|
||||
company_id: str,
|
||||
project_ids: Sequence[str],
|
||||
section: str,
|
||||
name: str,
|
||||
include_subprojects: bool,
|
||||
allow_public: bool = True,
|
||||
) -> HyperParamValues:
|
||||
if allow_public:
|
||||
company_constraint = {"company": {"$in": [None, "", company_id]}}
|
||||
else:
|
||||
company_constraint = {"company": company_id}
|
||||
if project_ids:
|
||||
if include_subprojects:
|
||||
project_ids = project_ids_with_children(project_ids)
|
||||
project_constraint = {"project": {"$in": project_ids}}
|
||||
else:
|
||||
project_constraint = {}
|
||||
|
||||
key_path = f"hyperparams.{ParameterKeyEscaper.escape(section)}.{ParameterKeyEscaper.escape(name)}"
|
||||
last_updated_task = (
|
||||
Task.objects(
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
**{f"{key_path.replace('.', '__')}__exists": True},
|
||||
)
|
||||
.only("last_update")
|
||||
.order_by("-last_update")
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
if not last_updated_task:
|
||||
return 0, []
|
||||
|
||||
redis_key = f"hyperparam_values_{company_id}_{'_'.join(project_ids)}_{section}_{name}_{allow_public}"
|
||||
last_update = last_updated_task.last_update or datetime.utcnow()
|
||||
cached_res = self._get_cached_hyperparam_values(
|
||||
key=redis_key, last_update=last_update
|
||||
)
|
||||
if cached_res:
|
||||
return cached_res
|
||||
|
||||
max_values = config.get("services.tasks.hyperparam_values.max_count", 100)
|
||||
pipeline = [
|
||||
{
|
||||
"$match": {
|
||||
**company_constraint,
|
||||
**project_constraint,
|
||||
key_path: {"$exists": True},
|
||||
}
|
||||
},
|
||||
{"$project": {"value": f"${key_path}.value"}},
|
||||
{"$group": {"_id": "$value"}},
|
||||
{"$sort": {"_id": 1}},
|
||||
{"$limit": max_values},
|
||||
{
|
||||
"$group": {
|
||||
"_id": 1,
|
||||
"total": {"$sum": 1},
|
||||
"results": {"$push": "$$ROOT._id"},
|
||||
}
|
||||
},
|
||||
]
|
||||
|
||||
result = next(Task.aggregate(pipeline, collation=Task._numeric_locale), None)
|
||||
if not result:
|
||||
return 0, []
|
||||
|
||||
total = int(result.get("total", 0))
|
||||
values = result.get("results", [])
|
||||
|
||||
ttl = config.get("services.tasks.hyperparam_values.cache_ttl_sec", 86400)
|
||||
cached = dict(last_update=last_update.timestamp(), total=total, values=values)
|
||||
self.redis.setex(redis_key, ttl, json.dumps(cached))
|
||||
|
||||
return total, values
|
||||
|
||||
@classmethod
|
||||
def dequeue_and_change_status(
|
||||
cls, task: Task, company_id: str, status_message: str, status_reason: str,
|
||||
|
||||
@@ -130,14 +130,14 @@ def collect_debug_image_urls(company: str, task: str) -> Set[str]:
|
||||
if not metrics:
|
||||
return set()
|
||||
|
||||
task_metrics = {task: set(metrics)}
|
||||
task_metrics = {task: {m: [] for m in metrics}}
|
||||
scroll_id = None
|
||||
urls = set()
|
||||
while True:
|
||||
res = event_bll.debug_images_iterator.get_task_events(
|
||||
company_id=company,
|
||||
task_metrics=task_metrics,
|
||||
iter_count=100,
|
||||
iter_count=10,
|
||||
state_id=scroll_id,
|
||||
)
|
||||
if not res.metric_events or not any(
|
||||
|
||||
@@ -109,6 +109,7 @@ def enqueue_task(
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
validate: bool = False,
|
||||
force: bool = False,
|
||||
) -> Tuple[int, dict]:
|
||||
if not queue_id:
|
||||
# try to get default queue
|
||||
@@ -128,6 +129,7 @@ def enqueue_task(
|
||||
status_reason=status_reason,
|
||||
status_message=status_message,
|
||||
allow_same_state_transition=False,
|
||||
force=force,
|
||||
).execute(enqueue_status=task.status)
|
||||
|
||||
try:
|
||||
@@ -160,6 +162,8 @@ def delete_task(
|
||||
force: bool,
|
||||
return_file_urls: bool,
|
||||
delete_output_models: bool,
|
||||
status_message: str,
|
||||
status_reason: str,
|
||||
) -> Tuple[int, Task, CleanupResult]:
|
||||
task = TaskBLL.get_task_with_access(
|
||||
task_id, company_id=company_id, requires_write_access=True
|
||||
@@ -177,6 +181,17 @@ def delete_task(
|
||||
current=task.status,
|
||||
)
|
||||
|
||||
try:
|
||||
TaskBLL.dequeue_and_change_status(
|
||||
task,
|
||||
company_id=company_id,
|
||||
status_message=status_message,
|
||||
status_reason=status_reason,
|
||||
)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
cleanup_res = cleanup_task(
|
||||
task,
|
||||
force=force,
|
||||
@@ -352,6 +367,7 @@ def stop_task(
|
||||
"system_tags",
|
||||
"last_worker",
|
||||
"last_update",
|
||||
"execution.queue",
|
||||
),
|
||||
requires_write_access=True,
|
||||
)
|
||||
@@ -365,7 +381,21 @@ def stop_task(
|
||||
and (datetime.utcnow() - t.last_update).total_seconds() < update_timeout
|
||||
)
|
||||
|
||||
if TaskSystemTags.development in task.system_tags or not is_run_by_worker(task):
|
||||
is_queued = task.status == TaskStatus.queued
|
||||
set_stopped = (
|
||||
is_queued
|
||||
or TaskSystemTags.development in task.system_tags
|
||||
or not is_run_by_worker(task)
|
||||
)
|
||||
|
||||
if set_stopped:
|
||||
if is_queued:
|
||||
try:
|
||||
TaskBLL.dequeue(task, company_id=company_id, silent_fail=True)
|
||||
except APIError:
|
||||
# dequeue may fail if the task was not enqueued
|
||||
pass
|
||||
|
||||
new_status = TaskStatus.stopped
|
||||
status_message = f"Stopped by {user_name}"
|
||||
else:
|
||||
|
||||
@@ -258,7 +258,7 @@ class WorkerBLL:
|
||||
tasks_info = {
|
||||
task.id: task
|
||||
for task in Task.objects(id__in=task_ids).only(
|
||||
"name", "started", "last_iteration"
|
||||
"name", "started", "last_iteration", "active_duration"
|
||||
)
|
||||
}
|
||||
|
||||
@@ -283,11 +283,7 @@ class WorkerBLL:
|
||||
if helper.task_id:
|
||||
task = tasks_info.get(helper.task_id, None)
|
||||
if task:
|
||||
worker.task.running_time = (
|
||||
int((datetime.utcnow() - task.started).total_seconds() * 1000)
|
||||
if task.started
|
||||
else 0
|
||||
)
|
||||
worker.task.running_time = (task.active_duration or 0) * 1000
|
||||
worker.task.last_iteration = task.last_iteration
|
||||
|
||||
update_queue_entries(worker.queue)
|
||||
|
||||
@@ -79,6 +79,8 @@ class BasicConfig:
|
||||
def logger(self, name: str) -> logging.Logger:
|
||||
if Path(name).is_file():
|
||||
name = Path(name).stem
|
||||
if name == "__init__" and Path(name).parent.stem:
|
||||
name = Path(name).parent.stem
|
||||
path = ".".join((self.prefix, name))
|
||||
return logging.getLogger(path)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
debug: false # Debug mode
|
||||
pretty_json: false # prettify json response
|
||||
return_stack: true # return stack trace on error
|
||||
log_calls: true # Log API Calls
|
||||
return_stack_to_caller: true # top-level control on whether to return stack trace in an API response
|
||||
|
||||
# if 'return_stack' is true and error contains a status code, return stack trace only for these status codes
|
||||
# valid values are:
|
||||
@@ -79,6 +79,11 @@
|
||||
max_age: 99999999999
|
||||
}
|
||||
|
||||
# provide a cookie domain override per company
|
||||
# cookies_domain_override {
|
||||
# <company-id>: <domain>
|
||||
# }
|
||||
|
||||
# # A list of fixed users
|
||||
# # Note: password may be bcrypt-hashed (generate using `python -c 'import bcrypt; print(bcrypt.hashpw("password", bcrypt.gensalt()))'`)
|
||||
# fixed_users {
|
||||
|
||||
4
apiserver/config/default/services/_mongo.conf
Normal file
4
apiserver/config/default/services/_mongo.conf
Normal file
@@ -0,0 +1,4 @@
|
||||
max_page_size: 500
|
||||
|
||||
# expiration time in seconds for the redis scroll states in get_many family of apis
|
||||
scroll_state_expiration_seconds: 600
|
||||
@@ -17,6 +17,10 @@ events_retrieval {
|
||||
|
||||
# the max amount of variants to aggregate on
|
||||
max_variants_count: 100
|
||||
|
||||
max_raw_scalars_size: 200000
|
||||
|
||||
scroll_id_key: "cTN5VEtWEC6QrHvUl0FTx9kNyO0CcCK1p57akxma"
|
||||
}
|
||||
|
||||
# if set then plot str will be checked for the valid json on plot add
|
||||
|
||||
@@ -28,6 +28,8 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"MONGODB_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_CONNECTION_STRING_ENV_KEY = "CLEARML_MONGODB_SERVICE_CONNECTION_STRING"
|
||||
|
||||
|
||||
class DatabaseEntry(models.Base):
|
||||
host = StringField(required=True)
|
||||
@@ -47,13 +49,17 @@ class DatabaseFactory:
|
||||
missing = []
|
||||
log.info("Initializing database connections")
|
||||
|
||||
override_connection_string = getenv(OVERRIDE_CONNECTION_STRING_ENV_KEY)
|
||||
override_hostname = first(map(getenv, OVERRIDE_HOST_ENV_KEY), None)
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
|
||||
override_port = first(map(getenv, OVERRIDE_PORT_ENV_KEY), None)
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
if override_connection_string:
|
||||
log.info(f"Using override mongodb connection string {override_connection_string}")
|
||||
else:
|
||||
if override_hostname:
|
||||
log.info(f"Using override mongodb host {override_hostname}")
|
||||
if override_port:
|
||||
log.info(f"Using override mongodb port {override_port}")
|
||||
|
||||
for key, alias in get_items(Database).items():
|
||||
if key not in db_entries:
|
||||
@@ -62,11 +68,13 @@ class DatabaseFactory:
|
||||
|
||||
entry = cls._create_db_entry(alias=alias, settings=db_entries.get(key))
|
||||
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
if override_connection_string:
|
||||
entry.host = override_connection_string
|
||||
else:
|
||||
if override_hostname:
|
||||
entry.host = furl(entry.host).set(host=override_hostname).url
|
||||
if override_port:
|
||||
entry.host = furl(entry.host).set(port=override_port).url
|
||||
|
||||
try:
|
||||
entry.validate()
|
||||
|
||||
@@ -48,6 +48,7 @@ class Credentials(EmbeddedDocument):
|
||||
meta = {"strict": False}
|
||||
key = StringField(required=True)
|
||||
secret = StringField(required=True)
|
||||
label = StringField()
|
||||
last_used = DateTimeField()
|
||||
|
||||
|
||||
|
||||
@@ -1,26 +1,41 @@
|
||||
import re
|
||||
from collections import namedtuple
|
||||
from functools import reduce
|
||||
from typing import Collection, Sequence, Union, Optional, Type, Tuple, Mapping, Any
|
||||
from functools import reduce, partial
|
||||
from typing import (
|
||||
Collection,
|
||||
Sequence,
|
||||
Union,
|
||||
Optional,
|
||||
Type,
|
||||
Tuple,
|
||||
Mapping,
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
)
|
||||
|
||||
from boltons.iterutils import first, bucketize, partition
|
||||
from boltons.iterutils import first, partition
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
from mongoengine import Q, Document, ListField, StringField
|
||||
from mongoengine import Q, Document, ListField, StringField, IntField
|
||||
from pymongo.command_cursor import CommandCursor
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.bll.redis_cache_manager import RedisCacheManager
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.database import Database
|
||||
from apiserver.database.errors import MakeGetAllQueryError
|
||||
from apiserver.database.projection import project_dict, ProjectionHelper
|
||||
from apiserver.database.props import PropsMixin
|
||||
from apiserver.database.query import RegexQ, RegexWrapper
|
||||
from apiserver.database.query import RegexQ, RegexWrapper, RegexQCombination
|
||||
from apiserver.database.utils import (
|
||||
get_company_or_none_constraint,
|
||||
get_fields_choices,
|
||||
field_does_not_exist,
|
||||
field_exists,
|
||||
)
|
||||
from apiserver.redis_manager import redman
|
||||
|
||||
log = config.logger("dbmodel")
|
||||
|
||||
@@ -70,6 +85,9 @@ class GetMixin(PropsMixin):
|
||||
_ordering_key = "order_by"
|
||||
_search_text_key = "search_text"
|
||||
|
||||
_start_key = "start"
|
||||
_size_key = "size"
|
||||
|
||||
_multi_field_param_sep = "__"
|
||||
_multi_field_param_prefix = {
|
||||
("_any_", "_or_"): lambda a, b: a | b,
|
||||
@@ -103,45 +121,106 @@ class GetMixin(PropsMixin):
|
||||
|
||||
class ListFieldBucketHelper:
|
||||
op_prefix = "__$"
|
||||
legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_prefix = "-"
|
||||
_legacy_exclude_mongo_op = "nin"
|
||||
|
||||
_default = "in"
|
||||
default_mongo_op = "in"
|
||||
_ops = {
|
||||
# op -> (mongo_op, sticky)
|
||||
"not": ("nin", False),
|
||||
"nop": (default_mongo_op, False),
|
||||
"all": ("all", True),
|
||||
"and": ("all", True),
|
||||
"any": (default_mongo_op, True),
|
||||
"or": (default_mongo_op, True),
|
||||
}
|
||||
_next = _default
|
||||
_sticky = False
|
||||
|
||||
def __init__(self, legacy=False):
|
||||
self._legacy = legacy
|
||||
self._current_op = None
|
||||
self._sticky = False
|
||||
self._support_legacy = legacy
|
||||
self.allow_empty = False
|
||||
|
||||
def key(self, v):
|
||||
def _get_op(self, v: str, translate: bool = False) -> Optional[str]:
|
||||
op = (
|
||||
v[len(self.op_prefix) :] if v and v.startswith(self.op_prefix) else None
|
||||
)
|
||||
if translate:
|
||||
tup = self._ops.get(op, None)
|
||||
return tup[0] if tup else None
|
||||
return op
|
||||
|
||||
def _key(self, v) -> Optional[Union[str, bool]]:
|
||||
if v is None:
|
||||
self._next = self._default
|
||||
return self._default
|
||||
elif self._legacy and v.startswith(self.legacy_exclude_prefix):
|
||||
self._next = self._default
|
||||
return self._ops["not"][0]
|
||||
elif v.startswith(self.op_prefix):
|
||||
self._next, self._sticky = self._ops.get(
|
||||
v[len(self.op_prefix) :], (self._default, self._sticky)
|
||||
)
|
||||
self.allow_empty = True
|
||||
return None
|
||||
|
||||
next_ = self._next
|
||||
if not self._sticky:
|
||||
self._next = self._default
|
||||
return next_
|
||||
op = self._get_op(v)
|
||||
if op is not None:
|
||||
# operator - set state and return None
|
||||
self._current_op, self._sticky = self._ops.get(
|
||||
op, (self.default_mongo_op, self._sticky)
|
||||
)
|
||||
return None
|
||||
elif self._current_op:
|
||||
current_op = self._current_op
|
||||
if not self._sticky:
|
||||
self._current_op = None
|
||||
return current_op
|
||||
elif self._support_legacy and v.startswith(self._legacy_exclude_prefix):
|
||||
self._current_op = None
|
||||
return False
|
||||
|
||||
def value_transform(self, v):
|
||||
if self._legacy and v and v.startswith(self.legacy_exclude_prefix):
|
||||
return v[len(self.legacy_exclude_prefix) :]
|
||||
return v
|
||||
return self.default_mongo_op
|
||||
|
||||
def get_global_op(self, data: Sequence[str]) -> int:
|
||||
op_to_res = {
|
||||
"in": Q.OR,
|
||||
"all": Q.AND,
|
||||
}
|
||||
data = (x for x in data if x is not None)
|
||||
first_op = (
|
||||
self._get_op(next(data, ""), translate=True) or self.default_mongo_op
|
||||
)
|
||||
return op_to_res.get(first_op, self.default_mongo_op)
|
||||
|
||||
def get_actions(self, data: Sequence[str]) -> Dict[str, List[Union[str, None]]]:
|
||||
actions = {}
|
||||
|
||||
for val in data:
|
||||
key = self._key(val)
|
||||
if key is None:
|
||||
continue
|
||||
elif self._support_legacy and key is False:
|
||||
key = self._legacy_exclude_mongo_op
|
||||
val = val[len(self._legacy_exclude_prefix) :]
|
||||
actions.setdefault(key, []).append(val)
|
||||
|
||||
return actions
|
||||
|
||||
get_all_query_options = QueryParameterOptions()
|
||||
|
||||
class GetManyScrollState(ProperDictMixin, Document):
|
||||
meta = {"db_alias": Database.backend, "strict": False}
|
||||
|
||||
id = StringField(primary_key=True)
|
||||
position = IntField(default=0)
|
||||
|
||||
_cache_manager = None
|
||||
|
||||
@classmethod
|
||||
def get_cache_manager(cls):
|
||||
if not cls._cache_manager:
|
||||
cls._cache_manager = RedisCacheManager(
|
||||
state_class=cls.GetManyScrollState,
|
||||
redis=redman.connection("apiserver"),
|
||||
expiration_interval=config.get(
|
||||
"services._mongo.scroll_state_expiration_seconds", 600
|
||||
),
|
||||
)
|
||||
|
||||
return cls._cache_manager
|
||||
|
||||
@classmethod
|
||||
def get(
|
||||
cls: Union["GetMixin", Document],
|
||||
@@ -240,7 +319,9 @@ class GetMixin(PropsMixin):
|
||||
Prepare a query object based on the provided query dictionary and various fields.
|
||||
|
||||
NOTE: BE VERY CAREFUL WITH THIS CALL, as it allows creating queries that span across companies.
|
||||
|
||||
IMPLEMENTATION NOTE: Make sure that inside this function or the functions it depends on RegexQ is always
|
||||
used instead of Q. Otherwise we can and up with some combination that is not processed according to
|
||||
RegexQ rules
|
||||
:param parameters_options: Specifies options for parsing the parameters (see ParametersOptions)
|
||||
:param parameters: Query dictionary (relevant keys are these specified by the various field names parameters).
|
||||
Supported parameters:
|
||||
@@ -273,10 +354,13 @@ class GetMixin(PropsMixin):
|
||||
).items():
|
||||
query &= cls.get_range_field_query(field, data)
|
||||
|
||||
for field in opts.fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
if data is not None:
|
||||
dict_query[field] = data
|
||||
for field, data in cls._pop_matching_params(
|
||||
patterns=opts.fields or [], parameters=parameters
|
||||
).items():
|
||||
if "._" in field or "_." in field:
|
||||
query &= RegexQ(__raw__={field: data})
|
||||
else:
|
||||
dict_query[field.replace(".", "__")] = data
|
||||
|
||||
for field in opts.datetime_fields or []:
|
||||
data = parameters.pop(field, None)
|
||||
@@ -308,22 +392,31 @@ class GetMixin(PropsMixin):
|
||||
break
|
||||
if any("._" in f for f in data.fields):
|
||||
q = reduce(
|
||||
lambda a, x: func(a, Q(__raw__={x: {"$regex": data.pattern, "$options": "i"}})),
|
||||
lambda a, x: func(
|
||||
a,
|
||||
RegexQ(
|
||||
__raw__={
|
||||
x: {"$regex": data.pattern, "$options": "i"}
|
||||
}
|
||||
),
|
||||
),
|
||||
data.fields,
|
||||
Q()
|
||||
RegexQ(),
|
||||
)
|
||||
else:
|
||||
regex = RegexWrapper(data.pattern, flags=re.IGNORECASE)
|
||||
sep_fields = [f.replace(".", "__") for f in data.fields]
|
||||
q = reduce(
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})), sep_fields, RegexQ()
|
||||
lambda a, x: func(a, RegexQ(**{x: regex})),
|
||||
sep_fields,
|
||||
RegexQ(),
|
||||
)
|
||||
query = query & q
|
||||
|
||||
return query & RegexQ(**dict_query)
|
||||
|
||||
@classmethod
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_range_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Return a range query for the provided field. The data should contain min and max values
|
||||
Both intervals are included. For open range queries either min or max can be None
|
||||
@@ -347,14 +440,14 @@ class GetMixin(PropsMixin):
|
||||
if max_val is not None:
|
||||
query[f"{mongoengine_field}__lte"] = max_val
|
||||
|
||||
q = Q(**query)
|
||||
q = RegexQ(**query)
|
||||
if min_val is None:
|
||||
q |= Q(**{mongoengine_field: None})
|
||||
q |= RegexQ(**{mongoengine_field: None})
|
||||
|
||||
return q
|
||||
|
||||
@classmethod
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> Q:
|
||||
def get_list_field_query(cls, field: str, data: Sequence[Optional[str]]) -> RegexQ:
|
||||
"""
|
||||
Get a proper mongoengine Q object that represents an "or" query for the provided values
|
||||
with respect to the given list field, with support for "none of empty" in case a None value
|
||||
@@ -366,30 +459,31 @@ class GetMixin(PropsMixin):
|
||||
"""
|
||||
if not isinstance(data, (list, tuple)):
|
||||
data = [data]
|
||||
# raise MakeGetAllQueryError("expected list", field)
|
||||
|
||||
# TODO: backwards compatibility only for older API versions
|
||||
helper = cls.ListFieldBucketHelper(legacy=True)
|
||||
actions = bucketize(
|
||||
data, key=helper.key, value_transform=helper.value_transform
|
||||
)
|
||||
global_op = helper.get_global_op(data)
|
||||
actions = helper.get_actions(data)
|
||||
|
||||
allow_empty = None in actions.get("in", {})
|
||||
mongoengine_field = field.replace(".", "__")
|
||||
|
||||
q = RegexQ()
|
||||
for action in filter(None, actions):
|
||||
q &= RegexQ(
|
||||
**{f"{mongoengine_field}__{action}": list(set(actions[action]))}
|
||||
)
|
||||
queries = [
|
||||
RegexQ(**{f"{mongoengine_field}__{action}": list(set(actions[action]))})
|
||||
for action in filter(None, actions)
|
||||
]
|
||||
|
||||
if not allow_empty:
|
||||
if not queries:
|
||||
q = RegexQ()
|
||||
else:
|
||||
q = RegexQCombination(operation=global_op, children=queries)
|
||||
|
||||
if not helper.allow_empty:
|
||||
return q
|
||||
|
||||
return (
|
||||
q
|
||||
| Q(**{f"{mongoengine_field}__exists": False})
|
||||
| Q(**{mongoengine_field: []})
|
||||
| RegexQ(**{f"{mongoengine_field}__exists": False})
|
||||
| RegexQ(**{mongoengine_field: []})
|
||||
| RegexQ(**{mongoengine_field: None})
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -417,27 +511,41 @@ class GetMixin(PropsMixin):
|
||||
return order_by
|
||||
|
||||
@classmethod
|
||||
def validate_paging(
|
||||
cls, parameters=None, default_page=None, default_page_size=None
|
||||
):
|
||||
""" Validate and extract paging info from from the provided dictionary. Supports default values. """
|
||||
if parameters is None:
|
||||
parameters = {}
|
||||
default_page = parameters.get("page", default_page)
|
||||
if default_page is None:
|
||||
return None, None
|
||||
default_page_size = parameters.get("page_size", default_page_size)
|
||||
if not default_page_size:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
elif default_page < 0:
|
||||
def validate_paging(cls, parameters=None, default_page=0, default_page_size=None):
|
||||
"""
|
||||
Validate and extract paging info from from the provided dictionary. Supports default values.
|
||||
If page is specified then it should be non-negative, if page size is specified then it should be positive
|
||||
If page size is specified and page is not then 0 page is assumed
|
||||
If page is specified then page size should be specified too
|
||||
"""
|
||||
parameters = parameters or {}
|
||||
|
||||
start = parameters.get(cls._start_key)
|
||||
if start is not None:
|
||||
return start, cls.validate_scroll_size(parameters)
|
||||
|
||||
max_page_size = config.get("services._mongo.max_page_size", 500)
|
||||
page = parameters.get("page", default_page)
|
||||
if page is not None and page < 0:
|
||||
raise errors.bad_request.ValidationError("page must be >=0", field="page")
|
||||
elif default_page_size < 1:
|
||||
|
||||
page_size = parameters.get("page_size", default_page_size or max_page_size)
|
||||
if page_size is not None and page_size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"page_size must be >0", field="page_size"
|
||||
)
|
||||
return default_page, default_page_size
|
||||
|
||||
if page_size is not None:
|
||||
page = page or 0
|
||||
page_size = min(page_size, max_page_size)
|
||||
return page * page_size, page_size
|
||||
|
||||
if page is not None:
|
||||
raise errors.bad_request.MissingRequiredFields(
|
||||
"page_size is required when page is requested", field="page_size"
|
||||
)
|
||||
|
||||
return None, None
|
||||
|
||||
@classmethod
|
||||
def get_projection(cls, parameters, override_projection=None, **__):
|
||||
@@ -481,6 +589,57 @@ class GetMixin(PropsMixin):
|
||||
def set_default_ordering(cls, parameters: dict, value: Sequence[str]) -> None:
|
||||
cls.set_ordering(parameters, cls.get_ordering(parameters) or value)
|
||||
|
||||
@classmethod
|
||||
def validate_scroll_size(cls, query_dict: dict) -> int:
|
||||
size = query_dict.get(cls._size_key)
|
||||
if not size or not isinstance(size, int) or size < 1:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"Integer size parameter greater than 1 should be provided when working with scroll"
|
||||
)
|
||||
return size
|
||||
|
||||
@classmethod
|
||||
def get_data_with_scroll_and_filter_support(
|
||||
cls,
|
||||
query_dict: dict,
|
||||
data_getter: Callable[[], Sequence[dict]],
|
||||
ret_params: dict,
|
||||
) -> Sequence[dict]:
|
||||
"""
|
||||
Retrieves the data by calling the provided data_getter api
|
||||
If scroll parameters are specified then put the query_dict 'start' parameter to the last
|
||||
scroll position and continue retrievals from that position
|
||||
If refresh_scroll is requested then bring once more the data from the beginning
|
||||
till the current scroll position
|
||||
In the end the scroll position is updated and accumulated frames are returned
|
||||
"""
|
||||
query_dict = query_dict or {}
|
||||
state: Optional[cls.GetManyScrollState] = None
|
||||
if "scroll_id" in query_dict:
|
||||
size = cls.validate_scroll_size(query_dict)
|
||||
state = cls.get_cache_manager().get_or_create_state_core(
|
||||
query_dict.get("scroll_id")
|
||||
)
|
||||
if query_dict.get("refresh_scroll"):
|
||||
query_dict[cls._size_key] = max(state.position, size)
|
||||
state.position = 0
|
||||
query_dict[cls._start_key] = state.position
|
||||
|
||||
data = data_getter()
|
||||
if cls._start_key in query_dict:
|
||||
query_dict[cls._start_key] = query_dict[cls._start_key] + len(data)
|
||||
|
||||
def update_state(returned_len: int):
|
||||
if not state:
|
||||
return
|
||||
state.position = query_dict[cls._start_key]
|
||||
cls.get_cache_manager().set_state(state)
|
||||
if ret_params is not None:
|
||||
ret_params["scroll_id"] = state.id
|
||||
|
||||
update_state(len(data))
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def get_many_with_join(
|
||||
cls,
|
||||
@@ -491,6 +650,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection=None,
|
||||
expand_reference_ids=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query with support for joining referenced documents according to the
|
||||
@@ -526,6 +686,7 @@ class GetMixin(PropsMixin):
|
||||
query=query,
|
||||
query_options=query_options,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
def projection_func(doc_type, projection, ids):
|
||||
@@ -556,6 +717,7 @@ class GetMixin(PropsMixin):
|
||||
allow_public=False,
|
||||
override_projection: Collection[str] = None,
|
||||
return_dicts=True,
|
||||
ret_params: dict = None,
|
||||
):
|
||||
"""
|
||||
Fetch all documents matching a provided query. Supported several built-in options
|
||||
@@ -601,12 +763,16 @@ class GetMixin(PropsMixin):
|
||||
_query = (q & query) if query else q
|
||||
|
||||
if return_dicts:
|
||||
return cls._get_many_override_none_ordering(
|
||||
data_getter = partial(
|
||||
cls._get_many_override_none_ordering,
|
||||
query=_query,
|
||||
parameters=parameters,
|
||||
override_projection=override_projection,
|
||||
override_collation=override_collation,
|
||||
)
|
||||
return cls.get_data_with_scroll_and_filter_support(
|
||||
query_dict=query_dict, data_getter=data_getter, ret_params=ret_params,
|
||||
)
|
||||
|
||||
return cls._get_many_no_company(
|
||||
query=_query,
|
||||
@@ -658,7 +824,7 @@ class GetMixin(PropsMixin):
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
if order_by and not override_collation:
|
||||
override_collation = cls._get_collation_override(order_by[0])
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -679,9 +845,9 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
qs = qs.exclude(*exclude)
|
||||
|
||||
if page is not None and page_size:
|
||||
if start is not None and size:
|
||||
# add paging
|
||||
qs = qs.skip(page * page_size).limit(page_size)
|
||||
qs = qs.skip(start).limit(size)
|
||||
|
||||
return qs
|
||||
|
||||
@@ -742,7 +908,10 @@ class GetMixin(PropsMixin):
|
||||
parameters = parameters or {}
|
||||
search_text = parameters.get(cls._search_text_key)
|
||||
order_by = cls.validate_order_by(parameters=parameters, search_text=search_text)
|
||||
page, page_size = cls.validate_paging(parameters=parameters)
|
||||
start, size = cls.validate_paging(parameters=parameters)
|
||||
if size is not None and size <= 0:
|
||||
return []
|
||||
|
||||
include, exclude = cls.split_projection(
|
||||
cls.get_projection(parameters, override_projection)
|
||||
)
|
||||
@@ -774,25 +943,28 @@ class GetMixin(PropsMixin):
|
||||
if exclude:
|
||||
query_sets = [qs.exclude(*exclude) for qs in query_sets]
|
||||
|
||||
if page is None or not page_size:
|
||||
if start is None or not size:
|
||||
return [obj.to_proper_dict(only=include) for qs in query_sets for obj in qs]
|
||||
|
||||
# add paging
|
||||
ret = []
|
||||
start = page * page_size
|
||||
for qs in query_sets:
|
||||
qs_size = qs.count()
|
||||
if qs_size < start:
|
||||
start -= qs_size
|
||||
continue
|
||||
last_set = len(query_sets) - 1
|
||||
for i, qs in enumerate(query_sets):
|
||||
last_size = len(ret)
|
||||
ret.extend(
|
||||
obj.to_proper_dict(only=include)
|
||||
for obj in qs.skip(start).limit(page_size)
|
||||
for obj in (qs.skip(start) if start else qs).limit(size)
|
||||
)
|
||||
if len(ret) >= page_size:
|
||||
added = len(ret) - last_size
|
||||
|
||||
if added > 0:
|
||||
start = 0
|
||||
size = max(0, size - added)
|
||||
elif i != last_set:
|
||||
start -= min(start, qs.count())
|
||||
|
||||
if size <= 0:
|
||||
break
|
||||
start = 0
|
||||
page_size -= len(ret)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from typing import Sequence
|
||||
|
||||
from mongoengine import (
|
||||
Document,
|
||||
StringField,
|
||||
DateTimeField,
|
||||
BooleanField,
|
||||
@@ -14,17 +13,15 @@ from apiserver.database.fields import (
|
||||
SafeDictField,
|
||||
SafeSortedListField,
|
||||
)
|
||||
from apiserver.database.model import DbModelMixin
|
||||
from apiserver.database.model import AttributedDocument
|
||||
from apiserver.database.model.base import GetMixin
|
||||
from apiserver.database.model.metadata import MetadataItem
|
||||
from apiserver.database.model.model_labels import ModelLabels
|
||||
from apiserver.database.model.company import Company
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.model.task.task import Task
|
||||
from apiserver.database.model.user import User
|
||||
|
||||
|
||||
class Model(DbModelMixin, Document):
|
||||
class Model(AttributedDocument):
|
||||
meta = {
|
||||
"db_alias": Database.backend,
|
||||
"strict": strict,
|
||||
@@ -73,8 +70,6 @@ class Model(DbModelMixin, Document):
|
||||
id = StringField(primary_key=True)
|
||||
name = StrippedStringField(user_set_allowed=True, min_length=3)
|
||||
parent = StringField(reference_field="Model", required=False)
|
||||
user = StringField(required=True, reference_field=User)
|
||||
company = StringField(required=True, reference_field=Company)
|
||||
project = StringField(reference_field=Project, user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
task = StringField(reference_field=Task)
|
||||
|
||||
@@ -11,6 +11,7 @@ class Project(AttributedDocument):
|
||||
get_all_query_options = GetMixin.QueryParameterOptions(
|
||||
pattern_fields=("name", "description"),
|
||||
list_fields=("tags", "system_tags", "id", "parent", "path"),
|
||||
range_fields=("last_update",),
|
||||
)
|
||||
|
||||
meta = {
|
||||
|
||||
@@ -219,6 +219,7 @@ class Task(AttributedDocument):
|
||||
"status",
|
||||
"project",
|
||||
"parent",
|
||||
"hyperparams.*",
|
||||
),
|
||||
range_fields=("started", "active_duration", "last_metrics.*", "last_iteration"),
|
||||
datetime_fields=("status_changed", "last_update"),
|
||||
@@ -233,7 +234,7 @@ class Task(AttributedDocument):
|
||||
type = StringField(required=True, choices=get_options(TaskType))
|
||||
status = StringField(default=TaskStatus.created, choices=get_options(TaskStatus))
|
||||
status_reason = StringField()
|
||||
status_message = StringField()
|
||||
status_message = StringField(user_set_allowed=True)
|
||||
status_changed = DateTimeField()
|
||||
comment = StringField(user_set_allowed=True)
|
||||
created = DateTimeField(required=True, user_set_allowed=True)
|
||||
|
||||
@@ -4,7 +4,12 @@ from threading import Lock
|
||||
from typing import Sequence
|
||||
|
||||
import six
|
||||
from mongoengine import EmbeddedDocumentField, EmbeddedDocumentListField
|
||||
from mongoengine import (
|
||||
EmbeddedDocumentField,
|
||||
EmbeddedDocumentListField,
|
||||
EmbeddedDocument,
|
||||
Document,
|
||||
)
|
||||
from mongoengine.base import get_document
|
||||
|
||||
from apiserver.database.fields import (
|
||||
@@ -25,6 +30,13 @@ class PropsMixin(object):
|
||||
__cached_dpath_computed_fields_lock = Lock()
|
||||
__cached_dpath_computed_fields = None
|
||||
|
||||
_document_classes = {}
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
if issubclass(cls, (Document, EmbeddedDocument)):
|
||||
PropsMixin._document_classes[cls._class_name] = cls
|
||||
|
||||
@classmethod
|
||||
def get_fields(cls):
|
||||
if cls.__cached_fields is None:
|
||||
@@ -57,8 +69,14 @@ class PropsMixin(object):
|
||||
def resolve_doc(v):
|
||||
if not isinstance(v, six.string_types):
|
||||
return v
|
||||
if v == 'self':
|
||||
|
||||
if v == "self":
|
||||
return cls_.owner_document
|
||||
|
||||
doc_cls = PropsMixin._document_classes.get(v)
|
||||
if doc_cls:
|
||||
return doc_cls
|
||||
|
||||
return get_document(v)
|
||||
|
||||
fields = {k: resolve_doc(v) for k, v in res.items()}
|
||||
@@ -72,7 +90,7 @@ class PropsMixin(object):
|
||||
).document_type
|
||||
fields.update(
|
||||
{
|
||||
'.'.join((field, subfield)): doc
|
||||
".".join((field, subfield)): doc
|
||||
for subfield, doc in PropsMixin._get_fields_with_attr(
|
||||
embedded_doc_cls, attr
|
||||
).items()
|
||||
@@ -80,10 +98,10 @@ class PropsMixin(object):
|
||||
)
|
||||
|
||||
collect_embedded_docs(EmbeddedDocumentField, lambda x: x)
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter('field'))
|
||||
collect_embedded_docs(EmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(LengthRangeEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(UniqueEmbeddedDocumentListField, attrgetter("field"))
|
||||
collect_embedded_docs(EmbeddedDocumentSortedListField, attrgetter("field"))
|
||||
|
||||
return fields
|
||||
|
||||
@@ -94,7 +112,7 @@ class PropsMixin(object):
|
||||
for depth, part in enumerate(parts):
|
||||
if current_cls is None:
|
||||
raise ValueError(
|
||||
'Invalid path (non-document encountered at %s)' % parts[: depth - 1]
|
||||
"Invalid path (non-document encountered at %s)" % parts[: depth - 1]
|
||||
)
|
||||
try:
|
||||
field_name, field = next(
|
||||
@@ -103,7 +121,7 @@ class PropsMixin(object):
|
||||
if k == part
|
||||
)
|
||||
except StopIteration:
|
||||
raise ValueError('Invalid field path %s' % parts[:depth])
|
||||
raise ValueError("Invalid field path %s" % parts[:depth])
|
||||
|
||||
translated_parts.append(part)
|
||||
|
||||
@@ -119,7 +137,7 @@ class PropsMixin(object):
|
||||
),
|
||||
):
|
||||
current_cls = field.field.document_type
|
||||
translated_parts.append('*')
|
||||
translated_parts.append("*")
|
||||
else:
|
||||
current_cls = None
|
||||
|
||||
@@ -128,7 +146,7 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_reference_fields(cls):
|
||||
if cls.__cached_reference_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'reference_field')
|
||||
fields = cls._get_fields_with_attr(cls, "reference_field")
|
||||
cls.__cached_reference_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_reference_fields
|
||||
|
||||
@@ -143,12 +161,12 @@ class PropsMixin(object):
|
||||
@classmethod
|
||||
def get_exclude_fields(cls):
|
||||
if cls.__cached_exclude_fields is None:
|
||||
fields = cls._get_fields_with_attr(cls, 'exclude_by_default')
|
||||
fields = cls._get_fields_with_attr(cls, "exclude_by_default")
|
||||
cls.__cached_exclude_fields = OrderedDict(sorted(fields.items()))
|
||||
return cls.__cached_exclude_fields
|
||||
|
||||
@classmethod
|
||||
def get_dpath_translated_path(cls, path, separator='.'):
|
||||
def get_dpath_translated_path(cls, path, separator="."):
|
||||
if cls.__cached_dpath_computed_fields is None:
|
||||
cls.__cached_dpath_computed_fields = {}
|
||||
if path not in cls.__cached_dpath_computed_fields:
|
||||
|
||||
@@ -5,7 +5,7 @@ Apply elasticsearch mappings to given hosts.
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
@@ -13,7 +13,7 @@ HERE = Path(__file__).resolve().parent
|
||||
|
||||
|
||||
def apply_mappings_to_cluster(
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None
|
||||
hosts: Sequence, key: Optional[str] = None, es_args: dict = None, http_auth: Tuple = None
|
||||
):
|
||||
"""Hosts maybe a sequence of strings or dicts in the form {"host": <host>, "port": <port>}"""
|
||||
|
||||
@@ -30,7 +30,7 @@ def apply_mappings_to_cluster(
|
||||
else:
|
||||
files = p.glob("**/*.json")
|
||||
|
||||
es = Elasticsearch(hosts=hosts, **(es_args or {}))
|
||||
es = Elasticsearch(hosts=hosts, http_auth=http_auth, **(es_args or {}))
|
||||
return [_send_template(f) for f in files]
|
||||
|
||||
|
||||
|
||||
@@ -82,7 +82,11 @@ def check_elastic_empty() -> bool:
|
||||
es_logger.addFilter(log_filter)
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
es = Elasticsearch(hosts=cluster_conf.get("hosts"))
|
||||
es = Elasticsearch(
|
||||
hosts=cluster_conf.get("hosts", None),
|
||||
http_auth=es_factory.get_credentials("events", cluster_conf),
|
||||
**cluster_conf.get("args", {})
|
||||
)
|
||||
return not es.indices.get_template(name="events*")
|
||||
except exceptions.NotFoundError as ex:
|
||||
log.error(ex)
|
||||
@@ -109,5 +113,7 @@ def init_es_data():
|
||||
|
||||
log.info(f"Applying mappings to ES host: {hosts_config}")
|
||||
args = cluster_conf.get("args", {})
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args)
|
||||
http_auth = es_factory.get_credentials(name)
|
||||
|
||||
res = apply_mappings_to_cluster(hosts_config, name, es_args=args, http_auth=http_auth)
|
||||
log.info(res)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "events-*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "queue_metrics_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
{
|
||||
"index_patterns": "worker_stats_*",
|
||||
"settings": {
|
||||
"number_of_shards": 1
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"_source": {
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from os import getenv
|
||||
from typing import Tuple
|
||||
from typing import Tuple, Optional
|
||||
|
||||
from boltons.iterutils import first
|
||||
from elasticsearch import Elasticsearch, Transport
|
||||
from elasticsearch import Elasticsearch
|
||||
|
||||
from apiserver.config_repo import config
|
||||
|
||||
@@ -21,6 +22,10 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"ELASTIC_SERVICE_PORT",
|
||||
)
|
||||
|
||||
OVERRIDE_USERNAME_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_USERNAME",)
|
||||
|
||||
OVERRIDE_PASSWORD_ENV_KEY = ("CLEARML_ELASTIC_SERVICE_PASSWORD",)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
log.info(f"Using override elastic host {OVERRIDE_HOST}")
|
||||
@@ -29,6 +34,14 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override elastic port {OVERRIDE_PORT}")
|
||||
|
||||
OVERRIDE_USERNAME = first(filter(None, map(getenv, OVERRIDE_USERNAME_ENV_KEY)))
|
||||
if OVERRIDE_USERNAME:
|
||||
log.info(f"Using override elastic username {OVERRIDE_USERNAME}")
|
||||
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
if OVERRIDE_PASSWORD:
|
||||
log.info("Using override elastic password ********")
|
||||
|
||||
_instances = {}
|
||||
|
||||
|
||||
@@ -48,6 +61,10 @@ class InvalidClusterConfiguration(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class MissingPasswordForElasticUser(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ESFactory:
|
||||
@classmethod
|
||||
def connect(cls, cluster_name):
|
||||
@@ -65,22 +82,45 @@ class ESFactory:
|
||||
if not hosts:
|
||||
raise InvalidClusterConfiguration(cluster_name)
|
||||
|
||||
http_auth = cls.get_credentials(cluster_name)
|
||||
|
||||
args = cluster_config.get("args", {})
|
||||
_instances[cluster_name] = Elasticsearch(
|
||||
hosts=hosts, transport_class=Transport, **args
|
||||
hosts=hosts, http_auth=http_auth, **args
|
||||
)
|
||||
|
||||
return _instances[cluster_name]
|
||||
|
||||
@classmethod
|
||||
def get_credentials(cls, cluster_name: str, cluster_config: dict = None) -> Optional[Tuple[str, str]]:
|
||||
cluster_config = cluster_config or cls.get_cluster_config(cluster_name)
|
||||
if not cluster_config.get("secure", True):
|
||||
return None
|
||||
|
||||
elastic_user = OVERRIDE_USERNAME or config.get("secure.elastic.user", None)
|
||||
if not elastic_user:
|
||||
return None
|
||||
|
||||
elastic_password = OVERRIDE_PASSWORD or config.get(
|
||||
"secure.elastic.password", None
|
||||
)
|
||||
if not elastic_password:
|
||||
raise MissingPasswordForElasticUser(
|
||||
f"cluster={cluster_name}, username={elastic_user}"
|
||||
)
|
||||
|
||||
return elastic_user, elastic_password
|
||||
|
||||
@classmethod
|
||||
def get_all_cluster_names(cls):
|
||||
return list(config.get("hosts.elastic"))
|
||||
|
||||
@classmethod
|
||||
def get_override(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
def get_override_host(cls, cluster_name: str) -> Tuple[str, str]:
|
||||
return OVERRIDE_HOST, OVERRIDE_PORT
|
||||
|
||||
@classmethod
|
||||
@lru_cache()
|
||||
def get_cluster_config(cls, cluster_name):
|
||||
"""
|
||||
Returns cluster config for the specified cluster path
|
||||
@@ -97,7 +137,7 @@ class ESFactory:
|
||||
for entry in cluster_config.get("hosts", []):
|
||||
entry[key] = value
|
||||
|
||||
host, port = cls.get_override(cluster_name)
|
||||
host, port = cls.get_override_host(cluster_name)
|
||||
|
||||
if host:
|
||||
set_host_prop("host", host)
|
||||
|
||||
@@ -298,8 +298,9 @@ class PrePopulate:
|
||||
if company_id is None:
|
||||
company_id = ""
|
||||
|
||||
# Always use a public user for pre-populated data
|
||||
cls.user_cls(id=user_id, name=user_name, company="").save()
|
||||
existing_user = cls.user_cls.objects(id=user_id).only("id").first()
|
||||
if not existing_user:
|
||||
cls.user_cls(id=user_id, name=user_name, company=company_id).save()
|
||||
|
||||
cls._import(zfile, company_id, user_id, metadata)
|
||||
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import threading
|
||||
from os import getenv
|
||||
from time import sleep
|
||||
|
||||
from boltons.iterutils import first
|
||||
from redis import StrictRedis
|
||||
from redis.sentinel import Sentinel, SentinelConnectionPool
|
||||
from rediscluster import RedisCluster
|
||||
|
||||
from apiserver.apierrors.errors.server_error import ConfigError, GeneralError
|
||||
from apiserver.config_repo import config
|
||||
@@ -21,6 +19,11 @@ OVERRIDE_PORT_ENV_KEY = (
|
||||
"TRAINS_REDIS_SERVICE_PORT",
|
||||
"REDIS_SERVICE_PORT",
|
||||
)
|
||||
OVERRIDE_PASSWORD_ENV_KEY = (
|
||||
"CLEARML_REDIS_SERVICE_PASSWORD",
|
||||
"TRAINS_REDIS_SERVICE_PASSWORD",
|
||||
"REDIS_SERVICE_PASSWORD",
|
||||
)
|
||||
|
||||
OVERRIDE_HOST = first(filter(None, map(getenv, OVERRIDE_HOST_ENV_KEY)))
|
||||
if OVERRIDE_HOST:
|
||||
@@ -30,99 +33,7 @@ OVERRIDE_PORT = first(filter(None, map(getenv, OVERRIDE_PORT_ENV_KEY)))
|
||||
if OVERRIDE_PORT:
|
||||
log.info(f"Using override redis port {OVERRIDE_PORT}")
|
||||
|
||||
|
||||
class MyPubSubWorkerThread(threading.Thread):
|
||||
def __init__(self, sentinel, on_new_master, msg_sleep_time, daemon=True):
|
||||
super(MyPubSubWorkerThread, self).__init__()
|
||||
self.daemon = daemon
|
||||
self.sentinel = sentinel
|
||||
self.on_new_master = on_new_master
|
||||
self.sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.msg_sleep_time = msg_sleep_time
|
||||
self._running = False
|
||||
self.pubsub = None
|
||||
|
||||
def subscribe(self):
|
||||
if self.pubsub:
|
||||
try:
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
self.pubsub = None
|
||||
|
||||
subscriptions = {"+switch-master": self.on_new_master}
|
||||
|
||||
while not self.pubsub or not self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub = self.sentinel.pubsub()
|
||||
self.pubsub.subscribe(**subscriptions)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while subscribing to sentinel at {self.sentinel_host} ({ex.args[0]}) Sleeping and retrying"
|
||||
)
|
||||
sleep(3)
|
||||
log.info(f"Subscribed to sentinel {self.sentinel_host}")
|
||||
|
||||
def run(self):
|
||||
if self._running:
|
||||
return
|
||||
self._running = True
|
||||
|
||||
self.subscribe()
|
||||
|
||||
while self.pubsub.subscribed:
|
||||
try:
|
||||
self.pubsub.get_message(
|
||||
ignore_subscribe_messages=True, timeout=self.msg_sleep_time
|
||||
)
|
||||
except Exception as ex:
|
||||
log.warn(
|
||||
f"Error while getting message from sentinel {self.sentinel_host} ({ex.args[0]}) Resubscribing"
|
||||
)
|
||||
self.subscribe()
|
||||
|
||||
self.pubsub.close()
|
||||
self._running = False
|
||||
|
||||
def stop(self):
|
||||
# stopping simply unsubscribes from all channels and patterns.
|
||||
# the unsubscribe responses that are generated will short circuit
|
||||
# the loop in run(), calling pubsub.close() to clean up the connection
|
||||
self.pubsub.unsubscribe()
|
||||
self.pubsub.punsubscribe()
|
||||
|
||||
|
||||
# todo,future - multi master clusters?
|
||||
class RedisCluster(object):
|
||||
def __init__(self, sentinel_hosts, service_name, **connection_kwargs):
|
||||
self.service_name = service_name
|
||||
self.sentinel = Sentinel(sentinel_hosts, **connection_kwargs)
|
||||
self.master = None
|
||||
self.master_host_port = None
|
||||
self.reconfigure()
|
||||
self.sentinel_threads = {}
|
||||
self.listen()
|
||||
|
||||
def reconfigure(self):
|
||||
try:
|
||||
self.master_host_port = self.sentinel.discover_master(self.service_name)
|
||||
self.master = self.sentinel.master_for(self.service_name)
|
||||
log.info(f"Reconfigured master to {self.master_host_port}")
|
||||
except Exception as ex:
|
||||
log.error(f"Error while reconfiguring. {ex.args[0]}")
|
||||
|
||||
def listen(self):
|
||||
def on_new_master(workerThread):
|
||||
self.reconfigure()
|
||||
|
||||
for sentinel in self.sentinel.sentinels:
|
||||
sentinel_host = sentinel.connection_pool.connection_kwargs["host"]
|
||||
self.sentinel_threads[sentinel_host] = MyPubSubWorkerThread(
|
||||
sentinel, on_new_master, msg_sleep_time=0.001, daemon=True
|
||||
)
|
||||
self.sentinel_threads[sentinel_host].start()
|
||||
OVERRIDE_PASSWORD = first(filter(None, map(getenv, OVERRIDE_PASSWORD_ENV_KEY)))
|
||||
|
||||
|
||||
class RedisManager(object):
|
||||
@@ -131,6 +42,9 @@ class RedisManager(object):
|
||||
for alias, alias_config in redis_config_dict.items():
|
||||
|
||||
alias_config = alias_config.as_plain_ordered_dict()
|
||||
alias_config["password"] = config.get(
|
||||
f"secure.redis.{alias}.password", None
|
||||
)
|
||||
|
||||
is_cluster = alias_config.get("cluster", False)
|
||||
|
||||
@@ -142,34 +56,19 @@ class RedisManager(object):
|
||||
if port:
|
||||
alias_config["port"] = port
|
||||
|
||||
db = alias_config.get("db", 0)
|
||||
password = OVERRIDE_PASSWORD or alias_config.get("password", None)
|
||||
if password:
|
||||
alias_config["password"] = password
|
||||
|
||||
sentinels = alias_config.get("sentinels", None)
|
||||
service_name = alias_config.get("service_name", None)
|
||||
|
||||
if not is_cluster and sentinels:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. mixed regular and cluster mode",
|
||||
alias=alias,
|
||||
)
|
||||
if is_cluster and (not sentinels or not service_name):
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing sentinels or service_name",
|
||||
alias=alias,
|
||||
)
|
||||
if not is_cluster and (not port or not host):
|
||||
if not port or not host:
|
||||
raise ConfigError(
|
||||
"Redis configuration is invalid. missing port or host", alias=alias
|
||||
)
|
||||
|
||||
if is_cluster:
|
||||
# todo support all redis connection args via sentinel's connection_kwargs
|
||||
del alias_config["sentinels"]
|
||||
del alias_config["cluster"]
|
||||
del alias_config["service_name"]
|
||||
self.aliases[alias] = RedisCluster(
|
||||
sentinels, service_name, **alias_config
|
||||
)
|
||||
del alias_config["db"]
|
||||
self.aliases[alias] = RedisCluster(**alias_config)
|
||||
else:
|
||||
self.aliases[alias] = StrictRedis(**alias_config)
|
||||
|
||||
@@ -177,27 +76,21 @@ class RedisManager(object):
|
||||
obj = self.aliases.get(alias)
|
||||
if not obj:
|
||||
raise GeneralError(f"Invalid Redis alias {alias}")
|
||||
if isinstance(obj, RedisCluster):
|
||||
obj.master.get("health")
|
||||
return obj.master
|
||||
else:
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
obj.get("health")
|
||||
return obj
|
||||
|
||||
def host(self, alias):
|
||||
r = self.connection(alias)
|
||||
pool = r.connection_pool
|
||||
if isinstance(pool, SentinelConnectionPool):
|
||||
connections = pool.connection_kwargs[
|
||||
"connection_pool"
|
||||
]._available_connections
|
||||
if isinstance(r, RedisCluster):
|
||||
connections = first(r.connection_pool._available_connections.values())
|
||||
else:
|
||||
connections = pool._available_connections
|
||||
connections = r.connection_pool._available_connections
|
||||
|
||||
if len(connections) > 0:
|
||||
return connections[0].host
|
||||
else:
|
||||
if not connections:
|
||||
return None
|
||||
|
||||
return connections[0].host
|
||||
|
||||
|
||||
redman = RedisManager(config.get("hosts.redis"))
|
||||
|
||||
@@ -3,7 +3,7 @@ bcrypt>=3.1.4
|
||||
boltons>=19.1.0
|
||||
boto3==1.14.13
|
||||
dpath>=1.4.2,<2.0
|
||||
elasticsearch>=7.0.0,<8.0.0
|
||||
elasticsearch==7.13.3
|
||||
fastjsonschema>=2.8
|
||||
flask-compress>=1.4.0
|
||||
flask-cors>=3.0.5
|
||||
@@ -16,18 +16,19 @@ jinja2==2.11.3
|
||||
jsonmodels>=2.3
|
||||
jsonschema>=2.6.0
|
||||
luqum>=0.10.0
|
||||
mongoengine==0.19.1
|
||||
mongoengine==0.23.1
|
||||
nested_dict>=1.61
|
||||
packaging==20.3
|
||||
psutil>=5.6.5
|
||||
pyhocon>=0.3.35
|
||||
pyjwt<2.0.0
|
||||
pymongo==3.10.1
|
||||
pymongo[srv]==3.12.0
|
||||
python-rapidjson>=0.6.3
|
||||
redis>=2.10.5
|
||||
redis==3.5.3
|
||||
redis-py-cluster>=2.1.3
|
||||
related>=0.7.2
|
||||
requests>=2.13.0
|
||||
semantic_version>=2.8.3,<3
|
||||
six
|
||||
tqdm
|
||||
validators>=0.12.4
|
||||
validators>=0.12.4
|
||||
|
||||
@@ -26,6 +26,10 @@ credentials {
|
||||
type: string
|
||||
description: Credentials secret key
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
}
|
||||
}
|
||||
batch_operation {
|
||||
|
||||
@@ -15,6 +15,10 @@ _definitions {
|
||||
type: string
|
||||
description: ""
|
||||
}
|
||||
label {
|
||||
type: string
|
||||
description: Optional credentials label
|
||||
}
|
||||
last_used {
|
||||
type: string
|
||||
description: ""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -199,6 +199,29 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -302,6 +325,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of models to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_frameworks {
|
||||
"2.8" {
|
||||
|
||||
@@ -152,6 +152,11 @@ _definitions {
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
last_update {
|
||||
description: "Last update time"
|
||||
type: string
|
||||
format: "date-time"
|
||||
}
|
||||
tags {
|
||||
type: array
|
||||
description: "User-defined tags"
|
||||
@@ -379,7 +384,7 @@ get_all {
|
||||
items { type: string }
|
||||
}
|
||||
page {
|
||||
description: "Page number, returns a specific page out of the resulting list of dataviews"
|
||||
description: "Page number, returns a specific page out of the resulting list of projects"
|
||||
type: integer
|
||||
minimum: 0
|
||||
}
|
||||
@@ -430,6 +435,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all_ex {
|
||||
internal: true
|
||||
@@ -469,7 +497,7 @@ get_all_ex {
|
||||
default: false
|
||||
}
|
||||
check_own_contents {
|
||||
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks, models and dataviews are counted"
|
||||
description: "If set to 'true' and project ids are passed to the query then for these projects their own tasks and models are counted"
|
||||
type: boolean
|
||||
default: false
|
||||
}
|
||||
@@ -488,6 +516,72 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of projects to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
"2.16": ${get_all_ex."2.15"} {
|
||||
request.properties.stats_with_children {
|
||||
description: "If include_stats flag is set then this flag contols whether the child projects tasks are taken into statistics or not"
|
||||
type: boolean
|
||||
default: true
|
||||
}
|
||||
response {
|
||||
properties {
|
||||
stats {
|
||||
properties {
|
||||
active.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
archived.properties {
|
||||
total_tasks {
|
||||
description: "Number of tasks"
|
||||
type: integer
|
||||
}
|
||||
completed_tasks {
|
||||
description: "Number of tasks completed in the last 24 hours"
|
||||
type: integer
|
||||
}
|
||||
running_tasks {
|
||||
description: "Number of running tasks"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
update {
|
||||
"2.1" {
|
||||
@@ -504,10 +598,6 @@ update {
|
||||
description: "Project name. Unique within the company."
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description. "
|
||||
type: string
|
||||
}
|
||||
description {
|
||||
description: "Project description"
|
||||
type: string
|
||||
@@ -594,7 +684,7 @@ merge {
|
||||
type: object
|
||||
properties {
|
||||
moved_entities {
|
||||
description: "The number of tasks, models and dataviews moved from the merged project into the destination"
|
||||
description: "The number of tasks and models moved from the merged project into the destination"
|
||||
type: integer
|
||||
}
|
||||
moved_projects {
|
||||
@@ -605,6 +695,42 @@ merge {
|
||||
}
|
||||
}
|
||||
}
|
||||
validate_delete {
|
||||
"2.14" {
|
||||
description: "Validates that the project existis and can be deleted"
|
||||
request {
|
||||
type: object
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
}
|
||||
}
|
||||
response {
|
||||
type: object
|
||||
properties {
|
||||
tasks {
|
||||
description: "The total number of tasks under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
non_archived_tasks {
|
||||
description: "The total number of non-archived tasks under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
models {
|
||||
description: "The total number of models under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
non_archived_models {
|
||||
description: "The total number of non-archived models under the project and all its children"
|
||||
type: integer
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
delete {
|
||||
"2.1" {
|
||||
description: "Deletes a project"
|
||||
@@ -613,7 +739,7 @@ delete {
|
||||
required: [ project ]
|
||||
properties {
|
||||
project {
|
||||
description: "Project id"
|
||||
description: "Project ID"
|
||||
type: string
|
||||
}
|
||||
force {
|
||||
@@ -803,7 +929,6 @@ get_hyper_parameters {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
get_task_tags {
|
||||
"2.8" {
|
||||
description: "Get user and system tags used for the tasks under the specified projects"
|
||||
|
||||
@@ -115,6 +115,29 @@ get_by_id {
|
||||
get_all_ex {
|
||||
internal: true
|
||||
"2.4": ${get_all."2.4"}
|
||||
"2.15": ${get_all_ex."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.4" {
|
||||
@@ -178,6 +201,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.4"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of queues to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_default {
|
||||
"2.4" {
|
||||
|
||||
@@ -534,7 +534,7 @@ _definitions {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
models {
|
||||
description: "Task models"
|
||||
@@ -685,6 +685,29 @@ get_all_ex {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all_ex."2.13"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all_ex"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all_ex to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_all {
|
||||
"2.1" {
|
||||
@@ -799,6 +822,29 @@ get_all {
|
||||
}
|
||||
}
|
||||
}
|
||||
"2.15": ${get_all."2.1"} {
|
||||
request {
|
||||
properties {
|
||||
scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID returned from the previos calls to get_all"
|
||||
}
|
||||
refresh_scroll {
|
||||
type: boolean
|
||||
description: "If set then all the data received with this scroll will be requeried"
|
||||
}
|
||||
size {
|
||||
type: integer
|
||||
minimum: 1
|
||||
description: "The number of tasks to retrieve"
|
||||
}
|
||||
}
|
||||
}
|
||||
response.properties.scroll_id {
|
||||
type: string
|
||||
description: "Scroll ID that can be used with the next calls to get_all to retrieve more data"
|
||||
}
|
||||
}
|
||||
}
|
||||
get_types {
|
||||
"2.8" {
|
||||
@@ -935,7 +981,7 @@ clone {
|
||||
new_task_container {
|
||||
description: "The docker container properties for the new task. If not provided then taken from the original task"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1113,7 +1159,7 @@ create {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1202,7 +1248,7 @@ validate {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1364,7 +1410,7 @@ edit {
|
||||
container {
|
||||
description: "Docker container parameters"
|
||||
type: object
|
||||
additionalProperties { type: string }
|
||||
additionalProperties { type: [string, null] }
|
||||
}
|
||||
runtime {
|
||||
description: "Task runtime mapping"
|
||||
|
||||
@@ -4,7 +4,7 @@ from hashlib import md5
|
||||
from flask import Flask
|
||||
from flask_compress import Compress
|
||||
from flask_cors import CORS
|
||||
from semantic_version import Version
|
||||
from packaging.version import Version
|
||||
|
||||
from apiserver.database import db
|
||||
from apiserver.bll.statistics.stats_reporter import StatisticsReporter
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from functools import partial
|
||||
|
||||
from flask import request, Response, redirect
|
||||
from werkzeug.datastructures import ImmutableMultiDict
|
||||
from werkzeug.exceptions import BadRequest
|
||||
|
||||
from apiserver.apierrors import APIError
|
||||
from apiserver.apierrors.base import BaseError
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import ServiceRepo, APICall
|
||||
from apiserver.service_repo.auth import AuthType
|
||||
from apiserver.service_repo.auth import AuthType, Token
|
||||
from apiserver.service_repo.errors import PathParsingError
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities.dicts import nested_set
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
@@ -29,7 +31,7 @@ class RequestHandlers:
|
||||
try:
|
||||
call = self._create_api_call(request)
|
||||
load_data_callback = partial(self._load_call_data, req=request)
|
||||
content, content_type = ServiceRepo.handle_call(
|
||||
content, content_type, company = ServiceRepo.handle_call(
|
||||
call, load_data_callback=load_data_callback
|
||||
)
|
||||
|
||||
@@ -51,20 +53,49 @@ class RequestHandlers:
|
||||
|
||||
if call.result.cookies:
|
||||
for key, value in call.result.cookies.items():
|
||||
kwargs = config.get("apiserver.auth.cookies")
|
||||
kwargs = config.get("apiserver.auth.cookies").copy()
|
||||
if value is None:
|
||||
kwargs = kwargs.copy()
|
||||
# Removing a cookie
|
||||
kwargs["max_age"] = 0
|
||||
kwargs["expires"] = 0
|
||||
response.set_cookie(key, "", **kwargs)
|
||||
else:
|
||||
response.set_cookie(key, value, **kwargs)
|
||||
value = ""
|
||||
elif not company:
|
||||
# Setting a cookie, let's try to figure out the company
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
company = Token.decode_identity(value).company
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if company:
|
||||
try:
|
||||
# use no default value to allow setting a null domain as well
|
||||
kwargs["domain"] = config.get(f"apiserver.auth.cookies_domain_override.{company}")
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
response.set_cookie(key, value, **kwargs)
|
||||
|
||||
return response
|
||||
except Exception as ex:
|
||||
log.exception(f"Failed processing request {request.url}: {ex}")
|
||||
return f"Failed processing request {request.url}", 500
|
||||
|
||||
@staticmethod
|
||||
def _apply_multi_dict(body: dict, md: ImmutableMultiDict):
|
||||
def convert_value(v: str):
|
||||
if v.replace(".", "", 1).isdigit():
|
||||
return float(v) if "." in v else int(v)
|
||||
if v in ("true", "True", "TRUE"):
|
||||
return True
|
||||
if v in ("false", "False", "FALSE"):
|
||||
return False
|
||||
return v
|
||||
|
||||
for k, v in md.lists():
|
||||
v = [convert_value(x) for x in v] if (len(v) > 1 or k.endswith("[]")) else convert_value(v[0])
|
||||
nested_set(body, k.rstrip("[]").split("."), v)
|
||||
|
||||
def _update_call_data(self, call, req):
|
||||
""" Use request payload/form to fill call data or batched data """
|
||||
if req.content_type == "application/json-lines":
|
||||
@@ -82,23 +113,12 @@ class RequestHandlers:
|
||||
req.on_json_loading_failed(msg)
|
||||
call.batched_data = items
|
||||
else:
|
||||
json_body = req.get_json(force=True, silent=False) if req.data else None
|
||||
# merge form and args
|
||||
form = req.form.copy()
|
||||
form.update(req.args)
|
||||
form = form.to_dict()
|
||||
# convert string numbers to floats
|
||||
for key in form:
|
||||
if form[key].replace(".", "", 1).isdigit():
|
||||
if "." in form[key]:
|
||||
form[key] = float(form[key])
|
||||
else:
|
||||
form[key] = int(form[key])
|
||||
elif form[key].lower() == "true":
|
||||
form[key] = True
|
||||
elif form[key].lower() == "false":
|
||||
form[key] = False
|
||||
call.data = json_body or form or {}
|
||||
body = (req.get_json(force=True, silent=False) if req.data else None) or {}
|
||||
if req.args:
|
||||
self._apply_multi_dict(body, req.args)
|
||||
if req.form:
|
||||
self._apply_multi_dict(body, req.form)
|
||||
call.data = body
|
||||
|
||||
def _call_or_empty_with_error(self, call, req, msg, code=500, subcode=0):
|
||||
call = call or APICall(
|
||||
|
||||
@@ -310,6 +310,12 @@ class APICall(DataContainer):
|
||||
_transaction_headers = _get_headers("Trx")
|
||||
""" Transaction ID """
|
||||
|
||||
_redacted_headers = {
|
||||
HEADER_AUTHORIZATION: " ",
|
||||
"Cookie": "=",
|
||||
}
|
||||
""" Headers whose value should be redacted. Maps header name to partition char """
|
||||
|
||||
@property
|
||||
def HEADER_TRANSACTION(self):
|
||||
return self._transaction_headers[0]
|
||||
@@ -584,17 +590,26 @@ class APICall(DataContainer):
|
||||
def json_flags(self):
|
||||
return self._json_flags
|
||||
|
||||
@property
|
||||
def extra_meta_fields(self):
|
||||
return {}
|
||||
|
||||
def mark_end(self):
|
||||
self._end_ts = time.time()
|
||||
self._duration = int((self._end_ts - self._start_ts) * 1000)
|
||||
|
||||
def get_response(self, include_stack: bool = False) -> Tuple[Union[dict, str], str]:
|
||||
def get_response(self, include_stack: bool = None) -> Tuple[Union[dict, str], str]:
|
||||
"""
|
||||
Get the response for this call.
|
||||
:param include_stack: If True, stack trace stored in this call's result should
|
||||
be included in the response (default is False)
|
||||
be included in the response (default follows configuration)
|
||||
:return: Response data (encoded according to self.content_type) and the data's content type
|
||||
"""
|
||||
include_stack = (
|
||||
include_stack
|
||||
if include_stack is not None
|
||||
else config.get("apiserver.return_stack_to_caller", False)
|
||||
)
|
||||
|
||||
def make_version_number(version: PartialVersion) -> Union[None, float, str]:
|
||||
"""
|
||||
@@ -629,6 +644,7 @@ class APICall(DataContainer):
|
||||
"result_msg": self.result.msg,
|
||||
"error_stack": self.result.traceback if include_stack else None,
|
||||
"error_data": self.result.error_data,
|
||||
**self.extra_meta_fields,
|
||||
},
|
||||
"data": self.result.data,
|
||||
}
|
||||
@@ -663,3 +679,15 @@ class APICall(DataContainer):
|
||||
error_data=error_data,
|
||||
cookies=self._result.cookies,
|
||||
)
|
||||
|
||||
def get_redacted_headers(self):
|
||||
headers = self.headers.copy()
|
||||
if not self.requires_authorization or self.auth:
|
||||
# We won't log the authorization header if call shouldn't be authorized, or if it was successfully
|
||||
# authorized. This means we'll only log authorization header for calls that failed to authorize (hopefully
|
||||
# this will allow us to debug authorization errors).
|
||||
for header, sep in self._redacted_headers.items():
|
||||
if header in headers:
|
||||
prefix, _, redact = headers[header].partition(sep)
|
||||
headers[header] = prefix + sep + f"<{len(redact)} bytes redacted>"
|
||||
return headers
|
||||
|
||||
@@ -12,6 +12,9 @@ from .payload import Payload
|
||||
token_secret = config.get('secure.auth.token_secret')
|
||||
|
||||
|
||||
log = config.logger(__file__)
|
||||
|
||||
|
||||
class Token(Payload):
|
||||
default_expiration_sec = config.get('apiserver.auth.default_expiration_sec')
|
||||
|
||||
@@ -94,3 +97,14 @@ class Token(Payload):
|
||||
token.exp = now + timedelta(seconds=expiration_sec)
|
||||
|
||||
return token.encode(**extra_payload)
|
||||
|
||||
@classmethod
|
||||
def decode_identity(cls, encoded_token):
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
from ..auth import Identity
|
||||
|
||||
decoded = cls.decode(encoded_token, verify=False)
|
||||
return Identity.from_dict(decoded.get("identity", {}))
|
||||
except Exception as ex:
|
||||
log.error(f"Failed parsing identity from encoded token: {ex}")
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import random
|
||||
import string
|
||||
|
||||
sys_random = random.SystemRandom()
|
||||
|
||||
|
||||
def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
|
||||
'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'):
|
||||
def get_random_string(
|
||||
length: int = 12, allowed_chars: str = string.ascii_letters + string.digits
|
||||
) -> str:
|
||||
"""
|
||||
Returns a securely generated random string.
|
||||
|
||||
@@ -12,20 +15,20 @@ def get_random_string(length=12, allowed_chars='abcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
Taken from the django.utils.crypto module.
|
||||
"""
|
||||
return ''.join(sys_random.choice(allowed_chars) for _ in range(length))
|
||||
return "".join(sys_random.choice(allowed_chars) for _ in range(length))
|
||||
|
||||
|
||||
def get_client_id(length=20):
|
||||
def get_client_id(length: int = 20) -> str:
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
Taken from the Django project.
|
||||
"""
|
||||
chars = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
|
||||
chars = string.ascii_uppercase + string.digits
|
||||
return get_random_string(length, chars)
|
||||
|
||||
|
||||
def get_secret_key(length=50):
|
||||
def get_secret_key(length: int = 50) -> str:
|
||||
"""
|
||||
Create a random secret key.
|
||||
|
||||
@@ -33,5 +36,5 @@ def get_secret_key(length=50):
|
||||
NOTE: asterisk is not supported due to issues with environment variables containing
|
||||
asterisks (in case the secret key is stored in an environment variable)
|
||||
"""
|
||||
chars = 'abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&(-_=+)'
|
||||
chars = string.ascii_letters + string.digits
|
||||
return get_random_string(length, chars)
|
||||
|
||||
@@ -10,6 +10,7 @@ from apiserver.apierrors import APIError, errors
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.utilities.partial_version import PartialVersion
|
||||
from .apicall import APICall
|
||||
from .auth import Identity
|
||||
from .endpoint import Endpoint
|
||||
from .errors import MalformedPathError, InvalidVersionError, CallFailedError
|
||||
from .util import parse_return_stack_on_code
|
||||
@@ -37,7 +38,7 @@ class ServiceRepo(object):
|
||||
"""If the check is set, parsing will fail for endpoint request with the version that is grater than the current
|
||||
maximum """
|
||||
|
||||
_max_version = PartialVersion("2.13")
|
||||
_max_version = PartialVersion("2.16")
|
||||
""" Maximum version number (the highest min_version value across all endpoints) """
|
||||
|
||||
_endpoint_exp = (
|
||||
@@ -233,19 +234,27 @@ class ServiceRepo(object):
|
||||
return subcode in subcode_list
|
||||
|
||||
@classmethod
|
||||
def _get_company(
|
||||
def _get_identity(
|
||||
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
||||
) -> Optional[str]:
|
||||
) -> Optional[Identity]:
|
||||
authorize = endpoint and endpoint.authorize
|
||||
if ignore_error or not authorize:
|
||||
try:
|
||||
return call.identity.company
|
||||
return call.identity
|
||||
except Exception:
|
||||
return None
|
||||
return call.identity.company
|
||||
return call.identity
|
||||
|
||||
@classmethod
|
||||
def _get_company(
|
||||
cls, call: APICall, endpoint: Endpoint = None, ignore_error: bool = False
|
||||
) -> Optional[str]:
|
||||
identity = cls._get_identity(call, endpoint=endpoint, ignore_error=ignore_error)
|
||||
return None if identity is None else identity.company
|
||||
|
||||
@classmethod
|
||||
def handle_call(cls, call: APICall, load_data_callback: Callable = None):
|
||||
company = None
|
||||
try:
|
||||
if call.failed:
|
||||
raise CallFailedError()
|
||||
@@ -316,4 +325,4 @@ class ServiceRepo(object):
|
||||
else:
|
||||
log.error(console_msg)
|
||||
|
||||
return content, content_type
|
||||
return content, content_type, company
|
||||
|
||||
@@ -13,6 +13,7 @@ from apiserver.apimodels.auth import (
|
||||
CredentialsResponse,
|
||||
RevokeCredentialsRequest,
|
||||
EditUserReq,
|
||||
CreateCredentialsRequest,
|
||||
)
|
||||
from apiserver.apimodels.base import UpdateResponse
|
||||
from apiserver.bll.auth import AuthBLL
|
||||
@@ -58,9 +59,13 @@ def get_token_for_user(call: APICall, _: str, request: GetTokenForUserRequest):
|
||||
""" Generates a token based on a requested user and company. INTERNAL. """
|
||||
if call.identity.role not in Role.get_system_roles():
|
||||
if call.identity.role != Role.admin and call.identity.user != request.user:
|
||||
raise errors.bad_request.InvalidUserId("cannot generate token for another user")
|
||||
raise errors.bad_request.InvalidUserId(
|
||||
"cannot generate token for another user"
|
||||
)
|
||||
if call.identity.company != request.company:
|
||||
raise errors.bad_request.InvalidId("cannot generate token in another company")
|
||||
raise errors.bad_request.InvalidId(
|
||||
"cannot generate token in another company"
|
||||
)
|
||||
|
||||
call.result.data_model = AuthBLL.get_token_for_user(
|
||||
user_id=request.user,
|
||||
@@ -93,7 +98,10 @@ def validate_token_endpoint(call: APICall, _, __):
|
||||
)
|
||||
def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
""" Create a user from. INTERNAL. """
|
||||
if call.identity.role not in Role.get_system_roles() and request.company != call.identity.company:
|
||||
if (
|
||||
call.identity.role not in Role.get_system_roles()
|
||||
and request.company != call.identity.company
|
||||
):
|
||||
raise errors.bad_request.InvalidId("cannot create user in another company")
|
||||
|
||||
user_id = AuthBLL.create_user(request=request, call=call)
|
||||
@@ -101,7 +109,7 @@ def create_user(call: APICall, _, request: CreateUserRequest):
|
||||
|
||||
|
||||
@endpoint("auth.create_credentials", response_data_model=CreateCredentialsResponse)
|
||||
def create_credentials(call: APICall, _, __):
|
||||
def create_credentials(call: APICall, _, request: CreateCredentialsRequest):
|
||||
if _is_protected_user(call.identity.user):
|
||||
raise errors.bad_request.InvalidUserId("protected identity")
|
||||
|
||||
@@ -109,6 +117,7 @@ def create_credentials(call: APICall, _, __):
|
||||
user_id=call.identity.user,
|
||||
company_id=call.identity.company,
|
||||
role=call.identity.role,
|
||||
label=request.label,
|
||||
)
|
||||
call.result.data_model = CreateCredentialsResponse(credentials=credentials)
|
||||
|
||||
@@ -151,7 +160,9 @@ def get_credentials(call: APICall, _, __):
|
||||
# we return ONLY the key IDs, never the secrets (want a secret? create new credentials)
|
||||
call.result.data_model = GetCredentialsResponse(
|
||||
credentials=[
|
||||
CredentialsResponse(access_key=c.key, last_used=c.last_used)
|
||||
CredentialsResponse(
|
||||
access_key=c.key, last_used=c.last_used, label=c.label
|
||||
)
|
||||
for c in user.credentials
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
import itertools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
from operator import itemgetter
|
||||
from typing import Sequence, Optional
|
||||
|
||||
import attr
|
||||
import jsonmodels.fields
|
||||
from boltons.iterutils import bucketize
|
||||
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.apimodels.events import (
|
||||
@@ -17,12 +21,19 @@ from apiserver.apimodels.events import (
|
||||
LogOrderEnum,
|
||||
GetDebugImageSampleRequest,
|
||||
NextDebugImageSampleRequest,
|
||||
MetricVariants as ApiMetrics,
|
||||
TaskPlotsRequest,
|
||||
TaskEventsRequest,
|
||||
ScalarMetricsIterRawRequest,
|
||||
)
|
||||
from apiserver.bll.event import EventBLL
|
||||
from apiserver.bll.event.event_common import EventType
|
||||
from apiserver.bll.event.event_common import EventType, MetricVariants
|
||||
from apiserver.bll.event.events_iterator import Scroll
|
||||
from apiserver.bll.event.scalar_key import ScalarKeyEnum, ScalarKey
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.service_repo import APICall, endpoint
|
||||
from apiserver.utilities import json
|
||||
from apiserver.utilities import json, extract_properties_to_lists
|
||||
|
||||
task_bll = TaskBLL()
|
||||
event_bll = EventBLL()
|
||||
@@ -36,7 +47,6 @@ def add(call: APICall, company_id, _):
|
||||
company_id, [data], call.worker, allow_locked_tasks=allow_locked
|
||||
)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = 1
|
||||
|
||||
|
||||
@endpoint("events.add_batch")
|
||||
@@ -47,7 +57,6 @@ def add_batch(call: APICall, company_id, _):
|
||||
|
||||
added, err_count, err_info = event_bll.add_events(company_id, events, call.worker)
|
||||
call.result.data = dict(added=added, errors=err_count, errors_info=err_info)
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
|
||||
@endpoint("events.get_task_log", required_fields=["task"])
|
||||
@@ -110,7 +119,8 @@ def get_task_log(call, company_id, request: LogEventsRequest):
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
)[0]
|
||||
|
||||
res = event_bll.log_events_iterator.get_task_events(
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.task_log,
|
||||
company_id=task.get_index_company(),
|
||||
task_id=task_id,
|
||||
batch_size=request.batch_size,
|
||||
@@ -255,31 +265,94 @@ def vector_metrics_iter_histogram(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", required_fields=["task"])
|
||||
def get_task_events(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
batch_size = call.data.get("batch_size", 500)
|
||||
event_type = call.data.get("event_type")
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
order = call.data.get("order") or "asc"
|
||||
class GetTaskEventsScroll(Scroll):
|
||||
from_key_value = jsonmodels.fields.StringField()
|
||||
total = jsonmodels.fields.IntField()
|
||||
request: TaskEventsRequest = jsonmodels.fields.EmbeddedField(TaskEventsRequest)
|
||||
|
||||
|
||||
def make_response(
|
||||
total: int, returned: int = 0, scroll_id: str = None, **kwargs
|
||||
) -> dict:
|
||||
return {
|
||||
"returned": returned,
|
||||
"total": total,
|
||||
"scroll_id": scroll_id,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
@endpoint("events.get_task_events", request_data_model=TaskEventsRequest)
|
||||
def get_task_events(call, company_id, request: TaskEventsRequest):
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
)[0]
|
||||
result = event_bll.get_task_events(
|
||||
task.get_index_company(),
|
||||
task_id,
|
||||
sort=[{"timestamp": {"order": order}}],
|
||||
event_type=EventType(event_type) if event_type else EventType.all,
|
||||
scroll_id=scroll_id,
|
||||
size=batch_size,
|
||||
|
||||
key = ScalarKeyEnum.iter
|
||||
scalar_key = ScalarKey.resolve(key)
|
||||
|
||||
if not request.scroll_id:
|
||||
from_key_value = None if (request.order == LogOrderEnum.desc) else 0
|
||||
total = None
|
||||
else:
|
||||
try:
|
||||
scroll = GetTaskEventsScroll.from_scroll_id(request.scroll_id)
|
||||
except ValueError:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
|
||||
|
||||
if scroll.from_key_value is None:
|
||||
return make_response(
|
||||
scroll_id=request.scroll_id, total=scroll.total, events=[]
|
||||
)
|
||||
|
||||
from_key_value = scalar_key.cast_value(scroll.from_key_value)
|
||||
total = scroll.total
|
||||
|
||||
scroll.request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
request = scroll.request
|
||||
|
||||
navigate_earlier = request.order == LogOrderEnum.desc
|
||||
metric_variants = _get_metric_variants_from_request(request.metrics)
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
batch_size = min(
|
||||
request.batch_size,
|
||||
int(
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 10_000)
|
||||
),
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
events=result.events,
|
||||
returned=len(result.events),
|
||||
total=result.total_events,
|
||||
scroll_id=result.next_scroll_id,
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=request.event_type,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=batch_size,
|
||||
key=ScalarKeyEnum.iter,
|
||||
navigate_earlier=navigate_earlier,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
scroll = GetTaskEventsScroll(
|
||||
from_key_value=str(res.events[-1][scalar_key.field]) if res.events else None,
|
||||
total=total,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return make_response(
|
||||
returned=len(res.events),
|
||||
total=total,
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
events=res.events,
|
||||
)
|
||||
|
||||
|
||||
@@ -288,6 +361,7 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
metric = call.data["metric"]
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@@ -299,6 +373,7 @@ def get_scalar_metric_data(call, company_id, _):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
metric=metric,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
|
||||
call.result.data = dict(
|
||||
@@ -321,7 +396,7 @@ def get_task_latest_scalar_values(call, company_id, _):
|
||||
)
|
||||
last_iters = event_bll.get_last_iters(
|
||||
company_id=company_id, event_type=EventType.all, task_id=task_id, iters=1
|
||||
)
|
||||
).get(task_id)
|
||||
call.result.data = dict(
|
||||
metrics=metrics,
|
||||
last_iter=last_iters[0] if last_iters else 0,
|
||||
@@ -421,6 +496,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
task_ids = call.data["tasks"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
no_scroll = call.data.get("no_scroll", False)
|
||||
|
||||
tasks = task_bll.assert_exists(
|
||||
company_id=call.identity.company,
|
||||
@@ -442,6 +518,7 @@ def get_multi_task_plots(call, company_id, req_model):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iter_count=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=no_scroll,
|
||||
)
|
||||
|
||||
tasks = {t.id: t.name for t in tasks}
|
||||
@@ -494,11 +571,22 @@ def get_task_plots_v1_7(call, company_id, _):
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.get_task_plots", min_version="1.8", required_fields=["task"])
|
||||
def get_task_plots(call, company_id, _):
|
||||
task_id = call.data["task"]
|
||||
iters = call.data.get("iters", 1)
|
||||
scroll_id = call.data.get("scroll_id")
|
||||
def _get_metric_variants_from_request(
|
||||
req_metrics: Sequence[ApiMetrics],
|
||||
) -> Optional[MetricVariants]:
|
||||
if not req_metrics:
|
||||
return None
|
||||
|
||||
return {m.metric: m.variants for m in req_metrics}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"events.get_task_plots", min_version="1.8", request_data_model=TaskPlotsRequest
|
||||
)
|
||||
def get_task_plots(call, company_id, request: TaskPlotsRequest):
|
||||
task_id = request.task
|
||||
iters = request.iters
|
||||
scroll_id = request.scroll_id
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company", "company_origin")
|
||||
@@ -509,6 +597,8 @@ def get_task_plots(call, company_id, _):
|
||||
sort=[{"iter": {"order": "desc"}}],
|
||||
last_iterations_per_plot=iters,
|
||||
scroll_id=scroll_id,
|
||||
no_scroll=request.no_scroll,
|
||||
metric_variants=_get_metric_variants_from_request(request.metrics),
|
||||
)
|
||||
|
||||
return_events = result.events
|
||||
@@ -594,9 +684,9 @@ def get_debug_images_v1_8(call, company_id, _):
|
||||
response_data_model=DebugImageResponse,
|
||||
)
|
||||
def get_debug_images(call, company_id, request: DebugImagesRequest):
|
||||
task_metrics = defaultdict(set)
|
||||
task_metrics = defaultdict(dict)
|
||||
for tm in request.metrics:
|
||||
task_metrics[tm.task].add(tm.metric)
|
||||
task_metrics[tm.task][tm.metric] = tm.variants
|
||||
for metrics in task_metrics.values():
|
||||
if None in metrics:
|
||||
metrics.clear()
|
||||
@@ -734,13 +824,115 @@ def _get_top_iter_unique_events_per_task(events, max_iters, tasks):
|
||||
|
||||
def _get_top_iter_unique_events(events, max_iters):
|
||||
top_unique_events = defaultdict(lambda: [])
|
||||
for e in events:
|
||||
key = e.get("metric", "") + e.get("variant", "")
|
||||
for ev in events:
|
||||
key = ev.get("metric", "") + ev.get("variant", "")
|
||||
evs = top_unique_events[key]
|
||||
if len(evs) < max_iters:
|
||||
evs.append(e)
|
||||
evs.append(ev)
|
||||
unique_events = list(
|
||||
itertools.chain.from_iterable(list(top_unique_events.values()))
|
||||
)
|
||||
unique_events.sort(key=lambda e: e["iter"], reverse=True)
|
||||
return unique_events
|
||||
|
||||
|
||||
class ScalarMetricsIterRawScroll(Scroll):
|
||||
from_key_value = jsonmodels.fields.StringField()
|
||||
total = jsonmodels.fields.IntField()
|
||||
request: ScalarMetricsIterRawRequest = jsonmodels.fields.EmbeddedField(
|
||||
ScalarMetricsIterRawRequest
|
||||
)
|
||||
|
||||
|
||||
@endpoint("events.scalar_metrics_iter_raw", min_version="2.16")
|
||||
def scalar_metrics_iter_raw(
|
||||
call: APICall, company_id: str, request: ScalarMetricsIterRawRequest
|
||||
):
|
||||
key = request.key or ScalarKeyEnum.iter
|
||||
scalar_key = ScalarKey.resolve(key)
|
||||
if request.batch_size and request.batch_size < 0:
|
||||
raise errors.bad_request.ValidationError(
|
||||
"batch_size should be non negative number"
|
||||
)
|
||||
|
||||
if not request.scroll_id:
|
||||
from_key_value = None
|
||||
total = None
|
||||
request.batch_size = request.batch_size or 10_000
|
||||
else:
|
||||
try:
|
||||
scroll = ScalarMetricsIterRawScroll.from_scroll_id(request.scroll_id)
|
||||
except ValueError:
|
||||
raise errors.bad_request.InvalidScrollId(scroll_id=request.scroll_id)
|
||||
|
||||
if scroll.from_key_value is None:
|
||||
return make_response(
|
||||
scroll_id=request.scroll_id, total=scroll.total, variants={}
|
||||
)
|
||||
|
||||
from_key_value = scalar_key.cast_value(scroll.from_key_value)
|
||||
total = scroll.total
|
||||
request.batch_size = request.batch_size or scroll.request.batch_size
|
||||
|
||||
task_id = request.task
|
||||
|
||||
task = task_bll.assert_exists(
|
||||
company_id, task_id, allow_public=True, only=("company",),
|
||||
)[0]
|
||||
|
||||
metric_variants = _get_metric_variants_from_request([request.metric])
|
||||
|
||||
if request.count_total and total is None:
|
||||
total = event_bll.events_iterator.count_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
metric_variants=metric_variants,
|
||||
)
|
||||
|
||||
batch_size = min(
|
||||
request.batch_size,
|
||||
int(
|
||||
config.get("services.events.events_retrieval.max_raw_scalars_size", 200_000)
|
||||
),
|
||||
)
|
||||
|
||||
events = []
|
||||
for iteration in range(0, math.ceil(batch_size / 10_000)):
|
||||
res = event_bll.events_iterator.get_task_events(
|
||||
event_type=EventType.metrics_scalar,
|
||||
company_id=task.company,
|
||||
task_id=task_id,
|
||||
batch_size=min(batch_size, 10_000),
|
||||
navigate_earlier=False,
|
||||
from_key_value=from_key_value,
|
||||
metric_variants=metric_variants,
|
||||
key=key,
|
||||
)
|
||||
if not res.events:
|
||||
break
|
||||
events.extend(res.events)
|
||||
from_key_value = str(events[-1][scalar_key.field])
|
||||
|
||||
key = str(key)
|
||||
variants = {
|
||||
variant: extract_properties_to_lists(
|
||||
["value", scalar_key.field], events, target_keys=["y", key]
|
||||
)
|
||||
for variant, events in bucketize(events, key=itemgetter("variant")).items()
|
||||
}
|
||||
|
||||
call.kpis["events"] = len(events)
|
||||
|
||||
scroll = ScalarMetricsIterRawScroll(
|
||||
from_key_value=str(events[-1][scalar_key.field]) if events else None,
|
||||
total=total,
|
||||
request=request,
|
||||
)
|
||||
|
||||
return make_response(
|
||||
returned=len(events),
|
||||
total=total,
|
||||
scroll_id=scroll.get_scroll_id(),
|
||||
variants=variants,
|
||||
)
|
||||
|
||||
@@ -124,11 +124,15 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
with translate_errors_context():
|
||||
_process_include_subprojects(call.data)
|
||||
with TimingContext("mongo", "models_get_all_ex"):
|
||||
ret_params = {}
|
||||
models = Model.get_many_with_join(
|
||||
company=company_id, query_dict=call.data, allow_public=True
|
||||
company=company_id,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_by_id_ex", required_fields=["id"])
|
||||
@@ -148,14 +152,16 @@ def get_all(call: APICall, company_id, _):
|
||||
conform_tag_fields(call, call.data)
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "models_get_all"):
|
||||
ret_params = {}
|
||||
models = Model.get_many(
|
||||
company=company_id,
|
||||
parameters=call.data,
|
||||
query_dict=call.data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, models)
|
||||
call.result.data = {"models": models}
|
||||
call.result.data = {"models": models, **ret_params}
|
||||
|
||||
|
||||
@endpoint("models.get_frameworks", request_data_model=GetFrameworksRequest)
|
||||
@@ -183,7 +189,7 @@ create_fields = {
|
||||
"metadata": list,
|
||||
}
|
||||
|
||||
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata")
|
||||
last_update_fields = ("uri", "framework", "design", "labels", "ready", "metadata", "system_tags", "tags")
|
||||
|
||||
|
||||
def parse_model_fields(call, valid_fields):
|
||||
|
||||
@@ -7,7 +7,7 @@ from apiserver.apierrors import errors
|
||||
from apiserver.apierrors.errors.bad_request import InvalidProjectId
|
||||
from apiserver.apimodels.base import UpdateResponse, MakePublicRequest, IdResponse
|
||||
from apiserver.apimodels.projects import (
|
||||
GetHyperParamRequest,
|
||||
GetParamsRequest,
|
||||
ProjectTagsRequest,
|
||||
ProjectTaskParentsRequest,
|
||||
ProjectHyperparamValuesRequest,
|
||||
@@ -16,11 +16,14 @@ from apiserver.apimodels.projects import (
|
||||
MoveRequest,
|
||||
MergeRequest,
|
||||
ProjectOrNoneRequest,
|
||||
ProjectRequest,
|
||||
)
|
||||
from apiserver.bll.organization import OrgBLL, Tags
|
||||
from apiserver.bll.project import ProjectBLL
|
||||
from apiserver.bll.project.project_cleanup import delete_project
|
||||
from apiserver.bll.task import TaskBLL
|
||||
from apiserver.bll.project import ProjectBLL, ProjectQueries
|
||||
from apiserver.bll.project.project_cleanup import (
|
||||
delete_project,
|
||||
validate_project_delete,
|
||||
)
|
||||
from apiserver.database.errors import translate_errors_context
|
||||
from apiserver.database.model.project import Project
|
||||
from apiserver.database.utils import (
|
||||
@@ -37,8 +40,8 @@ from apiserver.services.utils import (
|
||||
from apiserver.timing_context import TimingContext
|
||||
|
||||
org_bll = OrgBLL()
|
||||
task_bll = TaskBLL()
|
||||
project_bll = ProjectBLL()
|
||||
project_queries = ProjectQueries()
|
||||
|
||||
create_fields = {
|
||||
"name": None,
|
||||
@@ -107,8 +110,12 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
_adjust_search_parameters(data, shallow_search=request.shallow_search)
|
||||
|
||||
ret_params = {}
|
||||
projects = Project.get_many_with_join(
|
||||
company=company_id, query_dict=data, allow_public=allow_public,
|
||||
company=company_id,
|
||||
query_dict=data,
|
||||
allow_public=allow_public,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
|
||||
if request.check_own_contents and requested_ids:
|
||||
@@ -124,7 +131,7 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
|
||||
conform_output_tags(call, projects)
|
||||
if not request.include_stats:
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
return
|
||||
|
||||
project_ids = {project["id"] for project in projects}
|
||||
@@ -132,13 +139,14 @@ def get_all_ex(call: APICall, company_id: str, request: ProjectsGetRequest):
|
||||
company=company_id,
|
||||
project_ids=list(project_ids),
|
||||
specific_state=request.stats_for_state,
|
||||
include_children=request.stats_with_children,
|
||||
)
|
||||
|
||||
for project in projects:
|
||||
project["stats"] = stats[project["id"]]
|
||||
project["sub_projects"] = children[project["id"]]
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint("projects.get_all")
|
||||
@@ -147,15 +155,17 @@ def get_all(call: APICall):
|
||||
data = call.data
|
||||
_adjust_search_parameters(data, shallow_search=data.get("shallow_search", False))
|
||||
with translate_errors_context(), TimingContext("mongo", "projects_get_all"):
|
||||
ret_params = {}
|
||||
projects = Project.get_many(
|
||||
company=call.identity.company,
|
||||
query_dict=data,
|
||||
parameters=data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, projects)
|
||||
|
||||
call.result.data = {"projects": projects}
|
||||
call.result.data = {"projects": projects, **ret_params}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -230,6 +240,13 @@ def merge(call: APICall, company: str, request: MergeRequest):
|
||||
}
|
||||
|
||||
|
||||
@endpoint("projects.validate_delete")
|
||||
def validate_delete(call: APICall, company_id: str, request: ProjectRequest):
|
||||
call.result.data = validate_project_delete(
|
||||
company=company_id, project_id=request.project
|
||||
)
|
||||
|
||||
|
||||
@endpoint("projects.delete", request_data_model=DeleteRequest)
|
||||
def delete(call: APICall, company_id: str, request: DeleteRequest):
|
||||
res, affected_projects = delete_project(
|
||||
@@ -249,7 +266,7 @@ def get_unique_metric_variants(
|
||||
call: APICall, company_id: str, request: ProjectOrNoneRequest
|
||||
):
|
||||
|
||||
metrics = task_bll.get_unique_metric_variants(
|
||||
metrics = project_queries.get_unique_metric_variants(
|
||||
company_id,
|
||||
[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
@@ -261,11 +278,11 @@ def get_unique_metric_variants(
|
||||
@endpoint(
|
||||
"projects.get_hyper_parameters",
|
||||
min_version="2.9",
|
||||
request_data_model=GetHyperParamRequest,
|
||||
request_data_model=GetParamsRequest,
|
||||
)
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamRequest):
|
||||
def get_hyper_parameters(call: APICall, company_id: str, request: GetParamsRequest):
|
||||
|
||||
total, remaining, parameters = TaskBLL.get_aggregated_project_parameters(
|
||||
total, remaining, parameters = project_queries.get_aggregated_project_parameters(
|
||||
company_id,
|
||||
project_ids=[request.project] if request.project else None,
|
||||
include_subprojects=request.include_subprojects,
|
||||
@@ -288,7 +305,7 @@ def get_hyper_parameters(call: APICall, company_id: str, request: GetHyperParamR
|
||||
def get_hyperparam_values(
|
||||
call: APICall, company_id: str, request: ProjectHyperparamValuesRequest
|
||||
):
|
||||
total, values = task_bll.get_hyperparam_distinct_values(
|
||||
total, values = project_queries.get_hyperparam_distinct_values(
|
||||
company_id,
|
||||
project_ids=request.projects,
|
||||
section=request.section,
|
||||
|
||||
@@ -48,21 +48,29 @@ def get_by_id(call: APICall):
|
||||
@endpoint("queues.get_all_ex", min_version="2.4")
|
||||
def get_all_ex(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
ret_params = {}
|
||||
queues = queue_bll.get_queue_infos(
|
||||
company_id=call.identity.company, query_dict=call.data
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
call.result.data = {"queues": queues}
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@endpoint("queues.get_all", min_version="2.4")
|
||||
def get_all(call: APICall):
|
||||
conform_tag_fields(call, call.data)
|
||||
queues = queue_bll.get_all(company_id=call.identity.company, query_dict=call.data)
|
||||
ret_params = {}
|
||||
queues = queue_bll.get_all(
|
||||
company_id=call.identity.company,
|
||||
query_dict=call.data,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
conform_output_tags(call, queues)
|
||||
|
||||
call.result.data = {"queues": queues}
|
||||
call.result.data = {"queues": queues, **ret_params}
|
||||
|
||||
|
||||
@endpoint("queues.create", min_version="2.4", request_data_model=CreateRequest)
|
||||
|
||||
@@ -4,7 +4,6 @@ from functools import partial
|
||||
from typing import Sequence, Union, Tuple
|
||||
|
||||
import attr
|
||||
import dpath
|
||||
from mongoengine import EmbeddedDocument, Q
|
||||
from mongoengine.queryset.transform import COMPARISON_OPERATORS
|
||||
from pymongo import UpdateOne
|
||||
@@ -220,14 +219,17 @@ def get_all_ex(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
with TimingContext("mongo", "task_get_all_ex"):
|
||||
_process_include_subprojects(call_data)
|
||||
ret_params = {}
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@endpoint("tasks.get_by_id_ex", required_fields=["id"])
|
||||
@@ -236,14 +238,13 @@ def get_by_id_ex(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
with TimingContext("mongo", "task_get_by_id_ex"):
|
||||
tasks = Task.get_many_with_join(
|
||||
company=company_id, query_dict=call_data, allow_public=True,
|
||||
)
|
||||
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
|
||||
|
||||
@endpoint("tasks.get_all", required_fields=[])
|
||||
@@ -252,16 +253,17 @@ def get_all(call: APICall, company_id, _):
|
||||
|
||||
call_data = escape_execution_parameters(call)
|
||||
|
||||
with translate_errors_context():
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks}
|
||||
with TimingContext("mongo", "task_get_all"):
|
||||
ret_params = {}
|
||||
tasks = Task.get_many(
|
||||
company=company_id,
|
||||
parameters=call_data,
|
||||
query_dict=call_data,
|
||||
allow_public=True,
|
||||
ret_params=ret_params,
|
||||
)
|
||||
unprepare_from_saved(call, tasks)
|
||||
call.result.data = {"tasks": tasks, **ret_params}
|
||||
|
||||
|
||||
@endpoint("tasks.get_types", request_data_model=GetTypesRequest)
|
||||
@@ -403,15 +405,12 @@ def prepare_for_save(call: APICall, fields: dict, previous_task: Task = None):
|
||||
escape_dict_field(fields, path)
|
||||
|
||||
# Strip all script fields (remove leading and trailing whitespace chars) to avoid unusable names and paths
|
||||
for field in task_script_stripped_fields:
|
||||
try:
|
||||
path = f"script/{field}"
|
||||
value = dpath.get(fields, path)
|
||||
script = fields.get("script")
|
||||
if script:
|
||||
for field in task_script_stripped_fields:
|
||||
value = script.get(field)
|
||||
if isinstance(value, str):
|
||||
value = value.strip()
|
||||
dpath.set(fields, path, value)
|
||||
except KeyError:
|
||||
pass
|
||||
script[field] = value.strip()
|
||||
|
||||
return fields
|
||||
|
||||
@@ -546,10 +545,12 @@ def clone_task(call: APICall, company_id, request: CloneRequest):
|
||||
}
|
||||
|
||||
|
||||
def prepare_update_fields(call: APICall, task, call_data):
|
||||
def prepare_update_fields(call: APICall, call_data):
|
||||
valid_fields = deepcopy(Task.user_set_allowed())
|
||||
update_fields = {k: v for k, v in create_fields.items() if k in valid_fields}
|
||||
update_fields["output__error"] = None
|
||||
update_fields.update(
|
||||
status=None, status_reason=None, status_message=None, output__error=None
|
||||
)
|
||||
t_fields = task_fields
|
||||
t_fields.add("output__error")
|
||||
fields = parse_from_call(call_data, update_fields, t_fields)
|
||||
@@ -569,7 +570,7 @@ def update(call: APICall, company_id, req_model: UpdateRequest):
|
||||
if not task:
|
||||
raise errors.bad_request.InvalidTaskId(id=task_id)
|
||||
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, task, call.data)
|
||||
partial_update_dict, valid_fields = prepare_update_fields(call, call.data)
|
||||
|
||||
if not partial_update_dict:
|
||||
return UpdateResponse(updated=0)
|
||||
@@ -642,7 +643,7 @@ def update_batch(call: APICall, company_id, _):
|
||||
updated_projects = set()
|
||||
for id, data in items.items():
|
||||
task = tasks[id]
|
||||
fields, valid_fields = prepare_update_fields(call, task, data)
|
||||
fields, valid_fields = prepare_update_fields(call, data)
|
||||
partial_update_dict = Task.get_safe_update_dict(fields)
|
||||
if not partial_update_dict:
|
||||
continue
|
||||
@@ -744,8 +745,7 @@ def edit(call: APICall, company_id, req_model: UpdateRequest):
|
||||
"tasks.get_hyper_params", request_data_model=GetHyperParamsRequest,
|
||||
)
|
||||
def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
tasks_params = HyperParams.get_params(company_id, task_ids=request.tasks)
|
||||
|
||||
call.result.data = {
|
||||
"params": [{"task": task, **data} for task, data in tasks_params.items()]
|
||||
@@ -754,39 +754,36 @@ def get_hyper_params(call: APICall, company_id, request: GetHyperParamsRequest):
|
||||
|
||||
@endpoint("tasks.edit_hyper_params", request_data_model=EditHyperParamsRequest)
|
||||
def edit_hyper_params(call: APICall, company_id, request: EditHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
replace_hyperparams=request.replace_hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_hyper_params", request_data_model=DeleteHyperParamsRequest)
|
||||
def delete_hyper_params(call: APICall, company_id, request: DeleteHyperParamsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_params(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
hyperparams=request.hyperparams,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
"tasks.get_configurations", request_data_model=GetConfigurationsRequest,
|
||||
)
|
||||
def get_configurations(call: APICall, company_id, request: GetConfigurationsRequest):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
company_id, task_ids=request.tasks, names=request.names
|
||||
)
|
||||
tasks_params = HyperParams.get_configurations(
|
||||
company_id, task_ids=request.tasks, names=request.names
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
@@ -801,10 +798,9 @@ def get_configurations(call: APICall, company_id, request: GetConfigurationsRequ
|
||||
def get_configuration_names(
|
||||
call: APICall, company_id, request: GetConfigurationNamesRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
||||
)
|
||||
tasks_params = HyperParams.get_configuration_names(
|
||||
company_id, task_ids=request.tasks, skip_empty=request.skip_empty
|
||||
)
|
||||
|
||||
call.result.data = {
|
||||
"configurations": [
|
||||
@@ -815,31 +811,29 @@ def get_configuration_names(
|
||||
|
||||
@endpoint("tasks.edit_configuration", request_data_model=EditConfigurationRequest)
|
||||
def edit_configuration(call: APICall, company_id, request: EditConfigurationRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": HyperParams.edit_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
replace_configuration=request.replace_configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.delete_configuration", request_data_model=DeleteConfigurationRequest)
|
||||
def delete_configuration(
|
||||
call: APICall, company_id, request: DeleteConfigurationRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": HyperParams.delete_configuration(
|
||||
company_id,
|
||||
task_id=request.task,
|
||||
configuration=request.configuration,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -854,6 +848,7 @@ def enqueue(call: APICall, company_id, request: EnqueueRequest):
|
||||
queue_id=request.queue,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
force=request.force,
|
||||
)
|
||||
call.result.data_model = EnqueueResponse(queued=queued, **res)
|
||||
|
||||
@@ -1061,6 +1056,8 @@ def delete(call: APICall, company_id, request: DeleteRequest):
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
)
|
||||
if deleted:
|
||||
_reset_cached_tags(company_id, projects=[task.project] if task.project else [])
|
||||
@@ -1077,6 +1074,8 @@ def delete_many(call: APICall, company_id, request: DeleteManyRequest):
|
||||
force=request.force,
|
||||
return_file_urls=request.return_file_urls,
|
||||
delete_output_models=request.delete_output_models,
|
||||
status_message=request.status_message,
|
||||
status_reason=request.status_reason,
|
||||
),
|
||||
ids=request.ids,
|
||||
)
|
||||
@@ -1169,15 +1168,14 @@ def ping(_, company_id, request: PingRequest):
|
||||
def add_or_update_artifacts(
|
||||
call: APICall, company_id, request: AddOrUpdateArtifactsRequest
|
||||
):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"updated": Artifacts.add_or_update_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifacts=request.artifacts,
|
||||
force=True,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint(
|
||||
@@ -1186,31 +1184,28 @@ def add_or_update_artifacts(
|
||||
request_data_model=DeleteArtifactsRequest,
|
||||
)
|
||||
def delete_artifacts(call: APICall, company_id, request: DeleteArtifactsRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=request.force,
|
||||
)
|
||||
}
|
||||
call.result.data = {
|
||||
"deleted": Artifacts.delete_artifacts(
|
||||
company_id=company_id,
|
||||
task_id=request.task,
|
||||
artifact_ids=request.artifacts,
|
||||
force=True,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@endpoint("tasks.make_public", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=True
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.make_private", min_version="2.9", request_data_model=MakePublicRequest)
|
||||
def make_public(call: APICall, company_id, request: MakePublicRequest):
|
||||
with translate_errors_context():
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
call.result.data = Task.set_public(
|
||||
company_id, request.ids, invalid_cls=InvalidTaskId, enabled=False
|
||||
)
|
||||
|
||||
|
||||
@endpoint("tasks.move", request_data_model=MoveRequest)
|
||||
|
||||
@@ -71,7 +71,7 @@ class TestService(TestCase, TestServiceInterface):
|
||||
delete_params=delete_params,
|
||||
)
|
||||
|
||||
def setUp(self, version="1.7"):
|
||||
def setUp(self, version="999.0"):
|
||||
self._api = APIClient(base_url=f"http://localhost:8008/v{version}")
|
||||
self._deferred = []
|
||||
self._version = parse(version)
|
||||
|
||||
@@ -38,7 +38,7 @@ class TestEntityOrdering(TestService):
|
||||
self._assertGetTasksWithOrdering(order_by=order_field, page=0, page_size=20)
|
||||
|
||||
field_vals = []
|
||||
page_size = 2
|
||||
page_size = 4
|
||||
num_pages = 5
|
||||
for page in range(num_pages):
|
||||
paged_tasks = self._get_page_tasks(
|
||||
|
||||
@@ -2,9 +2,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestOrganization(TestService):
|
||||
def setUp(self, version="2.12"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def test_get_user_companies(self):
|
||||
company = self.api.organization.get_user_companies().companies[0]
|
||||
self.assertEqual(len(company.owners), company.allocated)
|
||||
|
||||
80
apiserver/tests/automated/test_paging_and_scrolling.py
Normal file
80
apiserver/tests/automated/test_paging_and_scrolling.py
Normal file
@@ -0,0 +1,80 @@
|
||||
import math
|
||||
from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestPagingAndScrolling(TestService):
|
||||
name_prefix = f"Test paging "
|
||||
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(**kwargs)
|
||||
self.task_ids = self._create_tasks()
|
||||
|
||||
def _create_tasks(self):
|
||||
tasks = [
|
||||
self._temp_task(
|
||||
name=f"{self.name_prefix}{i}",
|
||||
hyperparams={
|
||||
"test": {
|
||||
"param": {
|
||||
"section": "test",
|
||||
"name": "param",
|
||||
"type": "str",
|
||||
"value": str(i),
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
for i in range(18)
|
||||
]
|
||||
return tasks
|
||||
|
||||
def test_paging(self):
|
||||
page_size = 10
|
||||
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
|
||||
start = page * page_size
|
||||
expected_size = min(page_size, len(self.task_ids) - start)
|
||||
tasks = self._get_tasks(page=page, page_size=page_size,).tasks
|
||||
self.assertEqual(len(tasks), expected_size)
|
||||
for i, t in enumerate(tasks):
|
||||
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
|
||||
|
||||
def test_scrolling(self):
|
||||
page_size = 10
|
||||
scroll_id = None
|
||||
for page in range(0, math.ceil(len(self.task_ids) / page_size)):
|
||||
start = page * page_size
|
||||
expected_size = min(page_size, len(self.task_ids) - start)
|
||||
res = self._get_tasks(size=page_size, scroll_id=scroll_id,)
|
||||
self.assertTrue(res.scroll_id)
|
||||
scroll_id = res.scroll_id
|
||||
tasks = res.tasks
|
||||
self.assertEqual(len(tasks), expected_size)
|
||||
for i, t in enumerate(tasks):
|
||||
self.assertEqual(t.name, f"{self.name_prefix}{start + i}")
|
||||
|
||||
# no more data in this scroll
|
||||
tasks = self._get_tasks(size=page_size, scroll_id=scroll_id,).tasks
|
||||
self.assertFalse(tasks)
|
||||
|
||||
# refresh brings all
|
||||
tasks = self._get_tasks(
|
||||
size=page_size, scroll_id=scroll_id, refresh_scroll=True,
|
||||
).tasks
|
||||
self.assertEqual([t.id for t in tasks], self.task_ids)
|
||||
|
||||
def _get_tasks(self, **page_params):
|
||||
return self.api.tasks.get_all_ex(
|
||||
name="^Test paging ",
|
||||
order_by=["hyperparams.test.param.value"],
|
||||
**page_params,
|
||||
)
|
||||
|
||||
def _temp_task(self, name, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks",
|
||||
name=name,
|
||||
comment="Test task",
|
||||
type="testing",
|
||||
input=dict(view=dict()),
|
||||
**kwargs,
|
||||
)
|
||||
54
apiserver/tests/automated/test_project_delete.py
Normal file
54
apiserver/tests/automated/test_project_delete.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from apiserver.apierrors import errors
|
||||
from apiserver.database.model import EntityVisibility
|
||||
from apiserver.tests.automated import TestService
|
||||
from apiserver.database.utils import id as db_id
|
||||
|
||||
|
||||
class TestProjectsDelete(TestService):
|
||||
def setUp(self, version="2.14"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def new_task(self, **kwargs):
|
||||
return self.create_temp(
|
||||
"tasks", type="testing", name=db_id(), input=dict(view=dict()), **kwargs
|
||||
)
|
||||
|
||||
def new_model(self, **kwargs):
|
||||
return self.create_temp("models", uri="file:///a/b", name=db_id(), labels={}, **kwargs)
|
||||
|
||||
def new_project(self, **kwargs):
|
||||
return self.create_temp("projects", name=db_id(), description="", **kwargs)
|
||||
|
||||
def test_delete_fails_with_active_task(self):
|
||||
project = self.new_project()
|
||||
self.new_task(project=project)
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.tasks, 1)
|
||||
self.assertEqual(res.non_archived_tasks, 1)
|
||||
with self.api.raises(errors.bad_request.ProjectHasTasks):
|
||||
self.api.projects.delete(project=project)
|
||||
|
||||
def test_delete_with_archived_task(self):
|
||||
project = self.new_project()
|
||||
self.new_task(project=project, system_tags=[EntityVisibility.archived.value])
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.tasks, 1)
|
||||
self.assertEqual(res.non_archived_tasks, 0)
|
||||
self.api.projects.delete(project=project)
|
||||
|
||||
def test_delete_fails_with_active_model(self):
|
||||
project = self.new_project()
|
||||
self.new_model(project=project)
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.models, 1)
|
||||
self.assertEqual(res.non_archived_models, 1)
|
||||
with self.api.raises(errors.bad_request.ProjectHasModels):
|
||||
self.api.projects.delete(project=project)
|
||||
|
||||
def test_delete_with_archived_model(self):
|
||||
project = self.new_project()
|
||||
self.new_model(project=project, system_tags=[EntityVisibility.archived.value])
|
||||
res = self.api.projects.validate_delete(project=project)
|
||||
self.assertEqual(res.models, 1)
|
||||
self.assertEqual(res.non_archived_models, 0)
|
||||
self.api.projects.delete(project=project)
|
||||
@@ -6,9 +6,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestQueueAndModelMetadata(TestService):
|
||||
def setUp(self, version="2.13"):
|
||||
super().setUp(version=version)
|
||||
|
||||
meta1 = [{"key": "test_key", "type": "str", "value": "test_value"}]
|
||||
|
||||
def test_queue_metas(self):
|
||||
@@ -72,3 +69,12 @@ class TestQueueAndModelMetadata(TestService):
|
||||
return self.create_temp(
|
||||
"models", uri="file://test", name=name, labels={}, **kwargs
|
||||
)
|
||||
|
||||
def temp_project(self, **kwargs) -> str:
|
||||
self.update_missing(
|
||||
kwargs,
|
||||
name="Test models meta",
|
||||
description="test",
|
||||
delete_params=dict(force=True),
|
||||
)
|
||||
return self.create_temp("projects", **kwargs)
|
||||
|
||||
@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestSubProjects(TestService):
|
||||
def setUp(self, **kwargs):
|
||||
super().setUp(version="2.13")
|
||||
|
||||
def test_project_aggregations(self):
|
||||
"""This test requires user with user_auth_only... credentials in db"""
|
||||
user2_client = APIClient(
|
||||
@@ -203,6 +200,9 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res1.stats["active"]["status_count"]["stopped"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_runtime"], 2)
|
||||
self.assertEqual(res1.stats["active"]["completed_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["total_tasks"], 2)
|
||||
self.assertEqual(res1.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(
|
||||
{sp.name for sp in res1.sub_projects},
|
||||
{
|
||||
@@ -215,6 +215,9 @@ class TestSubProjects(TestService):
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["created"], 0)
|
||||
self.assertEqual(res2.stats["active"]["status_count"]["stopped"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_runtime"], 0)
|
||||
self.assertEqual(res2.stats["active"]["completed_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["total_tasks"], 0)
|
||||
self.assertEqual(res2.stats["active"]["running_tasks"], 0)
|
||||
self.assertEqual(res2.sub_projects, [])
|
||||
|
||||
def _run_tasks(self, *tasks):
|
||||
|
||||
@@ -198,6 +198,9 @@ class TestTags(TestService):
|
||||
def assertProjectStats(self, project: AttrDict):
|
||||
self.assertEqual(set(project.stats.keys()), {"active"})
|
||||
self.assertAlmostEqual(project.stats.active.total_runtime, 1, places=0)
|
||||
self.assertEqual(project.stats.active.completed_tasks, 1)
|
||||
self.assertEqual(project.stats.active.total_tasks, 1)
|
||||
self.assertEqual(project.stats.active.running_tasks, 0)
|
||||
for status, count in project.stats.active.status_count.items():
|
||||
self.assertEqual(count, 1 if status == "stopped" else 0)
|
||||
|
||||
|
||||
@@ -82,14 +82,10 @@ class TestTasksArtifacts(TestService):
|
||||
|
||||
# test edit running task
|
||||
self.api.tasks.started(task=task)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit, force=True)
|
||||
self.api.tasks.add_or_update_artifacts(task=task, artifacts=edit)
|
||||
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||
self._assertTaskArtifacts(artifacts, res)
|
||||
with self.api.raises(InvalidTaskStatus):
|
||||
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
|
||||
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}], force=True)
|
||||
self.api.tasks.delete_artifacts(task=task, artifacts=[{"key": artifacts[-1]["key"]}])
|
||||
res = self.api.tasks.get_all_ex(id=[task]).tasks[0]
|
||||
self._assertTaskArtifacts(artifacts[0: len(artifacts) - 1], res)
|
||||
|
||||
|
||||
@@ -12,9 +12,6 @@ from apiserver.tests.automated import TestService
|
||||
|
||||
|
||||
class TestTaskEvents(TestService):
|
||||
def setUp(self, version="2.9"):
|
||||
super().setUp(version=version)
|
||||
|
||||
def _temp_task(self, name="test task events"):
|
||||
task_input = dict(
|
||||
name=name, type="training", input=dict(mapping={}, view=dict(entries=[])),
|
||||
@@ -257,6 +254,45 @@ class TestTaskEvents(TestService):
|
||||
self.assertEqual(len(task_data["x"]), iterations)
|
||||
self.assertEqual(len(task_data["y"]), iterations)
|
||||
|
||||
def test_task_metric_raw(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
iter_count = 100
|
||||
task = self._temp_task()
|
||||
events = [
|
||||
{
|
||||
**self._create_task_event("training_stats_scalar", task, iteration),
|
||||
"metric": metric,
|
||||
"variant": variant,
|
||||
"value": iteration,
|
||||
}
|
||||
for iteration in range(iter_count)
|
||||
]
|
||||
self.send_batch(events)
|
||||
|
||||
batch_size = 15
|
||||
metric_param = {"metric": metric, "variants": [variant]}
|
||||
res = self.api.events.scalar_metrics_iter_raw(
|
||||
task=task, batch_size=batch_size, metric=metric_param, count_total=True
|
||||
)
|
||||
self.assertEqual(res.total, len(events))
|
||||
self.assertTrue(res.scroll_id)
|
||||
res_iters = []
|
||||
res_ys = []
|
||||
calls = 0
|
||||
while res.returned or calls > 10:
|
||||
calls += 1
|
||||
res_iters.extend(res.variants[variant]["iter"])
|
||||
res_ys.extend(res.variants[variant]["y"])
|
||||
scroll_id = res.scroll_id
|
||||
res = self.api.events.scalar_metrics_iter_raw(
|
||||
task=task, metric=metric_param, scroll_id=scroll_id
|
||||
)
|
||||
|
||||
self.assertEqual(calls, len(events) // batch_size + 1)
|
||||
self.assertEqual(res_iters, [ev["iter"] for ev in events])
|
||||
self.assertEqual(res_ys, [ev["value"] for ev in events])
|
||||
|
||||
def test_task_metric_value_intervals(self):
|
||||
metric = "Metric1"
|
||||
variant = "Variant1"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from datetime import timedelta, datetime
|
||||
from threading import Thread
|
||||
from time import sleep
|
||||
from typing import Optional
|
||||
@@ -10,6 +11,7 @@ from semantic_version import Version
|
||||
from apiserver.config_repo import config
|
||||
from apiserver.config.info import get_version
|
||||
from apiserver.database.model.settings import Settings
|
||||
from apiserver.redis_manager import redman
|
||||
from apiserver.utilities.threads_manager import ThreadsManager
|
||||
|
||||
log = config.logger(__name__)
|
||||
@@ -17,6 +19,8 @@ log = config.logger(__name__)
|
||||
|
||||
class CheckUpdatesThread(Thread):
|
||||
_enabled = bool(config.get("apiserver.check_for_updates.enabled", True))
|
||||
_lock_name = "check_updates"
|
||||
_redis = redman.connection("apiserver")
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class _VersionResponse:
|
||||
@@ -29,6 +33,19 @@ class CheckUpdatesThread(Thread):
|
||||
target=self._check_updates, daemon=True
|
||||
)
|
||||
|
||||
@property
|
||||
def update_interval(self):
|
||||
return timedelta(
|
||||
seconds=max(
|
||||
float(
|
||||
config.get(
|
||||
"apiserver.check_for_updates.check_interval_sec", 60 * 60 * 24,
|
||||
)
|
||||
),
|
||||
60 * 5,
|
||||
)
|
||||
)
|
||||
|
||||
def start(self) -> None:
|
||||
if not self._enabled:
|
||||
log.info("Checking for updates is disabled")
|
||||
@@ -37,12 +54,13 @@ class CheckUpdatesThread(Thread):
|
||||
|
||||
@property
|
||||
def component_name(self) -> str:
|
||||
return config.get("apiserver.check_for_updates.component_name", "clearml-server")
|
||||
return config.get(
|
||||
"apiserver.check_for_updates.component_name", "clearml-server"
|
||||
)
|
||||
|
||||
def _check_new_version_available(self) -> Optional[_VersionResponse]:
|
||||
url = config.get(
|
||||
"apiserver.check_for_updates.url",
|
||||
"https://updates.clear.ml/updates",
|
||||
"apiserver.check_for_updates.url", "https://updates.clear.ml/updates",
|
||||
)
|
||||
|
||||
uid = Settings.get_by_key("server.uuid")
|
||||
@@ -81,34 +99,31 @@ class CheckUpdatesThread(Thread):
|
||||
)
|
||||
|
||||
def _check_updates(self):
|
||||
update_interval_sec = max(
|
||||
float(
|
||||
config.get(
|
||||
"apiserver.check_for_updates.check_interval_sec",
|
||||
60 * 60 * 24,
|
||||
)
|
||||
),
|
||||
60 * 5,
|
||||
)
|
||||
while not ThreadsManager.terminating:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
response = self._check_new_version_available()
|
||||
if response:
|
||||
if response.patch_upgrade:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
|
||||
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
|
||||
f" is recommended!"
|
||||
)
|
||||
except Exception:
|
||||
log.exception("Failed obtaining updates")
|
||||
if self._redis.set(
|
||||
self._lock_name,
|
||||
value=datetime.utcnow().isoformat(),
|
||||
ex=self.update_interval - timedelta(seconds=60),
|
||||
nx=True,
|
||||
):
|
||||
response = self._check_new_version_available()
|
||||
if response:
|
||||
if response.patch_upgrade:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new package available: upgrade to v{response.version} "
|
||||
f"is recommended!\nRelease Notes:\n{os.linesep.join(response.description)}"
|
||||
)
|
||||
else:
|
||||
log.info(
|
||||
f"{self.component_name.upper()} new version available: upgrade to v{response.version}"
|
||||
f" is recommended!"
|
||||
)
|
||||
except Exception as ex:
|
||||
log.exception("Failed obtaining updates: " + str(ex))
|
||||
|
||||
sleep(update_interval_sec)
|
||||
sleep(self.update_interval.total_seconds())
|
||||
|
||||
|
||||
check_updates_thread = CheckUpdatesThread()
|
||||
|
||||
@@ -10,6 +10,7 @@ def extract_properties_to_lists(
|
||||
key_names: Sequence[str],
|
||||
data: Sequence[dict],
|
||||
extract_func: Optional[Callable[[dict], Tuple]] = None,
|
||||
target_keys: Optional[Sequence[str]] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Given a list of dictionaries and names of dictionary keys
|
||||
@@ -20,9 +21,10 @@ def extract_properties_to_lists(
|
||||
:param extract_func: the optional callable that extracts properties
|
||||
from a dictionary and put them in a tuple in the order corresponding to
|
||||
key_names. If not specified then properties are extracted according to key_names
|
||||
:param target_keys: optional alternative keys to use in the target dictionary. must be equal in length to key_names.
|
||||
"""
|
||||
if not data:
|
||||
return {k: [] for k in key_names}
|
||||
|
||||
value_sequences = zip(*map(extract_func or itemgetter(*key_names), data))
|
||||
return dict(zip(key_names, map(list, value_sequences)))
|
||||
return dict(zip((target_keys or key_names), map(list, value_sequences)))
|
||||
|
||||
@@ -37,15 +37,12 @@ def deep_merge(source: dict, override: dict) -> dict:
|
||||
|
||||
def nested_get(
|
||||
dictionary: Mapping,
|
||||
path: Union[Sequence[str], str],
|
||||
path: Sequence[str],
|
||||
default: Optional[Union[Any, Callable]] = None,
|
||||
) -> Any:
|
||||
if isinstance(path, str):
|
||||
path = [path]
|
||||
|
||||
node = dictionary
|
||||
for key in path:
|
||||
if key not in node:
|
||||
if not node or key not in node:
|
||||
if callable(default):
|
||||
return default()
|
||||
return default
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "1.0.2"
|
||||
__version__ = "1.2.0"
|
||||
|
||||
35
docker/build/Dockerfile
Normal file
35
docker/build/Dockerfile
Normal file
@@ -0,0 +1,35 @@
|
||||
FROM centos/nodejs-12-centos7 AS webapp
|
||||
|
||||
USER root
|
||||
WORKDIR /opt
|
||||
|
||||
RUN git clone https://github.com/allegroai/clearml-web.git
|
||||
RUN mv clearml-web /opt/open-webapp
|
||||
COPY --chmod=744 docker/build/internal_files/build_webapp.sh /tmp/internal_files/
|
||||
RUN /bin/bash -c '/tmp/internal_files/build_webapp.sh'
|
||||
|
||||
FROM centos:7 AS staging_image
|
||||
COPY --chmod=744 docker/build/internal_files/entrypoint.sh /opt/clearml/
|
||||
COPY fileserver /opt/clearml/fileserver/
|
||||
COPY apiserver /opt/clearml/apiserver/
|
||||
|
||||
FROM centos:7
|
||||
COPY --from=staging_image /opt/clearml/ /opt/clearml/
|
||||
|
||||
COPY --chmod=744 docker/build/internal_files/final_image_preparation.sh /tmp/internal_files/
|
||||
COPY docker/build/internal_files/clearml.conf.template /tmp/internal_files/
|
||||
RUN /bin/bash -c '/tmp/internal_files/final_image_preparation.sh'
|
||||
|
||||
COPY --from=webapp /opt/open-webapp/build /usr/share/nginx/html
|
||||
|
||||
EXPOSE 8080
|
||||
EXPOSE 8008
|
||||
EXPOSE 8081
|
||||
|
||||
ARG VERSION
|
||||
ARG BUILD
|
||||
ENV CLEARML_SERVER_VERSION=${VERSION}
|
||||
ENV CLEARML_SERVER_BUILD=${BUILD}
|
||||
|
||||
WORKDIR /opt/clearml/
|
||||
ENTRYPOINT ["/opt/clearml/entrypoint.sh"]
|
||||
9
docker/build/internal_files/build_webapp.sh
Normal file
9
docker/build/internal_files/build_webapp.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/usr/bin/env bash
|
||||
set -x
|
||||
set -e
|
||||
|
||||
cd /opt/open-webapp/
|
||||
npm ci --unsafe-perm node-sass
|
||||
|
||||
cd /opt/open-webapp/
|
||||
npm run build
|
||||
97
docker/build/internal_files/clearml.conf.template
Normal file
97
docker/build/internal_files/clearml.conf.template
Normal file
@@ -0,0 +1,97 @@
|
||||
# For more information on configuration, see:
|
||||
# * Official English Documentation: http://nginx.org/en/docs/
|
||||
# * Official Russian Documentation: http://nginx.org/ru/docs/
|
||||
|
||||
user nginx;
|
||||
worker_processes auto;
|
||||
error_log /var/log/nginx/error.log;
|
||||
pid /run/nginx.pid;
|
||||
|
||||
# Load dynamic modules. See /usr/share/doc/nginx/README.dynamic.
|
||||
include /usr/share/nginx/modules/*.conf;
|
||||
|
||||
events {
|
||||
worker_connections 1024;
|
||||
}
|
||||
|
||||
http {
|
||||
log_format main '$remote_addr - $remote_user [$time_local] "$request" '
|
||||
'$status $body_bytes_sent "$http_referer" '
|
||||
'"$http_user_agent" "$http_x_forwarded_for"';
|
||||
|
||||
access_log /var/log/nginx/access.log main;
|
||||
|
||||
sendfile on;
|
||||
tcp_nopush on;
|
||||
tcp_nodelay on;
|
||||
keepalive_timeout 65;
|
||||
types_hash_max_size 2048;
|
||||
|
||||
include /etc/nginx/mime.types;
|
||||
default_type application/octet-stream;
|
||||
|
||||
# Load modular configuration files from the /etc/nginx/conf.d directory.
|
||||
# See http://nginx.org/en/docs/ngx_core_module.html#include
|
||||
# for more information.
|
||||
include /etc/nginx/conf.d/*.conf;
|
||||
|
||||
server {
|
||||
listen 80 default_server;
|
||||
listen [::]:80 default_server;
|
||||
server_name _;
|
||||
root /usr/share/nginx/html;
|
||||
proxy_http_version 1.1;
|
||||
|
||||
# comppression
|
||||
gzip on;
|
||||
gzip_comp_level 9;
|
||||
gzip_http_version 1.0;
|
||||
gzip_min_length 512;
|
||||
gzip_proxied expired no-cache no-store private auth;
|
||||
gzip_types text/plain
|
||||
text/css
|
||||
application/json
|
||||
application/javascript
|
||||
application/x-javascript
|
||||
text/xml application/xml
|
||||
application/xml+rss
|
||||
text/javascript
|
||||
application/x-font-ttf
|
||||
font/woff2
|
||||
image/svg+xml
|
||||
image/x-icon;
|
||||
|
||||
# Load configuration files for the default server block.
|
||||
include /etc/nginx/default.d/*.conf;
|
||||
|
||||
location / {
|
||||
try_files $uri$args $uri$args/ $uri index.html /index.html;
|
||||
}
|
||||
|
||||
location /version.json {
|
||||
add_header Cache-Control 'no-cache';
|
||||
}
|
||||
|
||||
location /api {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $host;
|
||||
proxy_pass ${NGINX_APISERVER_ADDR};
|
||||
rewrite /api/(.*) /$1 break;
|
||||
}
|
||||
|
||||
location /files {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header Host $host;
|
||||
proxy_pass ${NGINX_FILESERVER_ADDR};
|
||||
rewrite /files/(.*) /$1 break;
|
||||
}
|
||||
|
||||
error_page 404 /404.html;
|
||||
location = /40x.html {
|
||||
}
|
||||
|
||||
error_page 500 502 503 504 /50x.html;
|
||||
location = /50x.html {
|
||||
}
|
||||
}
|
||||
}
|
||||
67
docker/build/internal_files/entrypoint.sh
Normal file
67
docker/build/internal_files/entrypoint.sh
Normal file
@@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
mkdir -p /var/log/clearml
|
||||
|
||||
SERVER_TYPE=$1
|
||||
|
||||
if (( $# < 1 )) ; then
|
||||
echo "The server type was not stated. It should be either apiserver, webserver or fileserver."
|
||||
sleep 60
|
||||
exit 1
|
||||
|
||||
elif [[ ${SERVER_TYPE} == "apiserver" ]]; then
|
||||
cd /opt/clearml/
|
||||
python3 -m apiserver.apierrors_generator
|
||||
|
||||
if [[ -n $CLEARML_USE_GUNICORN ]]; then
|
||||
MAX_REQUESTS=
|
||||
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS ]]; then
|
||||
MAX_REQUESTS="--max-requests $CLEARML_GUNICORN_MAX_REQUESTS"
|
||||
if [[ -n $CLEARML_GUNICORN_MAX_REQUESTS_JITTER ]]; then
|
||||
MAX_REQUESTS="$MAX_REQUESTS --max-requests-jitter $CLEARML_GUNICORN_MAX_REQUESTS_JITTER"
|
||||
fi
|
||||
fi
|
||||
|
||||
export GUNICORN_CMD_ARGS=${CLEARML_GUNICORN_CMD_ARGS}
|
||||
|
||||
# Note: don't be tempted to "fix" $MAX_REQUESTS with "$MAX_REQUESTS" as this produces an empty arg which fucks up gunicorn
|
||||
gunicorn \
|
||||
-w "${CLEARML_GUNICORN_WORKERS:-8}" \
|
||||
-t "${CLEARML_GUNICORN_TIMEOUT:-600}" --bind="${CLEARML_GUNICORN_BIND:-0.0.0.0:8008}" \
|
||||
$MAX_REQUESTS apiserver.server:app
|
||||
else
|
||||
python3 -m apiserver.server
|
||||
fi
|
||||
|
||||
elif [[ ${SERVER_TYPE} == "webserver" ]]; then
|
||||
|
||||
if [[ "${USER_KEY}" != "" ]] || [[ "${USER_SECRET}" != "" ]] || [[ "${COMPANY_ID}" != "" ]]; then
|
||||
cat << EOF > /usr/share/nginx/html/credentials.json
|
||||
{
|
||||
"userKey": "${USER_KEY}",
|
||||
"userSecret": "${USER_SECRET}",
|
||||
"companyID": "${COMPANY_ID}"
|
||||
}
|
||||
EOF
|
||||
fi
|
||||
|
||||
export NGINX_APISERVER_ADDR=${NGINX_APISERVER_ADDRESS:-http://apiserver:8008}
|
||||
export NGINX_FILESERVER_ADDR=${NGINX_FILESERVER_ADDRESS:-http://fileserver:8081}
|
||||
|
||||
envsubst '${NGINX_APISERVER_ADDR} ${NGINX_FILESERVER_ADDR}' < /etc/nginx/clearml.conf.template > /etc/nginx/nginx.conf
|
||||
|
||||
#start the server
|
||||
/usr/sbin/nginx -g "daemon off;"
|
||||
|
||||
elif [[ ${SERVER_TYPE} == "fileserver" ]]; then
|
||||
cd /opt/clearml/fileserver/
|
||||
if [ "$FILESERVER_USE_GUNICORN" = true ] ; then
|
||||
gunicorn -t 600 --bind=0.0.0.0:8081 fileserver:app
|
||||
else
|
||||
python3 fileserver.py
|
||||
fi
|
||||
|
||||
else
|
||||
echo "Server type ${SERVER_TYPE} is invalid. Please choose either apiserver, webserver or fileserver."
|
||||
fi
|
||||
18
docker/build/internal_files/final_image_preparation.sh
Normal file
18
docker/build/internal_files/final_image_preparation.sh
Normal file
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env bash
|
||||
set -o errexit
|
||||
set -o nounset
|
||||
set -o pipefail
|
||||
|
||||
yum update -y
|
||||
yum install -y https://dl.fedoraproject.org/pub/epel/epel-release-latest-7.noarch.rpm
|
||||
yum install -y python36 python36-pip nginx gcc python3-devel gettext
|
||||
yum -y upgrade
|
||||
python3 -m pip install -r /opt/clearml/fileserver/requirements.txt
|
||||
python3 -m pip install -r /opt/clearml/apiserver/requirements.txt
|
||||
mkdir -p /opt/clearml/log
|
||||
mkdir -p /opt/clearml/config
|
||||
ln -s /dev/stdout /var/log/nginx/access.log
|
||||
ln -s /dev/stderr /var/log/nginx/error.log
|
||||
mv /etc/nginx/nginx.conf /etc/nginx/nginx.conf.orig
|
||||
mv /tmp/internal_files/clearml.conf.template /etc/nginx/clearml.conf.template
|
||||
yum clean all
|
||||
@@ -19,6 +19,7 @@ services:
|
||||
environment:
|
||||
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
|
||||
CLEARML_ELASTIC_SERVICE_PORT: 9200
|
||||
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
CLEARML_MONGODB_SERVICE_HOST: mongo
|
||||
CLEARML_MONGODB_SERVICE_PORT: 27017
|
||||
CLEARML_REDIS_SERVICE_HOST: redis
|
||||
@@ -38,7 +39,8 @@ services:
|
||||
- backend
|
||||
container_name: clearml-elastic
|
||||
environment:
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
|
||||
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
bootstrap.memory_lock: "true"
|
||||
cluster.name: clearml
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
@@ -60,7 +62,7 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- c:/opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
|
||||
@@ -87,7 +89,7 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:3.6.5
|
||||
image: mongo:3.6.23
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
|
||||
@@ -19,6 +19,7 @@ services:
|
||||
environment:
|
||||
CLEARML_ELASTIC_SERVICE_HOST: elasticsearch
|
||||
CLEARML_ELASTIC_SERVICE_PORT: 9200
|
||||
CLEARML_ELASTIC_SERVICE_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
CLEARML_MONGODB_SERVICE_HOST: mongo
|
||||
CLEARML_MONGODB_SERVICE_PORT: 27017
|
||||
CLEARML_REDIS_SERVICE_HOST: redis
|
||||
@@ -38,7 +39,8 @@ services:
|
||||
- backend
|
||||
container_name: clearml-elastic
|
||||
environment:
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g
|
||||
ES_JAVA_OPTS: -Xms2g -Xmx2g -Dlog4j2.formatMsgNoLookups=true
|
||||
ELASTIC_PASSWORD: ${ELASTIC_PASSWORD}
|
||||
bootstrap.memory_lock: "true"
|
||||
cluster.name: clearml
|
||||
cluster.routing.allocation.node_initial_primaries_recoveries: "500"
|
||||
@@ -60,7 +62,7 @@ services:
|
||||
nofile:
|
||||
soft: 65536
|
||||
hard: 65536
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.6.2
|
||||
image: docker.elastic.co/elasticsearch/elasticsearch:7.16.2
|
||||
restart: unless-stopped
|
||||
volumes:
|
||||
- /opt/clearml/data/elastic_7:/usr/share/elasticsearch/data
|
||||
@@ -86,7 +88,7 @@ services:
|
||||
networks:
|
||||
- backend
|
||||
container_name: clearml-mongo
|
||||
image: mongo:3.6.5
|
||||
image: mongo:3.6.23
|
||||
restart: unless-stopped
|
||||
command: --setParameter internalQueryExecMaxBlockingSortBytes=196100200
|
||||
volumes:
|
||||
@@ -121,7 +123,9 @@ services:
|
||||
- backend
|
||||
container_name: clearml-agent-services
|
||||
image: allegroai/clearml-agent-services:latest
|
||||
restart: unless-stopped
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
privileged: true
|
||||
environment:
|
||||
CLEARML_HOST_IP: ${CLEARML_HOST_IP}
|
||||
@@ -132,7 +136,7 @@ services:
|
||||
CLEARML_API_SECRET_KEY: ${CLEARML_API_SECRET_KEY:-}
|
||||
CLEARML_AGENT_GIT_USER: ${CLEARML_AGENT_GIT_USER}
|
||||
CLEARML_AGENT_GIT_PASS: ${CLEARML_AGENT_GIT_PASS}
|
||||
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:->=0.17.0}
|
||||
CLEARML_AGENT_UPDATE_VERSION: ${CLEARML_AGENT_UPDATE_VERSION:-">=0.17.0"}
|
||||
CLEARML_AGENT_DEFAULT_BASE_DOCKER: "ubuntu:18.04"
|
||||
AWS_ACCESS_KEY_ID: ${AWS_ACCESS_KEY_ID:-}
|
||||
AWS_SECRET_ACCESS_KEY: ${AWS_SECRET_ACCESS_KEY:-}
|
||||
@@ -142,6 +146,7 @@ services:
|
||||
GOOGLE_APPLICATION_CREDENTIALS: ${GOOGLE_APPLICATION_CREDENTIALS:-}
|
||||
CLEARML_WORKER_ID: "clearml-services"
|
||||
CLEARML_AGENT_DOCKER_HOST_MOUNT: "/opt/clearml/agent:/root/.clearml"
|
||||
SHUTDOWN_IF_NO_ACCESS_KEY: 1
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- /opt/clearml/agent:/root/.clearml
|
||||
|
||||
BIN
docs/ClearML_Server_Diagram.png
Normal file
BIN
docs/ClearML_Server_Diagram.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 155 KiB |
@@ -1,5 +1,6 @@
|
||||
""" A Simple file server for uploading and downloading files """
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from pathlib import Path
|
||||
@@ -48,8 +49,12 @@ def upload():
|
||||
@app.route("/<path:path>", methods=["GET"])
|
||||
def download(path):
|
||||
as_attachment = "download" in request.args
|
||||
|
||||
_, encoding = mimetypes.guess_type(os.path.basename(path))
|
||||
mimetype = "application/octet-stream" if encoding == "gzip" else None
|
||||
|
||||
response = send_from_directory(
|
||||
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment
|
||||
app.config["UPLOAD_FOLDER"], path, as_attachment=as_attachment, mimetype=mimetype
|
||||
)
|
||||
if config.get("fileserver.download.disable_browser_caching", False):
|
||||
headers = response.headers
|
||||
|
||||
Reference in New Issue
Block a user