mirror of
https://github.com/clearml/clearml
synced 2025-06-26 18:16:07 +00:00
Initial beta version
This commit is contained in:
218
README.md
Normal file
218
README.md
Normal file
@@ -0,0 +1,218 @@
|
||||
# TRAINS - Magic Version Control & Experiment Manager for AI
|
||||
|
||||
<p style="font-size:1.2rem; font-weight:700;">"Because it’s a jungle out there"</p>
|
||||
|
||||
Behind every great scientist are great repeatable methods. Sadly, this is easier said than done.
|
||||
|
||||
When talented scientists, engineers, or developers work on their own, a mess may be unavoidable. Yet, it may still be
|
||||
manageable. However, with time and more people joining your project,
|
||||
managing the clutter takes its toll on productivity.
|
||||
As your project moves toward production,
|
||||
visibility and provenance for scaling your deep-learning efforts are a must, but both
|
||||
suffer as your team grows.
|
||||
|
||||
For teams or entire companies, TRAINS logs everything in one central server and takes on the responsibilities for visibility and provenance
|
||||
so productivity does not suffer.
|
||||
TRAINS records and manages various deep learning research workloads and does so with unbelievably small integration costs.
|
||||
|
||||
TRAINS is an auto-magical experiment manager that you can use productively with minimal integration and while
|
||||
preserving your existing methods and practices. Use it on a daily basis to boost collaboration and visibility,
|
||||
or use it to automatically collect your experimentation logs, outputs, and data to one centralized server for provenance.
|
||||
|
||||

|
||||
|
||||
## Why Should I Use TRAINS?
|
||||
|
||||
TRAINS is our solution to a problem we share with countless other researchers and developers in the
|
||||
machine learning/deep learning universe.
|
||||
Training production-grade deep learning models is a glorious but messy process.
|
||||
We built TRAINS to solve that problem. TRAINS tracks and controls the process by associating code version control, research projects, performance metrics, and model provenance.
|
||||
TRAINS removes the mess but leaves the glory.
|
||||
|
||||
|
||||
Choose TRAINS because...
|
||||
|
||||
* Sharing experiments with the team is difficult and gets even more difficult further up the chain.
|
||||
* Like all of us, you lost a model and are left with no repeatable process.
|
||||
* You setup up a central location for TensorBoard and it exploded with a gazillion experiments.
|
||||
* You accidentally threw away important results while trying to manually clean up the clutter.
|
||||
* You do not associate the train code commit with the model or TensorBoard logs.
|
||||
* You are storing model parameters in the checkpoint filename.
|
||||
* You cannot find any other tool for comparing results, hyper-parameters and code commits.
|
||||
* TRAINS requires **only two-lines of code** for full integration.
|
||||
* TRAINS is **free**.
|
||||
|
||||
## Main Features
|
||||
|
||||
* Seamless integration with leading frameworks, including: PyTorch, TensorFlow, Keras, and others coming soon!
|
||||
* Track everything with two lines of code.
|
||||
* Model logging that automatically associates models with code and the parameters used to train them, including initial weights logging.
|
||||
* Multi-user process tracking and collaboration.
|
||||
* Management capabilities including project management, filter-by-metric, and detailed experiment comparison.
|
||||
* Centralized server for aggregating logs, records, and general bookkeeping.
|
||||
* Automatically create a copy of models on centralized storage (TRAINS supports shared folders, S3, GS, and Azure is coming soon!).
|
||||
* Support for Jupyter notebook (see the [trains-jupyter-plugin]()) and PyCharm remote debugging (see the [trains-pycharm-plugin]()).
|
||||
* A field-tested, feature-rich SDK for your on-the-fly customization needs.
|
||||
|
||||
|
||||
## TRAINS Magically Logs
|
||||
|
||||
TRAINS magically logs the following:
|
||||
|
||||
* Git repository, branch and commit id
|
||||
* Hyper-parameters, including:
|
||||
* ArgParser for command line parameters with currently used values
|
||||
* Tensorflow Defines (absl-py)
|
||||
* Manually passed parameter dictionary
|
||||
* Initial model weights file
|
||||
* Model snapshots
|
||||
* stdout and stderr
|
||||
* TensorBoard scalars, metrics, histograms, images, and audio coming soon (also tensorboardX)
|
||||
* Matplotlib
|
||||
|
||||
## See for Yourself
|
||||
|
||||
We have a demo server up and running [https://demoapp.trainsai.io](https://demoapp.trainsai.io) (it resets every 24 hours and all of the data is deleted).
|
||||
You can test your code with it:
|
||||
|
||||
1. Install TRAINS
|
||||
|
||||
pip install trains
|
||||
|
||||
1. Add the following to your code:
|
||||
|
||||
from trains import Task
|
||||
Task = Task.init(project_name=”my_projcet”, task_name=”my_task”)
|
||||
|
||||
1. Run your code. When TRAINS connects to the server, a link prints.
|
||||
|
||||
1. In the Web-App, view your parameters, model and tensorboard metrics.
|
||||
|
||||
![GIF screen-shot here. If the Gif looks bad, a few png screen grabs:
|
||||
Home Page
|
||||
Projects Page
|
||||
Experiment Page with experiment open tab execution
|
||||
Experiment Page with experiment open tab model
|
||||
Experiment Page with experiment open tab results
|
||||
Results Page
|
||||
Comparison Page
|
||||
Parameters
|
||||
Graphs
|
||||
Images
|
||||
Experiment Models Page]
|
||||
|
||||
## How TRAINS Works
|
||||
|
||||
TRAINS is composed of the following:
|
||||
|
||||
* the [trains-server]()
|
||||
* the [Web-App]() (web user interface)
|
||||
* the Python SDK (auto-magically connects your code, see [Using TRAINS (Example)](#using-trains-example)).
|
||||
|
||||
The following diagram illustrates the interaction of the TRAINS-server and a GPU machine:
|
||||
|
||||
<pre>
|
||||
TRAINS-server
|
||||
|
||||
+--------------------------------------------------------------------+
|
||||
| |
|
||||
| Server Docker Elastic Docker Mongo Docker |
|
||||
| +-------------------------+ +---------------+ +------------+ |
|
||||
| | Pythonic Server | | | | | |
|
||||
| | +-----------------+ | | ElasticSearch | | MongoDB | |
|
||||
| | | WEB server | | | | | | |
|
||||
| | | Port 8080 | | | | | | |
|
||||
| | +--------+--------+ | | | | | |
|
||||
| | | | | | | | |
|
||||
| | +--------+--------+ | | | | | |
|
||||
| | | API server +----------------------------+ | |
|
||||
| | | Port 8008 +---------+ | | | |
|
||||
| | +-----------------+ | +-------+-------+ +-----+------+ |
|
||||
| | | | | |
|
||||
| | +-----------------+ | +---+----------------+------+ |
|
||||
| | | File Server +-------+ | Host Storage | |
|
||||
| | | Port 8081 | | +-----+ | |
|
||||
| | +-----------------+ | +---------------------------+ |
|
||||
| +------------+------------+ |
|
||||
+---------------|----------------------------------------------------+
|
||||
|HTTP
|
||||
+--------+
|
||||
GPU Machine |
|
||||
+------------------------|-------------------------------------------+
|
||||
| +------------------|--------------+ |
|
||||
| | Training | | +---------------------+ |
|
||||
| | Code +---+------------+ | | TRAINS configuration| |
|
||||
| | | TRAINS - SDK | | | ~/trains.conf | |
|
||||
| | | +------+ | |
|
||||
| | +----------------+ | +---------------------+ |
|
||||
| +---------------------------------+ |
|
||||
+--------------------------------------------------------------------+
|
||||
</pre>
|
||||
|
||||
## Installing and Configuring TRAINS
|
||||
|
||||
1. Install the trains-server docker (see [Installing the TRAINS Server](../trains_server)).
|
||||
|
||||
1. Install the TRAINS package:
|
||||
|
||||
pip install trains
|
||||
|
||||
1. Run the initial configuration wizard to setup the trains-server (ip:port and user credentials):
|
||||
|
||||
trains-init
|
||||
|
||||
After installing and configuring, your configuration is `~/trains.conf`. View a sample configuration file [here]([link to git]).
|
||||
|
||||
## Using TRAINS (Example)
|
||||
|
||||
Add these two lines of code to your script:
|
||||
|
||||
from trains import Task
|
||||
task = Task.init(project_name, task_name)
|
||||
|
||||
* If no project name is provided, then the repository name is used.
|
||||
* If no task (experiment) name is provided, then the main filename is used as experiment name
|
||||
|
||||
Executing your script prints a direct link to the currently running experiment page, for exampe:
|
||||
|
||||
```bash
|
||||
TRAINS Metrics page:
|
||||
|
||||
https://demoapp.trainsai.io/projects/76e5e2d45e914f52880621fe64601e85/experiments/241f06ae0f5c4b27b8ce8b64890ce152/output/log
|
||||
```
|
||||
|
||||
*[Add GIF screenshots here]*
|
||||
|
||||
For more examples and use cases, see [examples](link docs/examples/).
|
||||
|
||||
## Who Supports TRAINS?
|
||||
|
||||
The people behind *allegro.ai*.
|
||||
We build deep learning pipelines and infrastructure for enterprise companies.
|
||||
We built TRAINS to track and control the glorious
|
||||
but messy process of training production-grade deep learning models.
|
||||
We are committed to vigorously supporting and expanding the capabilities of TRAINS,
|
||||
because it is not only our beloved creation, we also use it daily.
|
||||
|
||||
## Why Are We Releasing TRAINS?
|
||||
|
||||
We believe TRAINS is ground-breaking. We wish to establish new standards of experiment management in
|
||||
machine- and deep-learning.
|
||||
Only the greater community can help us do that.
|
||||
|
||||
We promise to always be backwardly compatible. If you start working with TRAINS today, even though this code is still in the beta stage, your logs and data will always upgrade with you.
|
||||
|
||||
## License
|
||||
|
||||
Apache License, Version 2.0 (see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0.html) for more information)
|
||||
|
||||
## Guidelines for Contributing
|
||||
|
||||
See the TRAINS [Guidelines for Contributing](contributing.md).
|
||||
|
||||
## FAQ
|
||||
|
||||
See the TRAINS [FAQ](faq.md).
|
||||
|
||||
<p style="font-size:0.9rem; font-weight:700; font-style:italic">May the force (and the goddess of learning rates) be with you!</p>
|
||||
|
||||
54
docs/contributing.md
Normal file
54
docs/contributing.md
Normal file
@@ -0,0 +1,54 @@
|
||||
# Guidelines for Contributing
|
||||
|
||||
Firstly, we thank you for taking the time to contribute!
|
||||
|
||||
The following is a set of guidelines for contributing to TRAINS.
|
||||
These are primarily guidelines, not rules.
|
||||
Use your best judgment and feel free to propose changes to this document in a pull request.
|
||||
|
||||
## Reporting Bugs
|
||||
|
||||
This section guides you through submitting a bug report for TRAINS.
|
||||
By following these guidelines, you
|
||||
help maintainers and the community understand your report, reproduce the behavior, and find related reports.
|
||||
|
||||
Before creating bug reports, please check whether the bug you want to report already appears [here](link to issues).
|
||||
You may discover that you do not need to create a bug report.
|
||||
When you are creating a bug report, please include as much detail as possible.
|
||||
|
||||
**Note**: If you find a **Closed** issue that may be the same issue which you are currently experiencing,
|
||||
then open a **New** issue and include a link to the original (Closed) issue in the body of your new one.
|
||||
|
||||
Explain the problem and include additional details to help maintainers reproduce the problem:
|
||||
|
||||
* **Use a clear and descriptive title** for the issue to identify the problem.
|
||||
* **Describe the exact steps necessary to reproduce the problem** in as much detail as possible. Please do not just summarize what you did. Make sure to explain how you did it.
|
||||
* **Provide the specific environment setup.** Include the `pip freeze` output, specific environment variables, Python version, and other relevant information.
|
||||
* **Provide specific examples to demonstrate the steps.** Include links to files or GitHub projects, or copy/paste snippets which you use in those examples.
|
||||
* **If you are reporting any TRAINS crash,** include a crash report with a stack trace from the operating system. Make sure to add the crash report in the issue and place it in a [code block](https://help.github.com/en/articles/getting-started-with-writing-and-formatting-on-github#multiple-lines),
|
||||
a [file attachment](https://help.github.com/articles/file-attachments-on-issues-and-pull-requests/), or just put it in a [gist](https://gist.github.com/) (and provide link to that gist).
|
||||
* **Describe the behavior you observed after following the steps** and the exact problem with that behavior.
|
||||
* **Explain which behavior you expected to see and why.**
|
||||
* **For Web-App issues, please include screenshots and animated GIFs** which recreate the described steps and clearly demonstrate the problem. You can use [LICEcap](https://www.cockos.com/licecap/) to record GIFs on macOS and Windows, and [silentcast](https://github.com/colinkeenan/silentcast) or [byzanz](https://github.com/threedaymonk/byzanz) on Linux.
|
||||
|
||||
## Suggesting Enhancements
|
||||
|
||||
This section guides you through submitting an enhancement suggestion for TRAINS, including
|
||||
completely new features and minor improvements to existing functionality.
|
||||
By following these guidelines, you help maintainers and the community understand your suggestion and find related suggestions.
|
||||
|
||||
Enhancement suggestions are tracked as GitHub issues. After you determine which repository your enhancement suggestion is related to, create an issue on that repository and provide the following:
|
||||
|
||||
* **A clear and descriptive title** for the issue to identify the suggestion.
|
||||
* **A step-by-step description of the suggested enhancement** in as much detail as possible.
|
||||
* **Specific examples to demonstrate the steps.** Include copy/pasteable snippets which you use in those examples as [Markdown code blocks](https://help.github.com/articles/markdown-basics/#multiple-lines).
|
||||
* **Describe the current behavior and explain which behavior you expected to see instead and why.**
|
||||
* **Include screenshots or animated GIFs** which help you demonstrate the steps or point out the part of TRAINS which the suggestion is related to. You can use [LICEcap](https://www.cockos.com/licecap/) to record GIFs on macOS and Windows, and [silentcast](https://github.com/colinkeenan/silentcast) or [byzanz](https://github.com/threedaymonk/byzanz) on Linux.
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
160
docs/faq.md
Normal file
160
docs/faq.md
Normal file
@@ -0,0 +1,160 @@
|
||||
# FAQ
|
||||
|
||||
**Can I store more information on the models? For example, can I store enumeration of classes?**
|
||||
|
||||
YES!
|
||||
|
||||
Use the SDK `set_model_label_enumeration` method:
|
||||
|
||||
```python
|
||||
Task.current_task().set_model_label_enumeration( {‘label’: int(0), } )
|
||||
```
|
||||
|
||||
**Can I store the model configuration file as well?**
|
||||
|
||||
YES!
|
||||
|
||||
Use the SDK `set_model_design` method:
|
||||
|
||||
```python
|
||||
Task.current_task().set_model_design( ‘a very long text of the configuration file content’ )
|
||||
```
|
||||
|
||||
**I want to add more graphs, not just with Tensorboard. Is this supported?**
|
||||
|
||||
YES!
|
||||
|
||||
Use an SDK [Logger](link to git) object. An instance can be always be retrieved with `Task.current_task().get_logger()`:
|
||||
|
||||
```python
|
||||
logger = Task.current_task().get_logger()
|
||||
logger.report_scalar("loss", "classification", iteration=42, value=1.337)
|
||||
```
|
||||
|
||||
TRAINS supports scalars, plots, 2d/3d scatter diagrams, histograms, surface diagrams, confusion matrices, images, and text logging.
|
||||
|
||||
An example can be found [here](docs/manual_log.py).
|
||||
|
||||
**I noticed that all of my experiments appear as “Training”. Are there other options?**
|
||||
|
||||
YES!
|
||||
|
||||
When creating experiments and calling `Task.init`, you can pass an experiment type.
|
||||
The currently supported types are `Task.TaskTypes.training` and `Task.TaskTypes.testing`:
|
||||
|
||||
```python
|
||||
task = Task.init(project_name, task_name, Task.TaskTypes.testing)
|
||||
```
|
||||
|
||||
If you feel we should add a few more, let us know in the [issues]() section.
|
||||
|
||||
**I noticed I keep getting a message “warning: uncommitted code”. What does it mean?**
|
||||
|
||||
TRAINS not only detects your current repository and git commit,
|
||||
but it also warns you if you are using uncommitted code. TRAINS does this
|
||||
because uncommitted code means it will be difficult to reproduce this experiment.
|
||||
|
||||
**Is there something you can do about uncommitted code running?**
|
||||
|
||||
YES!
|
||||
|
||||
TRAINS currently stores the git diff together with the project.
|
||||
The Web-App will soon present the git diff as well. This is coming very soon!
|
||||
|
||||
**I read that there is a feature for centralized model storage. How do I use it?**
|
||||
|
||||
Pass the `output_uri` parameter to `Task.init`, for example:
|
||||
|
||||
```python
|
||||
Task.init(project_name, task_name, output_uri=’/mnt/shared/folder’)
|
||||
```
|
||||
|
||||
All of the stored snapshots are copied into a subfolder whose name contains the task ID, for example:
|
||||
|
||||
`/mnt/shared/folder/task_6ea4f0b56d994320a713aeaf13a86d9d/models/`
|
||||
|
||||
Other options include:
|
||||
|
||||
```python
|
||||
Task.init(project_name, task_name, output_uri=’s3://bucket/folder’)
|
||||
```
|
||||
|
||||
```python
|
||||
Task.init(project_name, task_name, output_uri=’gs://bucket/folder’)
|
||||
```
|
||||
|
||||
These require configuring the cloud storage credentials in `~/trains.conf` (see an [example](v)).
|
||||
|
||||
**I am training multiple models at the same time, but I only see one of them. What happened?**
|
||||
|
||||
This will be fixed in a future version. Currently, TRAINS does support multiple models
|
||||
from the same task/experiment so you can find all the models in the project Models tab.
|
||||
In the Task view, we only present the last one.
|
||||
|
||||
**Can I log input and output models manually?**
|
||||
|
||||
YES!
|
||||
|
||||
See [InputModel]() and [OutputModel]().
|
||||
|
||||
For example:
|
||||
|
||||
```python
|
||||
input_model = InputModel.import_model(link_to_initial_model_file)
|
||||
Task.current_task().connect(input_model)
|
||||
OutputModel(Task.current_task()).update_weights(link_to_new_model_file_here)
|
||||
```
|
||||
|
||||
**I am using Jupyter Notebook. Is this supported?**
|
||||
|
||||
YES!
|
||||
|
||||
Jupyter Notebook is supported.
|
||||
|
||||
**I do not use ArgParser for hyper-parameters. Do you have a solution?**
|
||||
|
||||
YES!
|
||||
|
||||
TRAINS supports using a Python dictionary for hyper-parameter logging.
|
||||
|
||||
```python
|
||||
parameters_dict = Task.current_task().connect(parameters_dict)
|
||||
```
|
||||
|
||||
From this point onward, not only are the dictionary key/value pairs stored, but also any change to the dictionary is automatically stored.
|
||||
|
||||
**Git is not well supported in Jupyter. We just gave up on properly committing our code. Do you have a solution?**
|
||||
|
||||
YES!
|
||||
|
||||
Check our [trains-jupyter-plugin](). It is a Jupyter plugin that allows you to commit your notebook directly from Jupyter. It also saves the Python version of the code and creates an updated `requirements.txt` so you know which packages you were using.
|
||||
|
||||
**Can I use TRAINS with scikit-learn?**
|
||||
|
||||
YES!
|
||||
|
||||
scikit-learn is supported. Everything you do is logged, with the caveat that models are not logged automatically.
|
||||
Models are not logged automatically because, in most cases, scikit-learn is simply pickling the object to files so there is no underlying frame to connect to.
|
||||
|
||||
**I am working with PyCharm and remotely debugging a machine, but the git repo is not detected. Do you have a solution?**
|
||||
|
||||
YES!
|
||||
|
||||
This is such a common occurrence that we created a PyCharm plugin that allows for a remote debugger to grab your local repository / commit ID. See our [trains-pycharm-plugin]() repository for instructions and [latest release]().
|
||||
|
||||
**How do I know a new version came out?**
|
||||
|
||||
Unfortunately, TRAINS currently does not support auto-update checks. We hope to add this soon.
|
||||
|
||||
**Sometimes I see experiments as running while they are not. What is it?**
|
||||
|
||||
When the Python process exits in an orderly fashion, TRAINS closes the experiment.
|
||||
If a process crashes, then sometimes the stop signal is missed. You can safely right click on the experiment in the Web-App and stop it.
|
||||
|
||||
**In the experiment log tab, I’m missing the first log lines. Where are they?**
|
||||
|
||||
Unfortunately, due to speed/optimization issues, we opted to display only the last several hundreds. The full log can be downloaded from the Web-App.
|
||||
|
||||
|
||||
|
||||
|
||||
43
examples/absl_example.py
Normal file
43
examples/absl_example.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# TRAINS - example code, absl logging
|
||||
#
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from absl import logging
|
||||
|
||||
from trains import Task
|
||||
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('echo', None, 'Text to echo.')
|
||||
flags.DEFINE_string('another_str', 'My string', 'A string', module_name='test')
|
||||
|
||||
task = Task.init(project_name='examples', task_name='absl example')
|
||||
|
||||
flags.DEFINE_integer('echo3', 3, 'Text to echo.')
|
||||
flags.DEFINE_string('echo5', '5', 'Text to echo.', module_name='test')
|
||||
|
||||
|
||||
parameters = {
|
||||
'list': [1, 2, 3],
|
||||
'dict': {'a': 1, 'b': 2},
|
||||
'int': 3,
|
||||
'float': 2.2,
|
||||
'string': 'my string',
|
||||
}
|
||||
parameters = task.connect(parameters)
|
||||
|
||||
|
||||
def main(_):
|
||||
print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info), file=sys.stderr)
|
||||
logging.info('echo is %s.', FLAGS.echo)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
160
examples/jupyter.ipynb
Normal file
160
examples/jupyter.ipynb
Normal file
File diff suppressed because one or more lines are too long
113
examples/keras_tensorboard.py
Normal file
113
examples/keras_tensorboard.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# TRAINS - Keras with Tensorboard example code, automatic logging model and Tensorboard outputs
|
||||
#
|
||||
# Train a simple deep NN on the MNIST dataset.
|
||||
# Gets to 98.40% test accuracy after 20 epochs
|
||||
# (there is *a lot* of margin for parameter tuning).
|
||||
# 2 seconds per epoch on a K520 GPU.
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import tensorflow
|
||||
|
||||
from keras.callbacks import TensorBoard, ModelCheckpoint
|
||||
from keras.datasets import mnist
|
||||
from keras.models import Sequential, Model
|
||||
from keras.layers.core import Dense, Dropout, Activation
|
||||
from keras.optimizers import SGD, Adam, RMSprop
|
||||
from keras.utils import np_utils
|
||||
from keras.models import load_model, save_model, model_from_json
|
||||
|
||||
from trains import Task
|
||||
|
||||
|
||||
class TensorBoardImage(TensorBoard):
|
||||
@staticmethod
|
||||
def make_image(tensor):
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
tensor = np.stack((tensor, tensor, tensor), axis=2)
|
||||
height, width, channels = tensor.shape
|
||||
image = Image.fromarray(tensor)
|
||||
import io
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG')
|
||||
image_string = output.getvalue()
|
||||
output.close()
|
||||
return tf.Summary.Image(height=height,
|
||||
width=width,
|
||||
colorspace=channels,
|
||||
encoded_image_string=image_string)
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
super().on_epoch_end(epoch, logs)
|
||||
import tensorflow as tf
|
||||
images = self.validation_data[0] # 0 - data; 1 - labels
|
||||
img = (255 * images[0].reshape(28, 28)).astype('uint8')
|
||||
|
||||
image = self.make_image(img)
|
||||
summary = tf.Summary(value=[tf.Summary.Value(tag='image', image=image)])
|
||||
self.writer.add_summary(summary, epoch)
|
||||
|
||||
|
||||
batch_size = 128
|
||||
nb_classes = 10
|
||||
nb_epoch = 6
|
||||
|
||||
# the data, shuffled and split between train and test sets
|
||||
(X_train, y_train), (X_test, y_test) = mnist.load_data()
|
||||
|
||||
X_train = X_train.reshape(60000, 784)
|
||||
X_test = X_test.reshape(10000, 784)
|
||||
X_train = X_train.astype('float32')
|
||||
X_test = X_test.astype('float32')
|
||||
X_train /= 255.
|
||||
X_test /= 255.
|
||||
print(X_train.shape[0], 'train samples')
|
||||
print(X_test.shape[0], 'test samples')
|
||||
|
||||
# convert class vectors to binary class matrices
|
||||
Y_train = np_utils.to_categorical(y_train, nb_classes)
|
||||
Y_test = np_utils.to_categorical(y_test, nb_classes)
|
||||
|
||||
model = Sequential()
|
||||
model.add(Dense(512, input_shape=(784,)))
|
||||
model.add(Activation('relu'))
|
||||
# model.add(Dropout(0.2))
|
||||
model.add(Dense(512))
|
||||
model.add(Activation('relu'))
|
||||
# model.add(Dropout(0.2))
|
||||
model.add(Dense(10))
|
||||
model.add(Activation('softmax'))
|
||||
|
||||
model2 = Sequential()
|
||||
model2.add(Dense(512, input_shape=(784,)))
|
||||
model2.add(Activation('relu'))
|
||||
|
||||
model.summary()
|
||||
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=RMSprop(),
|
||||
metrics=['accuracy'])
|
||||
|
||||
# Connecting TRAINS
|
||||
task = Task.init(project_name='examples', task_name='Keras with TensorBoard example')
|
||||
# setting model outputs
|
||||
labels = dict(('digit_%d' % i, i) for i in range(10))
|
||||
task.set_model_label_enumeration(labels)
|
||||
|
||||
board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example', write_images=False)
|
||||
model_store = ModelCheckpoint(filepath='/tmp/histogram_example/weight.{epoch}.hdf5')
|
||||
|
||||
# load previous model, if it is there
|
||||
try:
|
||||
model.load_weights('/tmp/histogram_example/weight.1.hdf5')
|
||||
except:
|
||||
pass
|
||||
|
||||
history = model.fit(X_train, Y_train,
|
||||
batch_size=batch_size, epochs=nb_epoch,
|
||||
callbacks=[board, model_store],
|
||||
verbose=1, validation_data=(X_test, Y_test))
|
||||
score = model.evaluate(X_test, Y_test, verbose=0)
|
||||
print('Test score:', score[0])
|
||||
print('Test accuracy:', score[1])
|
||||
29
examples/manual_model_config.py
Normal file
29
examples/manual_model_config.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# TRAINS - Example of manual model configuration
|
||||
#
|
||||
import torch
|
||||
from trains import Task
|
||||
|
||||
|
||||
task = Task.init(project_name='examples', task_name='Manual model configuration')
|
||||
|
||||
# create a model
|
||||
model = torch.nn.Module
|
||||
|
||||
# store dictionary of definition for a specific network design
|
||||
model_config_dict = {
|
||||
'value': 13.37,
|
||||
'dict': {'sub_value': 'string'},
|
||||
'list_of_ints': [1, 2, 3, 4],
|
||||
}
|
||||
task.set_model_config(config_dict=model_config_dict)
|
||||
|
||||
# or read form a config file (this will override the previous configuration dictionary)
|
||||
# task.set_model_config(config_text='this is just a blob\nof text from a configuration file')
|
||||
|
||||
# store the label enumeration the model is training for
|
||||
task.set_model_label_enumeration({'background': 0, 'cat': 1, 'dog': 2})
|
||||
print('Any model stored from this point onwards, will contain both model_config and label_enumeration')
|
||||
|
||||
# storing the model, it will have the task network configuration and label enumeration
|
||||
torch.save(model, '/tmp/model')
|
||||
print('Model saved')
|
||||
51
examples/manual_reporting.py
Normal file
51
examples/manual_reporting.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# TRAINS - Example of manual graphs and statistics reporting
|
||||
#
|
||||
import numpy as np
|
||||
import logging
|
||||
from trains import Task
|
||||
|
||||
|
||||
task = Task.init(project_name='examples', task_name='Manual reporting')
|
||||
|
||||
# example python logger
|
||||
logging.getLogger().setLevel('DEBUG')
|
||||
logging.debug('This is a debug message')
|
||||
logging.info('This is an info message')
|
||||
logging.warning('This is a warning message')
|
||||
logging.error('This is an error message')
|
||||
logging.critical('This is a critical message')
|
||||
|
||||
# get TRAINS logger object for any metrics / reports
|
||||
logger = task.get_logger()
|
||||
|
||||
# log text
|
||||
logger.console("hello")
|
||||
|
||||
# report scalar values
|
||||
logger.report_scalar("example_scalar", "series A", iteration=0, value=100)
|
||||
logger.report_scalar("example_scalar", "series A", iteration=1, value=200)
|
||||
|
||||
# report histogram
|
||||
histogram = np.random.randint(10, size=10)
|
||||
logger.report_vector("example_histogram", "random histogram", iteration=1, values=histogram)
|
||||
|
||||
# report confusion matrix
|
||||
confusion = np.random.randint(10, size=(10, 10))
|
||||
logger.report_matrix("example_confusion", "ignored", iteration=1, matrix=confusion)
|
||||
|
||||
# report 2d scatter plot
|
||||
scatter2d = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1))))
|
||||
logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=scatter2d)
|
||||
|
||||
# report 3d scatter plot
|
||||
scatter3d = np.random.randint(10, size=(10, 3))
|
||||
logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d)
|
||||
|
||||
# report image
|
||||
m = np.eye(256, 256, dtype=np.uint8)*255
|
||||
logger.report_image_and_upload("fail cases", "image uint", iteration=1, matrix=m)
|
||||
m = np.eye(256, 256, dtype=np.float)
|
||||
logger.report_image_and_upload("fail cases", "image float", iteration=1, matrix=m)
|
||||
|
||||
# flush reports (otherwise it will be flushed in the background, every couple of seconds)
|
||||
logger.flush()
|
||||
36
examples/matplotlib_example.py
Normal file
36
examples/matplotlib_example.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# TRAINS - Example of Matplotlib integration and reporting
|
||||
#
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from trains import Task
|
||||
|
||||
|
||||
task = Task.init(project_name='examples', task_name='Matplotlib example')
|
||||
|
||||
# create plot
|
||||
N = 50
|
||||
x = np.random.rand(N)
|
||||
y = np.random.rand(N)
|
||||
colors = np.random.rand(N)
|
||||
area = (30 * np.random.rand(N))**2 # 0 to 15 point radii
|
||||
plt.scatter(x, y, s=area, c=colors, alpha=0.5)
|
||||
plt.show()
|
||||
|
||||
# create another plot - with a name
|
||||
x = np.linspace(0, 10, 30)
|
||||
y = np.sin(x)
|
||||
plt.plot(x, y, 'o', color='black')
|
||||
plt.show()
|
||||
|
||||
# create image plot
|
||||
m = np.eye(256, 256, dtype=np.uint8)
|
||||
plt.imshow(m)
|
||||
plt.show()
|
||||
|
||||
# create image plot - with a name
|
||||
m = np.eye(256, 256, dtype=np.uint8)
|
||||
plt.imshow(m)
|
||||
plt.title('Image Title')
|
||||
plt.show()
|
||||
|
||||
print('This is a Matplotlib example')
|
||||
479
examples/pytorch_matplotlib.py
Normal file
479
examples/pytorch_matplotlib.py
Normal file
@@ -0,0 +1,479 @@
|
||||
# TRAINS - Example of Pytorch and matplotlib integration and reporting
|
||||
#
|
||||
"""
|
||||
Neural Transfer Using PyTorch
|
||||
=============================
|
||||
**Author**: `Alexis Jacq <https://alexis-jacq.github.io>`_
|
||||
|
||||
**Edited by**: `Winston Herring <https://github.com/winston6>`_
|
||||
Introduction
|
||||
------------
|
||||
This tutorial explains how to implement the `Neural-Style algorithm <https://arxiv.org/abs/1508.06576>`__
|
||||
developed by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge.
|
||||
Neural-Style, or Neural-Transfer, allows you to take an image and
|
||||
reproduce it with a new artistic style. The algorithm takes three images,
|
||||
an input image, a content-image, and a style-image, and changes the input
|
||||
to resemble the content of the content-image and the artistic style of the style-image.
|
||||
|
||||
.. figure:: /_static/img/neural-style/neuralstyle.png
|
||||
:alt: content1
|
||||
"""
|
||||
|
||||
######################################################################
|
||||
# Underlying Principle
|
||||
# --------------------
|
||||
#
|
||||
# The principle is simple: we define two distances, one for the content
|
||||
# (:math:`D_C`) and one for the style (:math:`D_S`). :math:`D_C` measures how different the content
|
||||
# is between two images while :math:`D_S` measures how different the style is
|
||||
# between two images. Then, we take a third image, the input, and
|
||||
# transform it to minimize both its content-distance with the
|
||||
# content-image and its style-distance with the style-image. Now we can
|
||||
# import the necessary packages and begin the neural transfer.
|
||||
#
|
||||
# Importing Packages and Selecting a Device
|
||||
# -----------------------------------------
|
||||
# Below is a list of the packages needed to implement the neural transfer.
|
||||
#
|
||||
# - ``torch``, ``torch.nn``, ``numpy`` (indispensables packages for
|
||||
# neural networks with PyTorch)
|
||||
# - ``torch.optim`` (efficient gradient descents)
|
||||
# - ``PIL``, ``PIL.Image``, ``matplotlib.pyplot`` (load and display
|
||||
# images)
|
||||
# - ``torchvision.transforms`` (transform PIL images into tensors)
|
||||
# - ``torchvision.models`` (train or load pre-trained models)
|
||||
# - ``copy`` (to deep copy the models; system package)
|
||||
|
||||
from __future__ import print_function
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
|
||||
from PIL import Image
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torchvision.transforms as transforms
|
||||
import torchvision.models as models
|
||||
|
||||
import copy
|
||||
from trains import Task
|
||||
|
||||
|
||||
task = Task.init(project_name='examples', task_name='pytorch with matplotlib example', task_type=Task.TaskTypes.testing)
|
||||
|
||||
|
||||
######################################################################
|
||||
# Next, we need to choose which device to run the network on and import the
|
||||
# content and style images. Running the neural transfer algorithm on large
|
||||
# images takes longer and will go much faster when running on a GPU. We can
|
||||
# use ``torch.cuda.is_available()`` to detect if there is a GPU available.
|
||||
# Next, we set the ``torch.device`` for use throughout the tutorial. Also the ``.to(device)``
|
||||
# method is used to move tensors or modules to a desired device.
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
######################################################################
|
||||
# Loading the Images
|
||||
# ------------------
|
||||
#
|
||||
# Now we will import the style and content images. The original PIL images have values between 0 and 255, but when
|
||||
# transformed into torch tensors, their values are converted to be between
|
||||
# 0 and 1. The images also need to be resized to have the same dimensions.
|
||||
# An important detail to note is that neural networks from the
|
||||
# torch library are trained with tensor values ranging from 0 to 1. If you
|
||||
# try to feed the networks with 0 to 255 tensor images, then the activated
|
||||
# feature maps will be unable sense the intended content and style.
|
||||
# However, pre-trained networks from the Caffe library are trained with 0
|
||||
# to 255 tensor images.
|
||||
#
|
||||
#
|
||||
# .. Note::
|
||||
# Here are links to download the images required to run the tutorial:
|
||||
# `picasso.jpg <https://pytorch.org/tutorials/_static/img/neural-style/picasso.jpg>`__ and
|
||||
# `dancing.jpg <https://pytorch.org/tutorials/_static/img/neural-style/dancing.jpg>`__.
|
||||
# Download these two images and add them to a directory
|
||||
# with name ``images`` in your current working directory.
|
||||
|
||||
# desired size of the output image
|
||||
imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
|
||||
|
||||
loader = transforms.Compose([
|
||||
transforms.Resize(imsize), # scale imported image
|
||||
transforms.ToTensor()]) # transform it into a torch tensor
|
||||
|
||||
|
||||
def image_loader(image_name):
|
||||
image = Image.open(image_name)
|
||||
# fake batch dimension required to fit network's input dimensions
|
||||
image = loader(image).unsqueeze(0)
|
||||
return image.to(device, torch.float)
|
||||
|
||||
|
||||
style_img = image_loader("./samples/picasso.jpg")
|
||||
content_img = image_loader("./samples/dancing.jpg")
|
||||
|
||||
assert style_img.size() == content_img.size(), \
|
||||
"we need to import style and content images of the same size"
|
||||
|
||||
######################################################################
|
||||
# Now, let's create a function that displays an image by reconverting a
|
||||
# copy of it to PIL format and displaying the copy using
|
||||
# ``plt.imshow``. We will try displaying the content and style images
|
||||
# to ensure they were imported correctly.
|
||||
|
||||
unloader = transforms.ToPILImage() # reconvert into PIL image
|
||||
|
||||
plt.ion()
|
||||
|
||||
|
||||
def imshow(tensor, title=None):
|
||||
image = tensor.cpu().clone() # we clone the tensor to not do changes on it
|
||||
image = image.squeeze(0) # remove the fake batch dimension
|
||||
image = unloader(image)
|
||||
plt.imshow(image)
|
||||
if title is not None:
|
||||
plt.title(title)
|
||||
plt.pause(0.001) # pause a bit so that plots are updated
|
||||
|
||||
|
||||
plt.figure()
|
||||
imshow(style_img, title='Style Image')
|
||||
|
||||
plt.figure()
|
||||
imshow(content_img, title='Content Image')
|
||||
|
||||
|
||||
######################################################################
|
||||
# Loss Functions
|
||||
# --------------
|
||||
# Content Loss
|
||||
# ~~~~~~~~~~~~
|
||||
#
|
||||
# The content loss is a function that represents a weighted version of the
|
||||
# content distance for an individual layer. The function takes the feature
|
||||
# maps :math:`F_{XL}` of a layer :math:`L` in a network processing input :math:`X` and returns the
|
||||
# weighted content distance :math:`w_{CL}.D_C^L(X,C)` between the image :math:`X` and the
|
||||
# content image :math:`C`. The feature maps of the content image(:math:`F_{CL}`) must be
|
||||
# known by the function in order to calculate the content distance. We
|
||||
# implement this function as a torch module with a constructor that takes
|
||||
# :math:`F_{CL}` as an input. The distance :math:`\|F_{XL} - F_{CL}\|^2` is the mean square error
|
||||
# between the two sets of feature maps, and can be computed using ``nn.MSELoss``.
|
||||
#
|
||||
# We will add this content loss module directly after the convolution
|
||||
# layer(s) that are being used to compute the content distance. This way
|
||||
# each time the network is fed an input image the content losses will be
|
||||
# computed at the desired layers and because of auto grad, all the
|
||||
# gradients will be computed. Now, in order to make the content loss layer
|
||||
# transparent we must define a ``forward`` method that computes the content
|
||||
# loss and then returns the layer’s input. The computed loss is saved as a
|
||||
# parameter of the module.
|
||||
#
|
||||
|
||||
class ContentLoss(nn.Module):
|
||||
|
||||
def __init__(self, target, ):
|
||||
super(ContentLoss, self).__init__()
|
||||
# we 'detach' the target content from the tree used
|
||||
# to dynamically compute the gradient: this is a stated value,
|
||||
# not a variable. Otherwise the forward method of the criterion
|
||||
# will throw an error.
|
||||
self.target = target.detach()
|
||||
|
||||
def forward(self, input):
|
||||
self.loss = F.mse_loss(input, self.target)
|
||||
return input
|
||||
|
||||
|
||||
######################################################################
|
||||
# .. Note::
|
||||
# **Important detail**: although this module is named ``ContentLoss``, it
|
||||
# is not a true PyTorch Loss function. If you want to define your content
|
||||
# loss as a PyTorch Loss function, you have to create a PyTorch autograd function
|
||||
# to recompute/implement the gradient manually in the ``backward``
|
||||
# method.
|
||||
|
||||
######################################################################
|
||||
# Style Loss
|
||||
# ~~~~~~~~~~
|
||||
#
|
||||
# The style loss module is implemented similarly to the content loss
|
||||
# module. It will act as a transparent layer in a
|
||||
# network that computes the style loss of that layer. In order to
|
||||
# calculate the style loss, we need to compute the gram matrix :math:`G_{XL}`. A gram
|
||||
# matrix is the result of multiplying a given matrix by its transposed
|
||||
# matrix. In this application the given matrix is a reshaped version of
|
||||
# the feature maps :math:`F_{XL}` of a layer :math:`L`. :math:`F_{XL}` is reshaped to form :math:`\hat{F}_{XL}`, a :math:`K`\ x\ :math:`N`
|
||||
# matrix, where :math:`K` is the number of feature maps at layer :math:`L` and :math:`N` is the
|
||||
# length of any vectorized feature map :math:`F_{XL}^k`. For example, the first line
|
||||
# of :math:`\hat{F}_{XL}` corresponds to the first vectorized feature map :math:`F_{XL}^1`.
|
||||
#
|
||||
# Finally, the gram matrix must be normalized by dividing each element by
|
||||
# the total number of elements in the matrix. This normalization is to
|
||||
# counteract the fact that :math:`\hat{F}_{XL}` matrices with a large :math:`N` dimension yield
|
||||
# larger values in the Gram matrix. These larger values will cause the
|
||||
# first layers (before pooling layers) to have a larger impact during the
|
||||
# gradient descent. Style features tend to be in the deeper layers of the
|
||||
# network so this normalization step is crucial.
|
||||
#
|
||||
|
||||
def gram_matrix(input):
|
||||
a, b, c, d = input.size() # a=batch size(=1)
|
||||
# b=number of feature maps
|
||||
# (c,d)=dimensions of a f. map (N=c*d)
|
||||
|
||||
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
|
||||
|
||||
G = torch.mm(features, features.t()) # compute the gram product
|
||||
|
||||
# we 'normalize' the values of the gram matrix
|
||||
# by dividing by the number of element in each feature maps.
|
||||
return G.div(a * b * c * d)
|
||||
|
||||
|
||||
######################################################################
|
||||
# Now the style loss module looks almost exactly like the content loss
|
||||
# module. The style distance is also computed using the mean square
|
||||
# error between :math:`G_{XL}` and :math:`G_{SL}`.
|
||||
#
|
||||
|
||||
class StyleLoss(nn.Module):
|
||||
|
||||
def __init__(self, target_feature):
|
||||
super(StyleLoss, self).__init__()
|
||||
self.target = gram_matrix(target_feature).detach()
|
||||
|
||||
def forward(self, input):
|
||||
G = gram_matrix(input)
|
||||
self.loss = F.mse_loss(G, self.target)
|
||||
return input
|
||||
|
||||
|
||||
######################################################################
|
||||
# Importing the Model
|
||||
# -------------------
|
||||
#
|
||||
# Now we need to import a pre-trained neural network. We will use a 19
|
||||
# layer VGG network like the one used in the paper.
|
||||
#
|
||||
# PyTorch’s implementation of VGG is a module divided into two child
|
||||
# ``Sequential`` modules: ``features`` (containing convolution and pooling layers),
|
||||
# and ``classifier`` (containing fully connected layers). We will use the
|
||||
# ``features`` module because we need the output of the individual
|
||||
# convolution layers to measure content and style loss. Some layers have
|
||||
# different behavior during training than evaluation, so we must set the
|
||||
# network to evaluation mode using ``.eval()``.
|
||||
#
|
||||
|
||||
cnn = models.vgg19(pretrained=True).features.to(device).eval()
|
||||
|
||||
######################################################################
|
||||
# Additionally, VGG networks are trained on images with each channel
|
||||
# normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
|
||||
# We will use them to normalize the image before sending it into the network.
|
||||
#
|
||||
|
||||
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
|
||||
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
|
||||
|
||||
|
||||
# create a module to normalize input image so we can easily put it in a
|
||||
# nn.Sequential
|
||||
class Normalization(nn.Module):
|
||||
def __init__(self, mean, std):
|
||||
super(Normalization, self).__init__()
|
||||
# .view the mean and std to make them [C x 1 x 1] so that they can
|
||||
# directly work with image Tensor of shape [B x C x H x W].
|
||||
# B is batch size. C is number of channels. H is height and W is width.
|
||||
self.mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
self.std = torch.tensor(std).view(-1, 1, 1)
|
||||
|
||||
def forward(self, img):
|
||||
# normalize img
|
||||
return (img - self.mean) / self.std
|
||||
|
||||
|
||||
######################################################################
|
||||
# A ``Sequential`` module contains an ordered list of child modules. For
|
||||
# instance, ``vgg19.features`` contains a sequence (Conv2d, ReLU, MaxPool2d,
|
||||
# Conv2d, ReLU…) aligned in the right order of depth. We need to add our
|
||||
# content loss and style loss layers immediately after the convolution
|
||||
# layer they are detecting. To do this we must create a new ``Sequential``
|
||||
# module that has content loss and style loss modules correctly inserted.
|
||||
#
|
||||
|
||||
# desired depth layers to compute style/content losses :
|
||||
content_layers_default = ['conv_4']
|
||||
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
|
||||
|
||||
|
||||
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
|
||||
style_img, content_img,
|
||||
content_layers=content_layers_default,
|
||||
style_layers=style_layers_default):
|
||||
cnn = copy.deepcopy(cnn)
|
||||
|
||||
# normalization module
|
||||
normalization = Normalization(normalization_mean, normalization_std).to(device)
|
||||
|
||||
# just in order to have an iterable access to or list of content/syle
|
||||
# losses
|
||||
content_losses = []
|
||||
style_losses = []
|
||||
|
||||
# assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
|
||||
# to put in modules that are supposed to be activated sequentially
|
||||
model = nn.Sequential(normalization)
|
||||
|
||||
i = 0 # increment every time we see a conv
|
||||
for layer in cnn.children():
|
||||
if isinstance(layer, nn.Conv2d):
|
||||
i += 1
|
||||
name = 'conv_{}'.format(i)
|
||||
elif isinstance(layer, nn.ReLU):
|
||||
name = 'relu_{}'.format(i)
|
||||
# The in-place version doesn't play very nicely with the ContentLoss
|
||||
# and StyleLoss we insert below. So we replace with out-of-place
|
||||
# ones here.
|
||||
layer = nn.ReLU(inplace=False)
|
||||
elif isinstance(layer, nn.MaxPool2d):
|
||||
name = 'pool_{}'.format(i)
|
||||
elif isinstance(layer, nn.BatchNorm2d):
|
||||
name = 'bn_{}'.format(i)
|
||||
else:
|
||||
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
|
||||
|
||||
model.add_module(name, layer)
|
||||
|
||||
if name in content_layers:
|
||||
# add content loss:
|
||||
target = model(content_img).detach()
|
||||
content_loss = ContentLoss(target)
|
||||
model.add_module("content_loss_{}".format(i), content_loss)
|
||||
content_losses.append(content_loss)
|
||||
|
||||
if name in style_layers:
|
||||
# add style loss:
|
||||
target_feature = model(style_img).detach()
|
||||
style_loss = StyleLoss(target_feature)
|
||||
model.add_module("style_loss_{}".format(i), style_loss)
|
||||
style_losses.append(style_loss)
|
||||
|
||||
# now we trim off the layers after the last content and style losses
|
||||
for i in range(len(model) - 1, -1, -1):
|
||||
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
|
||||
break
|
||||
|
||||
model = model[:(i + 1)]
|
||||
|
||||
return model, style_losses, content_losses
|
||||
|
||||
|
||||
######################################################################
|
||||
# Next, we select the input image. You can use a copy of the content image
|
||||
# or white noise.
|
||||
#
|
||||
|
||||
input_img = content_img.clone()
|
||||
# if you want to use white noise instead uncomment the below line:
|
||||
# input_img = torch.randn(content_img.data.size(), device=device)
|
||||
|
||||
# add the original input image to the figure:
|
||||
plt.figure()
|
||||
imshow(input_img, title='Input Image')
|
||||
|
||||
|
||||
######################################################################
|
||||
# Gradient Descent
|
||||
# ----------------
|
||||
#
|
||||
# As Leon Gatys, the author of the algorithm, suggested `here <https://discuss.pytorch.org/t/pytorch-tutorial-for-neural-transfert-of-artistic-style/336/20?u=alexis-jacq>`__, we will use
|
||||
# L-BFGS algorithm to run our gradient descent. Unlike training a network,
|
||||
# we want to train the input image in order to minimise the content/style
|
||||
# losses. We will create a PyTorch L-BFGS optimizer ``optim.LBFGS`` and pass
|
||||
# our image to it as the tensor to optimize.
|
||||
#
|
||||
|
||||
def get_input_optimizer(input_img):
|
||||
# this line to show that input is a parameter that requires a gradient
|
||||
optimizer = optim.LBFGS([input_img.requires_grad_()])
|
||||
return optimizer
|
||||
|
||||
|
||||
######################################################################
|
||||
# Finally, we must define a function that performs the neural transfer. For
|
||||
# each iteration of the networks, it is fed an updated input and computes
|
||||
# new losses. We will run the ``backward`` methods of each loss module to
|
||||
# dynamicaly compute their gradients. The optimizer requires a “closure”
|
||||
# function, which reevaluates the modul and returns the loss.
|
||||
#
|
||||
# We still have one final constraint to address. The network may try to
|
||||
# optimize the input with values that exceed the 0 to 1 tensor range for
|
||||
# the image. We can address this by correcting the input values to be
|
||||
# between 0 to 1 each time the network is run.
|
||||
#
|
||||
|
||||
def run_style_transfer(cnn, normalization_mean, normalization_std,
|
||||
content_img, style_img, input_img, num_steps=300,
|
||||
style_weight=1000000, content_weight=1):
|
||||
"""Run the style transfer."""
|
||||
print('Building the style transfer model..')
|
||||
model, style_losses, content_losses = get_style_model_and_losses(cnn,
|
||||
normalization_mean, normalization_std, style_img,
|
||||
content_img)
|
||||
optimizer = get_input_optimizer(input_img)
|
||||
|
||||
print('Optimizing..')
|
||||
run = [0]
|
||||
while run[0] <= num_steps:
|
||||
|
||||
def closure():
|
||||
# correct the values of updated input image
|
||||
input_img.data.clamp_(0, 1)
|
||||
|
||||
optimizer.zero_grad()
|
||||
model(input_img)
|
||||
style_score = 0
|
||||
content_score = 0
|
||||
|
||||
for sl in style_losses:
|
||||
style_score += sl.loss
|
||||
for cl in content_losses:
|
||||
content_score += cl.loss
|
||||
|
||||
style_score *= style_weight
|
||||
content_score *= content_weight
|
||||
|
||||
loss = style_score + content_score
|
||||
loss.backward()
|
||||
|
||||
run[0] += 1
|
||||
if run[0] % 50 == 0:
|
||||
print("run {}:".format(run))
|
||||
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
|
||||
style_score.item(), content_score.item()))
|
||||
print()
|
||||
|
||||
return style_score + content_score
|
||||
|
||||
optimizer.step(closure)
|
||||
|
||||
# a last correction...
|
||||
input_img.data.clamp_(0, 1)
|
||||
|
||||
return input_img
|
||||
|
||||
|
||||
######################################################################
|
||||
# Finally, we can run the algorithm.
|
||||
#
|
||||
|
||||
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
|
||||
content_img, style_img, input_img)
|
||||
|
||||
plt.figure()
|
||||
imshow(output, title='Output Image')
|
||||
|
||||
# sphinx_gallery_thumbnail_number = 4
|
||||
plt.ioff()
|
||||
plt.show()
|
||||
124
examples/pytorch_mnist.py
Normal file
124
examples/pytorch_mnist.py
Normal file
@@ -0,0 +1,124 @@
|
||||
# TRAINS - Example of Pytorch mnist training integration
|
||||
#
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
|
||||
from trains import Task
|
||||
task = Task.init(project_name='examples', task_name='pytorch mnist train')
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 20, 5, 1)
|
||||
self.conv2 = nn.Conv2d(20, 50, 5, 1)
|
||||
self.fc1 = nn.Linear(4 * 4 * 50, 500)
|
||||
self.fc2 = nn.Linear(500, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = F.relu(self.conv2(x))
|
||||
x = F.max_pool2d(x, 2, 2)
|
||||
x = x.view(-1, 4 * 4 * 50)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x, dim=1)
|
||||
|
||||
|
||||
def train(args, model, device, train_loader, optimizer, epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
data, target = data.to(device), target.to(device)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.item()))
|
||||
|
||||
|
||||
def test(args, model, device, test_loader):
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
with torch.no_grad():
|
||||
for data, target in test_loader:
|
||||
data, target = data.to(device), target.to(device)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
|
||||
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
|
||||
correct += pred.eq(target.view_as(pred)).sum().item()
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
|
||||
def main():
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
|
||||
parser.add_argument('--save-model', action='store_true', default=True,
|
||||
help='For Saving the current Model')
|
||||
args = parser.parse_args()
|
||||
use_cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
|
||||
device = torch.device("cuda" if use_cuda else "cpu")
|
||||
|
||||
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
datasets.MNIST('../data', train=False, transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
])),
|
||||
batch_size=args.test_batch_size, shuffle=True, **kwargs)
|
||||
|
||||
model = Net().to(device)
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(args, model, device, train_loader, optimizer, epoch)
|
||||
test(args, model, device, test_loader)
|
||||
|
||||
if (args.save_model):
|
||||
torch.save(model.state_dict(), "/tmp/mnist_cnn.pt")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
126
examples/pytorch_tensorboard.py
Normal file
126
examples/pytorch_tensorboard.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# TRAINS - Example of pytorch with tensorboard>=v1.14
|
||||
#
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
|
||||
from trains import Task
|
||||
task = Task.init(project_name='examples', task_name='pytroch with tensorboard')
|
||||
|
||||
|
||||
writer = SummaryWriter('runs')
|
||||
writer.add_text('lstm', 'This is an lstm', 0)
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=2, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if args.cuda:
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x)
|
||||
|
||||
|
||||
model = Net()
|
||||
if args.cuda:
|
||||
model.cuda()
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
|
||||
def train(epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data), Variable(target)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.data.item()))
|
||||
niter = epoch*len(train_loader)+batch_idx
|
||||
writer.add_scalar('Train/Loss', loss.data.item(), niter)
|
||||
|
||||
|
||||
def test():
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
for niter, (data, target) in enumerate(test_loader):
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data, volatile=True), Variable(target)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(output, target, size_average=False).data.item() # sum up batch loss
|
||||
pred = output.data.max(1)[1] # get the index of the max log-probability
|
||||
pred = pred.eq(target.data).cpu().sum()
|
||||
writer.add_scalar('Test/Loss', pred, niter)
|
||||
correct += pred
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(epoch)
|
||||
torch.save(model, '/tmp/model{}'.format(epoch))
|
||||
test()
|
||||
126
examples/pytorch_tensorboardX.py
Normal file
126
examples/pytorch_tensorboardX.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# TRAINS - Example of pytorch with tensorboardX
|
||||
#
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
from torch.autograd import Variable
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from trains import Task
|
||||
task = Task.init(project_name='examples', task_name='pytroch with tensorboardX')
|
||||
|
||||
|
||||
writer = SummaryWriter('runs')
|
||||
writer.add_text('lstm', 'This is an lstm', 0)
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=2, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--no-cuda', action='store_true', default=False,
|
||||
help='disables CUDA training')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
args = parser.parse_args()
|
||||
args.cuda = not args.no_cuda and torch.cuda.is_available()
|
||||
|
||||
torch.manual_seed(args.seed)
|
||||
if args.cuda:
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
|
||||
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
|
||||
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))])),
|
||||
batch_size=args.batch_size, shuffle=True, **kwargs)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x)
|
||||
|
||||
|
||||
model = Net()
|
||||
if args.cuda:
|
||||
model.cuda()
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
|
||||
|
||||
def train(epoch):
|
||||
model.train()
|
||||
for batch_idx, (data, target) in enumerate(train_loader):
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data), Variable(target)
|
||||
optimizer.zero_grad()
|
||||
output = model(data)
|
||||
loss = F.nll_loss(output, target)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
if batch_idx % args.log_interval == 0:
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
|
||||
epoch, batch_idx * len(data), len(train_loader.dataset),
|
||||
100. * batch_idx / len(train_loader), loss.data.item()))
|
||||
niter = epoch*len(train_loader)+batch_idx
|
||||
writer.add_scalar('Train/Loss', loss.data.item(), niter)
|
||||
|
||||
|
||||
def test():
|
||||
model.eval()
|
||||
test_loss = 0
|
||||
correct = 0
|
||||
for niter, (data, target) in enumerate(test_loader):
|
||||
if args.cuda:
|
||||
data, target = data.cuda(), target.cuda()
|
||||
data, target = Variable(data, volatile=True), Variable(target)
|
||||
output = model(data)
|
||||
test_loss += F.nll_loss(output, target, size_average=False).data.item() # sum up batch loss
|
||||
pred = output.data.max(1)[1] # get the index of the max log-probability
|
||||
pred = pred.eq(target.data).cpu().sum()
|
||||
writer.add_scalar('Test/Loss', pred, niter)
|
||||
correct += pred
|
||||
|
||||
test_loss /= len(test_loader.dataset)
|
||||
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
|
||||
test_loss, correct, len(test_loader.dataset),
|
||||
100. * correct / len(test_loader.dataset)))
|
||||
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(epoch)
|
||||
torch.save(model, '/tmp/model{}'.format(epoch))
|
||||
test()
|
||||
BIN
examples/samples/dancing.jpg
Normal file
BIN
examples/samples/dancing.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
BIN
examples/samples/picasso.jpg
Normal file
BIN
examples/samples/picasso.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 112 KiB |
237
examples/tensorboard_pr_curve.py
Normal file
237
examples/tensorboard_pr_curve.py
Normal file
@@ -0,0 +1,237 @@
|
||||
# TRAINS - Example of new tensorboard pr_curves model
|
||||
#
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Create sample PR curve summary data.
|
||||
We have 3 classes: R, G, and B. We generate colors within RGB space from 3
|
||||
normal distributions (1 at each corner of the color triangle: [255, 0, 0],
|
||||
[0, 255, 0], and [0, 0, 255]).
|
||||
The true label of each random color is associated with the normal distribution
|
||||
that generated it.
|
||||
Using 3 other normal distributions (over the distance each color is from a
|
||||
corner of the color triangle - RGB), we then compute the probability that each
|
||||
color belongs to the class. We use those probabilities to generate PR curves.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os.path
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
import tensorflow as tf
|
||||
from tensorboard.plugins.pr_curve import summary
|
||||
from trains import Task
|
||||
|
||||
task = Task.init(project_name='examples', task_name='tensorboard pr_curve')
|
||||
|
||||
tf.compat.v1.disable_v2_behavior()
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
flags.DEFINE_string('logdir', '/tmp/pr_curve_demo', 'Directory into which to write TensorBoard data.')
|
||||
|
||||
flags.DEFINE_integer('steps', 10,
|
||||
'Number of steps to generate for each PR curve.')
|
||||
|
||||
|
||||
def start_runs(
|
||||
logdir,
|
||||
steps,
|
||||
run_name,
|
||||
thresholds,
|
||||
mask_every_other_prediction=False):
|
||||
"""Generate a PR curve with precision and recall evenly weighted.
|
||||
Arguments:
|
||||
logdir: The directory into which to store all the runs' data.
|
||||
steps: The number of steps to run for.
|
||||
run_name: The name of the run.
|
||||
thresholds: The number of thresholds to use for PR curves.
|
||||
mask_every_other_prediction: Whether to mask every other prediction by
|
||||
alternating weights between 0 and 1.
|
||||
"""
|
||||
tf.compat.v1.reset_default_graph()
|
||||
tf.compat.v1.set_random_seed(42)
|
||||
|
||||
# Create a normal distribution layer used to generate true color labels.
|
||||
distribution = tf.compat.v1.distributions.Normal(loc=0., scale=142.)
|
||||
|
||||
# Sample the distribution to generate colors. Lets generate different numbers
|
||||
# of each color. The first dimension is the count of examples.
|
||||
|
||||
# The calls to sample() are given fixed random seed values that are "magic"
|
||||
# in that they correspond to the default seeds for those ops when the PR
|
||||
# curve test (which depends on this code) was written. We've pinned these
|
||||
# instead of continuing to use the defaults since the defaults are based on
|
||||
# node IDs from the sequence of nodes added to the graph, which can silently
|
||||
# change when this code or any TF op implementations it uses are modified.
|
||||
|
||||
# TODO(nickfelt): redo the PR curve test to avoid reliance on random seeds.
|
||||
|
||||
# Generate reds.
|
||||
number_of_reds = 100
|
||||
true_reds = tf.clip_by_value(
|
||||
tf.concat([
|
||||
255 - tf.abs(distribution.sample([number_of_reds, 1], seed=11)),
|
||||
tf.abs(distribution.sample([number_of_reds, 2], seed=34))
|
||||
], axis=1),
|
||||
0, 255)
|
||||
|
||||
# Generate greens.
|
||||
number_of_greens = 200
|
||||
true_greens = tf.clip_by_value(
|
||||
tf.concat([
|
||||
tf.abs(distribution.sample([number_of_greens, 1], seed=61)),
|
||||
255 - tf.abs(distribution.sample([number_of_greens, 1], seed=82)),
|
||||
tf.abs(distribution.sample([number_of_greens, 1], seed=105))
|
||||
], axis=1),
|
||||
0, 255)
|
||||
|
||||
# Generate blues.
|
||||
number_of_blues = 150
|
||||
true_blues = tf.clip_by_value(
|
||||
tf.concat([
|
||||
tf.abs(distribution.sample([number_of_blues, 2], seed=132)),
|
||||
255 - tf.abs(distribution.sample([number_of_blues, 1], seed=153))
|
||||
], axis=1),
|
||||
0, 255)
|
||||
|
||||
# Assign each color a vector of 3 booleans based on its true label.
|
||||
labels = tf.concat([
|
||||
tf.tile(tf.constant([[True, False, False]]), (number_of_reds, 1)),
|
||||
tf.tile(tf.constant([[False, True, False]]), (number_of_greens, 1)),
|
||||
tf.tile(tf.constant([[False, False, True]]), (number_of_blues, 1)),
|
||||
], axis=0)
|
||||
|
||||
# We introduce 3 normal distributions. They are used to predict whether a
|
||||
# color falls under a certain class (based on distances from corners of the
|
||||
# color triangle). The distributions vary per color. We have the distributions
|
||||
# narrow over time.
|
||||
initial_standard_deviations = [v + FLAGS.steps for v in (158, 200, 242)]
|
||||
iteration = tf.compat.v1.placeholder(tf.int32, shape=[])
|
||||
red_predictor = tf.compat.v1.distributions.Normal(
|
||||
loc=0.,
|
||||
scale=tf.cast(
|
||||
initial_standard_deviations[0] - iteration,
|
||||
dtype=tf.float32))
|
||||
green_predictor = tf.compat.v1.distributions.Normal(
|
||||
loc=0.,
|
||||
scale=tf.cast(
|
||||
initial_standard_deviations[1] - iteration,
|
||||
dtype=tf.float32))
|
||||
blue_predictor = tf.compat.v1.distributions.Normal(
|
||||
loc=0.,
|
||||
scale=tf.cast(
|
||||
initial_standard_deviations[2] - iteration,
|
||||
dtype=tf.float32))
|
||||
|
||||
# Make predictions (assign 3 probabilities to each color based on each color's
|
||||
# distance to each of the 3 corners). We seek double the area in the right
|
||||
# tail of the normal distribution.
|
||||
examples = tf.concat([true_reds, true_greens, true_blues], axis=0)
|
||||
probabilities_colors_are_red = (1 - red_predictor.cdf(
|
||||
tf.norm(tensor=examples - tf.constant([255., 0, 0]), axis=1))) * 2
|
||||
probabilities_colors_are_green = (1 - green_predictor.cdf(
|
||||
tf.norm(tensor=examples - tf.constant([0, 255., 0]), axis=1))) * 2
|
||||
probabilities_colors_are_blue = (1 - blue_predictor.cdf(
|
||||
tf.norm(tensor=examples - tf.constant([0, 0, 255.]), axis=1))) * 2
|
||||
|
||||
predictions = (
|
||||
probabilities_colors_are_red,
|
||||
probabilities_colors_are_green,
|
||||
probabilities_colors_are_blue
|
||||
)
|
||||
|
||||
# This is the crucial piece. We write data required for generating PR curves.
|
||||
# We create 1 summary per class because we create 1 PR curve per class.
|
||||
for i, color in enumerate(('red', 'green', 'blue')):
|
||||
description = ('The probabilities used to create this PR curve are '
|
||||
'generated from a normal distribution. Its standard '
|
||||
'deviation is initially %0.0f and decreases over time.' %
|
||||
initial_standard_deviations[i])
|
||||
|
||||
weights = None
|
||||
if mask_every_other_prediction:
|
||||
# Assign a weight of 0 to every even-indexed prediction. Odd-indexed
|
||||
# predictions are assigned a default weight of 1.
|
||||
consecutive_indices = tf.reshape(
|
||||
tf.range(tf.size(input=predictions[i])), tf.shape(input=predictions[i]))
|
||||
weights = tf.cast(consecutive_indices % 2, dtype=tf.float32)
|
||||
|
||||
summary.op(
|
||||
name=color,
|
||||
labels=labels[:, i],
|
||||
predictions=predictions[i],
|
||||
num_thresholds=thresholds,
|
||||
weights=weights,
|
||||
display_name='classifying %s' % color,
|
||||
description=description)
|
||||
merged_summary_op = tf.compat.v1.summary.merge_all()
|
||||
events_directory = os.path.join(logdir, run_name)
|
||||
sess = tf.compat.v1.Session()
|
||||
writer = tf.compat.v1.summary.FileWriter(events_directory, sess.graph)
|
||||
|
||||
for step in xrange(steps):
|
||||
feed_dict = {
|
||||
iteration: step,
|
||||
}
|
||||
merged_summary = sess.run(merged_summary_op, feed_dict=feed_dict)
|
||||
writer.add_summary(merged_summary, step)
|
||||
|
||||
writer.close()
|
||||
|
||||
|
||||
def run_all(logdir, steps, thresholds, verbose=False):
|
||||
"""Generate PR curve summaries.
|
||||
Arguments:
|
||||
logdir: The directory into which to store all the runs' data.
|
||||
steps: The number of steps to run for.
|
||||
verbose: Whether to print the names of runs into stdout during execution.
|
||||
thresholds: The number of thresholds to use for PR curves.
|
||||
"""
|
||||
# First, we generate data for a PR curve that assigns even weights for
|
||||
# predictions of all classes.
|
||||
run_name = 'colors'
|
||||
if verbose:
|
||||
print('--- Running: %s' % run_name)
|
||||
start_runs(
|
||||
logdir=logdir,
|
||||
steps=steps,
|
||||
run_name=run_name,
|
||||
thresholds=thresholds)
|
||||
|
||||
# Next, we generate data for a PR curve that assigns arbitrary weights to
|
||||
# predictions.
|
||||
run_name = 'mask_every_other_prediction'
|
||||
if verbose:
|
||||
print('--- Running: %s' % run_name)
|
||||
start_runs(
|
||||
logdir=logdir,
|
||||
steps=steps,
|
||||
run_name=run_name,
|
||||
thresholds=thresholds,
|
||||
mask_every_other_prediction=True)
|
||||
|
||||
|
||||
def main(_):
|
||||
print('Saving output to %s.' % FLAGS.logdir)
|
||||
run_all(FLAGS.logdir, FLAGS.steps, 50, verbose=True)
|
||||
print('Done. Output saved to %s.' % FLAGS.logdir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run(main)
|
||||
76
examples/tensorboard_toy.py
Normal file
76
examples/tensorboard_toy.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# TRAINS - Example of tensorboard with tensorflow (without any actual training)
|
||||
#
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
from time import sleep
|
||||
#import tensorflow.compat.v1 as tf
|
||||
#tf.disable_v2_behavior()
|
||||
|
||||
from trains import Task
|
||||
task = Task.init(project_name='examples', task_name='tensorboard toy example')
|
||||
|
||||
|
||||
k = tf.placeholder(tf.float32)
|
||||
|
||||
# Make a normal distribution, with a shifting mean
|
||||
mean_moving_normal = tf.random_normal(shape=[1000], mean=(5*k), stddev=1)
|
||||
# Record that distribution into a histogram summary
|
||||
tf.summary.histogram("normal/moving_mean", mean_moving_normal)
|
||||
tf.summary.scalar("normal/value", mean_moving_normal[-1])
|
||||
|
||||
# Make a normal distribution with shrinking variance
|
||||
variance_shrinking_normal = tf.random_normal(shape=[1000], mean=0, stddev=1-(k))
|
||||
# Record that distribution too
|
||||
tf.summary.histogram("normal/shrinking_variance", variance_shrinking_normal)
|
||||
tf.summary.scalar("normal/variance_shrinking_normal", variance_shrinking_normal[-1])
|
||||
|
||||
# Let's combine both of those distributions into one dataset
|
||||
normal_combined = tf.concat([mean_moving_normal, variance_shrinking_normal], 0)
|
||||
# We add another histogram summary to record the combined distribution
|
||||
tf.summary.histogram("normal/bimodal", normal_combined)
|
||||
tf.summary.scalar("normal/normal_combined", normal_combined[0])
|
||||
|
||||
# Add a gamma distribution
|
||||
gamma = tf.random_gamma(shape=[1000], alpha=k)
|
||||
tf.summary.histogram("gamma", gamma)
|
||||
|
||||
# And a poisson distribution
|
||||
poisson = tf.random_poisson(shape=[1000], lam=k)
|
||||
tf.summary.histogram("poisson", poisson)
|
||||
|
||||
# And a uniform distribution
|
||||
uniform = tf.random_uniform(shape=[1000], maxval=k*10)
|
||||
tf.summary.histogram("uniform", uniform)
|
||||
|
||||
# Finally, combine everything together!
|
||||
all_distributions = [mean_moving_normal, variance_shrinking_normal, gamma, poisson, uniform]
|
||||
all_combined = tf.concat(all_distributions, 0)
|
||||
tf.summary.histogram("all_combined", all_combined)
|
||||
|
||||
# convert to 4d [batch, col, row, RGB-channels]
|
||||
image = cv2.imread('./samples/picasso.jpg')
|
||||
image = image[:, :, 0][np.newaxis, :, :, np.newaxis]
|
||||
# image = image[np.newaxis, :, :, :] # test greyscale image
|
||||
|
||||
# un-comment to add image reporting
|
||||
tf.summary.image("test", image, max_outputs=10)
|
||||
|
||||
# Setup a session and summary writer
|
||||
summaries = tf.summary.merge_all()
|
||||
sess = tf.Session()
|
||||
|
||||
logger = task.get_logger()
|
||||
|
||||
# Use original FileWriter for comparison , run:
|
||||
# % tensorboard --logdir=/tmp/histogram_example
|
||||
writer = tf.summary.FileWriter("/tmp/histogram_example")
|
||||
|
||||
# Setup a loop and write the summaries to disk
|
||||
N = 40
|
||||
for step in range(N):
|
||||
k_val = step/float(N)
|
||||
summ = sess.run(summaries, feed_dict={k: k_val})
|
||||
writer.add_summary(summ, global_step=step)
|
||||
|
||||
print('Done!')
|
||||
358
examples/tensorflow_eager.py
Normal file
358
examples/tensorflow_eager.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# TRAINS - Example of tensorflow eager mode, model logging and tensorboard
|
||||
#
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A deep MNIST classifier using convolutional layers.
|
||||
Sample usage:
|
||||
python mnist.py --help
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import tensorflow as tf
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
from trains import Task
|
||||
|
||||
tf.enable_eager_execution()
|
||||
|
||||
task = Task.init(project_name='examples', task_name='Tensorflow eager mode')
|
||||
|
||||
|
||||
FLAGS = tf.app.flags.FLAGS
|
||||
tf.app.flags.DEFINE_integer('data_num', 100, """Flag of type integer""")
|
||||
tf.app.flags.DEFINE_string('img_path', './img', """Flag of type string""")
|
||||
|
||||
|
||||
layers = tf.keras.layers
|
||||
FLAGS = None
|
||||
|
||||
|
||||
class Discriminator(tf.keras.Model):
|
||||
"""GAN Discriminator.
|
||||
A network to differentiate between generated and real handwritten digits.
|
||||
"""
|
||||
|
||||
def __init__(self, data_format):
|
||||
"""Creates a model for discriminating between real and generated digits.
|
||||
Args:
|
||||
data_format: Either 'channels_first' or 'channels_last'.
|
||||
'channels_first' is typically faster on GPUs while 'channels_last' is
|
||||
typically faster on CPUs. See
|
||||
https://www.tensorflow.org/performance/performance_guide#data_formats
|
||||
"""
|
||||
super(Discriminator, self).__init__(name='')
|
||||
if data_format == 'channels_first':
|
||||
self._input_shape = [-1, 1, 28, 28]
|
||||
else:
|
||||
assert data_format == 'channels_last'
|
||||
self._input_shape = [-1, 28, 28, 1]
|
||||
self.conv1 = layers.Conv2D(
|
||||
64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
|
||||
self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format)
|
||||
self.conv2 = layers.Conv2D(
|
||||
128, 5, data_format=data_format, activation=tf.tanh)
|
||||
self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format)
|
||||
self.flatten = layers.Flatten()
|
||||
self.fc1 = layers.Dense(1024, activation=tf.tanh)
|
||||
self.fc2 = layers.Dense(1, activation=None)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Return two logits per image estimating input authenticity.
|
||||
Users should invoke __call__ to run the network, which delegates to this
|
||||
method (and not call this method directly).
|
||||
Args:
|
||||
inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
|
||||
or [batch_size, 1, 28, 28]
|
||||
Returns:
|
||||
A Tensor with shape [batch_size] containing logits estimating
|
||||
the probability that corresponding digit is real.
|
||||
"""
|
||||
x = tf.reshape(inputs, self._input_shape)
|
||||
x = self.conv1(x)
|
||||
x = self.pool1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.pool2(x)
|
||||
x = self.flatten(x)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class Generator(tf.keras.Model):
|
||||
"""Generator of handwritten digits similar to the ones in the MNIST dataset.
|
||||
"""
|
||||
|
||||
def __init__(self, data_format):
|
||||
"""Creates a model for discriminating between real and generated digits.
|
||||
Args:
|
||||
data_format: Either 'channels_first' or 'channels_last'.
|
||||
'channels_first' is typically faster on GPUs while 'channels_last' is
|
||||
typically faster on CPUs. See
|
||||
https://www.tensorflow.org/performance/performance_guide#data_formats
|
||||
"""
|
||||
super(Generator, self).__init__(name='')
|
||||
self.data_format = data_format
|
||||
# We are using 128 6x6 channels as input to the first deconvolution layer
|
||||
if data_format == 'channels_first':
|
||||
self._pre_conv_shape = [-1, 128, 6, 6]
|
||||
else:
|
||||
assert data_format == 'channels_last'
|
||||
self._pre_conv_shape = [-1, 6, 6, 128]
|
||||
self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh)
|
||||
|
||||
# In call(), we reshape the output of fc1 to _pre_conv_shape
|
||||
|
||||
# Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
|
||||
self.conv1 = layers.Conv2DTranspose(
|
||||
64, 4, strides=2, activation=None, data_format=data_format)
|
||||
|
||||
# Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
|
||||
self.conv2 = layers.Conv2DTranspose(
|
||||
1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
|
||||
|
||||
def call(self, inputs):
|
||||
"""Return a batch of generated images.
|
||||
Users should invoke __call__ to run the network, which delegates to this
|
||||
method (and not call this method directly).
|
||||
Args:
|
||||
inputs: A batch of noise vectors as a Tensor with shape
|
||||
[batch_size, length of noise vectors].
|
||||
Returns:
|
||||
A Tensor containing generated images. If data_format is 'channels_last',
|
||||
the shape of returned images is [batch_size, 28, 28, 1], else
|
||||
[batch_size, 1, 28, 28]
|
||||
"""
|
||||
|
||||
x = self.fc1(inputs)
|
||||
x = tf.reshape(x, shape=self._pre_conv_shape)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
|
||||
"""Original discriminator loss for GANs, with label smoothing.
|
||||
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
|
||||
details.
|
||||
Args:
|
||||
discriminator_real_outputs: Discriminator output on real data.
|
||||
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||
to be in the range of (-inf, inf).
|
||||
Returns:
|
||||
A scalar loss Tensor.
|
||||
"""
|
||||
|
||||
loss_on_real = tf.losses.sigmoid_cross_entropy(
|
||||
tf.ones_like(discriminator_real_outputs),
|
||||
discriminator_real_outputs,
|
||||
label_smoothing=0.25)
|
||||
loss_on_generated = tf.losses.sigmoid_cross_entropy(
|
||||
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
||||
loss = loss_on_real + loss_on_generated
|
||||
tf.contrib.summary.scalar('discriminator_loss', loss)
|
||||
return loss
|
||||
|
||||
|
||||
def generator_loss(discriminator_gen_outputs):
|
||||
"""Original generator loss for GANs.
|
||||
L = -log(sigmoid(D(G(z))))
|
||||
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
|
||||
for more details.
|
||||
Args:
|
||||
discriminator_gen_outputs: Discriminator output on generated data. Expected
|
||||
to be in the range of (-inf, inf).
|
||||
Returns:
|
||||
A scalar loss Tensor.
|
||||
"""
|
||||
loss = tf.losses.sigmoid_cross_entropy(
|
||||
tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
|
||||
tf.contrib.summary.scalar('generator_loss', loss)
|
||||
return loss
|
||||
|
||||
|
||||
def train_one_epoch(generator, discriminator, generator_optimizer,
|
||||
discriminator_optimizer, dataset, step_counter,
|
||||
log_interval, noise_dim):
|
||||
"""Train `generator` and `discriminator` models on `dataset`.
|
||||
Args:
|
||||
generator: Generator model.
|
||||
discriminator: Discriminator model.
|
||||
generator_optimizer: Optimizer to use for generator.
|
||||
discriminator_optimizer: Optimizer to use for discriminator.
|
||||
dataset: Dataset of images to train on.
|
||||
step_counter: An integer variable, used to write summaries regularly.
|
||||
log_interval: How many steps to wait between logging and collecting
|
||||
summaries.
|
||||
noise_dim: Dimension of noise vector to use.
|
||||
"""
|
||||
|
||||
total_generator_loss = 0.0
|
||||
total_discriminator_loss = 0.0
|
||||
for (batch_index, images) in enumerate(dataset):
|
||||
with tf.device('/cpu:0'):
|
||||
tf.assign_add(step_counter, 1)
|
||||
|
||||
with tf.contrib.summary.record_summaries_every_n_global_steps(
|
||||
log_interval, global_step=step_counter):
|
||||
current_batch_size = images.shape[0]
|
||||
noise = tf.random_uniform(
|
||||
shape=[current_batch_size, noise_dim],
|
||||
minval=-1.,
|
||||
maxval=1.,
|
||||
seed=batch_index)
|
||||
|
||||
# we can use 2 tapes or a single persistent tape.
|
||||
# Using two tapes is memory efficient since intermediate tensors can be
|
||||
# released between the two .gradient() calls below
|
||||
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
|
||||
generated_images = generator(noise)
|
||||
tf.contrib.summary.image(
|
||||
'generated_images',
|
||||
tf.reshape(generated_images, [-1, 28, 28, 1]),
|
||||
max_images=10)
|
||||
|
||||
discriminator_gen_outputs = discriminator(generated_images)
|
||||
discriminator_real_outputs = discriminator(images)
|
||||
discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
|
||||
discriminator_gen_outputs)
|
||||
total_discriminator_loss += discriminator_loss_val
|
||||
|
||||
generator_loss_val = generator_loss(discriminator_gen_outputs)
|
||||
total_generator_loss += generator_loss_val
|
||||
|
||||
generator_grad = gen_tape.gradient(generator_loss_val,
|
||||
generator.variables)
|
||||
discriminator_grad = disc_tape.gradient(discriminator_loss_val,
|
||||
discriminator.variables)
|
||||
|
||||
generator_optimizer.apply_gradients(
|
||||
zip(generator_grad, generator.variables))
|
||||
discriminator_optimizer.apply_gradients(
|
||||
zip(discriminator_grad, discriminator.variables))
|
||||
|
||||
if log_interval and batch_index > 0 and batch_index % log_interval == 0:
|
||||
print('Batch #%d\tAverage Generator Loss: %.6f\t'
|
||||
'Average Discriminator Loss: %.6f' %
|
||||
(batch_index, total_generator_loss / batch_index,
|
||||
total_discriminator_loss / batch_index))
|
||||
|
||||
|
||||
def main(_):
|
||||
(device, data_format) = ('/gpu:0', 'channels_first')
|
||||
if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0:
|
||||
(device, data_format) = ('/cpu:0', 'channels_last')
|
||||
print('Using device %s, and data format %s.' % (device, data_format))
|
||||
|
||||
# Load the datasets
|
||||
data = input_data.read_data_sets(FLAGS.data_dir)
|
||||
dataset = (
|
||||
tf.data.Dataset.from_tensor_slices(data.train.images[:1280]).shuffle(60000)
|
||||
.batch(FLAGS.batch_size))
|
||||
|
||||
# Create the models and optimizers.
|
||||
model_objects = {
|
||||
'generator': Generator(data_format),
|
||||
'discriminator': Discriminator(data_format),
|
||||
'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
|
||||
'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
|
||||
'step_counter': tf.train.get_or_create_global_step(),
|
||||
}
|
||||
|
||||
# Prepare summary writer and checkpoint info
|
||||
summary_writer = tf.contrib.summary.create_file_writer(
|
||||
FLAGS.output_dir, flush_millis=1000)
|
||||
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
|
||||
latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
|
||||
if latest_cpkt:
|
||||
print('Using latest checkpoint at ' + latest_cpkt)
|
||||
checkpoint = tf.train.Checkpoint(**model_objects)
|
||||
# Restore variables on creation if a checkpoint exists.
|
||||
checkpoint.restore(latest_cpkt)
|
||||
|
||||
with tf.device(device):
|
||||
for _ in range(3):
|
||||
start = time.time()
|
||||
with summary_writer.as_default():
|
||||
train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval,
|
||||
noise_dim=FLAGS.noise, **model_objects)
|
||||
end = time.time()
|
||||
checkpoint.save(checkpoint_prefix)
|
||||
print('\nTrain time for epoch #%d (step %d): %f' %
|
||||
(checkpoint.save_counter.numpy(),
|
||||
checkpoint.step_counter.numpy(),
|
||||
end - start))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--data-dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/input_data',
|
||||
help=('Directory for storing input data (default '
|
||||
'/tmp/tensorflow/mnist/input_data)'))
|
||||
parser.add_argument(
|
||||
'--batch-size',
|
||||
type=int,
|
||||
default=16,
|
||||
metavar='N',
|
||||
help='input batch size for training (default: 128)')
|
||||
parser.add_argument(
|
||||
'--log-interval',
|
||||
type=int,
|
||||
default=1,
|
||||
metavar='N',
|
||||
help=('number of batches between logging and writing summaries '
|
||||
'(default: 100)'))
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/',
|
||||
metavar='DIR',
|
||||
help='Directory to write TensorBoard summaries (defaults to none)')
|
||||
parser.add_argument(
|
||||
'--checkpoint_dir',
|
||||
type=str,
|
||||
default='/tmp/tensorflow/mnist/checkpoints/',
|
||||
metavar='DIR',
|
||||
help=('Directory to save checkpoints in (once per epoch) (default '
|
||||
'/tmp/tensorflow/mnist/checkpoints/)'))
|
||||
parser.add_argument(
|
||||
'--lr',
|
||||
type=float,
|
||||
default=0.001,
|
||||
metavar='LR',
|
||||
help='learning rate (default: 0.001)')
|
||||
parser.add_argument(
|
||||
'--noise',
|
||||
type=int,
|
||||
default=100,
|
||||
metavar='N',
|
||||
help='Length of noise vector for generator input (default: 100)')
|
||||
parser.add_argument(
|
||||
'--no-gpu',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='disables GPU usage even if a GPU is available')
|
||||
|
||||
FLAGS, unparsed = parser.parse_known_args()
|
||||
|
||||
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
|
||||
171
examples/tensorflow_mnist.py
Normal file
171
examples/tensorflow_mnist.py
Normal file
@@ -0,0 +1,171 @@
|
||||
# TRAINS - Example of tensorflow mnist training model logging
|
||||
#
|
||||
# Save and Restore a model using TensorFlow.
|
||||
# This example is using the MNIST database of handwritten digits
|
||||
# (http://yann.lecun.com/exdb/mnist/)
|
||||
#
|
||||
# Author: Aymeric Damien
|
||||
# Project: https://github.com/aymericdamien/TensorFlow-Examples/
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
from os.path import exists
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from trains import Task
|
||||
|
||||
MODEL_PATH = "/tmp/module_no_signatures"
|
||||
task = Task.init(project_name='examples', task_name='Tensorflow mnist example')
|
||||
|
||||
## block
|
||||
X_train = np.random.rand(100, 3)
|
||||
y_train = np.random.rand(100, 1)
|
||||
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
|
||||
model.compile(loss='categorical_crossentropy',
|
||||
optimizer=tf.keras.optimizers.SGD(),
|
||||
metrics=['accuracy'])
|
||||
model.fit(X_train, y_train, steps_per_epoch=1, nb_epoch=1)
|
||||
|
||||
with tf.Session(graph=tf.Graph()) as sess:
|
||||
if exists(MODEL_PATH):
|
||||
try:
|
||||
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], MODEL_PATH)
|
||||
m2 = tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], MODEL_PATH)
|
||||
except Exception:
|
||||
pass
|
||||
tf.train.Checkpoint
|
||||
## block end
|
||||
|
||||
# Import MNIST data
|
||||
from tensorflow.examples.tutorials.mnist import input_data
|
||||
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
|
||||
|
||||
# Parameters
|
||||
parameters = {
|
||||
'learning_rate': 0.001,
|
||||
'batch_size': 100,
|
||||
'display_step': 1,
|
||||
'model_path': "/tmp/model.ckpt",
|
||||
|
||||
# Network Parameters
|
||||
'n_hidden_1': 256, # 1st layer number of features
|
||||
'n_hidden_2': 256, # 2nd layer number of features
|
||||
'n_input': 784, # MNIST data input (img shape: 28*28)
|
||||
'n_classes': 10, # MNIST total classes (0-9 digits)
|
||||
}
|
||||
# TRAINS: connect parameters with the experiment/task for logging
|
||||
parameters = task.connect(parameters)
|
||||
|
||||
# tf Graph input
|
||||
x = tf.placeholder("float", [None, parameters['n_input']])
|
||||
y = tf.placeholder("float", [None, parameters['n_classes']])
|
||||
|
||||
|
||||
# Create model
|
||||
def multilayer_perceptron(x, weights, biases):
|
||||
# Hidden layer with RELU activation
|
||||
layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
|
||||
layer_1 = tf.nn.relu(layer_1)
|
||||
# Hidden layer with RELU activation
|
||||
layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
|
||||
layer_2 = tf.nn.relu(layer_2)
|
||||
# Output layer with linear activation
|
||||
out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
|
||||
return out_layer
|
||||
|
||||
# Store layers weight & bias
|
||||
weights = {
|
||||
'h1': tf.Variable(tf.random_normal([parameters['n_input'], parameters['n_hidden_1']])),
|
||||
'h2': tf.Variable(tf.random_normal([parameters['n_hidden_1'], parameters['n_hidden_2']])),
|
||||
'out': tf.Variable(tf.random_normal([parameters['n_hidden_2'], parameters['n_classes']]))
|
||||
}
|
||||
biases = {
|
||||
'b1': tf.Variable(tf.random_normal([parameters['n_hidden_1']])),
|
||||
'b2': tf.Variable(tf.random_normal([parameters['n_hidden_2']])),
|
||||
'out': tf.Variable(tf.random_normal([parameters['n_classes']]))
|
||||
}
|
||||
|
||||
# Construct model
|
||||
pred = multilayer_perceptron(x, weights, biases)
|
||||
|
||||
# Define loss and optimizer
|
||||
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
|
||||
optimizer = tf.train.AdamOptimizer(learning_rate=parameters['learning_rate']).minimize(cost)
|
||||
|
||||
# Initialize the variables (i.e. assign their default value)
|
||||
init = tf.global_variables_initializer()
|
||||
|
||||
# 'Saver' op to save and restore all the variables
|
||||
saver = tf.train.Saver()
|
||||
|
||||
# Running first session
|
||||
print("Starting 1st session...")
|
||||
with tf.Session() as sess:
|
||||
|
||||
# Run the initializer
|
||||
sess.run(init)
|
||||
|
||||
# Training cycle
|
||||
for epoch in range(3):
|
||||
avg_cost = 0.
|
||||
total_batch = int(mnist.train.num_examples/parameters['batch_size'])
|
||||
# Loop over all batches
|
||||
for i in range(total_batch):
|
||||
batch_x, batch_y = mnist.train.next_batch(parameters['batch_size'])
|
||||
# Run optimization op (backprop) and cost op (to get loss value)
|
||||
_, c = sess.run([optimizer, cost], feed_dict={x: batch_x,
|
||||
y: batch_y})
|
||||
# Compute average loss
|
||||
avg_cost += c / total_batch
|
||||
# Display logs per epoch step
|
||||
if epoch % parameters['display_step'] == 0:
|
||||
print("Epoch:", '%04d' % (epoch+1), "cost=", \
|
||||
"{:.9f}".format(avg_cost))
|
||||
save_path = saver.save(sess, parameters['model_path'])
|
||||
|
||||
print("First Optimization Finished!")
|
||||
|
||||
# Test model
|
||||
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
|
||||
# Calculate accuracy
|
||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
|
||||
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
|
||||
|
||||
# Save model weights to disk
|
||||
save_path = saver.save(sess, parameters['model_path'])
|
||||
print("Model saved in file: %s" % save_path)
|
||||
|
||||
# Running a new session
|
||||
print("Starting 2nd session...")
|
||||
with tf.Session() as sess:
|
||||
# Initialize variables
|
||||
sess.run(init)
|
||||
|
||||
# Restore model weights from previously saved model
|
||||
saver.restore(sess, parameters['model_path'])
|
||||
print("Model restored from file: %s" % save_path)
|
||||
|
||||
# Resume training
|
||||
for epoch in range(7):
|
||||
avg_cost = 0.
|
||||
total_batch = int(mnist.train.num_examples / parameters['batch_size'])
|
||||
# Loop over all batches
|
||||
for i in range(total_batch):
|
||||
batch_x, batch_y = mnist.train.next_batch(parameters['batch_size'])
|
||||
# Run optimization op (backprop) and cost op (to get loss value)
|
||||
_, c = sess.run([optimizer, cost], feed_dict={x: batch_x,
|
||||
y: batch_y})
|
||||
# Compute average loss
|
||||
avg_cost += c / total_batch
|
||||
# Display logs per epoch step
|
||||
if epoch % parameters['display_step'] == 0:
|
||||
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
|
||||
print("Second Optimization Finished!")
|
||||
|
||||
# Test model
|
||||
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
|
||||
# Calculate accuracy
|
||||
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
|
||||
print("Accuracy:", accuracy.eval(
|
||||
{x: mnist.test.images, y: mnist.test.labels}))
|
||||
131
examples/trains.conf
Normal file
131
examples/trains.conf
Normal file
@@ -0,0 +1,131 @@
|
||||
# TRAINS SDK configuration file
|
||||
api {
|
||||
host: http://localhost:8008
|
||||
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
|
||||
}
|
||||
sdk {
|
||||
# TRAINS - default SDK configuration
|
||||
|
||||
storage {
|
||||
cache {
|
||||
# Defaults to system temp folder / cache
|
||||
default_base_dir: "~/.trains/cache"
|
||||
}
|
||||
}
|
||||
|
||||
metrics {
|
||||
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
|
||||
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
|
||||
# X files are stored in the upload destination for each metric/variant combination.
|
||||
file_history_size: 100
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
quality: 87
|
||||
subsampling: 0
|
||||
}
|
||||
}
|
||||
|
||||
network {
|
||||
metrics {
|
||||
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
|
||||
# a specific iteration
|
||||
file_upload_threads: 4
|
||||
|
||||
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
|
||||
# being sent for upload
|
||||
file_upload_starvation_warning_sec: 120
|
||||
}
|
||||
|
||||
iteration {
|
||||
# Max number of retries when getting frames if the server returned an error (http code 500)
|
||||
max_retries_on_server_error: 5
|
||||
# Backoff factory for consecutive retry attempts.
|
||||
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
|
||||
retry_backoff_factor_sec: 10
|
||||
}
|
||||
}
|
||||
aws {
|
||||
s3 {
|
||||
# S3 credentials, used for read/write access by various SDK elements
|
||||
|
||||
# default, used for any bucket not specified below
|
||||
key: ""
|
||||
secret: ""
|
||||
region: ""
|
||||
|
||||
credentials: [
|
||||
# specifies key/secret credentials to use when handling s3 urls (read or write)
|
||||
# {
|
||||
# bucket: "my-bucket-name"
|
||||
# key: "my-access-key"
|
||||
# secret: "my-secret-key"
|
||||
# },
|
||||
# {
|
||||
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
|
||||
# host: "my-minio-host:9000"
|
||||
# key: "12345678"
|
||||
# secret: "12345678"
|
||||
# multipart: false
|
||||
# secure: false
|
||||
# }
|
||||
]
|
||||
}
|
||||
boto3 {
|
||||
pool_connections: 512
|
||||
max_multipart_concurrency: 16
|
||||
}
|
||||
}
|
||||
google.storage {
|
||||
# # Default project and credentials file
|
||||
# # Will be used when no bucket configuration is found
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
|
||||
# # Specific credentials per bucket and sub directory
|
||||
# credentials = [
|
||||
# {
|
||||
# bucket: "my-bucket"
|
||||
# subdir: "path/in/bucket" # Not required
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
# },
|
||||
# ]
|
||||
}
|
||||
|
||||
log {
|
||||
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
|
||||
null_log_propagate: False
|
||||
task_log_buffer_capacity: 66
|
||||
|
||||
# disable urllib info and lower levels
|
||||
disable_urllib3_info: True
|
||||
}
|
||||
|
||||
development {
|
||||
# Development-mode options
|
||||
|
||||
# dev task reuse window
|
||||
task_reuse_time_window_in_hours: 72.0
|
||||
|
||||
# Run VCS repository detection asynchronously
|
||||
vcs_repo_detect_async: False
|
||||
|
||||
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
|
||||
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
|
||||
store_uncommitted_code_diff_on_train: True
|
||||
|
||||
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
|
||||
support_stopping: True
|
||||
|
||||
# Development mode worker
|
||||
worker {
|
||||
# Status report period in seconds
|
||||
report_period_sec: 2
|
||||
|
||||
# Log all stdout & stderr
|
||||
log_stdout: True
|
||||
}
|
||||
}
|
||||
}
|
||||
31
requirements.txt
Normal file
31
requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
apache-libcloud>=2.2.1
|
||||
attrs>=18.0
|
||||
backports.functools-lru-cache>=1.0.2 ; python_version < '3'
|
||||
boto3>=1.9
|
||||
botocore>=1.12
|
||||
colorama>=0.4.1
|
||||
coloredlogs>=10.0
|
||||
enum34>=0.9
|
||||
funcsigs>=1.0
|
||||
furl>=2.0.0
|
||||
future>=0.16.0
|
||||
futures>=3.0.5 ; python_version < '3'
|
||||
google-cloud-storage>=1.13.2
|
||||
humanfriendly>=2.1
|
||||
jsonmodels>=2.2
|
||||
jsonschema>=2.6.0
|
||||
numpy>=1.10
|
||||
opencv-python>=3.2.0.8
|
||||
pathlib2>=2.3.0
|
||||
psutil>=3.4.2
|
||||
pyhocon>=0.3.38
|
||||
python-dateutil>=2.6.1
|
||||
PyYAML>=3.12
|
||||
requests-file>=1.4.2
|
||||
requests>=2.18.4
|
||||
six>=1.11.0
|
||||
tqdm>=4.19.5
|
||||
urllib3>=1.22
|
||||
watchdog>=0.8.0
|
||||
pyjwt>=1.6.4
|
||||
plotly>=3.9.0
|
||||
4
setup.cfg
Normal file
4
setup.cfg
Normal file
@@ -0,0 +1,4 @@
|
||||
[bdist_wheel]
|
||||
# Currently supports Python2 only,
|
||||
# Python 3 is coming...
|
||||
universal=1
|
||||
77
setup.py
Normal file
77
setup.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
TRAINS - Artificial Intelligence Version Control
|
||||
https://github.com/allegroai/trains
|
||||
"""
|
||||
|
||||
# Always prefer setuptools over distutils
|
||||
from setuptools import setup, find_packages
|
||||
from six import exec_
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
here = Path(__file__).resolve().parent
|
||||
|
||||
# Get the long description from the README file
|
||||
long_description = (here / 'README.md').read_text()
|
||||
|
||||
|
||||
def read_version_string():
|
||||
result = {}
|
||||
exec_((here / 'trains/version.py').read_text(), result)
|
||||
return result['__version__']
|
||||
|
||||
|
||||
version = read_version_string()
|
||||
|
||||
requirements = (here / 'requirements.txt').read_text().splitlines()
|
||||
|
||||
setup(
|
||||
name='trains',
|
||||
version=version,
|
||||
description='TRAINS - Magic Version Control & Experiment Manager for AI',
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown',
|
||||
# The project's main homepage.
|
||||
url='https://github.com/allegroai/trains',
|
||||
author='Allegroai',
|
||||
author_email='trains@allegro.ai',
|
||||
license='Apache License 2.0',
|
||||
classifiers=[
|
||||
# How mature is this project? Common values are
|
||||
# 3 - Alpha
|
||||
# 4 - Beta
|
||||
# 5 - Production/Stable
|
||||
'Development Status :: 4 - Beta',
|
||||
'Intended Audience :: Developers',
|
||||
'Intended Audience :: Science/Research',
|
||||
'Operating System :: POSIX :: Linux',
|
||||
'Operating System :: MacOS :: MacOS X',
|
||||
'Operating System :: Microsoft',
|
||||
'Topic :: Scientific/Engineering :: Artificial Intelligence',
|
||||
'Topic :: Software Development',
|
||||
'Topic :: Software Development :: Version Control',
|
||||
'Topic :: System :: Logging',
|
||||
'Topic :: System :: Monitoring',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3.5',
|
||||
'Programming Language :: Python :: 3.6',
|
||||
'Programming Language :: Python :: 3.7',
|
||||
'License :: OSI Approved :: Apache Software License',
|
||||
],
|
||||
keywords='trains development machine deep learning version control machine-learning machinelearning '
|
||||
'deeplearning deep-learning experiment-manager experimentmanager',
|
||||
packages=find_packages(exclude=['contrib', 'docs', 'data', 'examples', 'tests']),
|
||||
install_requires=requirements,
|
||||
package_data={
|
||||
'trains': ['config/default/*.conf', 'backend_api/config/default/*.conf']
|
||||
},
|
||||
include_package_data=True,
|
||||
# To provide executable scripts, use entry points in preference to the
|
||||
# "scripts" keyword. Entry points provide cross-platform support and allow
|
||||
# pip to create the appropriate form of executable for the target platform.
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
'trains-init = trains.config.default.__main__:main',
|
||||
],
|
||||
},
|
||||
)
|
||||
7
trains/__init__.py
Normal file
7
trains/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
""" TRAINS open SDK """
|
||||
|
||||
from .version import __version__
|
||||
from .task import Task
|
||||
from .model import InputModel, OutputModel
|
||||
from .logger import Logger
|
||||
from .errors import UsageError
|
||||
3
trains/backend_api/__init__.py
Normal file
3
trains/backend_api/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .version import __version__
|
||||
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
|
||||
from .config import load as load_config
|
||||
16
trains/backend_api/config/__init__.py
Normal file
16
trains/backend_api/config/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from ...backend_config import Config
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
def load(*additional_module_paths):
|
||||
# type: (str) -> Config
|
||||
"""
|
||||
Load configuration with the API defaults, using the additional module path provided
|
||||
:param additional_module_paths: Additional config paths for modules who'se default
|
||||
configurations should be loaded as well
|
||||
:return: Config object
|
||||
"""
|
||||
config = Config(verbose=False)
|
||||
this_module_path = Path(__file__).parent
|
||||
config.load_relative_to(this_module_path, *additional_module_paths)
|
||||
return config
|
||||
41
trains/backend_api/config/default/api.conf
Normal file
41
trains/backend_api/config/default/api.conf
Normal file
@@ -0,0 +1,41 @@
|
||||
{
|
||||
version: 1.5
|
||||
host: https://demoapi.trainsai.io
|
||||
|
||||
# default version assigned to requests with no specific version. this is not expected to change
|
||||
# as it keeps us backwards compatible.
|
||||
default_version: 1.5
|
||||
|
||||
http {
|
||||
max_req_size = 15728640 # request size limit (smaller than that configured in api server)
|
||||
|
||||
retries {
|
||||
# retry values (int, 0 means fail on first retry)
|
||||
total: 240 # Total number of retries to allow. Takes precedence over other counts.
|
||||
connect: 240 # How many times to retry on connection-related errors (never reached server)
|
||||
read: 240 # How many times to retry on read errors (waiting for server)
|
||||
redirect: 240 # How many redirects to perform (HTTP response with a status code 301, 302, 303, 307 or 308)
|
||||
status: 240 # How many times to retry on bad status codes
|
||||
|
||||
# backoff parameters
|
||||
# timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1))
|
||||
backoff_factor: 1.0
|
||||
backoff_max: 300.0
|
||||
}
|
||||
|
||||
wait_on_maintenance_forever: true
|
||||
|
||||
pool_maxsize: 512
|
||||
pool_connections: 512
|
||||
}
|
||||
|
||||
credentials {
|
||||
access_key: ""
|
||||
secret_key: ""
|
||||
}
|
||||
|
||||
auth {
|
||||
# When creating a request, if token will expire in less than this value, try to refresh the token
|
||||
token_expiration_threshold_sec = 360
|
||||
}
|
||||
}
|
||||
9
trains/backend_api/config/default/logging.conf
Normal file
9
trains/backend_api/config/default/logging.conf
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
version: 1
|
||||
loggers {
|
||||
urllib3 {
|
||||
level: ERROR
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
0
trains/backend_api/schema/__init__.py
Normal file
0
trains/backend_api/schema/__init__.py
Normal file
38
trains/backend_api/schema/action.py
Normal file
38
trains/backend_api/schema/action.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
import attr
|
||||
from attr.converters import optional as optional_converter
|
||||
from attr.validators import instance_of, optional, and_
|
||||
from six import string_types
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
sequence = instance_of((list, tuple))
|
||||
|
||||
|
||||
def sequence_of(types):
|
||||
def validator(_, attrib, value):
|
||||
assert all(isinstance(x, types) for x in value), attrib.name
|
||||
|
||||
return and_(sequence, validator)
|
||||
|
||||
|
||||
@attr.s
|
||||
class Action(object):
|
||||
name = attr.ib()
|
||||
version = attr.ib()
|
||||
service = attr.ib()
|
||||
definitions_keys = attr.ib(validator=sequence)
|
||||
authorize = attr.ib(validator=instance_of(bool), default=True)
|
||||
log_data = attr.ib(validator=instance_of(bool), default=True)
|
||||
log_result_data = attr.ib(validator=instance_of(bool), default=True)
|
||||
internal = attr.ib(default=False)
|
||||
allow_roles = attr.ib(default=None, validator=optional(sequence_of(string_types)))
|
||||
request = attr.ib(validator=optional(instance_of(dict)), default=None)
|
||||
batch_request = attr.ib(validator=optional(instance_of(dict)), default=None)
|
||||
response = attr.ib(validator=optional(instance_of(dict)), default=None)
|
||||
method = attr.ib(default=None)
|
||||
description = attr.ib(
|
||||
default=None,
|
||||
validator=optional(instance_of(string_types)),
|
||||
)
|
||||
201
trains/backend_api/schema/service.py
Normal file
201
trains/backend_api/schema/service.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import itertools
|
||||
import re
|
||||
|
||||
import attr
|
||||
import six
|
||||
|
||||
import pyhocon
|
||||
|
||||
from .action import Action
|
||||
|
||||
|
||||
class Service(object):
|
||||
""" Service schema handler """
|
||||
|
||||
__jsonschema_ref_ex = re.compile("^#/definitions/(.*)$")
|
||||
|
||||
@property
|
||||
def default(self):
|
||||
return self._default
|
||||
|
||||
@property
|
||||
def actions(self):
|
||||
return self._actions
|
||||
|
||||
@property
|
||||
def definitions(self):
|
||||
""" Raw service definitions (each might be dependant on some of its siblings) """
|
||||
return self._definitions
|
||||
|
||||
@property
|
||||
def definitions_refs(self):
|
||||
return self._definitions_refs
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def doc(self):
|
||||
return self._doc
|
||||
|
||||
def __init__(self, name, service_config):
|
||||
self._name = name
|
||||
self._default = None
|
||||
self._actions = []
|
||||
self._definitions = None
|
||||
self._definitions_refs = None
|
||||
self._doc = None
|
||||
self.parse(service_config)
|
||||
|
||||
@classmethod
|
||||
def get_ref_name(cls, ref_string):
|
||||
m = cls.__jsonschema_ref_ex.match(ref_string)
|
||||
if m:
|
||||
return m.group(1)
|
||||
|
||||
def parse(self, service_config):
|
||||
self._default = service_config.get(
|
||||
"_default", pyhocon.ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
|
||||
self._doc = '{} service'.format(self.name)
|
||||
description = service_config.get('_description', '')
|
||||
if description:
|
||||
self._doc += '\n\n{}'.format(description)
|
||||
self._definitions = service_config.get(
|
||||
"_definitions", pyhocon.ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
self._definitions_refs = {
|
||||
k: self._get_schema_references(v) for k, v in self._definitions.items()
|
||||
}
|
||||
all_refs = set(itertools.chain(*self.definitions_refs.values()))
|
||||
if not all_refs.issubset(self.definitions):
|
||||
raise ValueError(
|
||||
"Unresolved references (%s) in %s/definitions"
|
||||
% (", ".join(all_refs.difference(self.definitions)), self.name)
|
||||
)
|
||||
|
||||
actions = {
|
||||
k: v.as_plain_ordered_dict()
|
||||
for k, v in service_config.items()
|
||||
if not k.startswith("_")
|
||||
}
|
||||
self._actions = {
|
||||
action_name: action
|
||||
for action_name, action in (
|
||||
(action_name, self._parse_action_versions(action_name, action_versions))
|
||||
for action_name, action_versions in actions.items()
|
||||
)
|
||||
if action
|
||||
}
|
||||
|
||||
def _parse_action_versions(self, action_name, action_versions):
|
||||
def parse_version(action_version):
|
||||
try:
|
||||
return float(action_version)
|
||||
except (ValueError, TypeError) as ex:
|
||||
raise ValueError(
|
||||
"Failed parsing version number {} ({}) in {}/{}".format(
|
||||
action_version, ex.args[0], self.name, action_name
|
||||
)
|
||||
)
|
||||
|
||||
def add_internal(cfg):
|
||||
if "internal" in action_versions:
|
||||
cfg.setdefault("internal", action_versions["internal"])
|
||||
return cfg
|
||||
|
||||
return {
|
||||
parsed_version: action
|
||||
for parsed_version, action in (
|
||||
(parsed_version, self._parse_action(action_name, parsed_version, add_internal(cfg)))
|
||||
for parsed_version, cfg in (
|
||||
(parse_version(version), cfg)
|
||||
for version, cfg in action_versions.items()
|
||||
if version not in ["internal", "allow_roles", "authorize"]
|
||||
)
|
||||
)
|
||||
if action
|
||||
}
|
||||
|
||||
def _get_schema_references(self, s):
|
||||
refs = set()
|
||||
if isinstance(s, dict):
|
||||
for k, v in s.items():
|
||||
if isinstance(v, six.string_types):
|
||||
m = self.__jsonschema_ref_ex.match(v)
|
||||
if m:
|
||||
refs.add(m.group(1))
|
||||
continue
|
||||
elif k in ("oneOf", "anyOf") and isinstance(v, list):
|
||||
refs.update(*map(self._get_schema_references, v))
|
||||
refs.update(self._get_schema_references(v))
|
||||
return refs
|
||||
|
||||
def _expand_schema_references_with_definitions(self, schema, refs=None):
|
||||
definitions = schema.get("definitions", {})
|
||||
refs = refs if refs is not None else self._get_schema_references(schema)
|
||||
required_refs = set(refs).difference(definitions)
|
||||
if not required_refs:
|
||||
return required_refs
|
||||
if not required_refs.issubset(self.definitions):
|
||||
raise ValueError(
|
||||
"Unresolved references (%s)"
|
||||
% ", ".join(required_refs.difference(self.definitions))
|
||||
)
|
||||
|
||||
# update required refs with all sub requirements
|
||||
last_required_refs = None
|
||||
while last_required_refs != required_refs:
|
||||
last_required_refs = required_refs.copy()
|
||||
additional_refs = set(
|
||||
itertools.chain(
|
||||
*(self.definitions_refs.get(ref, []) for ref in required_refs)
|
||||
)
|
||||
)
|
||||
required_refs.update(additional_refs)
|
||||
return required_refs
|
||||
|
||||
def _resolve_schema_references(self, schema, refs=None):
|
||||
definitions = schema.get("definitions", {})
|
||||
definitions.update({k: v for k, v in self.definitions.items() if k in refs})
|
||||
schema["definitions"] = definitions
|
||||
|
||||
def _parse_action(self, action_name, action_version, action_config):
|
||||
data = self.default.copy()
|
||||
data.update(action_config)
|
||||
|
||||
if not action_config.get("generate", True):
|
||||
return None
|
||||
|
||||
definitions_keys = set()
|
||||
for schema_key in ("request", "response"):
|
||||
if schema_key in action_config:
|
||||
try:
|
||||
schema = action_config[schema_key]
|
||||
refs = self._expand_schema_references_with_definitions(schema)
|
||||
self._resolve_schema_references(schema, refs=refs)
|
||||
definitions_keys.update(refs)
|
||||
except ValueError as ex:
|
||||
name = "%s.%s/%.1f/%s" % (
|
||||
self.name,
|
||||
action_name,
|
||||
action_version,
|
||||
schema_key,
|
||||
)
|
||||
raise ValueError("%s in %s" % (str(ex), name))
|
||||
|
||||
return Action(
|
||||
name=action_name,
|
||||
version=action_version,
|
||||
definitions_keys=list(definitions_keys),
|
||||
service=self.name,
|
||||
**(
|
||||
{
|
||||
key: value
|
||||
for key, value in data.items()
|
||||
if key in attr.fields_dict(Action)
|
||||
}
|
||||
)
|
||||
)
|
||||
22
trains/backend_api/services/__init__.py
Normal file
22
trains/backend_api/services/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from .v2_1 import async_request
|
||||
from .v2_1 import auth
|
||||
from .v2_1 import debug
|
||||
from .v2_1 import events
|
||||
from .v2_1 import models
|
||||
from .v2_1 import news
|
||||
from .v2_1 import projects
|
||||
from .v2_1 import storage
|
||||
from .v2_1 import tasks
|
||||
|
||||
|
||||
__all__ = [
|
||||
'async_request',
|
||||
'auth',
|
||||
'debug',
|
||||
'events',
|
||||
'models',
|
||||
'news',
|
||||
'projects',
|
||||
'storage',
|
||||
'tasks',
|
||||
]
|
||||
0
trains/backend_api/services/v2_1/__init__.py
Normal file
0
trains/backend_api/services/v2_1/__init__.py
Normal file
414
trains/backend_api/services/v2_1/async_request.py
Normal file
414
trains/backend_api/services/v2_1/async_request.py
Normal file
@@ -0,0 +1,414 @@
|
||||
"""
|
||||
async service
|
||||
|
||||
This service provides support for asynchronous API calls.
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class Call(NonStrictDataModel):
|
||||
"""
|
||||
:param id: The job ID associated with this call.
|
||||
:type id: str
|
||||
:param status: The job's status.
|
||||
:type status: str
|
||||
:param created: Job creation time.
|
||||
:type created: str
|
||||
:param ended: Job end time.
|
||||
:type ended: str
|
||||
:param enqueued: Job enqueue time.
|
||||
:type enqueued: str
|
||||
:param meta: Metadata for this job, includes endpoint and additional relevant
|
||||
call data.
|
||||
:type meta: dict
|
||||
:param company: The Company this job belongs to.
|
||||
:type company: str
|
||||
:param exec_info: Job execution information.
|
||||
:type exec_info: str
|
||||
"""
|
||||
_schema = {
|
||||
'properties': {
|
||||
'company': {
|
||||
'description': 'The Company this job belongs to.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'created': {
|
||||
'description': 'Job creation time.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'ended': {'description': 'Job end time.', 'type': ['string', 'null']},
|
||||
'enqueued': {
|
||||
'description': 'Job enqueue time.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'exec_info': {
|
||||
'description': 'Job execution information.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'id': {
|
||||
'description': 'The job ID associated with this call.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'meta': {
|
||||
'additionalProperties': True,
|
||||
'description': 'Metadata for this job, includes endpoint and additional relevant call data.',
|
||||
'type': ['object', 'null'],
|
||||
},
|
||||
'status': {
|
||||
'description': "The job's status.",
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, id=None, status=None, created=None, ended=None, enqueued=None, meta=None, company=None, exec_info=None, **kwargs):
|
||||
super(Call, self).__init__(**kwargs)
|
||||
self.id = id
|
||||
self.status = status
|
||||
self.created = created
|
||||
self.ended = ended
|
||||
self.enqueued = enqueued
|
||||
self.meta = meta
|
||||
self.company = company
|
||||
self.exec_info = exec_info
|
||||
|
||||
@schema_property('id')
|
||||
def id(self):
|
||||
return self._property_id
|
||||
|
||||
@id.setter
|
||||
def id(self, value):
|
||||
if value is None:
|
||||
self._property_id = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "id", six.string_types)
|
||||
self._property_id = value
|
||||
|
||||
@schema_property('status')
|
||||
def status(self):
|
||||
return self._property_status
|
||||
|
||||
@status.setter
|
||||
def status(self, value):
|
||||
if value is None:
|
||||
self._property_status = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "status", six.string_types)
|
||||
self._property_status = value
|
||||
|
||||
@schema_property('created')
|
||||
def created(self):
|
||||
return self._property_created
|
||||
|
||||
@created.setter
|
||||
def created(self, value):
|
||||
if value is None:
|
||||
self._property_created = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "created", six.string_types)
|
||||
self._property_created = value
|
||||
|
||||
@schema_property('ended')
|
||||
def ended(self):
|
||||
return self._property_ended
|
||||
|
||||
@ended.setter
|
||||
def ended(self, value):
|
||||
if value is None:
|
||||
self._property_ended = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "ended", six.string_types)
|
||||
self._property_ended = value
|
||||
|
||||
@schema_property('enqueued')
|
||||
def enqueued(self):
|
||||
return self._property_enqueued
|
||||
|
||||
@enqueued.setter
|
||||
def enqueued(self, value):
|
||||
if value is None:
|
||||
self._property_enqueued = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "enqueued", six.string_types)
|
||||
self._property_enqueued = value
|
||||
|
||||
@schema_property('meta')
|
||||
def meta(self):
|
||||
return self._property_meta
|
||||
|
||||
@meta.setter
|
||||
def meta(self, value):
|
||||
if value is None:
|
||||
self._property_meta = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "meta", (dict,))
|
||||
self._property_meta = value
|
||||
|
||||
@schema_property('company')
|
||||
def company(self):
|
||||
return self._property_company
|
||||
|
||||
@company.setter
|
||||
def company(self, value):
|
||||
if value is None:
|
||||
self._property_company = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "company", six.string_types)
|
||||
self._property_company = value
|
||||
|
||||
@schema_property('exec_info')
|
||||
def exec_info(self):
|
||||
return self._property_exec_info
|
||||
|
||||
@exec_info.setter
|
||||
def exec_info(self, value):
|
||||
if value is None:
|
||||
self._property_exec_info = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "exec_info", six.string_types)
|
||||
self._property_exec_info = value
|
||||
|
||||
|
||||
class CallsRequest(Request):
|
||||
"""
|
||||
Get a list of all asynchronous API calls handled by the system.
|
||||
This includes both previously handled calls, calls being executed and calls waiting in queue.
|
||||
|
||||
:param status: Return only calls who's status is in this list.
|
||||
:type status: Sequence[str]
|
||||
:param endpoint: Return only calls handling this endpoint. Supports wildcards.
|
||||
:type endpoint: str
|
||||
:param task: Return only calls associated with this task ID. Supports
|
||||
wildcards.
|
||||
:type task: str
|
||||
"""
|
||||
|
||||
_service = "async"
|
||||
_action = "calls"
|
||||
_version = "1.5"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'endpoint': {
|
||||
'description': 'Return only calls handling this endpoint. Supports wildcards.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'status': {
|
||||
'description': "Return only calls who's status is in this list.",
|
||||
'items': {'enum': ['queued', 'in_progress', 'completed'], 'type': 'string'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
'task': {
|
||||
'description': 'Return only calls associated with this task ID. Supports wildcards.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, status=None, endpoint=None, task=None, **kwargs):
|
||||
super(CallsRequest, self).__init__(**kwargs)
|
||||
self.status = status
|
||||
self.endpoint = endpoint
|
||||
self.task = task
|
||||
|
||||
@schema_property('status')
|
||||
def status(self):
|
||||
return self._property_status
|
||||
|
||||
@status.setter
|
||||
def status(self, value):
|
||||
if value is None:
|
||||
self._property_status = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "status", (list, tuple))
|
||||
|
||||
self.assert_isinstance(value, "status", six.string_types, is_array=True)
|
||||
self._property_status = value
|
||||
|
||||
@schema_property('endpoint')
|
||||
def endpoint(self):
|
||||
return self._property_endpoint
|
||||
|
||||
@endpoint.setter
|
||||
def endpoint(self, value):
|
||||
if value is None:
|
||||
self._property_endpoint = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "endpoint", six.string_types)
|
||||
self._property_endpoint = value
|
||||
|
||||
@schema_property('task')
|
||||
def task(self):
|
||||
return self._property_task
|
||||
|
||||
@task.setter
|
||||
def task(self, value):
|
||||
if value is None:
|
||||
self._property_task = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "task", six.string_types)
|
||||
self._property_task = value
|
||||
|
||||
|
||||
class CallsResponse(Response):
|
||||
"""
|
||||
Response of async.calls endpoint.
|
||||
|
||||
:param calls: A list of the current asynchronous calls handled by the system.
|
||||
:type calls: Sequence[Call]
|
||||
"""
|
||||
_service = "async"
|
||||
_action = "calls"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'call': {
|
||||
'properties': {
|
||||
'company': {
|
||||
'description': 'The Company this job belongs to.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'created': {
|
||||
'description': 'Job creation time.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'ended': {
|
||||
'description': 'Job end time.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'enqueued': {
|
||||
'description': 'Job enqueue time.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'exec_info': {
|
||||
'description': 'Job execution information.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'id': {
|
||||
'description': 'The job ID associated with this call.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'meta': {
|
||||
'additionalProperties': True,
|
||||
'description': 'Metadata for this job, includes endpoint and additional relevant call data.',
|
||||
'type': ['object', 'null'],
|
||||
},
|
||||
'status': {
|
||||
'description': "The job's status.",
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
'properties': {
|
||||
'calls': {
|
||||
'description': 'A list of the current asynchronous calls handled by the system.',
|
||||
'items': {'$ref': '#/definitions/call'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, calls=None, **kwargs):
|
||||
super(CallsResponse, self).__init__(**kwargs)
|
||||
self.calls = calls
|
||||
|
||||
@schema_property('calls')
|
||||
def calls(self):
|
||||
return self._property_calls
|
||||
|
||||
@calls.setter
|
||||
def calls(self, value):
|
||||
if value is None:
|
||||
self._property_calls = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "calls", (list, tuple))
|
||||
if any(isinstance(v, dict) for v in value):
|
||||
value = [Call.from_dict(v) if isinstance(v, dict) else v for v in value]
|
||||
else:
|
||||
self.assert_isinstance(value, "calls", Call, is_array=True)
|
||||
self._property_calls = value
|
||||
|
||||
|
||||
class ResultRequest(Request):
|
||||
"""
|
||||
Try getting the result of a previously accepted asynchronous API call.
|
||||
If execution for the asynchronous call has completed, the complete call response data will be returned.
|
||||
Otherwise, a 202 code will be returned with no data
|
||||
|
||||
:param id: The id returned by the accepted asynchronous API call.
|
||||
:type id: str
|
||||
"""
|
||||
|
||||
_service = "async"
|
||||
_action = "result"
|
||||
_version = "1.5"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'id': {
|
||||
'description': 'The id returned by the accepted asynchronous API call.',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, id=None, **kwargs):
|
||||
super(ResultRequest, self).__init__(**kwargs)
|
||||
self.id = id
|
||||
|
||||
@schema_property('id')
|
||||
def id(self):
|
||||
return self._property_id
|
||||
|
||||
@id.setter
|
||||
def id(self, value):
|
||||
if value is None:
|
||||
self._property_id = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "id", six.string_types)
|
||||
self._property_id = value
|
||||
|
||||
|
||||
class ResultResponse(Response):
|
||||
"""
|
||||
Response of async.result endpoint.
|
||||
|
||||
"""
|
||||
_service = "async"
|
||||
_action = "result"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {'additionalProperties': True, 'definitions': {}, 'type': 'object'}
|
||||
|
||||
|
||||
response_mapping = {
|
||||
ResultRequest: ResultResponse,
|
||||
CallsRequest: CallsResponse,
|
||||
}
|
||||
1112
trains/backend_api/services/v2_1/auth.py
Normal file
1112
trains/backend_api/services/v2_1/auth.py
Normal file
File diff suppressed because it is too large
Load Diff
194
trains/backend_api/services/v2_1/debug.py
Normal file
194
trains/backend_api/services/v2_1/debug.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""
|
||||
debug service
|
||||
|
||||
Debugging utilities
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class ApiexRequest(Request):
|
||||
"""
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
class ApiexResponse(Response):
|
||||
"""
|
||||
Response of debug.apiex endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "apiex"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class EchoRequest(Request):
|
||||
"""
|
||||
Return request data
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class EchoResponse(Response):
|
||||
"""
|
||||
Response of debug.echo endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "echo"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class ExRequest(Request):
|
||||
"""
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
|
||||
|
||||
|
||||
class ExResponse(Response):
|
||||
"""
|
||||
Response of debug.ex endpoint.
|
||||
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ex"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingRequest(Request):
|
||||
"""
|
||||
Return a message. Does not require authorization.
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingResponse(Response):
|
||||
"""
|
||||
Response of debug.ping endpoint.
|
||||
|
||||
:param msg: A friendly message
|
||||
:type msg: str
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'msg': {
|
||||
'description': 'A friendly message',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, msg=None, **kwargs):
|
||||
super(PingResponse, self).__init__(**kwargs)
|
||||
self.msg = msg
|
||||
|
||||
@schema_property('msg')
|
||||
def msg(self):
|
||||
return self._property_msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
if value is None:
|
||||
self._property_msg = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "msg", six.string_types)
|
||||
self._property_msg = value
|
||||
|
||||
|
||||
class PingAuthRequest(Request):
|
||||
"""
|
||||
Return a message. Requires authorization.
|
||||
|
||||
"""
|
||||
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class PingAuthResponse(Response):
|
||||
"""
|
||||
Response of debug.ping_auth endpoint.
|
||||
|
||||
:param msg: A friendly message
|
||||
:type msg: str
|
||||
"""
|
||||
_service = "debug"
|
||||
_action = "ping_auth"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'msg': {
|
||||
'description': 'A friendly message',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, msg=None, **kwargs):
|
||||
super(PingAuthResponse, self).__init__(**kwargs)
|
||||
self.msg = msg
|
||||
|
||||
@schema_property('msg')
|
||||
def msg(self):
|
||||
return self._property_msg
|
||||
|
||||
@msg.setter
|
||||
def msg(self, value):
|
||||
if value is None:
|
||||
self._property_msg = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "msg", six.string_types)
|
||||
self._property_msg = value
|
||||
|
||||
|
||||
response_mapping = {
|
||||
EchoRequest: EchoResponse,
|
||||
PingRequest: PingResponse,
|
||||
PingAuthRequest: PingAuthResponse,
|
||||
ApiexRequest: ApiexResponse,
|
||||
ExRequest: ExResponse,
|
||||
}
|
||||
2846
trains/backend_api/services/v2_1/events.py
Normal file
2846
trains/backend_api/services/v2_1/events.py
Normal file
File diff suppressed because it is too large
Load Diff
2675
trains/backend_api/services/v2_1/models.py
Normal file
2675
trains/backend_api/services/v2_1/models.py
Normal file
File diff suppressed because it is too large
Load Diff
70
trains/backend_api/services/v2_1/news.py
Normal file
70
trains/backend_api/services/v2_1/news.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
news service
|
||||
|
||||
This service provides platform news.
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class GetRequest(Request):
|
||||
"""
|
||||
Gets latest news link
|
||||
|
||||
"""
|
||||
|
||||
_service = "news"
|
||||
_action = "get"
|
||||
_version = "1.5"
|
||||
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
|
||||
|
||||
|
||||
class GetResponse(Response):
|
||||
"""
|
||||
Response of news.get endpoint.
|
||||
|
||||
:param url: URL to news html file
|
||||
:type url: str
|
||||
"""
|
||||
_service = "news"
|
||||
_action = "get"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'url': {
|
||||
'description': 'URL to news html file',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, url=None, **kwargs):
|
||||
super(GetResponse, self).__init__(**kwargs)
|
||||
self.url = url
|
||||
|
||||
@schema_property('url')
|
||||
def url(self):
|
||||
return self._property_url
|
||||
|
||||
@url.setter
|
||||
def url(self, value):
|
||||
if value is None:
|
||||
self._property_url = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "url", six.string_types)
|
||||
self._property_url = value
|
||||
|
||||
|
||||
response_mapping = {
|
||||
GetRequest: GetResponse,
|
||||
}
|
||||
1847
trains/backend_api/services/v2_1/projects.py
Normal file
1847
trains/backend_api/services/v2_1/projects.py
Normal file
File diff suppressed because it is too large
Load Diff
681
trains/backend_api/services/v2_1/storage.py
Normal file
681
trains/backend_api/services/v2_1/storage.py
Normal file
@@ -0,0 +1,681 @@
|
||||
"""
|
||||
storage service
|
||||
|
||||
Provides a management API for customer-associated storage locations
|
||||
"""
|
||||
import six
|
||||
import types
|
||||
from datetime import datetime
|
||||
import enum
|
||||
|
||||
from dateutil.parser import parse as parse_datetime
|
||||
|
||||
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
|
||||
|
||||
|
||||
class Credentials(NonStrictDataModel):
|
||||
"""
|
||||
:param access_key: Credentials access key
|
||||
:type access_key: str
|
||||
:param secret_key: Credentials secret key
|
||||
:type secret_key: str
|
||||
"""
|
||||
_schema = {
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials access key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'secret_key': {
|
||||
'description': 'Credentials secret key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, access_key=None, secret_key=None, **kwargs):
|
||||
super(Credentials, self).__init__(**kwargs)
|
||||
self.access_key = access_key
|
||||
self.secret_key = secret_key
|
||||
|
||||
@schema_property('access_key')
|
||||
def access_key(self):
|
||||
return self._property_access_key
|
||||
|
||||
@access_key.setter
|
||||
def access_key(self, value):
|
||||
if value is None:
|
||||
self._property_access_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "access_key", six.string_types)
|
||||
self._property_access_key = value
|
||||
|
||||
@schema_property('secret_key')
|
||||
def secret_key(self):
|
||||
return self._property_secret_key
|
||||
|
||||
@secret_key.setter
|
||||
def secret_key(self, value):
|
||||
if value is None:
|
||||
self._property_secret_key = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "secret_key", six.string_types)
|
||||
self._property_secret_key = value
|
||||
|
||||
|
||||
class Storage(NonStrictDataModel):
|
||||
"""
|
||||
:param id: Entry ID
|
||||
:type id: str
|
||||
:param name: Entry name
|
||||
:type name: str
|
||||
:param company: Company ID
|
||||
:type company: str
|
||||
:param created: Entry creation time
|
||||
:type created: datetime.datetime
|
||||
:param uri: Storage URI
|
||||
:type uri: str
|
||||
:param credentials: Credentials required for accessing the storage
|
||||
:type credentials: Credentials
|
||||
"""
|
||||
_schema = {
|
||||
'properties': {
|
||||
'company': {'description': 'Company ID', 'type': ['string', 'null']},
|
||||
'created': {
|
||||
'description': 'Entry creation time',
|
||||
'format': 'date-time',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'credentials': {
|
||||
'description': 'Credentials required for accessing the storage',
|
||||
'oneOf': [{'$ref': '#/definitions/credentials'}, {'type': 'null'}],
|
||||
},
|
||||
'id': {'description': 'Entry ID', 'type': ['string', 'null']},
|
||||
'name': {'description': 'Entry name', 'type': ['string', 'null']},
|
||||
'uri': {'description': 'Storage URI', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, id=None, name=None, company=None, created=None, uri=None, credentials=None, **kwargs):
|
||||
super(Storage, self).__init__(**kwargs)
|
||||
self.id = id
|
||||
self.name = name
|
||||
self.company = company
|
||||
self.created = created
|
||||
self.uri = uri
|
||||
self.credentials = credentials
|
||||
|
||||
@schema_property('id')
|
||||
def id(self):
|
||||
return self._property_id
|
||||
|
||||
@id.setter
|
||||
def id(self, value):
|
||||
if value is None:
|
||||
self._property_id = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "id", six.string_types)
|
||||
self._property_id = value
|
||||
|
||||
@schema_property('name')
|
||||
def name(self):
|
||||
return self._property_name
|
||||
|
||||
@name.setter
|
||||
def name(self, value):
|
||||
if value is None:
|
||||
self._property_name = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "name", six.string_types)
|
||||
self._property_name = value
|
||||
|
||||
@schema_property('company')
|
||||
def company(self):
|
||||
return self._property_company
|
||||
|
||||
@company.setter
|
||||
def company(self, value):
|
||||
if value is None:
|
||||
self._property_company = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "company", six.string_types)
|
||||
self._property_company = value
|
||||
|
||||
@schema_property('created')
|
||||
def created(self):
|
||||
return self._property_created
|
||||
|
||||
@created.setter
|
||||
def created(self, value):
|
||||
if value is None:
|
||||
self._property_created = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "created", six.string_types + (datetime,))
|
||||
if not isinstance(value, datetime):
|
||||
value = parse_datetime(value)
|
||||
self._property_created = value
|
||||
|
||||
@schema_property('uri')
|
||||
def uri(self):
|
||||
return self._property_uri
|
||||
|
||||
@uri.setter
|
||||
def uri(self, value):
|
||||
if value is None:
|
||||
self._property_uri = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "uri", six.string_types)
|
||||
self._property_uri = value
|
||||
|
||||
@schema_property('credentials')
|
||||
def credentials(self):
|
||||
return self._property_credentials
|
||||
|
||||
@credentials.setter
|
||||
def credentials(self, value):
|
||||
if value is None:
|
||||
self._property_credentials = None
|
||||
return
|
||||
if isinstance(value, dict):
|
||||
value = Credentials.from_dict(value)
|
||||
else:
|
||||
self.assert_isinstance(value, "credentials", Credentials)
|
||||
self._property_credentials = value
|
||||
|
||||
|
||||
class CreateRequest(Request):
|
||||
"""
|
||||
Create a new storage entry
|
||||
|
||||
:param name: Storage name
|
||||
:type name: str
|
||||
:param uri: Storage URI
|
||||
:type uri: str
|
||||
:param credentials: Credentials required for accessing the storage
|
||||
:type credentials: Credentials
|
||||
:param company: Company under which to add this storage. Only valid for users
|
||||
with the root or system role, otherwise the calling user's company will be
|
||||
used.
|
||||
:type company: str
|
||||
"""
|
||||
|
||||
_service = "storage"
|
||||
_action = "create"
|
||||
_version = "1.5"
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'credentials': {
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials access key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'secret_key': {
|
||||
'description': 'Credentials secret key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
'properties': {
|
||||
'company': {
|
||||
'description': "Company under which to add this storage. Only valid for users with the root or system role, otherwise the calling user's company will be used.",
|
||||
'type': 'string',
|
||||
},
|
||||
'credentials': {
|
||||
'$ref': '#/definitions/credentials',
|
||||
'description': 'Credentials required for accessing the storage',
|
||||
},
|
||||
'name': {'description': 'Storage name', 'type': ['string', 'null']},
|
||||
'uri': {'description': 'Storage URI', 'type': 'string'},
|
||||
},
|
||||
'required': ['uri'],
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, uri, name=None, credentials=None, company=None, **kwargs):
|
||||
super(CreateRequest, self).__init__(**kwargs)
|
||||
self.name = name
|
||||
self.uri = uri
|
||||
self.credentials = credentials
|
||||
self.company = company
|
||||
|
||||
@schema_property('name')
|
||||
def name(self):
|
||||
return self._property_name
|
||||
|
||||
@name.setter
|
||||
def name(self, value):
|
||||
if value is None:
|
||||
self._property_name = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "name", six.string_types)
|
||||
self._property_name = value
|
||||
|
||||
@schema_property('uri')
|
||||
def uri(self):
|
||||
return self._property_uri
|
||||
|
||||
@uri.setter
|
||||
def uri(self, value):
|
||||
if value is None:
|
||||
self._property_uri = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "uri", six.string_types)
|
||||
self._property_uri = value
|
||||
|
||||
@schema_property('credentials')
|
||||
def credentials(self):
|
||||
return self._property_credentials
|
||||
|
||||
@credentials.setter
|
||||
def credentials(self, value):
|
||||
if value is None:
|
||||
self._property_credentials = None
|
||||
return
|
||||
if isinstance(value, dict):
|
||||
value = Credentials.from_dict(value)
|
||||
else:
|
||||
self.assert_isinstance(value, "credentials", Credentials)
|
||||
self._property_credentials = value
|
||||
|
||||
@schema_property('company')
|
||||
def company(self):
|
||||
return self._property_company
|
||||
|
||||
@company.setter
|
||||
def company(self, value):
|
||||
if value is None:
|
||||
self._property_company = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "company", six.string_types)
|
||||
self._property_company = value
|
||||
|
||||
|
||||
class CreateResponse(Response):
|
||||
"""
|
||||
Response of storage.create endpoint.
|
||||
|
||||
:param id: New storage ID
|
||||
:type id: str
|
||||
"""
|
||||
_service = "storage"
|
||||
_action = "create"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'id': {'description': 'New storage ID', 'type': ['string', 'null']},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, id=None, **kwargs):
|
||||
super(CreateResponse, self).__init__(**kwargs)
|
||||
self.id = id
|
||||
|
||||
@schema_property('id')
|
||||
def id(self):
|
||||
return self._property_id
|
||||
|
||||
@id.setter
|
||||
def id(self, value):
|
||||
if value is None:
|
||||
self._property_id = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "id", six.string_types)
|
||||
self._property_id = value
|
||||
|
||||
|
||||
class DeleteRequest(Request):
|
||||
"""
|
||||
Deletes a storage entry
|
||||
|
||||
:param storage: Storage entry ID
|
||||
:type storage: str
|
||||
"""
|
||||
|
||||
_service = "storage"
|
||||
_action = "delete"
|
||||
_version = "1.5"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'storage': {'description': 'Storage entry ID', 'type': 'string'},
|
||||
},
|
||||
'required': ['storage'],
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, storage, **kwargs):
|
||||
super(DeleteRequest, self).__init__(**kwargs)
|
||||
self.storage = storage
|
||||
|
||||
@schema_property('storage')
|
||||
def storage(self):
|
||||
return self._property_storage
|
||||
|
||||
@storage.setter
|
||||
def storage(self, value):
|
||||
if value is None:
|
||||
self._property_storage = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "storage", six.string_types)
|
||||
self._property_storage = value
|
||||
|
||||
|
||||
class DeleteResponse(Response):
|
||||
"""
|
||||
Response of storage.delete endpoint.
|
||||
|
||||
:param deleted: Number of storage entries deleted (0 or 1)
|
||||
:type deleted: int
|
||||
"""
|
||||
_service = "storage"
|
||||
_action = "delete"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'deleted': {
|
||||
'description': 'Number of storage entries deleted (0 or 1)',
|
||||
'type': ['integer', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, deleted=None, **kwargs):
|
||||
super(DeleteResponse, self).__init__(**kwargs)
|
||||
self.deleted = deleted
|
||||
|
||||
@schema_property('deleted')
|
||||
def deleted(self):
|
||||
return self._property_deleted
|
||||
|
||||
@deleted.setter
|
||||
def deleted(self, value):
|
||||
if value is None:
|
||||
self._property_deleted = None
|
||||
return
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
value = int(value)
|
||||
|
||||
self.assert_isinstance(value, "deleted", six.integer_types)
|
||||
self._property_deleted = value
|
||||
|
||||
|
||||
class GetAllRequest(Request):
|
||||
"""
|
||||
Get all storage entries
|
||||
|
||||
:param name: Get only storage entries whose name matches this pattern (python
|
||||
regular expression syntax)
|
||||
:type name: str
|
||||
:param id: List of Storage IDs used to filter results
|
||||
:type id: Sequence[str]
|
||||
:param page: Page number, returns a specific page out of the result list of
|
||||
results.
|
||||
:type page: int
|
||||
:param page_size: Page size, specifies the number of results returned in each
|
||||
page (last page may contain fewer results)
|
||||
:type page_size: int
|
||||
:param order_by: List of field names to order by. When search_text is used,
|
||||
'@text_score' can be used as a field representing the text score of returned
|
||||
documents. Use '-' prefix to specify descending order. Optional, recommended
|
||||
when using page
|
||||
:type order_by: Sequence[str]
|
||||
:param only_fields: List of document field names (nesting is supported using
|
||||
'.', e.g. execution.model_labels). If provided, this list defines the query's
|
||||
projection (only these fields will be returned for each result entry)
|
||||
:type only_fields: Sequence[str]
|
||||
"""
|
||||
|
||||
_service = "storage"
|
||||
_action = "get_all"
|
||||
_version = "1.5"
|
||||
_schema = {
|
||||
'definitions': {},
|
||||
'properties': {
|
||||
'id': {
|
||||
'description': 'List of Storage IDs used to filter results',
|
||||
'items': {'type': 'string'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
'name': {
|
||||
'description': 'Get only storage entries whose name matches this pattern (python regular expression syntax)',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'only_fields': {
|
||||
'description': "List of document field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)",
|
||||
'items': {'type': 'string'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
'order_by': {
|
||||
'description': "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page",
|
||||
'items': {'type': 'string'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
'page': {
|
||||
'description': 'Page number, returns a specific page out of the result list of results.',
|
||||
'minimum': 0,
|
||||
'type': ['integer', 'null'],
|
||||
},
|
||||
'page_size': {
|
||||
'description': 'Page size, specifies the number of results returned in each page (last page may contain fewer results)',
|
||||
'minimum': 1,
|
||||
'type': ['integer', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, name=None, id=None, page=None, page_size=None, order_by=None, only_fields=None, **kwargs):
|
||||
super(GetAllRequest, self).__init__(**kwargs)
|
||||
self.name = name
|
||||
self.id = id
|
||||
self.page = page
|
||||
self.page_size = page_size
|
||||
self.order_by = order_by
|
||||
self.only_fields = only_fields
|
||||
|
||||
@schema_property('name')
|
||||
def name(self):
|
||||
return self._property_name
|
||||
|
||||
@name.setter
|
||||
def name(self, value):
|
||||
if value is None:
|
||||
self._property_name = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "name", six.string_types)
|
||||
self._property_name = value
|
||||
|
||||
@schema_property('id')
|
||||
def id(self):
|
||||
return self._property_id
|
||||
|
||||
@id.setter
|
||||
def id(self, value):
|
||||
if value is None:
|
||||
self._property_id = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "id", (list, tuple))
|
||||
|
||||
self.assert_isinstance(value, "id", six.string_types, is_array=True)
|
||||
self._property_id = value
|
||||
|
||||
@schema_property('page')
|
||||
def page(self):
|
||||
return self._property_page
|
||||
|
||||
@page.setter
|
||||
def page(self, value):
|
||||
if value is None:
|
||||
self._property_page = None
|
||||
return
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
value = int(value)
|
||||
|
||||
self.assert_isinstance(value, "page", six.integer_types)
|
||||
self._property_page = value
|
||||
|
||||
@schema_property('page_size')
|
||||
def page_size(self):
|
||||
return self._property_page_size
|
||||
|
||||
@page_size.setter
|
||||
def page_size(self, value):
|
||||
if value is None:
|
||||
self._property_page_size = None
|
||||
return
|
||||
if isinstance(value, float) and value.is_integer():
|
||||
value = int(value)
|
||||
|
||||
self.assert_isinstance(value, "page_size", six.integer_types)
|
||||
self._property_page_size = value
|
||||
|
||||
@schema_property('order_by')
|
||||
def order_by(self):
|
||||
return self._property_order_by
|
||||
|
||||
@order_by.setter
|
||||
def order_by(self, value):
|
||||
if value is None:
|
||||
self._property_order_by = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "order_by", (list, tuple))
|
||||
|
||||
self.assert_isinstance(value, "order_by", six.string_types, is_array=True)
|
||||
self._property_order_by = value
|
||||
|
||||
@schema_property('only_fields')
|
||||
def only_fields(self):
|
||||
return self._property_only_fields
|
||||
|
||||
@only_fields.setter
|
||||
def only_fields(self, value):
|
||||
if value is None:
|
||||
self._property_only_fields = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "only_fields", (list, tuple))
|
||||
|
||||
self.assert_isinstance(value, "only_fields", six.string_types, is_array=True)
|
||||
self._property_only_fields = value
|
||||
|
||||
|
||||
class GetAllResponse(Response):
|
||||
"""
|
||||
Response of storage.get_all endpoint.
|
||||
|
||||
:param results: Storage entries list
|
||||
:type results: Sequence[Storage]
|
||||
"""
|
||||
_service = "storage"
|
||||
_action = "get_all"
|
||||
_version = "1.5"
|
||||
|
||||
_schema = {
|
||||
'definitions': {
|
||||
'credentials': {
|
||||
'properties': {
|
||||
'access_key': {
|
||||
'description': 'Credentials access key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'secret_key': {
|
||||
'description': 'Credentials secret key',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
'storage': {
|
||||
'properties': {
|
||||
'company': {
|
||||
'description': 'Company ID',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'created': {
|
||||
'description': 'Entry creation time',
|
||||
'format': 'date-time',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'credentials': {
|
||||
'description': 'Credentials required for accessing the storage',
|
||||
'oneOf': [
|
||||
{'$ref': '#/definitions/credentials'},
|
||||
{'type': 'null'},
|
||||
],
|
||||
},
|
||||
'id': {'description': 'Entry ID', 'type': ['string', 'null']},
|
||||
'name': {
|
||||
'description': 'Entry name',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
'uri': {
|
||||
'description': 'Storage URI',
|
||||
'type': ['string', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
},
|
||||
},
|
||||
'properties': {
|
||||
'results': {
|
||||
'description': 'Storage entries list',
|
||||
'items': {'$ref': '#/definitions/storage'},
|
||||
'type': ['array', 'null'],
|
||||
},
|
||||
},
|
||||
'type': 'object',
|
||||
}
|
||||
def __init__(
|
||||
self, results=None, **kwargs):
|
||||
super(GetAllResponse, self).__init__(**kwargs)
|
||||
self.results = results
|
||||
|
||||
@schema_property('results')
|
||||
def results(self):
|
||||
return self._property_results
|
||||
|
||||
@results.setter
|
||||
def results(self, value):
|
||||
if value is None:
|
||||
self._property_results = None
|
||||
return
|
||||
|
||||
self.assert_isinstance(value, "results", (list, tuple))
|
||||
if any(isinstance(v, dict) for v in value):
|
||||
value = [Storage.from_dict(v) if isinstance(v, dict) else v for v in value]
|
||||
else:
|
||||
self.assert_isinstance(value, "results", Storage, is_array=True)
|
||||
self._property_results = value
|
||||
|
||||
|
||||
response_mapping = {
|
||||
GetAllRequest: GetAllResponse,
|
||||
CreateRequest: CreateResponse,
|
||||
DeleteRequest: DeleteResponse,
|
||||
}
|
||||
8460
trains/backend_api/services/v2_1/tasks.py
Normal file
8460
trains/backend_api/services/v2_1/tasks.py
Normal file
File diff suppressed because it is too large
Load Diff
7
trains/backend_api/session/__init__.py
Normal file
7
trains/backend_api/session/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .session import Session
|
||||
from .datamodel import DataModel, NonStrictDataModel, schema_property, StringEnum
|
||||
from .request import Request, BatchRequest, CompoundRequest
|
||||
from .response import Response
|
||||
from .token_manager import TokenManager
|
||||
from .errors import TimeoutExpiredError, ResultNotReadyError
|
||||
from .callresult import CallResult
|
||||
8
trains/backend_api/session/apimodel.py
Normal file
8
trains/backend_api/session/apimodel.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from .datamodel import DataModel
|
||||
|
||||
|
||||
class ApiModel(DataModel):
|
||||
""" API-related data model """
|
||||
_service = None
|
||||
_action = None
|
||||
_version = None
|
||||
131
trains/backend_api/session/callresult.py
Normal file
131
trains/backend_api/session/callresult.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import sys
|
||||
import time
|
||||
|
||||
from ...backend_api.utils import get_response_cls
|
||||
|
||||
from .response import ResponseMeta, Response
|
||||
from .errors import ResultNotReadyError, TimeoutExpiredError
|
||||
|
||||
|
||||
class CallResult(object):
|
||||
@property
|
||||
def meta(self):
|
||||
return self.__meta
|
||||
|
||||
@property
|
||||
def response(self):
|
||||
return self.__response
|
||||
|
||||
@property
|
||||
def response_data(self):
|
||||
return self.__response_data
|
||||
|
||||
@property
|
||||
def async_accepted(self):
|
||||
return self.meta.result_code == 202
|
||||
|
||||
@property
|
||||
def request_cls(self):
|
||||
return self.__request_cls
|
||||
|
||||
def __init__(self, meta, response=None, response_data=None, request_cls=None, session=None):
|
||||
assert isinstance(meta, ResponseMeta)
|
||||
if response and not isinstance(response, Response):
|
||||
raise ValueError('response should be an instance of %s' % Response.__name__)
|
||||
elif response_data and not isinstance(response_data, dict):
|
||||
raise TypeError('data should be an instance of {}'.format(dict.__name__))
|
||||
|
||||
self.__meta = meta
|
||||
self.__response = response
|
||||
self.__request_cls = request_cls
|
||||
self.__session = session
|
||||
self.__async_result = None
|
||||
|
||||
if response_data is not None:
|
||||
self.__response_data = response_data
|
||||
elif response is not None:
|
||||
try:
|
||||
self.__response_data = response.to_dict()
|
||||
except AttributeError:
|
||||
raise TypeError('response should be an instance of {}'.format(Response.__name__))
|
||||
else:
|
||||
self.__response_data = None
|
||||
|
||||
@classmethod
|
||||
def from_result(cls, res, request_cls=None, logger=None, service=None, action=None, session=None):
|
||||
""" From requests result """
|
||||
response_cls = get_response_cls(request_cls)
|
||||
try:
|
||||
data = res.json()
|
||||
except ValueError:
|
||||
service = service or (request_cls._service if request_cls else 'unknown')
|
||||
action = action or (request_cls._action if request_cls else 'unknown')
|
||||
return cls(request_cls=request_cls, meta=ResponseMeta.from_raw_data(
|
||||
status_code=res.status_code, text=res.text, endpoint='%(service)s.%(action)s' % locals()))
|
||||
if 'meta' not in data:
|
||||
raise ValueError('Missing meta section in response payload')
|
||||
try:
|
||||
meta = ResponseMeta(**data['meta'])
|
||||
# TODO: validate meta?
|
||||
# meta.validate()
|
||||
except Exception as ex:
|
||||
raise ValueError('Failed parsing meta section in response payload (data=%s, error=%s)' % (data, ex))
|
||||
|
||||
response = None
|
||||
response_data = None
|
||||
try:
|
||||
response_data = data.get('data', {})
|
||||
if response_cls:
|
||||
response = response_cls(**response_data)
|
||||
# TODO: validate response?
|
||||
# response.validate()
|
||||
except Exception as e:
|
||||
if logger:
|
||||
logger.warn('Failed parsing response: %s' % str(e))
|
||||
return cls(meta=meta, response=response, response_data=response_data, request_cls=request_cls, session=session)
|
||||
|
||||
def ok(self):
|
||||
return self.meta.result_code == 200
|
||||
|
||||
def ready(self):
|
||||
if not self.async_accepted:
|
||||
return True
|
||||
session = self.__session
|
||||
res = session.send_request(service='async', action='result', json=dict(id=self.meta.id), async_enable=False)
|
||||
if res.status_code != session._async_status_code:
|
||||
self.__async_result = CallResult.from_result(res=res, request_cls=self.request_cls, logger=session._logger)
|
||||
return True
|
||||
|
||||
def result(self):
|
||||
if not self.async_accepted:
|
||||
return self
|
||||
if self.__async_result is None:
|
||||
raise ResultNotReadyError(self._format_msg('Timeout expired'), call_id=self.meta.id)
|
||||
return self.__async_result
|
||||
|
||||
def wait(self, timeout=None, poll_interval=5, verbose=False):
|
||||
if not self.async_accepted:
|
||||
return self
|
||||
session = self.__session
|
||||
poll_interval = max(1, poll_interval)
|
||||
remaining = max(0, timeout) if timeout else sys.maxsize
|
||||
while remaining > 0:
|
||||
if not self.ready():
|
||||
# Still pending, log and continue
|
||||
if verbose and session._logger:
|
||||
progress = ('waiting forever'
|
||||
if timeout is False
|
||||
else '%.1f/%.1f seconds remaining' % (remaining, float(timeout or 0)))
|
||||
session._logger.info('Waiting for asynchronous call %s (%s)'
|
||||
% (self.request_cls.__name__, progress))
|
||||
time.sleep(poll_interval)
|
||||
remaining -= poll_interval
|
||||
continue
|
||||
# We've got something (good or bad, we don't know), create a call result and return
|
||||
return self.result()
|
||||
|
||||
# Timeout expired, return the asynchronous call's result (we've got nothing better to report)
|
||||
raise TimeoutExpiredError(self._format_msg('Timeout expired'), call_id=self.meta.id)
|
||||
|
||||
def _format_msg(self, msg):
|
||||
return msg + ' for call %s (%s)' % (self.request_cls.__name__, self.meta.id)
|
||||
145
trains/backend_api/session/datamodel.py
Normal file
145
trains/backend_api/session/datamodel.py
Normal file
@@ -0,0 +1,145 @@
|
||||
import keyword
|
||||
|
||||
import enum
|
||||
import json
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
|
||||
import jsonschema
|
||||
from enum import Enum
|
||||
|
||||
import six
|
||||
|
||||
|
||||
def format_date(obj):
|
||||
if isinstance(obj, datetime):
|
||||
return str(obj)
|
||||
|
||||
|
||||
class SchemaProperty(property):
|
||||
def __init__(self, name=None, *args, **kwargs):
|
||||
super(SchemaProperty, self).__init__(*args, **kwargs)
|
||||
self.name = name
|
||||
|
||||
def setter(self, fset):
|
||||
return type(self)(self.name, self.fget, fset, self.fdel, self.__doc__)
|
||||
|
||||
|
||||
def schema_property(name):
|
||||
def init(*args, **kwargs):
|
||||
return SchemaProperty(name, *args, **kwargs)
|
||||
return init
|
||||
|
||||
|
||||
class DataModel(object):
|
||||
""" Data Model"""
|
||||
_schema = None
|
||||
_data_props_list = None
|
||||
|
||||
@classmethod
|
||||
def _get_data_props(cls):
|
||||
props = cls._data_props_list
|
||||
if props is None:
|
||||
props = {}
|
||||
for c in cls.__mro__:
|
||||
props.update({k: getattr(v, 'name', k) for k, v in vars(c).items()
|
||||
if isinstance(v, property)})
|
||||
cls._data_props_list = props
|
||||
return props.copy()
|
||||
|
||||
@classmethod
|
||||
def _to_base_type(cls, value):
|
||||
if isinstance(value, DataModel):
|
||||
return value.to_dict()
|
||||
elif isinstance(value, enum.Enum):
|
||||
return value.value
|
||||
elif isinstance(value, list):
|
||||
return [cls._to_base_type(model) for model in value]
|
||||
return value
|
||||
|
||||
def to_dict(self, only=None, except_=None):
|
||||
prop_values = {v: getattr(self, k) for k, v in self._get_data_props().items()}
|
||||
return {
|
||||
k: self._to_base_type(v)
|
||||
for k, v in prop_values.items()
|
||||
if v is not None and (not only or k in only) and (not except_ or k not in except_)
|
||||
}
|
||||
|
||||
def validate(self, schema=None):
|
||||
jsonschema.validate(
|
||||
self.to_dict(),
|
||||
schema or self._schema,
|
||||
types=dict(array=(list, tuple), integer=six.integer_types),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return '<{}.{}: {}>'.format(
|
||||
self.__module__.split('.')[-1],
|
||||
type(self).__name__,
|
||||
json.dumps(
|
||||
self.to_dict(),
|
||||
indent=4,
|
||||
default=format_date,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def assert_isinstance(value, field_name, expected, is_array=False):
|
||||
if not is_array:
|
||||
if not isinstance(value, expected):
|
||||
raise TypeError("Expected %s of type %s, got %s" % (field_name, expected, type(value).__name__))
|
||||
return
|
||||
|
||||
if not all(isinstance(x, expected) for x in value):
|
||||
raise TypeError(
|
||||
"Expected %s of type list[%s], got %s" % (
|
||||
field_name,
|
||||
expected,
|
||||
", ".join(set(type(x).__name__ for x in value)),
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def normalize_key(prop_key):
|
||||
if keyword.iskeyword(prop_key):
|
||||
prop_key += '_'
|
||||
return prop_key.replace('.', '__')
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, dct, strict=False):
|
||||
"""
|
||||
Create an instance from a dictionary while ignoring unnecessary keys
|
||||
"""
|
||||
allowed_keys = cls._get_data_props().values()
|
||||
invalid_keys = set(dct).difference(allowed_keys)
|
||||
if strict and invalid_keys:
|
||||
raise ValueError("Invalid keys %s" % tuple(invalid_keys))
|
||||
return cls(**{cls.normalize_key(key): value for key, value in dct.items() if key not in invalid_keys})
|
||||
|
||||
|
||||
class UnusedKwargsWarning(UserWarning):
|
||||
pass
|
||||
|
||||
|
||||
class NonStrictDataModelMixin(object):
|
||||
"""
|
||||
NonStrictDataModelMixin
|
||||
|
||||
:summary: supplies an __init__ method that warns about unused keywords
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
unexpected = [key for key in kwargs if not key.startswith('_')]
|
||||
if unexpected:
|
||||
message = '{}: unused keyword argument(s) {}' \
|
||||
.format(type(self).__name__, unexpected)
|
||||
warnings.warn(message, UnusedKwargsWarning)
|
||||
|
||||
|
||||
class NonStrictDataModel(DataModel, NonStrictDataModelMixin):
|
||||
pass
|
||||
|
||||
|
||||
class StringEnum(Enum):
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
7
trains/backend_api/session/defs.py
Normal file
7
trains/backend_api/session/defs.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from ...backend_config import EnvEntry
|
||||
|
||||
|
||||
ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
|
||||
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
|
||||
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
|
||||
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)
|
||||
17
trains/backend_api/session/errors.py
Normal file
17
trains/backend_api/session/errors.py
Normal file
@@ -0,0 +1,17 @@
|
||||
class SessionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class AsyncError(SessionError):
|
||||
def __init__(self, msg, *args, **kwargs):
|
||||
super(AsyncError, self).__init__(msg, *args)
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class TimeoutExpiredError(SessionError):
|
||||
pass
|
||||
|
||||
|
||||
class ResultNotReadyError(SessionError):
|
||||
pass
|
||||
76
trains/backend_api/session/request.py
Normal file
76
trains/backend_api/session/request.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import abc
|
||||
|
||||
import jsonschema
|
||||
import six
|
||||
|
||||
from .apimodel import ApiModel
|
||||
from .datamodel import DataModel
|
||||
|
||||
|
||||
class Request(ApiModel):
|
||||
_method = 'get'
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if kwargs:
|
||||
raise ValueError('Unsupported keyword arguments: %s' % ', '.join(kwargs.keys()))
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class BatchRequest(Request):
|
||||
|
||||
_batched_request_cls = abc.abstractproperty()
|
||||
|
||||
_schema_errors = (jsonschema.SchemaError, jsonschema.ValidationError, jsonschema.FormatError,
|
||||
jsonschema.RefResolutionError)
|
||||
|
||||
def __init__(self, requests, validate_requests=False, allow_raw_requests=True, **kwargs):
|
||||
super(BatchRequest, self).__init__(**kwargs)
|
||||
self._validate_requests = validate_requests
|
||||
self._allow_raw_requests = allow_raw_requests
|
||||
self._property_requests = None
|
||||
self.requests = requests
|
||||
|
||||
@property
|
||||
def requests(self):
|
||||
return self._property_requests
|
||||
|
||||
@requests.setter
|
||||
def requests(self, value):
|
||||
assert issubclass(self._batched_request_cls, Request)
|
||||
assert isinstance(value, (list, tuple))
|
||||
if not self._allow_raw_requests:
|
||||
if any(isinstance(x, dict) for x in value):
|
||||
value = [self._batched_request_cls(**x) if isinstance(x, dict) else x for x in value]
|
||||
assert all(isinstance(x, self._batched_request_cls) for x in value)
|
||||
|
||||
self._property_requests = value
|
||||
|
||||
def validate(self):
|
||||
if not self._validate_requests or self._allow_raw_requests:
|
||||
return
|
||||
for i, req in enumerate(self.requests):
|
||||
try:
|
||||
req.validate()
|
||||
except (jsonschema.SchemaError, jsonschema.ValidationError,
|
||||
jsonschema.FormatError, jsonschema.RefResolutionError) as e:
|
||||
raise Exception('Validation error in batch item #%d: %s' % (i, str(e)))
|
||||
|
||||
def get_json(self):
|
||||
return [r if isinstance(r, dict) else r.to_dict() for r in self.requests]
|
||||
|
||||
|
||||
class CompoundRequest(Request):
|
||||
_item_prop_name = 'item'
|
||||
|
||||
def _get_item(self):
|
||||
item = getattr(self, self._item_prop_name, None)
|
||||
if item is None:
|
||||
raise ValueError('Item property is empty or missing')
|
||||
assert isinstance(item, DataModel)
|
||||
return item
|
||||
|
||||
def to_dict(self):
|
||||
return self._get_item().to_dict()
|
||||
|
||||
def validate(self):
|
||||
return self._get_item().validate(self._schema)
|
||||
49
trains/backend_api/session/response.py
Normal file
49
trains/backend_api/session/response.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import requests
|
||||
|
||||
import jsonmodels.models
|
||||
import jsonmodels.fields
|
||||
import jsonmodels.errors
|
||||
|
||||
from .apimodel import ApiModel
|
||||
from .datamodel import NonStrictDataModelMixin
|
||||
|
||||
|
||||
class Response(ApiModel, NonStrictDataModelMixin):
|
||||
pass
|
||||
|
||||
|
||||
class _ResponseEndpoint(jsonmodels.models.Base):
|
||||
name = jsonmodels.fields.StringField()
|
||||
requested_version = jsonmodels.fields.FloatField()
|
||||
actual_version = jsonmodels.fields.FloatField()
|
||||
|
||||
|
||||
class ResponseMeta(jsonmodels.models.Base):
|
||||
@property
|
||||
def is_valid(self):
|
||||
return self._is_valid
|
||||
|
||||
@classmethod
|
||||
def from_raw_data(cls, status_code, text, endpoint=None):
|
||||
return cls(is_valid=False, result_code=status_code, result_subcode=0, result_msg=text,
|
||||
endpoint=_ResponseEndpoint(name=(endpoint or 'unknown')))
|
||||
|
||||
def __init__(self, is_valid=True, **kwargs):
|
||||
super(ResponseMeta, self).__init__(**kwargs)
|
||||
self._is_valid = is_valid
|
||||
|
||||
id = jsonmodels.fields.StringField(required=True)
|
||||
trx = jsonmodels.fields.StringField(required=True)
|
||||
endpoint = jsonmodels.fields.EmbeddedField([_ResponseEndpoint], required=True)
|
||||
result_code = jsonmodels.fields.IntField(required=True)
|
||||
result_subcode = jsonmodels.fields.IntField()
|
||||
result_msg = jsonmodels.fields.StringField(required=True)
|
||||
error_stack = jsonmodels.fields.StringField()
|
||||
|
||||
def __str__(self):
|
||||
if self.result_code == requests.codes.ok:
|
||||
return "<%d: %s/v%.1f>" % (self.result_code, self.endpoint.name, self.endpoint.actual_version)
|
||||
elif self._is_valid:
|
||||
return "<%d/%d: %s/v%.1f (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name,
|
||||
self.endpoint.actual_version, self.result_msg)
|
||||
return "<%d/%d: %s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name, self.result_msg)
|
||||
425
trains/backend_api/session/session.py
Normal file
425
trains/backend_api/session/session.py
Normal file
@@ -0,0 +1,425 @@
|
||||
import json as json_lib
|
||||
import sys
|
||||
import types
|
||||
from socket import gethostname
|
||||
|
||||
import requests
|
||||
import six
|
||||
from pyhocon import ConfigTree
|
||||
from requests.auth import HTTPBasicAuth
|
||||
|
||||
from .callresult import CallResult
|
||||
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY
|
||||
from .request import Request, BatchRequest
|
||||
from .token_manager import TokenManager
|
||||
from ..config import load
|
||||
from ..utils import get_http_session_with_retry
|
||||
from ..version import __version__
|
||||
|
||||
|
||||
class LoginError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Session(TokenManager):
|
||||
""" TRAINS API Session class. """
|
||||
|
||||
_AUTHORIZATION_HEADER = "Authorization"
|
||||
_WORKER_HEADER = "X-Trains-Worker"
|
||||
_ASYNC_HEADER = "X-Trains-Async"
|
||||
_CLIENT_HEADER = "X-Trains-Client"
|
||||
|
||||
_async_status_code = 202
|
||||
_session_requests = 0
|
||||
_session_initial_timeout = (1.0, 10)
|
||||
_session_timeout = (5.0, None)
|
||||
|
||||
# TODO: add requests.codes.gateway_timeout once we support async commits
|
||||
_retry_codes = [
|
||||
requests.codes.bad_gateway,
|
||||
requests.codes.service_unavailable,
|
||||
requests.codes.bandwidth_limit_exceeded,
|
||||
requests.codes.too_many_requests,
|
||||
]
|
||||
|
||||
@property
|
||||
def access_key(self):
|
||||
return self.__access_key
|
||||
|
||||
@property
|
||||
def secret_key(self):
|
||||
return self.__secret_key
|
||||
|
||||
@property
|
||||
def host(self):
|
||||
return self.__host
|
||||
|
||||
@property
|
||||
def worker(self):
|
||||
return self.__worker
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
worker=None,
|
||||
api_key=None,
|
||||
secret_key=None,
|
||||
host=None,
|
||||
logger=None,
|
||||
verbose=None,
|
||||
initialize_logging=True,
|
||||
client=None,
|
||||
config=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
if config is not None:
|
||||
self.config = config
|
||||
else:
|
||||
self.config = load()
|
||||
if initialize_logging:
|
||||
self.config.initialize_logging()
|
||||
|
||||
token_expiration_threshold_sec = self.config.get(
|
||||
"auth.token_expiration_threshold_sec", 60
|
||||
)
|
||||
|
||||
super(Session, self).__init__(
|
||||
token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs
|
||||
)
|
||||
|
||||
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
|
||||
self._logger = logger
|
||||
|
||||
self.__access_key = api_key or ENV_ACCESS_KEY.get(
|
||||
default=(self.config.get("api.credentials.access_key", None) or
|
||||
"EGRTCO8JMSIGI6S39GTP43NFWXDQOW")
|
||||
)
|
||||
if not self.access_key:
|
||||
raise ValueError(
|
||||
"Missing access_key. Please set in configuration file or pass in session init."
|
||||
)
|
||||
|
||||
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
|
||||
default=(self.config.get("api.credentials.secret_key", None) or
|
||||
"x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8")
|
||||
)
|
||||
if not self.secret_key:
|
||||
raise ValueError(
|
||||
"Missing secret_key. Please set in configuration file or pass in session init."
|
||||
)
|
||||
|
||||
host = host or ENV_HOST.get(default=self.config.get("api.host"))
|
||||
if not host:
|
||||
raise ValueError("host is required in init or config")
|
||||
|
||||
self.__host = host.strip("/")
|
||||
http_retries_config = self.config.get(
|
||||
"api.http.retries", ConfigTree()
|
||||
).as_plain_ordered_dict()
|
||||
http_retries_config["status_forcelist"] = self._retry_codes
|
||||
self.__http_session = get_http_session_with_retry(**http_retries_config)
|
||||
|
||||
self.__worker = worker or gethostname()
|
||||
|
||||
self.__max_req_size = self.config.get("api.http.max_req_size")
|
||||
if not self.__max_req_size:
|
||||
raise ValueError("missing max request size")
|
||||
|
||||
self.client = client or "api-{}".format(__version__)
|
||||
|
||||
self.refresh_token()
|
||||
|
||||
def _send_request(
|
||||
self,
|
||||
service,
|
||||
action,
|
||||
version=None,
|
||||
method="get",
|
||||
headers=None,
|
||||
auth=None,
|
||||
data=None,
|
||||
json=None,
|
||||
refresh_token_if_unauthorized=True,
|
||||
):
|
||||
""" Internal implementation for making a raw API request.
|
||||
- Constructs the api endpoint name
|
||||
- Injects the worker id into the headers
|
||||
- Allows custom authorization using a requests auth object
|
||||
- Intercepts `Unauthorized` responses and automatically attempts to refresh the session token once in this
|
||||
case (only once). This is done since permissions are embedded in the token, and addresses a case where
|
||||
server-side permissions have changed but are not reflected in the current token. Refreshing the token will
|
||||
generate a token with the updated permissions.
|
||||
"""
|
||||
host = self.host
|
||||
headers = headers.copy() if headers else {}
|
||||
headers[self._WORKER_HEADER] = self.worker
|
||||
headers[self._CLIENT_HEADER] = self.client
|
||||
|
||||
token_refreshed_on_error = False
|
||||
url = (
|
||||
"{host}/v{version}/{service}.{action}"
|
||||
if version
|
||||
else "{host}/{service}.{action}"
|
||||
).format(**locals())
|
||||
while True:
|
||||
res = self.__http_session.request(
|
||||
method, url, headers=headers, auth=auth, data=data, json=json,
|
||||
timeout=self._session_initial_timeout if self._session_requests < 1 else self._session_timeout,
|
||||
)
|
||||
if (
|
||||
refresh_token_if_unauthorized
|
||||
and res.status_code == requests.codes.unauthorized
|
||||
and not token_refreshed_on_error
|
||||
):
|
||||
# it seems we're unauthorized, so we'll try to refresh our token once in case permissions changed since
|
||||
# the last time we got the token, and try again
|
||||
self.refresh_token()
|
||||
token_refreshed_on_error = True
|
||||
# try again
|
||||
continue
|
||||
if (
|
||||
res.status_code == requests.codes.service_unavailable
|
||||
and self.config.get("api.http.wait_on_maintenance_forever", True)
|
||||
):
|
||||
self._logger.warn(
|
||||
"Service unavailable: {} is undergoing maintenance, retrying...".format(
|
||||
host
|
||||
)
|
||||
)
|
||||
continue
|
||||
break
|
||||
self._session_requests += 1
|
||||
return res
|
||||
|
||||
def send_request(
|
||||
self,
|
||||
service,
|
||||
action,
|
||||
version=None,
|
||||
method="get",
|
||||
headers=None,
|
||||
data=None,
|
||||
json=None,
|
||||
async_enable=False,
|
||||
):
|
||||
"""
|
||||
Send a raw API request.
|
||||
:param service: service name
|
||||
:param action: action name
|
||||
:param version: version number (default is the preconfigured api version)
|
||||
:param method: method type (default is 'get')
|
||||
:param headers: request headers (authorization and content type headers will be automatically added)
|
||||
:param json: json to send in the request body (jsonable object or builtin types construct. if used,
|
||||
content type will be application/json)
|
||||
:param data: Dictionary, bytes, or file-like object to send in the request body
|
||||
:param async_enable: whether request is asynchronous
|
||||
:return: requests Response instance
|
||||
"""
|
||||
headers = headers.copy() if headers else {}
|
||||
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
|
||||
if async_enable:
|
||||
headers[self._ASYNC_HEADER] = "1"
|
||||
return self._send_request(
|
||||
service=service,
|
||||
action=action,
|
||||
version=version,
|
||||
method=method,
|
||||
headers=headers,
|
||||
data=data,
|
||||
json=json,
|
||||
)
|
||||
|
||||
def send_request_batch(
|
||||
self,
|
||||
service,
|
||||
action,
|
||||
version=None,
|
||||
headers=None,
|
||||
data=None,
|
||||
json=None,
|
||||
method="get",
|
||||
):
|
||||
"""
|
||||
Send a raw batch API request. Batch requests always use application/json-lines content type.
|
||||
:param service: service name
|
||||
:param action: action name
|
||||
:param version: version number (default is the preconfigured api version)
|
||||
:param headers: request headers (authorization and content type headers will be automatically added)
|
||||
:param json: iterable of json items (batched items, jsonable objects or builtin types constructs). These will
|
||||
be sent as a multi-line payload in the request body.
|
||||
:param data: iterable of bytes objects (batched items). These will be sent as a multi-line payload in the
|
||||
request body.
|
||||
:param method: HTTP method
|
||||
:return: requests Response instance
|
||||
"""
|
||||
if not all(
|
||||
isinstance(x, (list, tuple, type(None), types.GeneratorType))
|
||||
for x in (data, json)
|
||||
):
|
||||
raise ValueError("Expecting list, tuple or generator in 'data' or 'json'")
|
||||
|
||||
if not data and not json:
|
||||
raise ValueError(
|
||||
"Missing data (data or json), batch requests are meaningless without it."
|
||||
)
|
||||
|
||||
headers = headers.copy() if headers else {}
|
||||
headers["Content-Type"] = "application/json-lines"
|
||||
|
||||
if data:
|
||||
req_data = "\n".join(data)
|
||||
else:
|
||||
req_data = "\n".join(json_lib.dumps(x) for x in json)
|
||||
|
||||
cur = 0
|
||||
results = []
|
||||
while True:
|
||||
size = self.__max_req_size
|
||||
slice = req_data[cur : cur + size]
|
||||
if not slice:
|
||||
break
|
||||
if len(slice) < size:
|
||||
# this is the remainder, no need to search for newline
|
||||
pass
|
||||
elif slice[-1] != "\n":
|
||||
# search for the last newline in order to send a coherent request
|
||||
size = slice.rfind("\n") + 1
|
||||
# readjust the slice
|
||||
slice = req_data[cur : cur + size]
|
||||
res = self.send_request(
|
||||
method=method,
|
||||
service=service,
|
||||
action=action,
|
||||
data=slice,
|
||||
headers=headers,
|
||||
version=version,
|
||||
)
|
||||
results.append(res)
|
||||
if res.status_code != requests.codes.ok:
|
||||
break
|
||||
cur += size
|
||||
return results
|
||||
|
||||
def validate_request(self, req_obj):
|
||||
""" Validate an API request against the current version and the request's schema """
|
||||
|
||||
try:
|
||||
# make sure we're using a compatible version for this request
|
||||
# validate the request (checks required fields and specific field version restrictions)
|
||||
validate = req_obj.validate
|
||||
except AttributeError:
|
||||
raise TypeError(
|
||||
'"req_obj" parameter must be an backend_api.session.Request object'
|
||||
)
|
||||
|
||||
validate()
|
||||
|
||||
def send_async(self, req_obj):
|
||||
"""
|
||||
Asynchronously sends an API request using a request object.
|
||||
:param req_obj: The request object
|
||||
:type req_obj: Request
|
||||
:return: CallResult object containing the raw response, response metadata and parsed response object.
|
||||
"""
|
||||
return self.send(req_obj=req_obj, async_enable=True)
|
||||
|
||||
def send(self, req_obj, async_enable=False, headers=None):
|
||||
"""
|
||||
Sends an API request using a request object.
|
||||
:param req_obj: The request object
|
||||
:type req_obj: Request
|
||||
:param async_enable: Request this method be executed in an asynchronous manner
|
||||
:param headers: Additional headers to send with request
|
||||
:return: CallResult object containing the raw response, response metadata and parsed response object.
|
||||
"""
|
||||
self.validate_request(req_obj)
|
||||
|
||||
if isinstance(req_obj, BatchRequest):
|
||||
# TODO: support async for batch requests as well
|
||||
if async_enable:
|
||||
raise NotImplementedError(
|
||||
"Async behavior is currently not implemented for batch requests"
|
||||
)
|
||||
|
||||
json_data = req_obj.get_json()
|
||||
res = self.send_request_batch(
|
||||
service=req_obj._service,
|
||||
action=req_obj._action,
|
||||
version=req_obj._version,
|
||||
json=json_data,
|
||||
method=req_obj._method,
|
||||
headers=headers,
|
||||
)
|
||||
# TODO: handle multiple results in this case
|
||||
try:
|
||||
res = next(r for r in res if r.status_code != 200)
|
||||
except StopIteration:
|
||||
# all are 200
|
||||
res = res[0]
|
||||
else:
|
||||
res = self.send_request(
|
||||
service=req_obj._service,
|
||||
action=req_obj._action,
|
||||
version=req_obj._version,
|
||||
json=req_obj.to_dict(),
|
||||
method=req_obj._method,
|
||||
async_enable=async_enable,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
call_result = CallResult.from_result(
|
||||
res=res,
|
||||
request_cls=req_obj.__class__,
|
||||
logger=self._logger,
|
||||
service=req_obj._service,
|
||||
action=req_obj._action,
|
||||
session=self,
|
||||
)
|
||||
|
||||
return call_result
|
||||
|
||||
def _do_refresh_token(self, old_token, exp=None):
|
||||
""" TokenManager abstract method implementation.
|
||||
Here we ignore the old token and simply obtain a new token.
|
||||
"""
|
||||
verbose = self._verbose and self._logger
|
||||
if verbose:
|
||||
self._logger.info(
|
||||
"Refreshing token from {} (access_key={}, exp={})".format(
|
||||
self.host, self.access_key, exp
|
||||
)
|
||||
)
|
||||
|
||||
auth = HTTPBasicAuth(self.access_key, self.secret_key)
|
||||
try:
|
||||
data = {"expiration_sec": exp} if exp else {}
|
||||
res = self._send_request(
|
||||
service="auth",
|
||||
action="login",
|
||||
auth=auth,
|
||||
json=data,
|
||||
refresh_token_if_unauthorized=False,
|
||||
)
|
||||
try:
|
||||
resp = res.json()
|
||||
except ValueError:
|
||||
resp = {}
|
||||
if res.status_code != 200:
|
||||
msg = resp.get("meta", {}).get("result_msg", res.reason)
|
||||
raise LoginError(
|
||||
"Failed getting token (error {} from {}): {}".format(
|
||||
res.status_code, self.host, msg
|
||||
)
|
||||
)
|
||||
if verbose:
|
||||
self._logger.info("Received new token")
|
||||
return resp["data"]["token"]
|
||||
except LoginError:
|
||||
six.reraise(*sys.exc_info())
|
||||
except Exception as ex:
|
||||
raise LoginError(str(ex))
|
||||
|
||||
def __str__(self):
|
||||
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
|
||||
self=self, secret_key=self.secret_key[:5] + "*" * (len(self.secret_key) - 5)
|
||||
)
|
||||
95
trains/backend_api/session/token_manager.py
Normal file
95
trains/backend_api/session/token_manager.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import sys
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from time import time
|
||||
|
||||
import jwt
|
||||
import six
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class TokenManager(object):
|
||||
|
||||
@property
|
||||
def token_expiration_threshold_sec(self):
|
||||
return self.__token_expiration_threshold_sec
|
||||
|
||||
@token_expiration_threshold_sec.setter
|
||||
def token_expiration_threshold_sec(self, value):
|
||||
self.__token_expiration_threshold_sec = value
|
||||
|
||||
@property
|
||||
def req_token_expiration_sec(self):
|
||||
""" Token expiration sec requested when refreshing token """
|
||||
return self.__req_token_expiration_sec
|
||||
|
||||
@req_token_expiration_sec.setter
|
||||
def req_token_expiration_sec(self, value):
|
||||
assert isinstance(value, (type(None), int))
|
||||
self.__req_token_expiration_sec = value
|
||||
|
||||
@property
|
||||
def token_expiration_sec(self):
|
||||
return self.__token_expiration_sec
|
||||
|
||||
@property
|
||||
def token(self):
|
||||
return self._get_token()
|
||||
|
||||
@property
|
||||
def raw_token(self):
|
||||
return self.__token
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token=None,
|
||||
req_token_expiration_sec=None,
|
||||
token_history=None,
|
||||
token_expiration_threshold_sec=60,
|
||||
**kwargs
|
||||
):
|
||||
super(TokenManager, self).__init__()
|
||||
assert isinstance(token_history, (type(None), dict))
|
||||
self.token_expiration_threshold_sec = token_expiration_threshold_sec
|
||||
self.req_token_expiration_sec = req_token_expiration_sec
|
||||
self._set_token(token)
|
||||
|
||||
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
|
||||
if token:
|
||||
try:
|
||||
exp = exp or self._get_token_exp(token)
|
||||
if at_least_sec:
|
||||
at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec)
|
||||
else:
|
||||
at_least_sec = self.token_expiration_threshold_sec
|
||||
return max(0, (exp - time() - at_least_sec))
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
def _get_token_exp(cls, token):
|
||||
""" Get token expiration time. If not present, assume forever """
|
||||
return jwt.decode(token, verify=False).get('exp', sys.maxsize)
|
||||
|
||||
def _set_token(self, token):
|
||||
if token:
|
||||
self.__token = token
|
||||
self.__token_expiration_sec = self._get_token_exp(token)
|
||||
else:
|
||||
self.__token = None
|
||||
self.__token_expiration_sec = 0
|
||||
|
||||
def get_token_valid_period_sec(self):
|
||||
return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec)
|
||||
|
||||
def _get_token(self):
|
||||
if self.get_token_valid_period_sec() <= 0:
|
||||
self.refresh_token()
|
||||
return self.__token
|
||||
|
||||
@abstractmethod
|
||||
def _do_refresh_token(self, old_token, exp=None):
|
||||
pass
|
||||
|
||||
def refresh_token(self):
|
||||
self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec))
|
||||
86
trains/backend_api/utils.py
Normal file
86
trains/backend_api/utils.py
Normal file
@@ -0,0 +1,86 @@
|
||||
import ssl
|
||||
import sys
|
||||
|
||||
import requests
|
||||
from requests.adapters import HTTPAdapter
|
||||
## from requests.packages.urllib3.util.retry import Retry
|
||||
from urllib3.util import Retry
|
||||
from urllib3 import PoolManager
|
||||
import six
|
||||
|
||||
if six.PY3:
|
||||
from functools import lru_cache
|
||||
elif six.PY2:
|
||||
# python 2 support
|
||||
from backports.functools_lru_cache import lru_cache
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_config():
|
||||
from ..backend_config import Config
|
||||
config = Config(verbose=False)
|
||||
config.reload()
|
||||
return config
|
||||
|
||||
|
||||
class TLSv1HTTPAdapter(HTTPAdapter):
|
||||
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
|
||||
self.poolmanager = PoolManager(num_pools=connections,
|
||||
maxsize=maxsize,
|
||||
block=block,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1_2)
|
||||
|
||||
|
||||
def get_http_session_with_retry(
|
||||
total=0,
|
||||
connect=None,
|
||||
read=None,
|
||||
redirect=None,
|
||||
status=None,
|
||||
status_forcelist=None,
|
||||
backoff_factor=0,
|
||||
backoff_max=None,
|
||||
pool_connections=None,
|
||||
pool_maxsize=None):
|
||||
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
|
||||
raise ValueError('Bad configuration. All retry count values must be null or int')
|
||||
|
||||
if status_forcelist and not all(isinstance(x, int) for x in status_forcelist):
|
||||
raise ValueError('Bad configuration. Retry status_forcelist must be null or list of ints')
|
||||
|
||||
pool_maxsize = (
|
||||
pool_maxsize
|
||||
if pool_maxsize is not None
|
||||
else get_config().get('api.http.pool_maxsize', 512)
|
||||
)
|
||||
|
||||
pool_connections = (
|
||||
pool_connections
|
||||
if pool_connections is not None
|
||||
else get_config().get('api.http.pool_connections', 512)
|
||||
)
|
||||
|
||||
session = requests.Session()
|
||||
|
||||
if backoff_max is not None:
|
||||
Retry.BACKOFF_MAX = backoff_max
|
||||
|
||||
retry = Retry(
|
||||
total=total, connect=connect, read=read, redirect=redirect, status=status,
|
||||
status_forcelist=status_forcelist, backoff_factor=backoff_factor)
|
||||
|
||||
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
|
||||
session.mount('http://', adapter)
|
||||
session.mount('https://', adapter)
|
||||
return session
|
||||
|
||||
|
||||
def get_response_cls(request_cls):
|
||||
""" Extract a request's response class using the mapping found in the module defining the request's service """
|
||||
for req_cls in request_cls.mro():
|
||||
module = sys.modules[req_cls.__module__]
|
||||
if hasattr(module, 'action_mapping'):
|
||||
return module.action_mapping[(request_cls._action, request_cls._version)][1]
|
||||
elif hasattr(module, 'response_mapping'):
|
||||
return module.response_mapping[req_cls]
|
||||
raise TypeError('no response class!')
|
||||
1
trains/backend_api/version.py
Normal file
1
trains/backend_api/version.py
Normal file
@@ -0,0 +1 @@
|
||||
__version__ = '2.0.0'
|
||||
4
trains/backend_config/__init__.py
Normal file
4
trains/backend_config/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .defs import Environment
|
||||
from .config import Config, ConfigEntry
|
||||
from .errors import ConfigurationError
|
||||
from .environment import EnvEntry
|
||||
291
trains/backend_config/bucket_config.py
Normal file
291
trains/backend_config/bucket_config.py
Normal file
@@ -0,0 +1,291 @@
|
||||
import abc
|
||||
import warnings
|
||||
from operator import itemgetter
|
||||
|
||||
import furl
|
||||
import six
|
||||
from attr import attrib, attrs
|
||||
|
||||
|
||||
def _none_to_empty_string(maybe_string):
|
||||
return maybe_string if maybe_string is not None else ""
|
||||
|
||||
|
||||
def _url_stripper(bucket):
|
||||
bucket = _none_to_empty_string(bucket)
|
||||
bucket = bucket.strip("\"'").rstrip("/")
|
||||
return bucket
|
||||
|
||||
|
||||
@attrs
|
||||
class S3BucketConfig(object):
|
||||
bucket = attrib(type=str, converter=_url_stripper, default="")
|
||||
host = attrib(type=str, converter=_none_to_empty_string, default="")
|
||||
key = attrib(type=str, converter=_none_to_empty_string, default="")
|
||||
secret = attrib(type=str, converter=_none_to_empty_string, default="")
|
||||
multipart = attrib(type=bool, default=True)
|
||||
acl = attrib(type=str, converter=_none_to_empty_string, default="")
|
||||
secure = attrib(type=bool, default=True)
|
||||
region = attrib(type=str, converter=_none_to_empty_string, default="")
|
||||
|
||||
def update(self, key, secret, multipart=True, region=None):
|
||||
self.key = key
|
||||
self.secret = secret
|
||||
self.multipart = multipart
|
||||
self.region = region
|
||||
|
||||
def is_valid(self):
|
||||
return self.key and self.secret
|
||||
|
||||
def get_bucket_host(self):
|
||||
return self.bucket, self.host
|
||||
|
||||
@classmethod
|
||||
def from_list(cls, dict_list, log=None):
|
||||
if not isinstance(dict_list, (tuple, list)) or not all(
|
||||
isinstance(x, dict) for x in dict_list
|
||||
):
|
||||
raise ValueError("Expecting a list of configurations dictionaries")
|
||||
configs = [cls(**entry) for entry in dict_list]
|
||||
valid_configs = [conf for conf in configs if conf.is_valid()]
|
||||
if log and len(valid_configs) < len(configs):
|
||||
log.warn(
|
||||
"Invalid bucket configurations detected for {}".format(
|
||||
", ".join(
|
||||
"/".join((config.host, config.bucket))
|
||||
for config in configs
|
||||
if config not in valid_configs
|
||||
)
|
||||
)
|
||||
)
|
||||
return configs
|
||||
|
||||
|
||||
BucketConfig = S3BucketConfig
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class BaseBucketConfigurations(object):
|
||||
def __init__(self, buckets=None, *_, **__):
|
||||
self._buckets = buckets or []
|
||||
self._prefixes = None
|
||||
|
||||
def _update_prefixes(self, refresh=True):
|
||||
if self._prefixes and not refresh:
|
||||
return
|
||||
prefixes = (
|
||||
(config, self._get_prefix_from_bucket_config(config))
|
||||
for config in self._buckets
|
||||
)
|
||||
self._prefixes = sorted(prefixes, key=itemgetter(1), reverse=True)
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_prefix_from_bucket_config(self, config):
|
||||
pass
|
||||
|
||||
|
||||
class S3BucketConfigurations(BaseBucketConfigurations):
|
||||
def __init__(
|
||||
self, buckets=None, default_key="", default_secret="", default_region=""
|
||||
):
|
||||
super(S3BucketConfigurations, self).__init__()
|
||||
self._buckets = buckets if buckets else list()
|
||||
self._default_key = default_key
|
||||
self._default_secret = default_secret
|
||||
self._default_region = default_region
|
||||
self._default_multipart = True
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, s3_configuration):
|
||||
config_list = S3BucketConfig.from_list(
|
||||
s3_configuration.get("credentials", default=None)
|
||||
)
|
||||
|
||||
default_key = s3_configuration.get("key", default="")
|
||||
default_secret = s3_configuration.get("secret", default="")
|
||||
default_region = s3_configuration.get("region", default="")
|
||||
|
||||
default_key = _none_to_empty_string(default_key)
|
||||
default_secret = _none_to_empty_string(default_secret)
|
||||
default_region = _none_to_empty_string(default_region)
|
||||
|
||||
return cls(config_list, default_key, default_secret, default_region)
|
||||
|
||||
def add_config(self, bucket_config):
|
||||
self._buckets.insert(0, bucket_config)
|
||||
self._prefixes = None
|
||||
|
||||
def remove_config(self, bucket_config):
|
||||
self._buckets.remove(bucket_config)
|
||||
self._prefixes = None
|
||||
|
||||
def get_config_by_bucket(self, bucket, host=None):
|
||||
try:
|
||||
return next(
|
||||
bucket_config
|
||||
for bucket_config in self._buckets
|
||||
if (bucket, host) == bucket_config.get_bucket_host()
|
||||
)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def update_config_with_defaults(self, bucket_config):
|
||||
bucket_config.update(
|
||||
key=self._default_key,
|
||||
secret=self._default_secret,
|
||||
region=bucket_config.region or self._default_region,
|
||||
multipart=bucket_config.multipart or self._default_multipart,
|
||||
)
|
||||
|
||||
def _get_prefix_from_bucket_config(self, config):
|
||||
scheme = "s3"
|
||||
prefix = furl.furl()
|
||||
|
||||
if config.host:
|
||||
prefix.set(
|
||||
scheme=scheme,
|
||||
netloc=config.host.lower(),
|
||||
path=config.bucket.lower() if config.bucket else "",
|
||||
)
|
||||
else:
|
||||
prefix.set(scheme=scheme, path=config.bucket.lower())
|
||||
bucket = prefix.path.segments[0]
|
||||
prefix.path.segments.pop(0)
|
||||
prefix.set(netloc=bucket)
|
||||
|
||||
return str(prefix)
|
||||
|
||||
def get_config_by_uri(self, uri):
|
||||
"""
|
||||
Get the credentials for an AWS S3 bucket from the config
|
||||
:param uri: URI of bucket, directory or file
|
||||
:return: bucket config
|
||||
:rtype: S3BucketConfig
|
||||
"""
|
||||
|
||||
def find_match(uri):
|
||||
self._update_prefixes(refresh=False)
|
||||
uri = uri.lower()
|
||||
res = (
|
||||
config
|
||||
for config, prefix in self._prefixes
|
||||
if prefix is not None and uri.startswith(prefix)
|
||||
)
|
||||
|
||||
try:
|
||||
return next(res)
|
||||
except StopIteration:
|
||||
return None
|
||||
|
||||
match = find_match(uri)
|
||||
|
||||
if match:
|
||||
return match
|
||||
|
||||
parsed = furl.furl(uri)
|
||||
|
||||
if parsed.port:
|
||||
host = parsed.netloc
|
||||
parts = parsed.path.segments
|
||||
bucket = parts[0] if parts else None
|
||||
else:
|
||||
host = None
|
||||
bucket = parsed.netloc
|
||||
|
||||
return S3BucketConfig(
|
||||
key=self._default_key,
|
||||
secret=self._default_secret,
|
||||
region=self._default_region,
|
||||
multipart=True,
|
||||
bucket=bucket,
|
||||
host=host,
|
||||
)
|
||||
|
||||
|
||||
BucketConfigurations = S3BucketConfigurations
|
||||
|
||||
|
||||
@attrs
|
||||
class GSBucketConfig(object):
|
||||
bucket = attrib(type=str)
|
||||
subdir = attrib(type=str, converter=_url_stripper, default="")
|
||||
project = attrib(type=str, default=None)
|
||||
credentials_json = attrib(type=str, default=None)
|
||||
|
||||
def update(self, **kwargs):
|
||||
for item in kwargs:
|
||||
if not hasattr(self, item):
|
||||
warnings.warn("Unexpected argument {} for update. Ignored".format(item))
|
||||
else:
|
||||
setattr(self, item, kwargs[item])
|
||||
|
||||
|
||||
class GSBucketConfigurations(BaseBucketConfigurations):
|
||||
def __init__(self, buckets=None, default_project=None, default_credentials=None):
|
||||
super(GSBucketConfigurations, self).__init__(buckets)
|
||||
self._default_project = default_project
|
||||
self._default_credentials = default_credentials
|
||||
|
||||
self._update_prefixes()
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, gs_configuration):
|
||||
if gs_configuration is None:
|
||||
return cls()
|
||||
|
||||
config_list = gs_configuration.get("credentials", default=list())
|
||||
buckets_configs = [GSBucketConfig(**entry) for entry in config_list]
|
||||
|
||||
default_project = gs_configuration.get("project", default=None)
|
||||
default_credentials = gs_configuration.get("credentials_json", default=None)
|
||||
|
||||
return cls(buckets_configs, default_project, default_credentials)
|
||||
|
||||
def add_config(self, bucket_config):
|
||||
self._buckets.insert(0, bucket_config)
|
||||
self._update_prefixes()
|
||||
|
||||
def remove_config(self, bucket_config):
|
||||
self._buckets.remove(bucket_config)
|
||||
self._update_prefixes()
|
||||
|
||||
def update_config_with_defaults(self, bucket_config):
|
||||
bucket_config.update(
|
||||
project=bucket_config.project or self._default_project,
|
||||
credentials_json=bucket_config.credentials_json
|
||||
or self._default_credentials,
|
||||
)
|
||||
|
||||
def get_config_by_uri(self, uri):
|
||||
"""
|
||||
Get the credentials for a Google Storage bucket from the config
|
||||
:param uri: URI of bucket, directory or file
|
||||
:return: bucket config
|
||||
:rtype: GSBucketConfig
|
||||
"""
|
||||
|
||||
res = (
|
||||
config
|
||||
for config, prefix in self._prefixes
|
||||
if prefix is not None and uri.lower().startswith(prefix)
|
||||
)
|
||||
|
||||
try:
|
||||
return next(res)
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
parsed = furl.furl(uri)
|
||||
|
||||
return GSBucketConfig(
|
||||
bucket=parsed.netloc,
|
||||
subdir=str(parsed.path),
|
||||
project=self._default_project,
|
||||
credentials_json=self._default_credentials,
|
||||
)
|
||||
|
||||
def _get_prefix_from_bucket_config(self, config):
|
||||
prefix = furl.furl(scheme="gs", netloc=config.bucket, path=config.subdir)
|
||||
return str(prefix)
|
||||
412
trains/backend_config/config.py
Normal file
412
trains/backend_config/config.py
Normal file
@@ -0,0 +1,412 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from fnmatch import fnmatch
|
||||
from logging import Logger
|
||||
from os.path import expanduser
|
||||
from typing import Any, Text
|
||||
|
||||
import pyhocon
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree
|
||||
from pyparsing import (
|
||||
ParseFatalException,
|
||||
ParseException,
|
||||
RecursiveGrammarException,
|
||||
ParseSyntaxException,
|
||||
)
|
||||
from six.moves.urllib.parse import urlparse
|
||||
from watchdog.observers import Observer
|
||||
|
||||
from .bucket_config import S3BucketConfig
|
||||
from .defs import (
|
||||
Environment,
|
||||
DEFAULT_CONFIG_FOLDER,
|
||||
LOCAL_CONFIG_PATHS,
|
||||
ENV_CONFIG_PATHS,
|
||||
LOCAL_CONFIG_FILES,
|
||||
LOCAL_CONFIG_FILE_OVERRIDE_VAR,
|
||||
ENV_CONFIG_PATH_OVERRIDE_VAR,
|
||||
)
|
||||
from .defs import is_config_file
|
||||
from .entry import Entry, NotSet
|
||||
from .errors import ConfigurationError
|
||||
from .log import initialize as initialize_log, logger
|
||||
from .reloader import ConfigReloader
|
||||
from .utils import get_options
|
||||
|
||||
log = logger(__file__)
|
||||
|
||||
|
||||
class ConfigEntry(Entry):
|
||||
logger = None
|
||||
|
||||
def __init__(self, config, *keys, **kwargs):
|
||||
# type: (Config, Text, Any) -> None
|
||||
super(ConfigEntry, self).__init__(*keys, **kwargs)
|
||||
self.config = config
|
||||
|
||||
def _get(self, key):
|
||||
# type: (Text) -> Any
|
||||
return self.config.get(key, NotSet)
|
||||
|
||||
def error(self, message):
|
||||
# type: (Text) -> None
|
||||
log.error(message.capitalize())
|
||||
|
||||
|
||||
class Config(object):
|
||||
"""
|
||||
Represents a server configuration.
|
||||
If watch=True, will watch configuration folders for changes and reload itself.
|
||||
NOTE: will not watch folders that were created after initialization.
|
||||
"""
|
||||
|
||||
# used in place of None in Config.get as default value because None is a valid value
|
||||
_MISSING = object()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config_folder=None,
|
||||
env=None,
|
||||
verbose=True,
|
||||
relative_to=None,
|
||||
app=None,
|
||||
watch=False,
|
||||
is_server=False,
|
||||
**_
|
||||
):
|
||||
self._app = app
|
||||
self._verbose = verbose
|
||||
self._folder_name = config_folder or DEFAULT_CONFIG_FOLDER
|
||||
self._roots = []
|
||||
self._config = ConfigTree()
|
||||
self._env = env or os.environ.get("TRAINS_ENV", Environment.default)
|
||||
self.config_paths = set()
|
||||
self.watch = watch
|
||||
self.is_server = is_server
|
||||
if watch:
|
||||
self.observer = Observer()
|
||||
self.observer.start()
|
||||
self.handler = ConfigReloader(self)
|
||||
|
||||
if self._verbose:
|
||||
print("Config env:%s" % str(self._env))
|
||||
|
||||
if not self._env:
|
||||
raise ValueError(
|
||||
"Missing environment in either init of environment variable"
|
||||
)
|
||||
if self._env not in get_options(Environment):
|
||||
raise ValueError("Invalid environment %s" % env)
|
||||
if relative_to is not None:
|
||||
self.load_relative_to(relative_to)
|
||||
|
||||
@property
|
||||
def root(self):
|
||||
return self.roots[0] if self.roots else None
|
||||
|
||||
@property
|
||||
def roots(self):
|
||||
return self._roots
|
||||
|
||||
@roots.setter
|
||||
def roots(self, value):
|
||||
self._roots = value
|
||||
|
||||
@property
|
||||
def env(self):
|
||||
return self._env
|
||||
|
||||
def logger(self, path=None):
|
||||
return logger(path)
|
||||
|
||||
def load_relative_to(self, *module_paths):
|
||||
def normalize(p):
|
||||
return Path(os.path.abspath(str(p))).with_name(self._folder_name)
|
||||
|
||||
self.roots = list(map(normalize, module_paths))
|
||||
self.reload()
|
||||
if self.watch:
|
||||
for path in self.config_paths:
|
||||
self.observer.schedule(self.handler, str(path), recursive=True)
|
||||
|
||||
def _reload(self):
|
||||
env = self._env
|
||||
config = self._config.copy()
|
||||
|
||||
if self.is_server:
|
||||
env_config_paths = ENV_CONFIG_PATHS
|
||||
else:
|
||||
env_config_paths = []
|
||||
|
||||
env_config_path_override = os.environ.get(ENV_CONFIG_PATH_OVERRIDE_VAR)
|
||||
if env_config_path_override:
|
||||
env_config_paths = [expanduser(env_config_path_override)]
|
||||
|
||||
# merge configuration from root and other environment config paths
|
||||
config = functools.reduce(
|
||||
lambda cfg, path: ConfigTree.merge_configs(
|
||||
cfg,
|
||||
self._read_recursive_for_env(path, env, verbose=self._verbose),
|
||||
copy_trees=True,
|
||||
),
|
||||
self.roots + env_config_paths,
|
||||
config,
|
||||
)
|
||||
|
||||
# merge configuration from local configuration paths
|
||||
config = functools.reduce(
|
||||
lambda cfg, path: ConfigTree.merge_configs(
|
||||
cfg, self._read_recursive(path, verbose=self._verbose), copy_trees=True
|
||||
),
|
||||
LOCAL_CONFIG_PATHS,
|
||||
config,
|
||||
)
|
||||
|
||||
local_config_files = LOCAL_CONFIG_FILES
|
||||
local_config_override = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR)
|
||||
if local_config_override:
|
||||
local_config_files = [expanduser(local_config_override)]
|
||||
|
||||
# merge configuration from local configuration files
|
||||
config = functools.reduce(
|
||||
lambda cfg, file_path: ConfigTree.merge_configs(
|
||||
cfg,
|
||||
self._read_single_file(file_path, verbose=self._verbose),
|
||||
copy_trees=True,
|
||||
),
|
||||
local_config_files,
|
||||
config,
|
||||
)
|
||||
|
||||
config["env"] = env
|
||||
return config
|
||||
|
||||
def replace(self, config):
|
||||
self._config = config
|
||||
|
||||
def reload(self):
|
||||
self.replace(self._reload())
|
||||
|
||||
def initialize_logging(self):
|
||||
logging_config = self._config.get("logging", None)
|
||||
if not logging_config:
|
||||
return False
|
||||
|
||||
# handle incomplete file handlers
|
||||
deleted = []
|
||||
handlers = logging_config.get("handlers", {})
|
||||
for name, handler in list(handlers.items()):
|
||||
cls = handler.get("class", None)
|
||||
is_file = cls and "FileHandler" in cls
|
||||
if cls is None or (is_file and "filename" not in handler):
|
||||
deleted.append(name)
|
||||
del handlers[name]
|
||||
elif is_file:
|
||||
file = Path(handler.get("filename"))
|
||||
if not file.is_file():
|
||||
file.parent.mkdir(parents=True, exist_ok=True)
|
||||
file.touch()
|
||||
|
||||
# remove dependency in deleted handlers
|
||||
root_logger = logging_config.get("root", None)
|
||||
loggers = list(logging_config.get("loggers", {}).values()) + (
|
||||
[root_logger] if root_logger else []
|
||||
)
|
||||
for logger in loggers:
|
||||
handlers = logger.get("handlers", None)
|
||||
if not handlers:
|
||||
continue
|
||||
logger["handlers"] = [h for h in handlers if h not in deleted]
|
||||
|
||||
extra = None
|
||||
if self._app:
|
||||
extra = {"app": self._app}
|
||||
initialize_log(logging_config, extra=extra)
|
||||
return True
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self._config[key]
|
||||
|
||||
def get(self, key, default=_MISSING):
|
||||
value = self._config.get(key, default)
|
||||
if value is self._MISSING and not default:
|
||||
raise KeyError(
|
||||
"Unable to find value for key '{}' and default value was not provided.".format(
|
||||
key
|
||||
)
|
||||
)
|
||||
return value
|
||||
|
||||
def to_dict(self):
|
||||
return self._config.as_plain_ordered_dict()
|
||||
|
||||
def as_json(self):
|
||||
return json.dumps(self.to_dict(), indent=2)
|
||||
|
||||
def _read_recursive_for_env(self, root_path_str, env, verbose=True):
|
||||
root_path = Path(root_path_str)
|
||||
if root_path.exists():
|
||||
default_config = self._read_recursive(
|
||||
root_path / Environment.default, verbose=verbose
|
||||
)
|
||||
env_config = self._read_recursive(
|
||||
root_path / env, verbose=verbose
|
||||
) # None is ok, will return empty config
|
||||
config = ConfigTree.merge_configs(default_config, env_config, True)
|
||||
else:
|
||||
config = ConfigTree()
|
||||
|
||||
return config
|
||||
|
||||
def _read_recursive(self, conf_root, verbose=True):
|
||||
conf = ConfigTree()
|
||||
if not conf_root:
|
||||
return conf
|
||||
conf_root = Path(conf_root)
|
||||
|
||||
if not conf_root.exists():
|
||||
if verbose:
|
||||
print("No config in %s" % str(conf_root))
|
||||
return conf
|
||||
|
||||
if self.watch:
|
||||
self.config_paths.add(conf_root)
|
||||
|
||||
if verbose:
|
||||
print("Loading config from %s" % str(conf_root))
|
||||
for root, dirs, files in os.walk(str(conf_root)):
|
||||
|
||||
rel_dir = str(Path(root).relative_to(conf_root))
|
||||
if rel_dir == ".":
|
||||
rel_dir = ""
|
||||
prefix = rel_dir.replace("/", ".")
|
||||
|
||||
for filename in files:
|
||||
if not is_config_file(filename):
|
||||
continue
|
||||
|
||||
if prefix != "":
|
||||
key = prefix + "." + Path(filename).stem
|
||||
else:
|
||||
key = Path(filename).stem
|
||||
|
||||
file_path = str(Path(root) / filename)
|
||||
|
||||
conf.put(key, self._read_single_file(file_path, verbose=verbose))
|
||||
|
||||
return conf
|
||||
|
||||
@staticmethod
|
||||
def _read_single_file(file_path, verbose=True):
|
||||
if not file_path or not Path(file_path).is_file():
|
||||
return ConfigTree()
|
||||
|
||||
if verbose:
|
||||
print("Loading config from file %s" % file_path)
|
||||
|
||||
try:
|
||||
return pyhocon.ConfigFactory.parse_file(file_path)
|
||||
except ParseSyntaxException as ex:
|
||||
msg = "Failed parsing {0} ({1.__class__.__name__}): (at char {1.loc}, line:{1.lineno}, col:{1.column})".format(
|
||||
file_path, ex
|
||||
)
|
||||
six.reraise(
|
||||
ConfigurationError,
|
||||
ConfigurationError(msg, file_path=file_path),
|
||||
sys.exc_info()[2],
|
||||
)
|
||||
except (ParseException, ParseFatalException, RecursiveGrammarException) as ex:
|
||||
msg = "Failed parsing {0} ({1.__class__.__name__}): {1}".format(
|
||||
file_path, ex
|
||||
)
|
||||
six.reraise(ConfigurationError, ConfigurationError(msg), sys.exc_info()[2])
|
||||
except Exception as ex:
|
||||
print("Failed loading %s: %s" % (file_path, ex))
|
||||
raise
|
||||
|
||||
def get_config_for_bucket(self, base_url, extra_configurations=None):
|
||||
"""
|
||||
Get the credentials for an AWS S3 bucket from the config
|
||||
:param base_url: URL of bucket
|
||||
:param extra_configurations:
|
||||
:return: bucket config
|
||||
:rtype: bucket config
|
||||
"""
|
||||
|
||||
warnings.warn(
|
||||
"Use backend_config.bucket_config.BucketList.get_config_for_uri",
|
||||
DeprecationWarning,
|
||||
)
|
||||
configs = S3BucketConfig.from_list(self.get("sdk.aws.s3.credentials", []))
|
||||
if extra_configurations:
|
||||
configs.extend(extra_configurations)
|
||||
|
||||
def find_match(host=None, bucket=None):
|
||||
if not host and not bucket:
|
||||
raise ValueError("host or bucket required")
|
||||
try:
|
||||
if host:
|
||||
res = {
|
||||
config
|
||||
for config in configs
|
||||
if (config.host and fnmatch(host, config.host))
|
||||
and (
|
||||
not bucket
|
||||
or not config.bucket
|
||||
or fnmatch(bucket.lower(), config.bucket.lower())
|
||||
)
|
||||
}
|
||||
else:
|
||||
res = {
|
||||
config
|
||||
for config in configs
|
||||
if config.bucket
|
||||
and fnmatch(bucket.lower(), config.bucket.lower())
|
||||
}
|
||||
return next(iter(res))
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
parsed = urlparse(base_url)
|
||||
parts = Path(parsed.path.strip("/")).parts
|
||||
if parsed.netloc:
|
||||
# We have a netloc (either an actual hostname or an AWS bucket name).
|
||||
# First, we'll try with the netloc as host, but if we don't find anything, we'll try without a host and
|
||||
# with the netloc as the bucket name
|
||||
match = None
|
||||
if parts:
|
||||
# try host/bucket only if path parts contain any element
|
||||
match = find_match(host=parsed.netloc, bucket=parts[0])
|
||||
if not match:
|
||||
# no path parts or no config found for host/bucket, try netloc as bucket
|
||||
match = find_match(bucket=parsed.netloc)
|
||||
else:
|
||||
# No netloc, so we'll simply search by bucket
|
||||
match = find_match(bucket=parts[0])
|
||||
|
||||
if match:
|
||||
return match
|
||||
|
||||
non_aws_s3_host_suffix = ":9000"
|
||||
if parsed.netloc.endswith(non_aws_s3_host_suffix):
|
||||
host = parsed.netloc
|
||||
bucket = parts[0] if parts else None
|
||||
else:
|
||||
host = None
|
||||
bucket = parsed.netloc
|
||||
|
||||
return S3BucketConfig(
|
||||
key=self.get("sdk.aws.s3.key", None),
|
||||
secret=self.get("sdk.aws.s3.secret", None),
|
||||
region=self.get("sdk.aws.s3.region", None),
|
||||
multipart=True,
|
||||
bucket=bucket,
|
||||
host=host,
|
||||
)
|
||||
46
trains/backend_config/converters.py
Normal file
46
trains/backend_config/converters.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import base64
|
||||
from distutils.util import strtobool
|
||||
from typing import Union, Optional, Text, Any, TypeVar, Callable, Tuple
|
||||
|
||||
import six
|
||||
|
||||
ConverterType = TypeVar("ConverterType", bound=Callable[[Any], Any])
|
||||
|
||||
|
||||
def base64_to_text(value):
|
||||
# type: (Any) -> Text
|
||||
return base64.b64decode(value).decode("utf-8")
|
||||
|
||||
|
||||
def text_to_bool(value):
|
||||
# type: (Text) -> bool
|
||||
return bool(strtobool(value))
|
||||
|
||||
|
||||
def any_to_bool(value):
|
||||
# type: (Optional[Union[int, float, Text]]) -> bool
|
||||
if isinstance(value, six.text_type):
|
||||
return text_to_bool(value)
|
||||
return bool(value)
|
||||
|
||||
|
||||
def or_(*converters, **kwargs):
|
||||
# type: (ConverterType, Tuple[Exception, ...]) -> ConverterType
|
||||
"""
|
||||
Wrapper that implements an "optional converter" pattern. Allows specifying a converter
|
||||
for which a set of exceptions is ignored (and the original value is returned)
|
||||
:param converter: A converter callable
|
||||
:param exceptions: A tuple of exception types to ignore
|
||||
"""
|
||||
# noinspection PyUnresolvedReferences
|
||||
exceptions = kwargs.get("exceptions", (ValueError, TypeError))
|
||||
|
||||
def wrapper(value):
|
||||
for converter in converters:
|
||||
try:
|
||||
return converter(value)
|
||||
except exceptions:
|
||||
pass
|
||||
return value
|
||||
|
||||
return wrapper
|
||||
53
trains/backend_config/defs.py
Normal file
53
trains/backend_config/defs.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from os.path import expanduser
|
||||
from pathlib2 import Path
|
||||
|
||||
ENV_VAR = 'TRAINS_ENV'
|
||||
""" Name of system environment variable that can be used to specify the config environment name """
|
||||
|
||||
|
||||
DEFAULT_CONFIG_FOLDER = 'config'
|
||||
""" Default config folder to search for when loading relative to a given path """
|
||||
|
||||
|
||||
ENV_CONFIG_PATHS = [
|
||||
]
|
||||
|
||||
|
||||
""" Environment-related config paths """
|
||||
|
||||
|
||||
LOCAL_CONFIG_PATHS = [
|
||||
'/etc/opt/trains', # used by servers for docker-generated configuration
|
||||
expanduser('~/.trains/config'),
|
||||
]
|
||||
""" Local config paths, not related to environment """
|
||||
|
||||
|
||||
LOCAL_CONFIG_FILES = [
|
||||
expanduser('~/trains.conf'), # used for workstation configuration (end-users, workers)
|
||||
]
|
||||
""" Local config files (not paths) """
|
||||
|
||||
|
||||
LOCAL_CONFIG_FILE_OVERRIDE_VAR = 'TRAINS_CONFIG_FILE'
|
||||
""" Local config file override environment variable. If this is set, no other local config files will be used. """
|
||||
|
||||
|
||||
ENV_CONFIG_PATH_OVERRIDE_VAR = 'TRAINS_CONFIG_PATH'
|
||||
"""
|
||||
Environment-related config path override environment variable. If this is set, no other env config path will be used.
|
||||
"""
|
||||
|
||||
|
||||
class Environment(object):
|
||||
""" Supported environment names """
|
||||
default = 'default'
|
||||
demo = 'demo'
|
||||
local = 'local'
|
||||
|
||||
|
||||
CONFIG_FILE_EXTENSION = '.conf'
|
||||
|
||||
|
||||
def is_config_file(path):
|
||||
return Path(path).suffix == CONFIG_FILE_EXTENSION
|
||||
96
trains/backend_config/entry.py
Normal file
96
trains/backend_config/entry.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import abc
|
||||
from typing import Optional, Any, Tuple, Text, Callable, Dict
|
||||
|
||||
import six
|
||||
|
||||
from .converters import any_to_bool
|
||||
|
||||
NotSet = object()
|
||||
|
||||
Converter = Callable[[Any], Any]
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Entry(object):
|
||||
"""
|
||||
Configuration entry definition
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def default_conversions(cls):
|
||||
# type: () -> Dict[Any, Converter]
|
||||
return {
|
||||
bool: any_to_bool,
|
||||
six.text_type: lambda s: six.text_type(s).strip(),
|
||||
}
|
||||
|
||||
def __init__(self, key, *more_keys, **kwargs):
|
||||
# type: (Text, Text, Any) -> None
|
||||
"""
|
||||
:param key: Entry's key (at least one).
|
||||
:param more_keys: More alternate keys for this entry.
|
||||
:param type: Value type. If provided, will be used choosing a default conversion or
|
||||
(if none exists) for casting the environment value.
|
||||
:param converter: Value converter. If provided, will be used to convert the environment value.
|
||||
:param default: Default value. If provided, will be used as the default value on calls to get() and get_pair()
|
||||
in case no value is found for any key and no specific default value was provided in the call.
|
||||
Default value is None.
|
||||
:param help: Help text describing this entry
|
||||
"""
|
||||
self.keys = (key,) + more_keys
|
||||
self.type = kwargs.pop("type", six.text_type)
|
||||
self.converter = kwargs.pop("converter", None)
|
||||
self.default = kwargs.pop("default", None)
|
||||
self.help = kwargs.pop("help", None)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.key)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self.keys[0]
|
||||
|
||||
def convert(self, value, converter=None):
|
||||
# type: (Any, Converter) -> Optional[Any]
|
||||
converter = converter or self.converter
|
||||
if not converter:
|
||||
converter = self.default_conversions().get(self.type, self.type)
|
||||
return converter(value)
|
||||
|
||||
def get_pair(self, default=NotSet, converter=None):
|
||||
# type: (Any, Converter) -> Optional[Tuple[Text, Any]]
|
||||
for key in self.keys:
|
||||
value = self._get(key)
|
||||
if value is NotSet:
|
||||
continue
|
||||
try:
|
||||
value = self.convert(value, converter)
|
||||
except Exception as ex:
|
||||
self.error("invalid value {key}={value}: {ex}".format(**locals()))
|
||||
break
|
||||
return key, value
|
||||
result = self.default if default is NotSet else default
|
||||
return self.key, result
|
||||
|
||||
def get(self, default=NotSet, converter=None):
|
||||
# type: (Any, Converter) -> Optional[Any]
|
||||
return self.get_pair(default=default, converter=converter)[1]
|
||||
|
||||
def set(self, value):
|
||||
# type: (Any, Any) -> (Text, Any)
|
||||
key, _ = self.get_pair(default=None, converter=None)
|
||||
self._set(key, str(value))
|
||||
|
||||
def _set(self, key, value):
|
||||
# type: (Text, Text) -> None
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get(self, key):
|
||||
# type: (Text) -> Any
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def error(self, message):
|
||||
# type: (Text) -> None
|
||||
pass
|
||||
25
trains/backend_config/environment.py
Normal file
25
trains/backend_config/environment.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from os import getenv, environ
|
||||
|
||||
from .converters import text_to_bool
|
||||
from .entry import Entry, NotSet
|
||||
|
||||
|
||||
class EnvEntry(Entry):
|
||||
@classmethod
|
||||
def default_conversions(cls):
|
||||
conversions = super(EnvEntry, cls).default_conversions().copy()
|
||||
conversions[bool] = text_to_bool
|
||||
return conversions
|
||||
|
||||
def _get(self, key):
|
||||
value = getenv(key, "").strip()
|
||||
return value or NotSet
|
||||
|
||||
def _set(self, key, value):
|
||||
environ[key] = value
|
||||
|
||||
def __str__(self):
|
||||
return "env:{}".format(super(EnvEntry, self).__str__())
|
||||
|
||||
def error(self, message):
|
||||
print("Environment configuration: {}".format(message))
|
||||
5
trains/backend_config/errors.py
Normal file
5
trains/backend_config/errors.py
Normal file
@@ -0,0 +1,5 @@
|
||||
class ConfigurationError(Exception):
|
||||
|
||||
def __init__(self, msg, file_path=None, *args):
|
||||
super(ConfigurationError, self).__init__(msg, *args)
|
||||
self.file_path = file_path
|
||||
30
trains/backend_config/log.py
Normal file
30
trains/backend_config/log.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging.config
|
||||
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
def logger(path=None):
|
||||
name = "trains"
|
||||
if path:
|
||||
p = Path(path)
|
||||
module = (p.parent if p.stem.startswith('_') else p).stem
|
||||
name = "trains.%s" % module
|
||||
return logging.getLogger(name)
|
||||
|
||||
|
||||
def initialize(logging_config=None, extra=None):
|
||||
if extra is not None:
|
||||
from logging import Logger
|
||||
|
||||
class _Logger(Logger):
|
||||
__extra = extra.copy()
|
||||
|
||||
def _log(self, level, msg, args, exc_info=None, extra=None, **kwargs):
|
||||
extra = extra or {}
|
||||
extra.update(self.__extra)
|
||||
super(_Logger, self)._log(level, msg, args, exc_info=exc_info, extra=extra, **kwargs)
|
||||
|
||||
Logger.manager.loggerClass = _Logger
|
||||
|
||||
if logging_config is not None:
|
||||
logging.config.dictConfig(dict(logging_config))
|
||||
32
trains/backend_config/reloader.py
Normal file
32
trains/backend_config/reloader.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import logging
|
||||
|
||||
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent, FileModifiedEvent, \
|
||||
FileMovedEvent
|
||||
|
||||
from .defs import is_config_file
|
||||
from .log import logger
|
||||
|
||||
|
||||
log = logger(__file__)
|
||||
log.setLevel(logging.DEBUG)
|
||||
|
||||
|
||||
class ConfigReloader(FileSystemEventHandler):
|
||||
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
|
||||
def reload(self):
|
||||
try:
|
||||
self.config.reload()
|
||||
except Exception as ex:
|
||||
log.warning('failed loading configuration: %s: %s', type(ex), ex)
|
||||
|
||||
def on_any_event(self, event):
|
||||
if not (
|
||||
is_config_file(event.src_path) and
|
||||
isinstance(event, (FileCreatedEvent, FileDeletedEvent, FileModifiedEvent, FileMovedEvent))
|
||||
):
|
||||
return
|
||||
log.debug('reloading configuration - triggered by %s', event)
|
||||
self.reload()
|
||||
9
trains/backend_config/utils.py
Normal file
9
trains/backend_config/utils.py
Normal file
@@ -0,0 +1,9 @@
|
||||
|
||||
def get_items(cls):
|
||||
""" get key/value items from an enum-like class (members represent enumeration key/value) """
|
||||
return {k: v for k, v in vars(cls).items() if not k.startswith('_')}
|
||||
|
||||
|
||||
def get_options(cls):
|
||||
""" get options from an enum-like class (members represent enumeration key/value) """
|
||||
return get_items(cls).values()
|
||||
2
trains/backend_interface/__init__.py
Normal file
2
trains/backend_interface/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
""" High-level abstractions for backend API """
|
||||
from .task import Task, TaskStatusEnum, TaskEntry
|
||||
147
trains/backend_interface/base.py
Normal file
147
trains/backend_interface/base.py
Normal file
@@ -0,0 +1,147 @@
|
||||
import abc
|
||||
|
||||
import requests.exceptions
|
||||
import six
|
||||
from ..backend_api import Session
|
||||
from ..backend_api.session import BatchRequest
|
||||
|
||||
from ..config import config_obj
|
||||
from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY
|
||||
from ..debugging import get_logger
|
||||
from ..backend_api.version import __version__
|
||||
from .session import SendError, SessionInterface
|
||||
|
||||
|
||||
class InterfaceBase(SessionInterface):
|
||||
""" Base class for a backend manager class """
|
||||
_default_session = None
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self._session
|
||||
|
||||
@property
|
||||
def log(self):
|
||||
return self._log
|
||||
|
||||
def __init__(self, session=None, log=None, **kwargs):
|
||||
super(InterfaceBase, self).__init__()
|
||||
self._session = session or self._get_default_session()
|
||||
self._log = log or self._create_log()
|
||||
|
||||
def _create_log(self):
|
||||
log = get_logger(str(self.__class__.__name__))
|
||||
try:
|
||||
log.setLevel(LOG_LEVEL_ENV_VAR.get(default=log.level))
|
||||
except TypeError as ex:
|
||||
raise ValueError('Invalid log level defined in environment variable `%s`: %s' % (LOG_LEVEL_ENV_VAR, ex))
|
||||
return log
|
||||
|
||||
@classmethod
|
||||
def _send(cls, session, req, ignore_errors=False, raise_on_errors=True, log=None, async_enable=False):
|
||||
""" Convenience send() method providing a standardized error reporting """
|
||||
while True:
|
||||
try:
|
||||
res = session.send(req, async_enable=async_enable)
|
||||
if res.meta.result_code in (200, 202) or ignore_errors:
|
||||
return res
|
||||
|
||||
if isinstance(req, BatchRequest):
|
||||
error_msg = 'Action failed %s' % res.meta
|
||||
else:
|
||||
error_msg = 'Action failed %s (%s)' \
|
||||
% (res.meta, ', '.join('%s=%s' % p for p in req.to_dict().items()))
|
||||
if log:
|
||||
log.error(error_msg)
|
||||
|
||||
if res.meta.result_code <= 500:
|
||||
# Proper backend error/bad status code - raise or return
|
||||
if raise_on_errors:
|
||||
raise SendError(res, error_msg)
|
||||
return res
|
||||
|
||||
except requests.exceptions.BaseHTTPError as e:
|
||||
log.error('failed sending %s: %s' % (str(req), str(e)))
|
||||
|
||||
# Infrastructure error
|
||||
if log:
|
||||
log.info('retrying request %s' % str(req))
|
||||
|
||||
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
|
||||
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,
|
||||
log=self.log, async_enable=async_enable)
|
||||
|
||||
@classmethod
|
||||
def _get_default_session(cls):
|
||||
if not InterfaceBase._default_session:
|
||||
InterfaceBase._default_session = Session(
|
||||
initialize_logging=False,
|
||||
client='sdk-%s' % __version__,
|
||||
config=config_obj,
|
||||
api_key=API_ACCESS_KEY.get(),
|
||||
secret_key=API_SECRET_KEY.get(),
|
||||
)
|
||||
return InterfaceBase._default_session
|
||||
|
||||
@classmethod
|
||||
def _set_default_session(cls, session):
|
||||
"""
|
||||
Set a new default session to the system
|
||||
|
||||
Warning: Use only for debug and testing
|
||||
:param session: The new default session
|
||||
"""
|
||||
|
||||
InterfaceBase._default_session = session
|
||||
|
||||
@property
|
||||
def default_session(self):
|
||||
if hasattr(self, '_session'):
|
||||
return self._session
|
||||
return self._get_default_session()
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class IdObjectBase(InterfaceBase):
|
||||
|
||||
def __init__(self, id, session=None, log=None, **kwargs):
|
||||
super(IdObjectBase, self).__init__(session, log, **kwargs)
|
||||
self._data = None
|
||||
self._id = None
|
||||
self.id = self.normalize_id(id)
|
||||
|
||||
@property
|
||||
def id(self):
|
||||
return self._id
|
||||
|
||||
@id.setter
|
||||
def id(self, value):
|
||||
should_reload = value is not None and value != self._id
|
||||
self._id = value
|
||||
if should_reload:
|
||||
self.reload()
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
if self._data is None:
|
||||
self.reload()
|
||||
return self._data
|
||||
|
||||
@abc.abstractmethod
|
||||
def _reload(self):
|
||||
pass
|
||||
|
||||
def reload(self):
|
||||
if not self.id:
|
||||
raise ValueError('Failed reloading %s: missing id' % type(self).__name__)
|
||||
self._data = self._reload()
|
||||
|
||||
@classmethod
|
||||
def normalize_id(cls, id):
|
||||
return id.strip() if id else None
|
||||
|
||||
@classmethod
|
||||
def resolve_id(cls, obj):
|
||||
if isinstance(obj, cls):
|
||||
return obj.id
|
||||
return obj
|
||||
4
trains/backend_interface/metrics/__init__.py
Normal file
4
trains/backend_interface/metrics/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
""" Metrics management and batching support """
|
||||
from .interface import Metrics
|
||||
from .reporter import Reporter
|
||||
from .events import ScalarEvent, VectorEvent, PlotEvent, ImageEvent
|
||||
258
trains/backend_interface/metrics/events.py
Normal file
258
trains/backend_interface/metrics/events.py
Normal file
@@ -0,0 +1,258 @@
|
||||
import abc
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
import attr
|
||||
import cv2
|
||||
import numpy as np
|
||||
import pathlib2
|
||||
import six
|
||||
from ...backend_api.services import events
|
||||
from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
from ...config import config
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class MetricsEventAdapter(object):
|
||||
"""
|
||||
Adapter providing all the base attributes required by a metrics event and defining an interface used by the
|
||||
metrics manager when batching and writing events.
|
||||
"""
|
||||
|
||||
_default_nan_value = 0.
|
||||
""" Default value used when a np.nan value is encountered """
|
||||
|
||||
@attr.attrs(cmp=False, slots=True)
|
||||
class FileEntry(object):
|
||||
""" File entry used to report on file data that needs to be uploaded prior to sending the event """
|
||||
|
||||
event = attr.attrib()
|
||||
|
||||
name = attr.attrib()
|
||||
""" File name """
|
||||
|
||||
stream = attr.attrib()
|
||||
""" File-like object containing the file's data """
|
||||
|
||||
url_prop = attr.attrib()
|
||||
""" Property name that should be updated with the uploaded url """
|
||||
|
||||
key_prop = attr.attrib()
|
||||
|
||||
upload_uri = attr.attrib()
|
||||
|
||||
url = attr.attrib(default=None)
|
||||
|
||||
exception = attr.attrib(default=None)
|
||||
|
||||
def set_exception(self, exp):
|
||||
self.exception = exp
|
||||
self.event.upload_exception = exp
|
||||
|
||||
@property
|
||||
def metric(self):
|
||||
return self._metric
|
||||
|
||||
@metric.setter
|
||||
def metric(self, value):
|
||||
self._metric = value
|
||||
|
||||
@property
|
||||
def variant(self):
|
||||
return self._variant
|
||||
|
||||
def __init__(self, metric, variant, iter=None, timestamp=None, task=None, gen_timestamp_if_none=True):
|
||||
if not timestamp and gen_timestamp_if_none:
|
||||
timestamp = int(time.time() * 1000)
|
||||
self._metric = metric
|
||||
self._variant = variant
|
||||
self._iter = iter
|
||||
self._timestamp = timestamp
|
||||
self._task = task
|
||||
|
||||
# Try creating an event just to trigger validation
|
||||
_ = self.get_api_event()
|
||||
self.upload_exception = None
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_api_event(self):
|
||||
""" Get an API event instance """
|
||||
pass
|
||||
|
||||
def get_file_entry(self):
|
||||
""" Get information for a file that should be uploaded before this event is sent """
|
||||
pass
|
||||
|
||||
def update(self, task=None, **kwargs):
|
||||
""" Update event properties """
|
||||
if task:
|
||||
self._task = task
|
||||
|
||||
def _get_base_dict(self):
|
||||
""" Get a dict with the base attributes """
|
||||
res = dict(
|
||||
task=self._task,
|
||||
timestamp=self._timestamp,
|
||||
metric=self._metric,
|
||||
variant=self._variant
|
||||
)
|
||||
if self._iter is not None:
|
||||
res.update(iter=self._iter)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _convert_np_nan(cls, val):
|
||||
if np.isnan(val) or np.isinf(val):
|
||||
return cls._default_nan_value
|
||||
return val
|
||||
|
||||
|
||||
class ScalarEvent(MetricsEventAdapter):
|
||||
""" Scalar event adapter """
|
||||
|
||||
def __init__(self, metric, variant, value, iter, **kwargs):
|
||||
self._value = self._convert_np_nan(value)
|
||||
super(ScalarEvent, self).__init__(metric=metric, variant=variant, iter=iter, **kwargs)
|
||||
|
||||
def get_api_event(self):
|
||||
return events.MetricsScalarEvent(
|
||||
value=self._value,
|
||||
**self._get_base_dict())
|
||||
|
||||
|
||||
class VectorEvent(MetricsEventAdapter):
|
||||
""" Vector event adapter """
|
||||
|
||||
def __init__(self, metric, variant, values, iter, **kwargs):
|
||||
self._values = [self._convert_np_nan(v) for v in values]
|
||||
super(VectorEvent, self).__init__(metric=metric, variant=variant, iter=iter, **kwargs)
|
||||
|
||||
def get_api_event(self):
|
||||
return events.MetricsVectorEvent(
|
||||
values=self._values,
|
||||
**self._get_base_dict())
|
||||
|
||||
|
||||
class PlotEvent(MetricsEventAdapter):
|
||||
""" Plot event adapter """
|
||||
|
||||
def __init__(self, metric, variant, plot_str, iter=None, **kwargs):
|
||||
self._plot_str = plot_str
|
||||
super(PlotEvent, self).__init__(metric=metric, variant=variant, iter=iter, **kwargs)
|
||||
|
||||
def get_api_event(self):
|
||||
return events.MetricsPlotEvent(
|
||||
plot_str=self._plot_str,
|
||||
**self._get_base_dict())
|
||||
|
||||
|
||||
class ImageEventNoUpload(MetricsEventAdapter):
|
||||
|
||||
def __init__(self, metric, variant, src, iter=0, **kwargs):
|
||||
parts = urlparse(src)
|
||||
self._url = urlunparse((parts.scheme, parts.netloc, '', '', '', ''))
|
||||
self._key = urlunparse(('', '', parts.path, parts.params, parts.query, parts.fragment))
|
||||
super(ImageEventNoUpload, self).__init__(metric, variant, iter=iter, **kwargs)
|
||||
|
||||
def get_api_event(self):
|
||||
return events.MetricsImageEvent(
|
||||
url=self._url,
|
||||
key=self._key,
|
||||
**self._get_base_dict())
|
||||
|
||||
|
||||
class ImageEvent(MetricsEventAdapter):
|
||||
""" Image event adapter """
|
||||
_format = '.' + str(config.get('metrics.images.format', 'JPEG')).upper().lstrip('.')
|
||||
_quality = int(config.get('metrics.images.quality', 87))
|
||||
_subsampling = int(config.get('metrics.images.subsampling', 0))
|
||||
|
||||
_metric_counters = {}
|
||||
_metric_counters_lock = Lock()
|
||||
_image_file_history_size = int(config.get('metrics.file_history_size', 5))
|
||||
|
||||
def __init__(self, metric, variant, image_data, iter=0, upload_uri=None,
|
||||
image_file_history_size=None, **kwargs):
|
||||
if not hasattr(image_data, 'shape'):
|
||||
raise ValueError('Image must have a shape attribute')
|
||||
self._image_data = image_data
|
||||
self._url = None
|
||||
self._key = None
|
||||
self._count = self._get_metric_count(metric, variant)
|
||||
if not image_file_history_size:
|
||||
image_file_history_size = self._image_file_history_size
|
||||
if image_file_history_size < 1:
|
||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
|
||||
else:
|
||||
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
|
||||
self._upload_uri = upload_uri
|
||||
super(ImageEvent, self).__init__(metric, variant, iter=iter, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _get_metric_count(cls, metric, variant, next=True):
|
||||
""" Returns the next count number for the given metric/variant (rotates every few calls) """
|
||||
counters = cls._metric_counters
|
||||
key = '%s_%s' % (metric, variant)
|
||||
try:
|
||||
cls._metric_counters_lock.acquire()
|
||||
value = counters.get(key, -1)
|
||||
if next:
|
||||
value = counters[key] = value + 1
|
||||
return value
|
||||
finally:
|
||||
cls._metric_counters_lock.release()
|
||||
|
||||
def get_api_event(self):
|
||||
return events.MetricsImageEvent(
|
||||
url=self._url,
|
||||
key=self._key,
|
||||
**self._get_base_dict())
|
||||
|
||||
def update(self, url=None, key=None, **kwargs):
|
||||
super(ImageEvent, self).update(**kwargs)
|
||||
if url is not None:
|
||||
self._url = url
|
||||
if key is not None:
|
||||
self._key = key
|
||||
|
||||
def get_file_entry(self):
|
||||
# don't provide file in case this event is out of the history window
|
||||
last_count = self._get_metric_count(self.metric, self.variant, next=False)
|
||||
if abs(self._count - last_count) > self._image_file_history_size:
|
||||
output = None
|
||||
else:
|
||||
image_data = self._image_data
|
||||
if not isinstance(image_data, np.ndarray):
|
||||
# try conversion, if it fails we'll leave it to the user.
|
||||
image_data = np.ndarray(image_data, dtype=np.uint8)
|
||||
image_data = np.atleast_3d(image_data)
|
||||
if image_data.dtype != np.uint8:
|
||||
if np.issubdtype(image_data.dtype, np.floating) and image_data.max() <= 1.0:
|
||||
image_data = (image_data*255).astype(np.uint8)
|
||||
else:
|
||||
image_data = image_data.astype(np.uint8)
|
||||
shape = image_data.shape
|
||||
height, width, channel = shape[:3]
|
||||
if channel == 1:
|
||||
image_data = np.reshape(image_data, (height, width))
|
||||
|
||||
# serialize image
|
||||
_, img_bytes = cv2.imencode(
|
||||
self._format, image_data,
|
||||
params=(cv2.IMWRITE_JPEG_QUALITY, self._quality),
|
||||
)
|
||||
|
||||
output = six.BytesIO(img_bytes.tostring())
|
||||
output.seek(0)
|
||||
|
||||
filename = str(pathlib2.Path(self._filename).with_suffix(self._format.lower()))
|
||||
|
||||
return self.FileEntry(
|
||||
event=self,
|
||||
name=filename,
|
||||
stream=output,
|
||||
url_prop='url',
|
||||
key_prop='key',
|
||||
upload_uri=self._upload_uri
|
||||
)
|
||||
192
trains/backend_interface/metrics/interface.py
Normal file
192
trains/backend_interface/metrics/interface.py
Normal file
@@ -0,0 +1,192 @@
|
||||
from functools import partial
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from threading import Lock
|
||||
from time import time
|
||||
|
||||
from humanfriendly import format_timespan
|
||||
from ...backend_api.services import events as api_events
|
||||
from ..base import InterfaceBase
|
||||
from ...config import config
|
||||
from ...debugging import get_logger
|
||||
from ...storage import StorageHelper
|
||||
|
||||
from .events import MetricsEventAdapter
|
||||
|
||||
|
||||
upload_pool = ThreadPool(processes=1)
|
||||
file_upload_pool = ThreadPool(processes=config.get('network.metrics.file_upload_threads', 4))
|
||||
|
||||
log = get_logger('metrics')
|
||||
|
||||
|
||||
class Metrics(InterfaceBase):
|
||||
""" Metrics manager and batch writer """
|
||||
_storage_lock = Lock()
|
||||
_file_upload_starvation_warning_sec = config.get('network.metrics.file_upload_starvation_warning_sec', None)
|
||||
|
||||
@property
|
||||
def storage_key_prefix(self):
|
||||
return self._storage_key_prefix
|
||||
|
||||
def _get_storage(self, storage_uri=None):
|
||||
""" Storage helper used to upload files """
|
||||
try:
|
||||
# use a lock since this storage object will be requested by thread pool threads, so we need to make sure
|
||||
# any singleton initialization will occur only once
|
||||
self._storage_lock.acquire()
|
||||
storage_uri = storage_uri or self._storage_uri
|
||||
return StorageHelper.get(storage_uri)
|
||||
except Exception as e:
|
||||
log.error('Failed getting storage helper for %s: %s' % (storage_uri, str(e)))
|
||||
finally:
|
||||
self._storage_lock.release()
|
||||
|
||||
def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', log=None):
|
||||
super(Metrics, self).__init__(session, log=log)
|
||||
self._task_id = task_id
|
||||
self._storage_uri = storage_uri.rstrip('/') if storage_uri else None
|
||||
self._storage_key_prefix = storage_uri_suffix.strip('/') if storage_uri_suffix else None
|
||||
self._file_related_event_time = None
|
||||
self._file_upload_time = None
|
||||
|
||||
def write_events(self, events, async_enable=True, callback=None, **kwargs):
|
||||
"""
|
||||
Write events to the backend, uploading any required files.
|
||||
:param events: A list of event objects
|
||||
:param async_enable: If True, upload is performed asynchronously and an AsyncResult object is returned, otherwise a
|
||||
blocking call is made and the upload result is returned.
|
||||
:param callback: A optional callback called when upload was completed in case async is True
|
||||
:return: .backend_api.session.CallResult if async is False otherwise AsyncResult. Note that if no events were
|
||||
sent, None will be returned.
|
||||
"""
|
||||
if not events:
|
||||
return
|
||||
|
||||
storage_uri = kwargs.pop('storage_uri', self._storage_uri)
|
||||
|
||||
if not async_enable:
|
||||
return self._do_write_events(events, storage_uri)
|
||||
|
||||
def safe_call(*args, **kwargs):
|
||||
try:
|
||||
return self._do_write_events(*args, **kwargs)
|
||||
except Exception as e:
|
||||
return e
|
||||
|
||||
return upload_pool.apply_async(
|
||||
safe_call,
|
||||
args=(events, storage_uri),
|
||||
callback=partial(self._callback_wrapper, callback))
|
||||
|
||||
def _callback_wrapper(self, callback, res):
|
||||
""" A wrapper for the async callback for handling common errors """
|
||||
if not res:
|
||||
# no result yet
|
||||
return
|
||||
elif isinstance(res, Exception):
|
||||
# error
|
||||
self.log.error('Error trying to send metrics: %s' % str(res))
|
||||
elif not res.ok():
|
||||
# bad result, log error
|
||||
self.log.error('Failed reporting metrics: %s' % str(res.meta))
|
||||
# call callback, even if we received an error
|
||||
if callback:
|
||||
callback(res)
|
||||
|
||||
def _do_write_events(self, events, storage_uri=None):
|
||||
""" Sends an iterable of events as a series of batch operations. note: metric send does not raise on error"""
|
||||
assert isinstance(events, (list, tuple))
|
||||
assert all(isinstance(x, MetricsEventAdapter) for x in events)
|
||||
|
||||
# def event_key(ev):
|
||||
# return (ev.metric, ev.variant)
|
||||
#
|
||||
# events = sorted(events, key=event_key)
|
||||
# multiple_events_for = [k for k, v in groupby(events, key=event_key) if len(list(v)) > 1]
|
||||
# if multiple_events_for:
|
||||
# log.warning(
|
||||
# 'More than one metrics event sent for these metric/variant combinations in a report: %s' %
|
||||
# ', '.join('%s/%s' % k for k in multiple_events_for))
|
||||
|
||||
storage_uri = storage_uri or self._storage_uri
|
||||
|
||||
now = time()
|
||||
|
||||
def update_and_get_file_entry(ev):
|
||||
entry = ev.get_file_entry()
|
||||
kwargs = {}
|
||||
if entry:
|
||||
e_storage_uri = entry.upload_uri or storage_uri
|
||||
self._file_related_event_time = now
|
||||
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
|
||||
filename = entry.name
|
||||
key = '/'.join(x for x in (self._storage_key_prefix, ev.metric, ev.variant, filename.strip('/')) if x)
|
||||
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
|
||||
kwargs[entry.key_prop] = key
|
||||
kwargs[entry.url_prop] = url
|
||||
if not entry.stream:
|
||||
# if entry has no stream, we won't upload it
|
||||
entry = None
|
||||
else:
|
||||
if not hasattr(entry.stream, 'read'):
|
||||
raise ValueError('Invalid file object %s' % entry.stream)
|
||||
entry.url = url
|
||||
ev.update(task=self._task_id, **kwargs)
|
||||
return entry
|
||||
|
||||
# prepare event needing file upload
|
||||
entries = []
|
||||
for ev in events:
|
||||
try:
|
||||
e = update_and_get_file_entry(ev)
|
||||
if e:
|
||||
entries.append(e)
|
||||
except Exception as ex:
|
||||
log.warning(str(ex))
|
||||
|
||||
# upload the needed files
|
||||
if entries:
|
||||
# upload files
|
||||
def upload(e):
|
||||
upload_uri = e.upload_uri or storage_uri
|
||||
|
||||
try:
|
||||
storage = self._get_storage(upload_uri)
|
||||
url = storage.upload_from_stream(e.stream, e.url)
|
||||
e.event.update(url=url)
|
||||
except Exception as exp:
|
||||
log.debug("Failed uploading to {} ({})".format(
|
||||
upload_uri if upload_uri else "(Could not calculate upload uri)",
|
||||
exp,
|
||||
))
|
||||
|
||||
e.set_exception(exp)
|
||||
|
||||
res = file_upload_pool.map_async(upload, entries)
|
||||
res.wait()
|
||||
|
||||
# remember the last time we uploaded a file
|
||||
self._file_upload_time = time()
|
||||
|
||||
t_f, t_u, t_ref = \
|
||||
(self._file_related_event_time, self._file_upload_time, self._file_upload_starvation_warning_sec)
|
||||
if t_f and t_u and t_ref and (t_f - t_u) > t_ref:
|
||||
log.warning('Possible metrics file upload starvation: files were not uploaded for %s' %
|
||||
format_timespan(t_ref))
|
||||
|
||||
# send the events in a batched request
|
||||
good_events = [ev for ev in events if ev.upload_exception is None]
|
||||
error_events = [ev for ev in events if ev.upload_exception is not None]
|
||||
|
||||
if error_events:
|
||||
log.error("Not uploading {}/{} events because the data upload failed".format(
|
||||
len(error_events),
|
||||
len(events),
|
||||
))
|
||||
|
||||
if good_events:
|
||||
batched_requests = [api_events.AddRequest(event=ev.get_api_event()) for ev in good_events]
|
||||
req = api_events.AddBatchRequest(requests=batched_requests)
|
||||
return self.send(req, raise_on_errors=False)
|
||||
|
||||
return None
|
||||
457
trains/backend_interface/metrics/reporter.py
Normal file
457
trains/backend_interface/metrics/reporter.py
Normal file
@@ -0,0 +1,457 @@
|
||||
import collections
|
||||
import json
|
||||
|
||||
import cv2
|
||||
import six
|
||||
|
||||
from ..base import InterfaceBase
|
||||
from ..setupuploadmixin import SetupUploadMixin
|
||||
from ...utilities.async_manager import AsyncManagerMixin
|
||||
from ...utilities.plotly import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
|
||||
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict
|
||||
from ...utilities.py3_interop import AbstractContextManager
|
||||
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload
|
||||
|
||||
|
||||
class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin):
|
||||
"""
|
||||
A simple metrics reporter class.
|
||||
This class caches reports and supports both a explicit flushing and context-based flushing. To ensure reports are
|
||||
sent to the backend, please use (assuming an instance of Reporter named 'reporter'):
|
||||
- use the context manager feature (which will automatically flush when exiting the context):
|
||||
with reporter:
|
||||
reporter.report...
|
||||
...
|
||||
- explicitly call flush:
|
||||
reporter.report...
|
||||
...
|
||||
reporter.flush()
|
||||
"""
|
||||
|
||||
def __init__(self, metrics, flush_threshold=10, async_enable=False):
|
||||
"""
|
||||
Create a reporter
|
||||
:param metrics: A Metrics manager instance that handles actual reporting, uploads etc.
|
||||
:type metrics: .backend_interface.metrics.Metrics
|
||||
:param flush_threshold: Events flush threshold. This determines the threshold over which cached reported events
|
||||
are flushed and sent to the backend.
|
||||
:type flush_threshold: int
|
||||
"""
|
||||
log = metrics.log.getChild('reporter')
|
||||
log.setLevel(log.level)
|
||||
super(Reporter, self).__init__(session=metrics.session, log=log)
|
||||
self._metrics = metrics
|
||||
self._flush_threshold = flush_threshold
|
||||
self._events = []
|
||||
self._bucket_config = None
|
||||
self._storage_uri = None
|
||||
self._async_enable = async_enable
|
||||
|
||||
def _set_storage_uri(self, value):
|
||||
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
|
||||
self._storage_uri = value
|
||||
|
||||
storage_uri = property(None, _set_storage_uri)
|
||||
|
||||
@property
|
||||
def flush_threshold(self):
|
||||
return self._flush_threshold
|
||||
|
||||
@flush_threshold.setter
|
||||
def flush_threshold(self, value):
|
||||
self._flush_threshold = max(0, value)
|
||||
|
||||
@property
|
||||
def async_enable(self):
|
||||
return self._async_enable
|
||||
|
||||
@async_enable.setter
|
||||
def async_enable(self, value):
|
||||
self._async_enable = bool(value)
|
||||
|
||||
def _report(self, ev):
|
||||
self._events.append(ev)
|
||||
if len(self._events) >= self._flush_threshold:
|
||||
self._write()
|
||||
|
||||
def _write(self):
|
||||
if not self._events:
|
||||
return
|
||||
# print('reporting %d events' % len(self._events))
|
||||
res = self._metrics.write_events(self._events, async_enable=self._async_enable, storage_uri=self._storage_uri)
|
||||
if self._async_enable:
|
||||
self._add_async_result(res)
|
||||
self._events = []
|
||||
|
||||
def flush(self):
|
||||
"""
|
||||
Flush cached reports to backend.
|
||||
"""
|
||||
self._write()
|
||||
# wait for all reports
|
||||
if self.get_num_results() > 0:
|
||||
self.wait_for_results()
|
||||
|
||||
def report_scalar(self, title, series, value, iter):
|
||||
"""
|
||||
Report a scalar value
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param value: Reported value
|
||||
:type value: float
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
"""
|
||||
ev = ScalarEvent(metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series), value=value, iter=iter)
|
||||
self._report(ev)
|
||||
|
||||
def report_vector(self, title, series, values, iter):
|
||||
"""
|
||||
Report a vector of values
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param values: Reported values
|
||||
:type value: [float]
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
"""
|
||||
if not isinstance(values, collections.Iterable):
|
||||
raise ValueError('values: expected an iterable')
|
||||
ev = VectorEvent(metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series), values=values, iter=iter)
|
||||
self._report(ev)
|
||||
|
||||
def report_plot(self, title, series, plot, iter):
|
||||
"""
|
||||
Report a Plotly chart
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/)
|
||||
:type plot: str or dict
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
"""
|
||||
if isinstance(plot, dict):
|
||||
plot = json.dumps(plot)
|
||||
elif not isinstance(plot, six.string_types):
|
||||
raise ValueError('Plot should be a string or a dict')
|
||||
ev = PlotEvent(metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series), plot_str=plot, iter=iter)
|
||||
self._report(ev)
|
||||
|
||||
def report_image(self, title, series, src, iter):
|
||||
"""
|
||||
Report an image.
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param src: Image source URI. This URI will be used by the webapp and workers when trying to obtain the image
|
||||
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
|
||||
:type src: str
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
"""
|
||||
ev = ImageEventNoUpload(metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series), iter=iter, src=src)
|
||||
self._report(ev)
|
||||
|
||||
def report_image_and_upload(self, title, series, iter, path=None, matrix=None, upload_uri=None,
|
||||
max_image_history=None):
|
||||
"""
|
||||
Report an image and upload its contents. Image is uploaded to a preconfigured bucket (see setup_upload()) with
|
||||
a key (filename) describing the task ID, title, series and iteration.
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:param path: A path to an image file. Required unless matrix is provided.
|
||||
:type path: str
|
||||
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided.
|
||||
:type matrix: str
|
||||
:param max_image_history: maximum number of image to store per metric/variant combination
|
||||
use negative value for unlimited. default is set in global configuration (default=5)
|
||||
"""
|
||||
if not self._storage_uri and not upload_uri:
|
||||
raise ValueError('Upload configuration is required (use setup_upload())')
|
||||
if len([x for x in (path, matrix) if x is not None]) != 1:
|
||||
raise ValueError('Expected only one of [filename, matrix]')
|
||||
kwargs = dict(metric=self._normalize_name(title),
|
||||
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
|
||||
if matrix is None:
|
||||
matrix = cv2.imread(path)
|
||||
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, **kwargs)
|
||||
self._report(ev)
|
||||
|
||||
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None):
|
||||
"""
|
||||
Report an histogram bar plot
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param histogram: The histogram data.
|
||||
A row for each dataset(bar in a bar group). A column for each bucket.
|
||||
:type histogram: numpy array
|
||||
:param iter: Iteration number
|
||||
:type value: int
|
||||
:param labels: The labels for each bar group.
|
||||
:type labels: list of strings.
|
||||
:param xlabels: The labels of the x axis.
|
||||
:type xlabels: List of strings.
|
||||
:param comment: comment underneath the title
|
||||
:type comment: str
|
||||
"""
|
||||
plotly_dict = create_2d_histogram_plot(
|
||||
np_row_wise=histogram,
|
||||
title=title,
|
||||
labels=labels,
|
||||
series=series,
|
||||
xlabels=xlabels,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
return self.report_plot(
|
||||
title=self._normalize_name(title),
|
||||
series=self._normalize_name(series),
|
||||
plot=plotly_dict,
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False, comment=None):
|
||||
"""
|
||||
Report a (possibly multiple) line plot.
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: All the series' data, one for each line in the plot.
|
||||
:type series: An iterable of LineSeriesInfo.
|
||||
:param iter: Iteration number
|
||||
:type iter: int
|
||||
:param xtitle: x-axis title
|
||||
:type xtitle: str
|
||||
:param ytitle: y-axis title
|
||||
:type ytitle: str
|
||||
:param mode: 'lines' / 'markers' / 'lines+markers'
|
||||
:type mode: str
|
||||
:param reverse_xaxis: If true X axis will be displayed from high to low (reversed)
|
||||
:type reverse_xaxis: bool
|
||||
:param comment: comment underneath the title
|
||||
:type comment: str
|
||||
"""
|
||||
|
||||
plotly_dict = create_line_plot(
|
||||
title=title,
|
||||
series=series,
|
||||
xtitle=xtitle,
|
||||
ytitle=ytitle,
|
||||
mode=mode,
|
||||
reverse_xaxis=reverse_xaxis,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
return self.report_plot(
|
||||
title=self._normalize_name(title),
|
||||
series='',
|
||||
plot=plotly_dict,
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
def report_2d_scatter(self, title, series, data, iter, mode='lines', xtitle=None, ytitle=None, labels=None,
|
||||
comment=None):
|
||||
"""
|
||||
Report a 2d scatter graph (with lines)
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param data: A scattered data: pairs of x,y as rows in a numpy array
|
||||
:type scatter: ndarray
|
||||
:param iter: Iteration number
|
||||
:type iter: int
|
||||
:param mode: (type str) 'lines'/'markers'/'lines+markers'
|
||||
:param xtitle: optional x-axis title
|
||||
:param ytitle: optional y-axis title
|
||||
:param labels: label (text) per point in the scatter (in the same order)
|
||||
:param comment: comment underneath the title
|
||||
:type comment: str
|
||||
"""
|
||||
plotly_dict = create_2d_scatter_series(
|
||||
np_row_wise=data,
|
||||
title=title,
|
||||
series_name=series,
|
||||
mode=mode,
|
||||
xtitle=xtitle,
|
||||
ytitle=ytitle,
|
||||
labels=labels,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
return self.report_plot(
|
||||
title=self._normalize_name(title),
|
||||
series=self._normalize_name(series),
|
||||
plot=plotly_dict,
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
def report_3d_scatter(self, title, series, data, iter, labels=None, mode='lines', color=((217, 217, 217, 0.14),),
|
||||
marker_size=5, line_width=0.8, xtitle=None, ytitle=None, ztitle=None, fill=None,
|
||||
comment=None):
|
||||
"""
|
||||
Report a 3d scatter graph (with markers)
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param data: A scattered data: pairs of x,y,z as rows in a numpy array. or list of numpy arrays
|
||||
:type data: ndarray.
|
||||
:param iter: Iteration number
|
||||
:type iter: int
|
||||
:param labels: label (text) per point in the scatter (in the same order)
|
||||
:type labels: str
|
||||
:param mode: (type str) 'lines'/'markers'/'lines+markers'
|
||||
:param color: list of RGBA colors [(217, 217, 217, 0.14),]
|
||||
:param marker_size: marker size in px
|
||||
:param line_width: line width in px
|
||||
:param xtitle: optional x-axis title
|
||||
:param ytitle: optional y-axis title
|
||||
:param ztitle: optional z-axis title
|
||||
:param comment: comment underneath the title
|
||||
"""
|
||||
data_series = data if isinstance(data, list) else [data]
|
||||
|
||||
def get_labels(i):
|
||||
if labels and isinstance(labels, list):
|
||||
try:
|
||||
item = labels[i]
|
||||
except IndexError:
|
||||
item = labels[-1]
|
||||
if isinstance(item, list):
|
||||
return item
|
||||
return labels
|
||||
|
||||
plotly_obj = plotly_scatter3d_layout_dict(
|
||||
title=title,
|
||||
xaxis_title=xtitle,
|
||||
yaxis_title=ytitle,
|
||||
zaxis_title=ztitle,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
for i, values in enumerate(data_series):
|
||||
plotly_obj = create_3d_scatter_series(
|
||||
np_row_wise=values,
|
||||
title=title,
|
||||
series_name=series[i] if isinstance(series, list) else None,
|
||||
labels=get_labels(i),
|
||||
plotly_obj=plotly_obj,
|
||||
mode=mode,
|
||||
line_width=line_width,
|
||||
marker_size=marker_size,
|
||||
color=color,
|
||||
fill_axis=fill,
|
||||
)
|
||||
|
||||
return self.report_plot(
|
||||
title=self._normalize_name(title),
|
||||
series=self._normalize_name(series) if not isinstance(series, list) else None,
|
||||
plot=plotly_obj,
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
def report_value_matrix(self, title, series, data, iter, xlabels=None, ylabels=None, comment=None):
|
||||
"""
|
||||
Report a heat-map matrix
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param data: A heat-map matrix (example: confusion matrix)
|
||||
:type data: ndarray
|
||||
:param iter: Iteration number
|
||||
:type iter: int
|
||||
:param xlabels: optional label per column of the matrix
|
||||
:param ylabels: optional label per row of the matrix
|
||||
:param comment: comment underneath the title
|
||||
"""
|
||||
|
||||
plotly_dict = create_value_matrix(
|
||||
np_value_matrix=data,
|
||||
title=title,
|
||||
xlabels=xlabels,
|
||||
ylabels=ylabels,
|
||||
series=series,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
return self.report_plot(
|
||||
title=self._normalize_name(title),
|
||||
series=self._normalize_name(series),
|
||||
plot=plotly_dict,
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
def report_value_surface(self, title, series, data, iter, xlabels=None, ylabels=None,
|
||||
xtitle=None, ytitle=None, ztitle=None, camera=None, comment=None):
|
||||
"""
|
||||
Report a 3d surface (same data as heat-map matrix, only presented differently)
|
||||
|
||||
:param title: Title (AKA metric)
|
||||
:type title: str
|
||||
:param series: Series (AKA variant)
|
||||
:type series: str
|
||||
:param data: A heat-map matrix (example: confusion matrix)
|
||||
:type data: ndarray
|
||||
:param iter: Iteration number
|
||||
:type iter: int
|
||||
:param xlabels: optional label per column of the matrix
|
||||
:param ylabels: optional label per row of the matrix
|
||||
:param xtitle: optional x-axis title
|
||||
:param ytitle: optional y-axis title
|
||||
:param ztitle: optional z-axis title
|
||||
:param camera: X,Y,Z camera position. def: (1,1,1)
|
||||
:param comment: comment underneath the title
|
||||
"""
|
||||
|
||||
plotly_dict = create_3d_surface(
|
||||
np_value_matrix=data,
|
||||
title=title + '/' + series,
|
||||
xlabels=xlabels,
|
||||
ylabels=ylabels,
|
||||
series=series,
|
||||
xtitle=xtitle,
|
||||
ytitle=ytitle,
|
||||
ztitle=ztitle,
|
||||
camera=camera,
|
||||
comment=comment,
|
||||
)
|
||||
|
||||
return self.report_plot(
|
||||
title=self._normalize_name(title),
|
||||
series=self._normalize_name(series),
|
||||
plot=plotly_dict,
|
||||
iter=iter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _normalize_name(cls, name):
|
||||
if not name:
|
||||
return name
|
||||
return name.replace('$', '/').replace('.', '/')
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
# don't flush in case an exception was raised
|
||||
if not exc_type:
|
||||
self.flush()
|
||||
408
trains/backend_interface/model.py
Normal file
408
trains/backend_interface/model.py
Normal file
@@ -0,0 +1,408 @@
|
||||
from collections import namedtuple
|
||||
from functools import partial
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ..backend_api.services import models
|
||||
from .base import IdObjectBase
|
||||
from .util import make_message
|
||||
from ..storage import StorageHelper
|
||||
from ..utilities.async_manager import AsyncManagerMixin
|
||||
|
||||
ModelPackage = namedtuple('ModelPackage', 'weights design')
|
||||
|
||||
|
||||
class ModelDoesNotExistError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class _StorageUriMixin(object):
|
||||
@property
|
||||
def upload_storage_uri(self):
|
||||
""" A URI into which models are uploaded """
|
||||
return self._upload_storage_uri
|
||||
|
||||
@upload_storage_uri.setter
|
||||
def upload_storage_uri(self, value):
|
||||
self._upload_storage_uri = value.rstrip('/') if value else None
|
||||
|
||||
|
||||
class DummyModel(models.Model, _StorageUriMixin):
|
||||
def __init__(self, upload_storage_uri=None, *args, **kwargs):
|
||||
super(DummyModel, self).__init__(*args, **kwargs)
|
||||
self.upload_storage_uri = upload_storage_uri
|
||||
|
||||
def update(self, **kwargs):
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
|
||||
""" Manager for backend model objects """
|
||||
|
||||
_EMPTY_MODEL_ID = 'empty'
|
||||
|
||||
@property
|
||||
def model_id(self):
|
||||
return self.id
|
||||
|
||||
@property
|
||||
def storage(self):
|
||||
return StorageHelper.get(self.upload_storage_uri)
|
||||
|
||||
def __init__(self, upload_storage_uri, cache_dir, model_id=None,
|
||||
upload_storage_suffix='models', session=None, log=None):
|
||||
super(Model, self).__init__(id=model_id, session=session, log=log)
|
||||
self._upload_storage_suffix = upload_storage_suffix
|
||||
if model_id == self._EMPTY_MODEL_ID:
|
||||
# Set an empty data object
|
||||
self._data = models.Model()
|
||||
else:
|
||||
self._data = None
|
||||
self._cache_dir = cache_dir
|
||||
self.upload_storage_uri = upload_storage_uri
|
||||
|
||||
def publish(self):
|
||||
self.send(models.SetReadyRequest(model=self.id, publish_task=False))
|
||||
self.reload()
|
||||
|
||||
def _reload(self):
|
||||
""" Reload the model object """
|
||||
if self.id == self._EMPTY_MODEL_ID:
|
||||
return
|
||||
res = self.send(models.GetByIdRequest(model=self.id))
|
||||
return res.response.model
|
||||
|
||||
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
|
||||
if not self.upload_storage_uri:
|
||||
raise ValueError('Model has no storage URI defined (nowhere to upload to)')
|
||||
helper = self.storage
|
||||
target_filename = target_filename or Path(model_file).name
|
||||
dest_path = '/'.join((self.upload_storage_uri, self._upload_storage_suffix or '.', target_filename))
|
||||
result = helper.upload(
|
||||
src_path=model_file,
|
||||
dest_path=dest_path,
|
||||
async_enable=async_enable,
|
||||
cb=partial(self._upload_callback, cb=cb),
|
||||
)
|
||||
if async_enable:
|
||||
def msg(num_results):
|
||||
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
|
||||
|
||||
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
|
||||
return dest_path
|
||||
|
||||
def _upload_callback(self, res, cb=None):
|
||||
if res is None:
|
||||
self.log.debug('Starting model upload')
|
||||
elif res is False:
|
||||
self.log.info('Failed model upload')
|
||||
else:
|
||||
self.log.info('Completed model upload to %s' % res)
|
||||
if cb:
|
||||
cb(res)
|
||||
|
||||
@staticmethod
|
||||
def _wrap_design(design):
|
||||
"""
|
||||
Wrap design text with a dictionary.
|
||||
|
||||
In the backend, the design is a dictionary with a 'design' key in it.
|
||||
For the client, it is a text. This function wraps a design string with
|
||||
the proper dictionary.
|
||||
|
||||
:param design: If it is a dictionary, it mast have a 'design' key in it.
|
||||
In that case, return design as-is.
|
||||
If it is a string, return the dictionary {'design': design}.
|
||||
If it is None (or any False value), return the dictionary {'design': ''}
|
||||
|
||||
:return: A proper design dictionary according to design parameter.
|
||||
"""
|
||||
if isinstance(design, dict):
|
||||
if 'design' not in design:
|
||||
raise ValueError('design dictionary must have \'design\' key in it')
|
||||
|
||||
return design
|
||||
|
||||
return {'design': design if design else ''}
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_design(design):
|
||||
"""
|
||||
Unwrap design text from a dictionary.
|
||||
|
||||
In the backend, the design is a dictionary with a 'design' key in it.
|
||||
For the client, it is a text. This function unwraps a design string from
|
||||
the dictionary.
|
||||
|
||||
:param design: If it is a dictionary with a 'design' key in it, return
|
||||
design['design'].
|
||||
If it is a dictionary without 'design' key, return the first value
|
||||
in it's values list.
|
||||
If it is an empty dictionary, None, or any other False value,
|
||||
return an empty string.
|
||||
If it is a string, return design as-is.
|
||||
|
||||
:return: The design string according to design parameter.
|
||||
"""
|
||||
if not design:
|
||||
return ''
|
||||
|
||||
if isinstance(design, six.string_types):
|
||||
return design
|
||||
|
||||
if isinstance(design, dict):
|
||||
if 'design' in design:
|
||||
return design['design']
|
||||
|
||||
return list(design.values())[0]
|
||||
|
||||
raise ValueError('design must be a string or a dictionary with at least one value')
|
||||
|
||||
def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None,
|
||||
task_id=None, project_id=None, parent_id=None, uri=None, framework=None,
|
||||
upload_storage_uri=None, target_filename=None, iteration=None):
|
||||
""" Update model weights file and various model properties """
|
||||
|
||||
if self.id is None:
|
||||
if upload_storage_uri:
|
||||
self.upload_storage_uri = upload_storage_uri
|
||||
self._create_empty_model(self.upload_storage_uri)
|
||||
|
||||
# upload model file if needed and get uri
|
||||
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
|
||||
# update fields
|
||||
design = self._wrap_design(design) if design else self.data.design
|
||||
name = name or self.data.name
|
||||
comment = comment or self.data.comment
|
||||
tags = tags or self.data.tags
|
||||
labels = labels or self.data.labels
|
||||
task = task_id or self.data.task
|
||||
project = project_id or self.data.project
|
||||
parent = parent_id or self.data.parent
|
||||
|
||||
self.send(models.EditRequest(
|
||||
model=self.id,
|
||||
uri=uri,
|
||||
name=name,
|
||||
comment=comment,
|
||||
tags=tags,
|
||||
labels=labels,
|
||||
design=design,
|
||||
task=task,
|
||||
project=project,
|
||||
parent=parent,
|
||||
framework=framework or self.data.framework,
|
||||
iteration=iteration,
|
||||
))
|
||||
self.reload()
|
||||
|
||||
def update_and_upload(self, model_file, design=None, labels=None, name=None, comment=None,
|
||||
tags=None, task_id=None, project_id=None, parent_id=None, framework=None, async_enable=False,
|
||||
target_filename=None, cb=None, iteration=None):
|
||||
""" Update the given model for a given task ID """
|
||||
if async_enable:
|
||||
def callback(uploaded_uri):
|
||||
if uploaded_uri is None:
|
||||
return
|
||||
|
||||
# If not successful, mark model as failed_uploading
|
||||
if uploaded_uri is False:
|
||||
uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri)
|
||||
|
||||
self.update(
|
||||
uri=uploaded_uri,
|
||||
task_id=task_id,
|
||||
name=name,
|
||||
comment=comment,
|
||||
tags=tags,
|
||||
design=design,
|
||||
labels=labels,
|
||||
project_id=project_id,
|
||||
parent_id=parent_id,
|
||||
framework=framework,
|
||||
iteration=iteration,
|
||||
)
|
||||
|
||||
if cb:
|
||||
cb(model_file)
|
||||
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename, cb=callback)
|
||||
return uri
|
||||
else:
|
||||
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
|
||||
self.update(
|
||||
uri=uri,
|
||||
task_id=task_id,
|
||||
name=name,
|
||||
comment=comment,
|
||||
tags=tags,
|
||||
design=design,
|
||||
labels=labels,
|
||||
project_id=project_id,
|
||||
parent_id=parent_id,
|
||||
framework=framework,
|
||||
)
|
||||
|
||||
return uri
|
||||
|
||||
def _complete_update_for_task(self, uri, task_id=None, name=None, comment=None, tags=None, override_model_id=None,
|
||||
cb=None):
|
||||
if self._data:
|
||||
name = name or self.data.name
|
||||
comment = comment or self.data.comment
|
||||
tags = tags or self.data.tags
|
||||
uri = (uri or self.data.uri) if not override_model_id else None
|
||||
|
||||
res = self.send(
|
||||
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
|
||||
override_model_id=override_model_id))
|
||||
if self.id is None:
|
||||
# update the model id. in case it was just created, this will trigger a reload of the model object
|
||||
self.id = res.response.id
|
||||
else:
|
||||
self.reload()
|
||||
try:
|
||||
if cb:
|
||||
cb(uri)
|
||||
except Exception as ex:
|
||||
self.log.warning('Failed calling callback on complete_update_for_task: %s' % str(ex))
|
||||
pass
|
||||
|
||||
def update_for_task_and_upload(
|
||||
self, model_file, task_id, name=None, comment=None, tags=None, override_model_id=None, target_filename=None,
|
||||
async_enable=False, cb=None, iteration=None):
|
||||
""" Update the given model for a given task ID """
|
||||
if async_enable:
|
||||
callback = partial(
|
||||
self._complete_update_for_task, task_id=task_id, name=name, comment=comment, tags=tags,
|
||||
override_model_id=override_model_id, cb=cb)
|
||||
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable, cb=callback)
|
||||
return uri
|
||||
else:
|
||||
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
|
||||
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
|
||||
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
|
||||
override_model_id=override_model_id, iteration=iteration))
|
||||
return uri
|
||||
|
||||
def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None):
|
||||
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
|
||||
|
||||
@property
|
||||
def model_design(self):
|
||||
""" Get the model design. For now, this is stored as a single key in the design dict. """
|
||||
try:
|
||||
return self._unwrap_design(self.data.design)
|
||||
except ValueError:
|
||||
# no design is yet specified
|
||||
return None
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
try:
|
||||
return self.data.labels
|
||||
except ValueError:
|
||||
# no labels is yet specified
|
||||
return None
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
try:
|
||||
return self.data.name
|
||||
except ValueError:
|
||||
# no name is yet specified
|
||||
return None
|
||||
|
||||
@property
|
||||
def comment(self):
|
||||
try:
|
||||
return self.data.comment
|
||||
except ValueError:
|
||||
# no comment is yet specified
|
||||
return None
|
||||
|
||||
@property
|
||||
def tags(self):
|
||||
return self.data.tags
|
||||
|
||||
@property
|
||||
def locked(self):
|
||||
if self.id is None:
|
||||
return False
|
||||
return bool(self.data.ready)
|
||||
|
||||
def download_model_weights(self):
|
||||
""" Download the model weights into a local file in our cache """
|
||||
uri = self.data.uri
|
||||
helper = StorageHelper.get(uri, logger=self._log, verbose=True)
|
||||
return helper.download_to_file(uri, force_cache=True)
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
return self._cache_dir
|
||||
|
||||
def save_model_design_file(self):
|
||||
""" Download model description file into a local file in our cache_dir """
|
||||
design = self.model_design
|
||||
filename = self.data.name + '.txt'
|
||||
p = Path(self.cache_dir) / filename
|
||||
# we always write the original model design to file, to prevent any mishaps
|
||||
# if p.is_file():
|
||||
# return str(p)
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_text(six.text_type(design))
|
||||
return str(p)
|
||||
|
||||
def get_model_package(self):
|
||||
""" Get a named tuple containing the model's weights and design """
|
||||
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
|
||||
|
||||
def get_model_design(self):
|
||||
""" Get model description (text) """
|
||||
return self.model_design
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, session, log=None, **kwargs):
|
||||
req = models.GetAllRequest(**kwargs)
|
||||
res = cls._send(session=session, req=req, log=log)
|
||||
return res
|
||||
|
||||
def clone(self, name, comment=None, child=True, tags=None, task=None, ready=True):
|
||||
"""
|
||||
Clone this model into a new model.
|
||||
:param name: Name for the new model
|
||||
:param comment: Optional comment for the new model
|
||||
:param child: Should the new model be a child of this model? (default True)
|
||||
:return: The new model's ID
|
||||
"""
|
||||
data = self.data
|
||||
assert isinstance(data, models.Model)
|
||||
parent = self.id if child else None
|
||||
req = models.CreateRequest(
|
||||
uri=data.uri,
|
||||
name=name,
|
||||
labels=data.labels,
|
||||
comment=comment or data.comment,
|
||||
tags=tags or data.tags,
|
||||
framework=data.framework,
|
||||
design=data.design,
|
||||
ready=ready,
|
||||
project=data.project,
|
||||
parent=parent,
|
||||
task=task,
|
||||
)
|
||||
res = self.send(req)
|
||||
return res.response.id
|
||||
|
||||
def _create_empty_model(self, upload_storage_uri=None):
|
||||
upload_storage_uri = upload_storage_uri or self.upload_storage_uri
|
||||
name = make_message('Anonymous model %(time)s')
|
||||
uri = '{}/uploading_file'.format(upload_storage_uri or 'file://')
|
||||
req = models.CreateRequest(uri=uri, name=name, labels={})
|
||||
res = self.send(req)
|
||||
if not res:
|
||||
return False
|
||||
self.id = res.response.id
|
||||
return True
|
||||
28
trains/backend_interface/session.py
Normal file
28
trains/backend_interface/session.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class SendError(Exception):
|
||||
""" A session send() error class """
|
||||
@property
|
||||
def result(self):
|
||||
return self._result
|
||||
|
||||
def __init__(self, result, *args, **kwargs):
|
||||
super(SendError, self).__init__(*args, **kwargs)
|
||||
self._result = result
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class SessionInterface(object):
|
||||
""" Session wrapper interface providing a session property and a send convenience method """
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def session(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
|
||||
pass
|
||||
43
trains/backend_interface/setupuploadmixin.py
Normal file
43
trains/backend_interface/setupuploadmixin.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from abc import abstractproperty
|
||||
|
||||
from ..backend_config.bucket_config import S3BucketConfig
|
||||
from ..storage import StorageHelper
|
||||
|
||||
|
||||
class SetupUploadMixin(object):
|
||||
log = abstractproperty()
|
||||
storage_uri = abstractproperty()
|
||||
|
||||
def setup_upload(
|
||||
self, bucket_name, host=None, access_key=None, secret_key=None, region=None, multipart=True, https=True):
|
||||
"""
|
||||
Setup upload options (currently only S3 is supported)
|
||||
:param bucket_name: AWS bucket name
|
||||
:type bucket_name: str
|
||||
:param host: Hostname. Only required in case a Non-AWS S3 solution such as a local Minio server is used)
|
||||
:type host: str
|
||||
:param access_key: AWS access key. If not provided, we'll attempt to obtain the key from the
|
||||
configuration file (bucket-specific, than global)
|
||||
:type access_key: str
|
||||
:param secret_key: AWS secret key. If not provided, we'll attempt to obtain the secret from the
|
||||
configuration file (bucket-specific, than global)
|
||||
:type secret_key: str
|
||||
:param multipart: Server supports multipart. Only required when using a Non-AWS S3 solution that doesn't support
|
||||
multipart.
|
||||
:type multipart: bool
|
||||
:param https: Server supports HTTPS. Only required when using a Non-AWS S3 solution that only supports HTTPS.
|
||||
:type https: bool
|
||||
:param region: Bucket region. Required if the bucket doesn't reside in the default region (us-east-1)
|
||||
:type region: str
|
||||
"""
|
||||
self._bucket_config = S3BucketConfig(
|
||||
bucket=bucket_name,
|
||||
host=host,
|
||||
key=access_key,
|
||||
secret=secret_key,
|
||||
multipart=multipart,
|
||||
secure=https,
|
||||
region=region
|
||||
)
|
||||
self.storage_uri = ('s3://%(host)s/%(bucket_name)s' if host else 's3://%(bucket_name)s') % locals()
|
||||
StorageHelper.add_configuration(self._bucket_config, log=self.log)
|
||||
1
trains/backend_interface/task/__init__.py
Normal file
1
trains/backend_interface/task/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .task import Task, TaskEntry, TaskStatusEnum
|
||||
85
trains/backend_interface/task/access.py
Normal file
85
trains/backend_interface/task/access.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import itertools
|
||||
import operator
|
||||
|
||||
from abc import abstractproperty
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
|
||||
class AccessMixin(object):
|
||||
""" A mixin providing task fields access functionality """
|
||||
session = abstractproperty()
|
||||
data = abstractproperty()
|
||||
cache_dir = abstractproperty()
|
||||
log = abstractproperty()
|
||||
|
||||
def _get_task_property(self, prop_path, raise_on_error=True, log_on_error=True, default=None):
|
||||
obj = self.data
|
||||
props = prop_path.split('.')
|
||||
for i in range(len(props)):
|
||||
obj = getattr(obj, props[i], None)
|
||||
if obj is None:
|
||||
msg = 'Task has no %s section defined' % '.'.join(props[:i + 1])
|
||||
if log_on_error:
|
||||
self.log.info(msg)
|
||||
if raise_on_error:
|
||||
raise ValueError(msg)
|
||||
return default
|
||||
return obj
|
||||
|
||||
def _set_task_property(self, prop_path, value, raise_on_error=True, log_on_error=True):
|
||||
props = prop_path.split('.')
|
||||
if len(props) > 1:
|
||||
obj = self._get_task_property('.'.join(props[:-1]), raise_on_error=raise_on_error,
|
||||
log_on_error=log_on_error)
|
||||
else:
|
||||
obj = self.data
|
||||
setattr(obj, props[-1], value)
|
||||
|
||||
def save_exec_model_design_file(self, filename='model_design.txt', use_cache=False):
|
||||
""" Save execution model design to file """
|
||||
p = Path(self.cache_dir) / filename
|
||||
if use_cache and p.is_file():
|
||||
return str(p)
|
||||
desc = self._get_task_property('execution.model_desc')
|
||||
try:
|
||||
design = six.next(six.itervalues(desc))
|
||||
except StopIteration:
|
||||
design = None
|
||||
if not design:
|
||||
raise ValueError('Task has no design in execution.model_desc')
|
||||
p.parent.mkdir(parents=True, exist_ok=True)
|
||||
p.write_text('%s' % design)
|
||||
return str(p)
|
||||
|
||||
def get_parameters(self):
|
||||
return self._get_task_property('execution.parameters')
|
||||
|
||||
def get_label_num_description(self):
|
||||
""" Get a dict of label number to a string representing all labels associated with this number on the
|
||||
model labels
|
||||
"""
|
||||
model_labels = self._get_task_property('execution.model_labels')
|
||||
label_getter = operator.itemgetter(0)
|
||||
num_getter = operator.itemgetter(1)
|
||||
groups = list(itertools.groupby(sorted(model_labels.items(), key=num_getter), key=num_getter))
|
||||
if any(len(set(label_getter(x) for x in group)) > 1 for _, group in groups):
|
||||
raise ValueError("Multiple labels mapped to same model index not supported")
|
||||
return {key: ','.join(label_getter(x) for x in group) for key, group in groups}
|
||||
|
||||
def get_output_destination(self, extra_path=None, **kwargs):
|
||||
""" Get the task's output destination, with an optional suffix """
|
||||
return self._get_task_property('output.destination', **kwargs)
|
||||
|
||||
def get_num_of_classes(self):
|
||||
""" number of classes based on the task's labels """
|
||||
model_labels = self.data.execution.model_labels
|
||||
expected_num_of_classes = 0
|
||||
for labels, index in model_labels.items():
|
||||
expected_num_of_classes += 1 if int(index) > 0 else 0
|
||||
num_of_classes = int(max(model_labels.values()))
|
||||
if num_of_classes != expected_num_of_classes:
|
||||
self.log.warn('The highest label index is %d, while there are %d non-bg labels' %
|
||||
(num_of_classes, expected_num_of_classes))
|
||||
return num_of_classes + 1 # +1 is meant for bg!
|
||||
314
trains/backend_interface/task/args.py
Normal file
314
trains/backend_interface/task/args.py
Normal file
@@ -0,0 +1,314 @@
|
||||
import yaml
|
||||
|
||||
from six import PY2
|
||||
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS
|
||||
from copy import copy
|
||||
|
||||
from ...utilities.args import call_original_argparser
|
||||
|
||||
|
||||
class _Arguments(object):
|
||||
_prefix_sep = '/'
|
||||
# TODO: separate dict and argparse after we add UI support
|
||||
_prefix_dict = 'dict' + _prefix_sep
|
||||
_prefix_args = 'argparse' + _prefix_sep
|
||||
_prefix_tf_defines = 'TF_DEFINE' + _prefix_sep
|
||||
|
||||
class _ProxyDictWrite(dict):
|
||||
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
|
||||
def __init__(self, arguments, *args, **kwargs):
|
||||
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
|
||||
self._arguments = arguments
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
super(_Arguments._ProxyDictWrite, self).__setitem__(key, value)
|
||||
if self._arguments:
|
||||
self._arguments.copy_from_dict(self)
|
||||
|
||||
class _ProxyDictReadOnly(dict):
|
||||
""" Dictionary wrapper that prevents modifications to the dictionary """
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
pass
|
||||
|
||||
def __init__(self, task):
|
||||
super(_Arguments, self).__init__()
|
||||
self._task = task
|
||||
|
||||
def set_defaults(self, *dicts, **kwargs):
|
||||
self._task.set_parameters(*dicts, **kwargs)
|
||||
|
||||
def add_argument(self, option_strings, type=None, default=None, help=None):
|
||||
if not option_strings:
|
||||
raise Exception('Expected at least one argument name (option string)')
|
||||
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
|
||||
# TODO: add argparse prefix
|
||||
# name = self._prefix_args + name
|
||||
self._task.set_parameter(name=name, value=default, description=help)
|
||||
|
||||
def connect(self, parser):
|
||||
self._task.connect_argparse(parser)
|
||||
|
||||
@classmethod
|
||||
def _add_to_defaults(cls, a_parser, defaults, a_args=None, a_namespace=None, a_parsed_args=None):
|
||||
actions = [
|
||||
a for a in a_parser._actions
|
||||
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
|
||||
]
|
||||
args_dict = {}
|
||||
try:
|
||||
if isinstance(a_parsed_args, dict):
|
||||
args_dict = a_parsed_args
|
||||
else:
|
||||
if a_parsed_args:
|
||||
args_dict = a_parsed_args.__dict__
|
||||
else:
|
||||
args_dict = call_original_argparser(a_parser, args=a_args, namespace=a_namespace).__dict__
|
||||
defaults_ = {
|
||||
a.dest: args_dict.get(a.dest) if (args_dict.get(a.dest) is not None) else ''
|
||||
for a in actions
|
||||
}
|
||||
except Exception:
|
||||
# don't crash us if we failed parsing the inputs
|
||||
defaults_ = {
|
||||
a.dest: a.default if a.default is not None else ''
|
||||
for a in actions
|
||||
}
|
||||
|
||||
full_args_dict = copy(defaults)
|
||||
full_args_dict.update(args_dict)
|
||||
defaults.update(defaults_)
|
||||
|
||||
# deal with sub parsers
|
||||
sub_parsers = [
|
||||
a for a in a_parser._actions
|
||||
if isinstance(a, _SubParsersAction)
|
||||
]
|
||||
for sub_parser in sub_parsers:
|
||||
if sub_parser.dest and sub_parser.dest != SUPPRESS:
|
||||
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest)
|
||||
for choice in sub_parser.choices.values():
|
||||
# recursively parse
|
||||
defaults = cls._add_to_defaults(
|
||||
a_parser=choice,
|
||||
defaults=defaults,
|
||||
a_parsed_args=a_parsed_args or full_args_dict
|
||||
)
|
||||
|
||||
return defaults
|
||||
|
||||
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
|
||||
task_defaults = {}
|
||||
self._add_to_defaults(parser, task_defaults, args, namespace, parsed_args)
|
||||
|
||||
# Make sure we didn't miss anything
|
||||
if parsed_args:
|
||||
for k, v in parsed_args.__dict__.items():
|
||||
if k not in task_defaults:
|
||||
if type(v) == None:
|
||||
task_defaults[k] = ''
|
||||
elif type(v) in (str, int, float, bool, list):
|
||||
task_defaults[k] = v
|
||||
|
||||
# Verify arguments
|
||||
for k, v in task_defaults.items():
|
||||
try:
|
||||
if type(v) is list:
|
||||
task_defaults[k] = '[' + ', '.join(map("{0}".format, v)) + ']'
|
||||
elif type(v) not in (str, int, float, bool):
|
||||
task_defaults[k] = str(v)
|
||||
except Exception:
|
||||
del task_defaults[k]
|
||||
# Add prefix, TODO: add argparse prefix
|
||||
# task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()])
|
||||
task_defaults = dict([(k, v) for k, v in task_defaults.items()])
|
||||
# Store to task
|
||||
self._task.update_parameters(task_defaults)
|
||||
|
||||
@classmethod
|
||||
def _find_parser_action(cls, a_parser, name):
|
||||
# find by name
|
||||
_actions = [(a_parser, a) for a in a_parser._actions if a.dest == name]
|
||||
if _actions:
|
||||
return _actions
|
||||
# iterate over subparsers
|
||||
_actions = []
|
||||
sub_parsers = [a for a in a_parser._actions if isinstance(a, _SubParsersAction)]
|
||||
for sub_parser in sub_parsers:
|
||||
for choice in sub_parser.choices.values():
|
||||
# recursively parse
|
||||
_action = cls._find_parser_action(choice, name)
|
||||
if _action:
|
||||
_actions.extend(_action)
|
||||
return _actions
|
||||
|
||||
def copy_to_parser(self, parser, parsed_args):
|
||||
# todo: change to argparse prefix only
|
||||
# task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
|
||||
# if k.startswith(self._prefix_args)])
|
||||
task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items()
|
||||
if not k.startswith(self._prefix_tf_defines)])
|
||||
for k, v in task_arguments.items():
|
||||
# if we have a StoreTrueAction and the value is either False or Empty or 0 change the default to False
|
||||
# with the rest we have to make sure the type is correct
|
||||
matched_actions = self._find_parser_action(parser, k)
|
||||
for parent_parser, current_action in matched_actions:
|
||||
if current_action and isinstance(current_action, _StoreConstAction):
|
||||
# make the default value boolean
|
||||
# first check if False value
|
||||
const_value = current_action.const if current_action.const is not None else (
|
||||
current_action.default if current_action.default is not None else True)
|
||||
const_type = type(const_value)
|
||||
strip_v = str(v).lower().strip()
|
||||
if const_type == bool:
|
||||
if strip_v == 'false' or not strip_v:
|
||||
const_value = False
|
||||
elif strip_v == 'true':
|
||||
const_value = True
|
||||
else:
|
||||
# first try to cast to integer
|
||||
try:
|
||||
const_value = int(strip_v)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
const_value = strip_v
|
||||
# then cast to const type (might be boolean)
|
||||
try:
|
||||
const_value = const_type(const_value)
|
||||
current_action.const = const_value
|
||||
except ValueError:
|
||||
pass
|
||||
task_arguments[k] = const_value
|
||||
elif current_action and current_action.nargs == '+':
|
||||
try:
|
||||
v = yaml.load(v.strip())
|
||||
if current_action.type:
|
||||
v = [current_action.type(a) for a in v]
|
||||
elif current_action.default:
|
||||
v_type = type(current_action.default[0])
|
||||
v = [v_type(a) for a in v]
|
||||
task_arguments[k] = v
|
||||
except Exception:
|
||||
pass
|
||||
elif current_action and not current_action.type:
|
||||
# cast manually if there is no type
|
||||
var_type = type(current_action.default)
|
||||
# if we have an int, we should cast to float, because it is more generic
|
||||
if var_type == int:
|
||||
var_type = float
|
||||
elif var_type == type(None):
|
||||
var_type = str
|
||||
# now we should try and cast the value if we can
|
||||
try:
|
||||
v = var_type(v)
|
||||
task_arguments[k] = v
|
||||
except Exception:
|
||||
pass
|
||||
# add as default
|
||||
try:
|
||||
if current_action and isinstance(current_action, _SubParsersAction):
|
||||
current_action.default = v
|
||||
current_action.required = False
|
||||
elif current_action and isinstance(current_action, _StoreAction):
|
||||
current_action.default = v
|
||||
current_action.required = False
|
||||
# python2 doesn't support defaults for positional arguments, unless used with nargs=?
|
||||
if PY2 and not current_action.nargs:
|
||||
current_action.nargs = '?'
|
||||
else:
|
||||
parent_parser.add_argument(
|
||||
'--%s' % k,
|
||||
default=v,
|
||||
type=type(v),
|
||||
required=False,
|
||||
help='Task parameter %s (default %s)' % (k, v),
|
||||
)
|
||||
except ArgumentError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
# if we already have an instance of parsed args, we should update its values
|
||||
if parsed_args:
|
||||
for k, v in task_arguments.items():
|
||||
setattr(parsed_args, k, v)
|
||||
parser.set_defaults(**task_arguments)
|
||||
|
||||
def copy_from_dict(self, dictionary, prefix=None):
|
||||
# TODO: add dict prefix
|
||||
prefix = prefix or '' # self._prefix_dict
|
||||
if prefix:
|
||||
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
|
||||
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
|
||||
cur_params.update(prefix_dictionary)
|
||||
self._task.set_parameters(cur_params)
|
||||
else:
|
||||
self._task.update_parameters(dictionary)
|
||||
if not isinstance(dictionary, self._ProxyDictWrite):
|
||||
return self._ProxyDictWrite(self, **dictionary)
|
||||
return dictionary
|
||||
|
||||
def copy_to_dict(self, dictionary, prefix=None):
|
||||
# iterate over keys and merge values according to parameter type in dictionary
|
||||
# TODO: add dict prefix
|
||||
prefix = prefix or '' # self._prefix_dict
|
||||
if prefix:
|
||||
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
|
||||
if k.startswith(prefix)])
|
||||
else:
|
||||
parameters = dict([(k, v) for k, v in self._task.get_parameters().items()
|
||||
if not k.startswith(self._prefix_tf_defines)])
|
||||
|
||||
for k, v in dictionary.items():
|
||||
param = parameters.get(k, None)
|
||||
if param is None:
|
||||
continue
|
||||
v_type = type(v)
|
||||
# assume more general purpose type int -> float
|
||||
if v_type == int:
|
||||
v_type = float
|
||||
elif v_type == bool:
|
||||
# cast based on string or int
|
||||
try:
|
||||
param = bool(float(param))
|
||||
except ValueError:
|
||||
try:
|
||||
param = str(param).lower().strip() == 'true'
|
||||
except ValueError:
|
||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
||||
(str(k), str(param), str(k), str(v)))
|
||||
continue
|
||||
elif v_type == list:
|
||||
try:
|
||||
p = str(param).strip()
|
||||
param = yaml.load(p)
|
||||
except Exception:
|
||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
||||
(str(k), str(param), str(k), str(v)))
|
||||
continue
|
||||
elif v_type == dict:
|
||||
try:
|
||||
p = str(param).strip()
|
||||
param = yaml.load(p)
|
||||
except Exception:
|
||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
||||
(str(k), str(param), str(k), str(v)))
|
||||
elif v_type == type(None):
|
||||
v_type = str
|
||||
|
||||
try:
|
||||
dictionary[k] = v_type(param)
|
||||
except ValueError:
|
||||
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
|
||||
(str(k), str(param), str(k), str(v)))
|
||||
continue
|
||||
# add missing parameters to dictionary
|
||||
for k, v in parameters.items():
|
||||
if k not in dictionary:
|
||||
dictionary[k] = v
|
||||
|
||||
if not isinstance(dictionary, self._ProxyDictReadOnly):
|
||||
return self._ProxyDictReadOnly(**dictionary)
|
||||
return dictionary
|
||||
48
trains/backend_interface/task/development/stop_signal.py
Normal file
48
trains/backend_interface/task/development/stop_signal.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from ....config import config
|
||||
from ....backend_interface import Task, TaskStatusEnum
|
||||
|
||||
|
||||
class TaskStopReason(object):
|
||||
stopped = "stopped"
|
||||
reset = "reset"
|
||||
status_changed = "status_changed"
|
||||
|
||||
|
||||
class TaskStopSignal(object):
|
||||
enabled = bool(config.get('development.support_stopping', False))
|
||||
|
||||
_number_of_consecutive_reset_tests = 4
|
||||
|
||||
_unexpected_statuses = (
|
||||
TaskStatusEnum.closed,
|
||||
TaskStatusEnum.stopped,
|
||||
TaskStatusEnum.failed,
|
||||
TaskStatusEnum.published,
|
||||
)
|
||||
|
||||
def __init__(self, task):
|
||||
assert isinstance(task, Task)
|
||||
self.task = task
|
||||
self._task_reset_state_counter = 0
|
||||
|
||||
def test(self):
|
||||
status = self.task.status
|
||||
message = self.task.data.status_message
|
||||
|
||||
if status == TaskStatusEnum.in_progress and "stopping" in message:
|
||||
return TaskStopReason.stopped
|
||||
|
||||
if status in self._unexpected_statuses and "worker" not in message:
|
||||
return TaskStopReason.status_changed
|
||||
|
||||
if status == TaskStatusEnum.created:
|
||||
self._task_reset_state_counter += 1
|
||||
|
||||
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:
|
||||
return TaskStopReason.reset
|
||||
|
||||
self.task.get_logger().warning(
|
||||
"Task {} was reset! if state is consistent we shall terminate.".format(self.task.id),
|
||||
)
|
||||
else:
|
||||
self._task_reset_state_counter = 0
|
||||
26
trains/backend_interface/task/development/worker.py
Normal file
26
trains/backend_interface/task/development/worker.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from socket import gethostname
|
||||
|
||||
import attr
|
||||
|
||||
from ....config import config, running_remotely, dev_worker_name
|
||||
|
||||
|
||||
@attr.s
|
||||
class DevWorker(object):
|
||||
prefix = attr.ib(type=str, default="MANUAL:")
|
||||
|
||||
report_period = float(config.get('development.worker.report_period_sec', 30.))
|
||||
report_stdout = bool(config.get('development.worker.log_stdout', True))
|
||||
|
||||
@classmethod
|
||||
def is_enabled(cls, model_updated=False):
|
||||
return False
|
||||
|
||||
def status_report(self, timestamp=None):
|
||||
return True
|
||||
|
||||
def register(self):
|
||||
return True
|
||||
|
||||
def unregister(self):
|
||||
return True
|
||||
110
trains/backend_interface/task/log.py
Normal file
110
trains/backend_interface/task/log.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import time
|
||||
from logging import LogRecord, getLogger, basicConfig
|
||||
from logging.handlers import BufferingHandler
|
||||
|
||||
from ...backend_api.services import events
|
||||
from ...config import config
|
||||
|
||||
buffer_capacity = config.get('log.task_log_buffer_capacity', 100)
|
||||
|
||||
|
||||
class TaskHandler(BufferingHandler):
|
||||
__flush_max_history_seconds = 30.
|
||||
__once = False
|
||||
|
||||
@property
|
||||
def task_id(self):
|
||||
return self._task_id
|
||||
|
||||
@task_id.setter
|
||||
def task_id(self, value):
|
||||
self._task_id = value
|
||||
|
||||
def __init__(self, session, task_id, capacity=buffer_capacity):
|
||||
super(TaskHandler, self).__init__(capacity)
|
||||
self.task_id = task_id
|
||||
self.session = session
|
||||
self.last_timestamp = 0
|
||||
self.counter = 1
|
||||
self._last_event = None
|
||||
|
||||
def shouldFlush(self, record):
|
||||
"""
|
||||
Should the handler flush its buffer?
|
||||
|
||||
Returns true if the buffer is up to capacity. This method can be
|
||||
overridden to implement custom flushing strategies.
|
||||
"""
|
||||
|
||||
# Notice! protect against infinite loops, i.e. flush while sending previous records
|
||||
# if self.lock._is_owned():
|
||||
# return False
|
||||
|
||||
# if we need to add handlers to the base_logger,
|
||||
# it will not automatically create stream one when first used, so we must manually configure it.
|
||||
if not TaskHandler.__once:
|
||||
base_logger = getLogger()
|
||||
if len(base_logger.handlers) == 1 and isinstance(base_logger.handlers[0], TaskHandler):
|
||||
if record.name != 'console' and not record.name.startswith('trains.'):
|
||||
base_logger.removeHandler(self)
|
||||
basicConfig()
|
||||
base_logger.addHandler(self)
|
||||
TaskHandler.__once = True
|
||||
else:
|
||||
TaskHandler.__once = True
|
||||
|
||||
# if we passed the max buffer
|
||||
if len(self.buffer) >= self.capacity:
|
||||
return True
|
||||
|
||||
# if the first entry in the log was too long ago.
|
||||
if len(self.buffer) and (time.time() - self.buffer[0].created) > self.__flush_max_history_seconds:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _record_to_event(self, record):
|
||||
# type: (LogRecord) -> events.TaskLogEvent
|
||||
timestamp = int(record.created * 1000)
|
||||
if timestamp == self.last_timestamp:
|
||||
timestamp += self.counter
|
||||
self.counter += 1
|
||||
else:
|
||||
self.last_timestamp = timestamp
|
||||
self.counter = 1
|
||||
|
||||
# unite all records in a single second
|
||||
if self._last_event and timestamp - self._last_event.timestamp < 1000 and \
|
||||
record.levelname.lower() == str(self._last_event.level):
|
||||
# ignore backspaces (they are often used)
|
||||
self._last_event.msg += '\n' + record.getMessage().replace('\x08', '')
|
||||
return None
|
||||
|
||||
self._last_event = events.TaskLogEvent(
|
||||
task=self.task_id,
|
||||
timestamp=timestamp,
|
||||
level=record.levelname.lower(),
|
||||
worker=self.session.worker,
|
||||
msg=record.getMessage().replace('\x08', '') # ignore backspaces (they are often used)
|
||||
)
|
||||
return self._last_event
|
||||
|
||||
def flush(self):
|
||||
if not self.buffer:
|
||||
return
|
||||
self.acquire()
|
||||
buffer = self.buffer
|
||||
try:
|
||||
if not buffer:
|
||||
return
|
||||
self.buffer = []
|
||||
record_events = [self._record_to_event(record) for record in buffer]
|
||||
self._last_event = None
|
||||
requests = [events.AddRequest(e) for e in record_events if e]
|
||||
res = self.session.send(events.AddBatchRequest(requests=requests))
|
||||
if not res.ok():
|
||||
print("Failed logging task to backend ({:d} lines, {})".format(len(buffer), str(res.meta)))
|
||||
except Exception:
|
||||
print("Failed logging task to backend ({:d} lines)".format(len(buffer)))
|
||||
finally:
|
||||
self.release()
|
||||
2
trains/backend_interface/task/repo/__init__.py
Normal file
2
trains/backend_interface/task/repo/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .scriptinfo import ScriptInfo
|
||||
from .freeze import pip_freeze
|
||||
248
trains/backend_interface/task/repo/detectors.py
Normal file
248
trains/backend_interface/task/repo/detectors.py
Normal file
@@ -0,0 +1,248 @@
|
||||
import abc
|
||||
import os
|
||||
from subprocess import call, CalledProcessError
|
||||
|
||||
import attr
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ....config.defs import (
|
||||
VCS_REPO_TYPE,
|
||||
VCS_DIFF,
|
||||
VCS_STATUS,
|
||||
VCS_ROOT,
|
||||
VCS_BRANCH,
|
||||
VCS_COMMIT_ID,
|
||||
VCS_REPOSITORY_URL,
|
||||
)
|
||||
from ....debugging import get_logger
|
||||
from .util import get_command_output
|
||||
|
||||
_logger = get_logger("Repository Detection")
|
||||
|
||||
|
||||
class DetectionError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@attr.s
|
||||
class Result(object):
|
||||
"""" Repository information as queried by a detector """
|
||||
|
||||
url = attr.ib(default="")
|
||||
branch = attr.ib(default="")
|
||||
commit = attr.ib(default="")
|
||||
root = attr.ib(default="")
|
||||
status = attr.ib(default="")
|
||||
diff = attr.ib(default="")
|
||||
modified = attr.ib(default=False, type=bool, converter=bool)
|
||||
|
||||
def is_empty(self):
|
||||
return not any(attr.asdict(self).values())
|
||||
|
||||
|
||||
@six.add_metaclass(abc.ABCMeta)
|
||||
class Detector(object):
|
||||
""" Base class for repository detection """
|
||||
|
||||
"""
|
||||
Commands are represented using the result class, where each attribute contains
|
||||
the command used to obtain the value of the same attribute in the actual result.
|
||||
"""
|
||||
|
||||
@attr.s
|
||||
class Commands(object):
|
||||
"""" Repository information as queried by a detector """
|
||||
|
||||
url = attr.ib(default=None, type=list)
|
||||
branch = attr.ib(default=None, type=list)
|
||||
commit = attr.ib(default=None, type=list)
|
||||
root = attr.ib(default=None, type=list)
|
||||
status = attr.ib(default=None, type=list)
|
||||
diff = attr.ib(default=None, type=list)
|
||||
modified = attr.ib(default=None, type=list)
|
||||
|
||||
def __init__(self, type_name, name=None):
|
||||
self.type_name = type_name
|
||||
self.name = name or type_name
|
||||
|
||||
def _get_commands(self):
|
||||
""" Returns a RepoInfo instance containing a command for each info attribute """
|
||||
return self.Commands()
|
||||
|
||||
def _get_command_output(self, path, name, command):
|
||||
""" Run a command and return its output """
|
||||
try:
|
||||
return get_command_output(command, path)
|
||||
|
||||
except (CalledProcessError, UnicodeDecodeError) as ex:
|
||||
_logger.warning(
|
||||
"Can't get {} information for {} repo in {}: {}".format(
|
||||
name, self.type_name, path, str(ex)
|
||||
)
|
||||
)
|
||||
return ""
|
||||
|
||||
def _get_info(self, path, include_diff=False):
|
||||
"""
|
||||
Get repository information.
|
||||
:param path: Path to repository
|
||||
:param include_diff: Whether to include the diff command's output (if available)
|
||||
:return: RepoInfo instance
|
||||
"""
|
||||
path = str(path)
|
||||
commands = self._get_commands()
|
||||
if not include_diff:
|
||||
commands.diff = None
|
||||
|
||||
info = Result(
|
||||
**{
|
||||
name: self._get_command_output(path, name, command)
|
||||
for name, command in attr.asdict(commands).items()
|
||||
if command
|
||||
}
|
||||
)
|
||||
|
||||
return info
|
||||
|
||||
def _post_process_info(self, info):
|
||||
# check if there are uncommitted changes in the current repository
|
||||
return info
|
||||
|
||||
def get_info(self, path, include_diff=False):
|
||||
"""
|
||||
Get repository information.
|
||||
:param path: Path to repository
|
||||
:param include_diff: Whether to include the diff command's output (if available)
|
||||
:return: RepoInfo instance
|
||||
"""
|
||||
info = self._get_info(path, include_diff)
|
||||
return self._post_process_info(info)
|
||||
|
||||
def _is_repo_type(self, script_path):
|
||||
try:
|
||||
with open(os.devnull, "wb") as devnull:
|
||||
return (
|
||||
call(
|
||||
[self.type_name, "status"],
|
||||
stderr=devnull,
|
||||
stdout=devnull,
|
||||
cwd=str(script_path),
|
||||
)
|
||||
== 0
|
||||
)
|
||||
except CalledProcessError:
|
||||
_logger.warning("Can't get {} status".format(self.type_name))
|
||||
except (OSError, EnvironmentError, IOError):
|
||||
# File not found or can't be executed
|
||||
pass
|
||||
return False
|
||||
|
||||
def exists(self, script_path):
|
||||
"""
|
||||
Test whether the given script resides in
|
||||
a repository type represented by this plugin.
|
||||
"""
|
||||
return self._is_repo_type(script_path)
|
||||
|
||||
|
||||
class HgDetector(Detector):
|
||||
def __init__(self):
|
||||
super(HgDetector, self).__init__("hg")
|
||||
|
||||
def _get_commands(self):
|
||||
return self.Commands(
|
||||
url=["hg", "paths", "--verbose"],
|
||||
branch=["hg", "--debug", "id", "-b"],
|
||||
commit=["hg", "--debug", "id", "-i"],
|
||||
root=["hg", "root"],
|
||||
status=["hg", "status"],
|
||||
diff=["hg", "diff"],
|
||||
modified=["hg", "status", "-m"],
|
||||
)
|
||||
|
||||
def _post_process_info(self, info):
|
||||
if info.url:
|
||||
info.url = info.url.split(" = ")[1]
|
||||
|
||||
if info.commit:
|
||||
info.commit = info.commit.rstrip("+")
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class GitDetector(Detector):
|
||||
def __init__(self):
|
||||
super(GitDetector, self).__init__("git")
|
||||
|
||||
def _get_commands(self):
|
||||
return self.Commands(
|
||||
url=["git", "remote", "get-url", "origin"],
|
||||
branch=["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"],
|
||||
commit=["git", "rev-parse", "HEAD"],
|
||||
root=["git", "rev-parse", "--show-toplevel"],
|
||||
status=["git", "status", "-s"],
|
||||
diff=["git", "diff"],
|
||||
modified=["git", "ls-files", "-m"],
|
||||
)
|
||||
|
||||
def _post_process_info(self, info):
|
||||
if info.url and not info.url.endswith(".git"):
|
||||
info.url += ".git"
|
||||
|
||||
if (info.branch or "").startswith("origin/"):
|
||||
info.branch = info.branch[len("origin/") :]
|
||||
|
||||
return info
|
||||
|
||||
|
||||
class EnvDetector(Detector):
|
||||
def __init__(self, type_name):
|
||||
super(EnvDetector, self).__init__(type_name, "{} environment".format(type_name))
|
||||
|
||||
def _is_repo_type(self, script_path):
|
||||
return VCS_REPO_TYPE.get(default="").lower() == self.type_name and bool(
|
||||
VCS_REPOSITORY_URL.get()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_root(root):
|
||||
"""
|
||||
Get the absolute location of the parent folder (where .git resides)
|
||||
"""
|
||||
root_parts = list(reversed(Path(root).parts))
|
||||
cwd_abs = list(reversed(Path.cwd().parts))
|
||||
count = len(cwd_abs)
|
||||
for i, p in enumerate(cwd_abs):
|
||||
if i >= len(root_parts):
|
||||
break
|
||||
if p == root_parts[i]:
|
||||
count -= 1
|
||||
cwd_abs.reverse()
|
||||
root_abs_path = Path().joinpath(*cwd_abs[:count])
|
||||
return str(root_abs_path)
|
||||
|
||||
def _get_info(self, _, include_diff=False):
|
||||
repository_url = VCS_REPOSITORY_URL.get()
|
||||
|
||||
if not repository_url:
|
||||
raise DetectionError("No VCS environment data")
|
||||
|
||||
return Result(
|
||||
url=repository_url,
|
||||
branch=VCS_BRANCH.get(),
|
||||
commit=VCS_COMMIT_ID.get(),
|
||||
root=VCS_ROOT.get(converter=self._normalize_root),
|
||||
status=VCS_STATUS.get(),
|
||||
diff=VCS_DIFF.get(),
|
||||
)
|
||||
|
||||
|
||||
class GitEnvDetector(EnvDetector):
|
||||
def __init__(self):
|
||||
super(GitEnvDetector, self).__init__("git")
|
||||
|
||||
|
||||
class HgEnvDetector(EnvDetector):
|
||||
def __init__(self):
|
||||
super(HgEnvDetector, self).__init__("hg")
|
||||
11
trains/backend_interface/task/repo/freeze.py
Normal file
11
trains/backend_interface/task/repo/freeze.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import sys
|
||||
|
||||
from .util import get_command_output
|
||||
|
||||
|
||||
def pip_freeze():
|
||||
try:
|
||||
return get_command_output([sys.executable, "-m", "pip", "freeze"]).splitlines()
|
||||
except Exception as ex:
|
||||
print('Failed calling "pip freeze": {}'.format(str(ex)))
|
||||
return []
|
||||
162
trains/backend_interface/task/repo/scriptinfo.py
Normal file
162
trains/backend_interface/task/repo/scriptinfo.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import attr
|
||||
from furl import furl
|
||||
from pathlib2 import Path
|
||||
|
||||
from ....debugging import get_logger
|
||||
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
|
||||
|
||||
_logger = get_logger("Repository Detection")
|
||||
|
||||
|
||||
class ScriptInfoError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ScriptInfo(object):
|
||||
|
||||
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
|
||||
""" Script info detection plugins, in order of priority """
|
||||
|
||||
@classmethod
|
||||
def _get_jupyter_notebook_filename(cls):
|
||||
if not sys.argv[0].endswith('/ipykernel_launcher.py') or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
|
||||
return None
|
||||
|
||||
# we can safely assume that we can import the notebook package here
|
||||
try:
|
||||
from notebook.notebookapp import list_running_servers
|
||||
import requests
|
||||
current_kernel = sys.argv[2].split('/')[-1].replace('kernel-', '').replace('.json', '')
|
||||
server_info = next(list_running_servers())
|
||||
r = requests.get(
|
||||
url=server_info['url'] + 'api/sessions',
|
||||
headers={'Authorization': 'token {}'.format(server_info.get('token', '')), })
|
||||
r.raise_for_status()
|
||||
notebooks = r.json()
|
||||
|
||||
cur_notebook = None
|
||||
for n in notebooks:
|
||||
if n['kernel']['id'] == current_kernel:
|
||||
cur_notebook = n
|
||||
break
|
||||
|
||||
notebook_path = cur_notebook['notebook']['path']
|
||||
entry_point_filename = notebook_path.split('/')[-1]
|
||||
|
||||
# now we should try to find the actual file
|
||||
entry_point = (Path.cwd() / entry_point_filename).absolute()
|
||||
if not entry_point.is_file():
|
||||
entry_point = (Path.cwd() / notebook_path).absolute()
|
||||
|
||||
# now replace the .ipynb with .py
|
||||
# we assume we will have that file available with the Jupyter notebook plugin
|
||||
entry_point = entry_point.with_suffix('.py')
|
||||
|
||||
return entry_point.as_posix()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_entry_point(cls, repo_root, script_path):
|
||||
repo_root = Path(repo_root).absolute()
|
||||
|
||||
try:
|
||||
# Use os.path.relpath as it calculates up dir movements (../)
|
||||
entry_point = os.path.relpath(str(script_path), str(Path.cwd()))
|
||||
except ValueError:
|
||||
# Working directory not under repository root
|
||||
entry_point = script_path.relative_to(repo_root)
|
||||
|
||||
return Path(entry_point).as_posix()
|
||||
|
||||
@classmethod
|
||||
def _get_working_dir(cls, repo_root):
|
||||
repo_root = Path(repo_root).absolute()
|
||||
|
||||
try:
|
||||
return Path.cwd().relative_to(repo_root).as_posix()
|
||||
except ValueError:
|
||||
# Working directory not under repository root
|
||||
return os.path.curdir
|
||||
|
||||
@classmethod
|
||||
def _get_script_info(cls, filepath, check_uncommitted=False, log=None):
|
||||
jupyter_filepath = cls._get_jupyter_notebook_filename()
|
||||
if jupyter_filepath:
|
||||
script_path = Path(os.path.normpath(jupyter_filepath)).absolute()
|
||||
else:
|
||||
script_path = Path(os.path.normpath(filepath)).absolute()
|
||||
if not script_path.is_file():
|
||||
raise ScriptInfoError(
|
||||
"Script file [{}] could not be found".format(filepath)
|
||||
)
|
||||
|
||||
script_dir = script_path.parent
|
||||
|
||||
def _log(msg, *args, **kwargs):
|
||||
if not log:
|
||||
return
|
||||
log.warning(
|
||||
"Failed auto-detecting task repository: {}".format(
|
||||
msg.format(*args, **kwargs)
|
||||
)
|
||||
)
|
||||
|
||||
plugin = next((p for p in cls.plugins if p.exists(script_dir)), None)
|
||||
repo_info = DetectionResult()
|
||||
if not plugin:
|
||||
_log("expected one of: {}", ", ".join((p.name for p in cls.plugins)))
|
||||
else:
|
||||
try:
|
||||
repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted)
|
||||
except Exception as ex:
|
||||
_log("no info for {} ({})", script_dir, ex)
|
||||
else:
|
||||
if repo_info.is_empty():
|
||||
_log("no info for {}", script_dir)
|
||||
|
||||
repo_root = repo_info.root or script_dir
|
||||
working_dir = cls._get_working_dir(repo_root)
|
||||
entry_point = cls._get_entry_point(repo_root, script_path)
|
||||
|
||||
script_info = dict(
|
||||
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
|
||||
branch=repo_info.branch,
|
||||
version_num=repo_info.commit,
|
||||
entry_point=entry_point,
|
||||
working_dir=working_dir,
|
||||
diff=repo_info.diff,
|
||||
)
|
||||
|
||||
messages = []
|
||||
if repo_info.modified:
|
||||
messages.append(
|
||||
"======> WARNING! UNCOMMITTED CHANGES IN REPOSITORY {} <======".format(
|
||||
script_info.get("repository", "")
|
||||
)
|
||||
)
|
||||
|
||||
if not any(script_info.values()):
|
||||
script_info = None
|
||||
|
||||
return ScriptInfoResult(script=script_info, warning_messages=messages)
|
||||
|
||||
@classmethod
|
||||
def get(cls, filepath=sys.argv[0], check_uncommitted=False, log=None):
|
||||
try:
|
||||
return cls._get_script_info(
|
||||
filepath=filepath, check_uncommitted=check_uncommitted, log=log
|
||||
)
|
||||
except Exception as ex:
|
||||
if log:
|
||||
log.warning("Failed auto-detecting task repository: {}".format(ex))
|
||||
return ScriptInfoResult()
|
||||
|
||||
|
||||
@attr.s
|
||||
class ScriptInfoResult(object):
|
||||
script = attr.ib(default=None)
|
||||
warning_messages = attr.ib(factory=list)
|
||||
12
trains/backend_interface/task/repo/util.py
Normal file
12
trains/backend_interface/task/repo/util.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import os
|
||||
from subprocess import check_output
|
||||
|
||||
|
||||
def get_command_output(command, path=None):
|
||||
"""
|
||||
Run a command and return its output
|
||||
:raises CalledProcessError: when command execution fails
|
||||
:raises UnicodeDecodeError: when output decoding fails
|
||||
"""
|
||||
with open(os.devnull, "wb") as devnull:
|
||||
return check_output(command, cwd=path, stderr=devnull).decode().strip()
|
||||
811
trains/backend_interface/task/task.py
Normal file
811
trains/backend_interface/task/task.py
Normal file
@@ -0,0 +1,811 @@
|
||||
""" Backend task management support """
|
||||
import collections
|
||||
import itertools
|
||||
import logging
|
||||
from copy import copy
|
||||
from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
import six
|
||||
|
||||
from ...backend_interface.task.development.worker import DevWorker
|
||||
from ...backend_api import Session
|
||||
from ...backend_api.services import tasks, models, events, projects
|
||||
from pathlib2 import Path
|
||||
from pyhocon import ConfigTree, ConfigFactory
|
||||
|
||||
from ..base import IdObjectBase
|
||||
from ..metrics import Metrics, Reporter
|
||||
from ..model import Model
|
||||
from ..setupuploadmixin import SetupUploadMixin
|
||||
from ..util import make_message, get_or_create_project, get_single_result, \
|
||||
exact_match_regex
|
||||
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
|
||||
running_remotely, get_cache_dir, config_obj
|
||||
from ...debugging import get_logger
|
||||
from ...debugging.log import LoggerRoot
|
||||
from ...storage import StorageHelper
|
||||
from ...storage.helper import StorageError
|
||||
from .access import AccessMixin
|
||||
from .log import TaskHandler
|
||||
from .repo import ScriptInfo
|
||||
from ...config import config
|
||||
|
||||
TaskStatusEnum = tasks.TaskStatusEnum
|
||||
|
||||
|
||||
class TaskEntry(tasks.CreateRequest):
|
||||
pass
|
||||
|
||||
|
||||
class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
|
||||
""" Task manager providing task object access and management. Includes read/write access to task-associated
|
||||
frames and models.
|
||||
"""
|
||||
|
||||
_anonymous_dataview_id = '__anonymous__'
|
||||
|
||||
def __init__(self, session=None, task_id=None, log=None, project_name=None,
|
||||
task_name=None, task_type=tasks.TaskTypeEnum.training, log_to_backend=True,
|
||||
raise_on_validation_errors=True, force_create=False):
|
||||
"""
|
||||
Create a new task instance.
|
||||
:param session: Optional API Session instance. If not provided, a default session based on the system's
|
||||
configuration will be used.
|
||||
:type session: Session
|
||||
:param task_id: Optional task ID. If not provided, a new task will be created using the API
|
||||
and its information reflected in the resulting instance.
|
||||
:type task_id: string
|
||||
:param log: Optional log to be used. If not provided, and internal log shared with all backend objects will be
|
||||
used instead.
|
||||
:type log: logging.Logger
|
||||
:param project_name: Optional project name, used only if a new task is created. The new task will be associated
|
||||
with a project by this name. If no such project exists, a new project will be created using the API.
|
||||
:type project_name: str
|
||||
:param task_name: Optional task name, used only if a new task is created.
|
||||
:type project_name: str
|
||||
:param task_type: Optional task type, used only if a new task is created. Default is custom task.
|
||||
:type project_name: str (see tasks.TaskTypeEnum)
|
||||
:param log_to_backend: If True, all calls to the task's log will be logged to the backend using the API.
|
||||
This value can be overridden using the environment variable TRAINS_LOG_TASK_TO_BACKEND.
|
||||
:type log_to_backend: bool
|
||||
:param force_create: If True a new task will always be created (task_id, if provided, will be ignored)
|
||||
:type force_create: bool
|
||||
"""
|
||||
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
|
||||
super(Task, self).__init__(id=task_id, session=session, log=log)
|
||||
self._storage_uri = None
|
||||
self._input_model = None
|
||||
self._output_model = None
|
||||
self._metrics_manager = None
|
||||
self._reporter = None
|
||||
self._curr_label_stats = {}
|
||||
self._raise_on_validation_errors = raise_on_validation_errors
|
||||
self._parameters_allowed_types = (
|
||||
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
# generate a new task
|
||||
self.id = self._auto_generate(project_name=project_name, task_name=task_name, task_type=task_type)
|
||||
else:
|
||||
# this is an existing task, let's try to verify stuff
|
||||
self._validate()
|
||||
|
||||
if running_remotely() or DevWorker.report_stdout:
|
||||
log_to_backend = False
|
||||
self._log_to_backend = log_to_backend
|
||||
self._setup_log(default_log_to_backend=log_to_backend)
|
||||
|
||||
def _setup_log(self, default_log_to_backend=None, replace_existing=False):
|
||||
"""
|
||||
Setup logging facilities for this task.
|
||||
:param default_log_to_backend: Should this task log to the backend. If not specified, value for this option
|
||||
will be obtained from the environment, with this value acting as a default in case configuration for this is
|
||||
missing.
|
||||
If the value for this option is false, we won't touch the current logger configuration regarding TaskHandler(s)
|
||||
:param replace_existing: If True and another task is already logging to the backend, replace the handler with
|
||||
a handler for this task.
|
||||
"""
|
||||
# Make sure urllib is never in debug/info,
|
||||
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
|
||||
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
|
||||
logging.getLogger('urllib3').setLevel(logging.WARNING)
|
||||
|
||||
log_to_backend = get_log_to_backend(default=default_log_to_backend) or self._log_to_backend
|
||||
if not log_to_backend:
|
||||
return
|
||||
|
||||
# Handle the root logger and our own logger. We use set() to make sure we create no duplicates
|
||||
# in case these are the same logger...
|
||||
loggers = {logging.getLogger(), LoggerRoot.get_base_logger()}
|
||||
|
||||
# Find all TaskHandler handlers for these loggers
|
||||
handlers = {logger: h for logger in loggers for h in logger.handlers if isinstance(h, TaskHandler)}
|
||||
|
||||
if handlers and not replace_existing:
|
||||
# Handlers exist and we shouldn't replace them
|
||||
return
|
||||
|
||||
# Remove all handlers, we'll add new ones
|
||||
for logger, handler in handlers.items():
|
||||
logger.removeHandler(handler)
|
||||
|
||||
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
|
||||
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
|
||||
# handler instance handled them)
|
||||
backend_handler = TaskHandler(self.session, self.task_id)
|
||||
|
||||
# Add backend handler to both loggers:
|
||||
# 1. to root logger root logger
|
||||
# 2. to our own logger as well, since our logger is not propagated to the root logger
|
||||
# (if we propagate our logger will be caught be the root handlers as well, and
|
||||
# we do not want that)
|
||||
for logger in loggers:
|
||||
logger.addHandler(backend_handler)
|
||||
|
||||
def _validate(self, check_output_dest_credentials=True):
|
||||
raise_errors = self._raise_on_validation_errors
|
||||
output_dest = self.get_output_destination(raise_on_error=False, log_on_error=False)
|
||||
if output_dest and check_output_dest_credentials:
|
||||
try:
|
||||
self.log.info('Validating output destination')
|
||||
conf = get_config_for_bucket(base_url=output_dest)
|
||||
if not conf:
|
||||
msg = 'Failed resolving output destination (no credentials found for %s)' % output_dest
|
||||
self.log.warn(msg)
|
||||
if raise_errors:
|
||||
raise Exception(msg)
|
||||
else:
|
||||
StorageHelper._test_bucket_config(conf=conf, log=self.log, raise_on_error=raise_errors)
|
||||
except StorageError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
self.log.error('Failed trying to verify output destination: %s' % ex)
|
||||
|
||||
@classmethod
|
||||
def _resolve_task_id(cls, task_id, log=None):
|
||||
if not task_id:
|
||||
task_id = cls.normalize_id(get_remote_task_id())
|
||||
if task_id:
|
||||
log = log or get_logger('task')
|
||||
log.info('Using task ID from env %s=%s' % (TASK_ID_ENV_VAR[0], task_id))
|
||||
return task_id
|
||||
|
||||
def _update_repository(self):
|
||||
result = ScriptInfo.get(log=self.log)
|
||||
for msg in result.warning_messages:
|
||||
self.get_logger().console(msg)
|
||||
|
||||
self.data.script = result.script
|
||||
# Since we might run asynchronously, don't use self.data (lest someone else
|
||||
# overwrite it before we have a chance to call edit)
|
||||
self._edit(script=result.script)
|
||||
|
||||
def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training):
|
||||
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
|
||||
|
||||
project_id = None
|
||||
if project_name:
|
||||
project_id = get_or_create_project(self, project_name, created_msg)
|
||||
|
||||
tags = ['development'] if not running_remotely() else []
|
||||
|
||||
req = tasks.CreateRequest(
|
||||
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
|
||||
type=task_type,
|
||||
comment=created_msg,
|
||||
project=project_id,
|
||||
input={'view': {}},
|
||||
tags=tags,
|
||||
)
|
||||
res = self.send(req)
|
||||
|
||||
return res.response.id
|
||||
|
||||
def _set_storage_uri(self, value):
|
||||
value = value.rstrip('/')
|
||||
self._storage_uri = StorageHelper.conform_url(value)
|
||||
self.data.output.destination = self._storage_uri
|
||||
self._edit(output_dest=self._storage_uri)
|
||||
self.output_model.upload_storage_uri = self._storage_uri
|
||||
|
||||
@property
|
||||
def storage_uri(self):
|
||||
if self._storage_uri:
|
||||
return self._storage_uri
|
||||
if running_remotely():
|
||||
return self.data.output.destination
|
||||
else:
|
||||
return None
|
||||
|
||||
@storage_uri.setter
|
||||
def storage_uri(self, value):
|
||||
self._set_storage_uri(value)
|
||||
|
||||
@property
|
||||
def task_id(self):
|
||||
return self.id
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.data.name
|
||||
|
||||
@property
|
||||
def task_type(self):
|
||||
return self.data.type
|
||||
|
||||
@property
|
||||
def project(self):
|
||||
return self.data.project
|
||||
|
||||
@property
|
||||
def input_model_id(self):
|
||||
return self.data.execution.model
|
||||
|
||||
@property
|
||||
def output_model_id(self):
|
||||
return self.data.output.model
|
||||
|
||||
@property
|
||||
def comment(self):
|
||||
return self.data.comment
|
||||
|
||||
@property
|
||||
def cache_dir(self):
|
||||
""" Cache dir used to store task related files """
|
||||
return Path(get_cache_dir()) / self.id
|
||||
|
||||
@property
|
||||
def status(self):
|
||||
""" The task's status. In order to stay updated, we always reload the task info when this value is accessed. """
|
||||
self.reload()
|
||||
return self._status
|
||||
|
||||
@property
|
||||
def _status(self):
|
||||
""" Return the task's cached status (don't reload if we don't have to) """
|
||||
return self.data.status
|
||||
|
||||
@property
|
||||
def input_model(self):
|
||||
""" A model manager used to handle the input model object """
|
||||
model_id = self._get_task_property('execution.model', raise_on_error=False)
|
||||
if not model_id:
|
||||
return None
|
||||
if self._input_model is None:
|
||||
self._input_model = Model(
|
||||
session=self.session,
|
||||
model_id=model_id,
|
||||
cache_dir=self.cache_dir,
|
||||
log=self.log,
|
||||
upload_storage_uri=None)
|
||||
return self._input_model
|
||||
|
||||
@property
|
||||
def output_model(self):
|
||||
""" A model manager used to manage the output model object """
|
||||
if self._output_model is None:
|
||||
self._output_model = self._get_output_model(upload_required=True)
|
||||
return self._output_model
|
||||
|
||||
def create_output_model(self):
|
||||
return self._get_output_model(upload_required=False, force=True)
|
||||
|
||||
def _get_output_model(self, upload_required=True, force=False):
|
||||
return Model(
|
||||
session=self.session,
|
||||
model_id=None if force else self._get_task_property(
|
||||
'output.model', raise_on_error=False, log_on_error=False),
|
||||
cache_dir=self.cache_dir,
|
||||
upload_storage_uri=self.storage_uri or self.get_output_destination(
|
||||
raise_on_error=upload_required, log_on_error=upload_required),
|
||||
upload_storage_suffix=self._get_output_destination_suffix('models'),
|
||||
log=self.log)
|
||||
|
||||
@property
|
||||
def metrics_manager(self):
|
||||
""" A metrics manager used to manage the metrics related to this task """
|
||||
return self._get_metrics_manager(self.get_output_destination())
|
||||
|
||||
@property
|
||||
def reporter(self):
|
||||
"""
|
||||
Returns a simple metrics reporter instance
|
||||
"""
|
||||
if self._reporter is None:
|
||||
try:
|
||||
storage_uri = self.get_output_destination(log_on_error=False)
|
||||
except ValueError:
|
||||
storage_uri = None
|
||||
self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri))
|
||||
return self._reporter
|
||||
|
||||
def _get_metrics_manager(self, storage_uri):
|
||||
if self._metrics_manager is None:
|
||||
self._metrics_manager = Metrics(
|
||||
session=self.session,
|
||||
task_id=self.id,
|
||||
storage_uri=storage_uri,
|
||||
storage_uri_suffix=self._get_output_destination_suffix('metrics')
|
||||
)
|
||||
return self._metrics_manager
|
||||
|
||||
def _get_output_destination_suffix(self, extra_path=None):
|
||||
return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x)
|
||||
|
||||
def _reload(self):
|
||||
""" Reload the task object from the backend """
|
||||
res = self.send(tasks.GetByIdRequest(task=self.id))
|
||||
return res.response.task
|
||||
|
||||
def reset(self, set_started_on_success=True):
|
||||
""" Reset the task. Task will be reloaded following a successful reset. """
|
||||
self.send(tasks.ResetRequest(task=self.id))
|
||||
if set_started_on_success:
|
||||
self.started()
|
||||
self.reload()
|
||||
|
||||
def started(self, ignore_errors=True):
|
||||
""" Signal that this task has started """
|
||||
return self.send(tasks.StartedRequest(self.id), ignore_errors=ignore_errors)
|
||||
|
||||
def stopped(self, ignore_errors=True):
|
||||
""" Signal that this task has stopped """
|
||||
return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors)
|
||||
|
||||
def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None):
|
||||
""" Signal that this task has stopped """
|
||||
return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message),
|
||||
ignore_errors=ignore_errors)
|
||||
|
||||
def publish(self, ignore_errors=True):
|
||||
""" Signal that this task will be published """
|
||||
if self.status != tasks.TaskStatusEnum.stopped:
|
||||
raise ValueError("Can't publish, Task is not stopped")
|
||||
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
|
||||
assert isinstance(resp.response, tasks.PublishResponse)
|
||||
return resp
|
||||
|
||||
def update_model_desc(self, new_model_desc_file=None):
|
||||
""" Change the task's model_desc """
|
||||
execution = self._get_task_property('execution')
|
||||
p = Path(new_model_desc_file)
|
||||
if not p.is_file():
|
||||
raise IOError('mode_desc file %s cannot be found' % new_model_desc_file)
|
||||
new_model_desc = p.read_text()
|
||||
model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design'
|
||||
execution.model_desc[model_desc_key] = new_model_desc
|
||||
|
||||
res = self._edit(execution=execution)
|
||||
return res.response
|
||||
|
||||
def update_output_model(self, model_uri, name=None, comment=None, tags=None):
|
||||
"""
|
||||
Update the task's output model.
|
||||
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
|
||||
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
|
||||
providing a file://path/to/file URI)
|
||||
:param model_uri: URI for the updated model weights file
|
||||
:type model_uri: str
|
||||
:param name: Optional updated model name
|
||||
:type name: str
|
||||
:param comment: Optional updated model description
|
||||
:type comment: str
|
||||
:param tags: Optional updated model tags
|
||||
:type tags: [str]
|
||||
"""
|
||||
self._conditionally_start_task()
|
||||
self._get_output_model(upload_required=False).update_for_task(model_uri, self.id, name, comment, tags)
|
||||
|
||||
def update_output_model_and_upload(
|
||||
self, model_file, name=None, comment=None, tags=None, async_enable=False, cb=None, iteration=None):
|
||||
"""
|
||||
Update the task's output model weights file. File is first uploaded to the preconfigured output destination (see
|
||||
task's output.destination property or call setup_upload()), than the model object associated with the task is
|
||||
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
|
||||
:param model_file: Path to the updated model weights file
|
||||
:type model_file: str
|
||||
:param name: Optional updated model name
|
||||
:type name: str
|
||||
:param comment: Optional updated model description
|
||||
:type comment: str
|
||||
:param tags: Optional updated model tags
|
||||
:type tags: [str]
|
||||
:param async_enable: Request asynchronous upload. If False, the call blocks until upload is completed and the
|
||||
API call updating the model returns. If True, the call returns immediately, while upload and update are
|
||||
scheduled in another thread. Default is False.
|
||||
:type async_enable: bool
|
||||
:param cb: Asynchronous callback. If async=True, this callback will be invoked once the asynchronous upload and
|
||||
update have completed.
|
||||
:return: The URI of the uploaded weights file. If async=True, this is the expected URI as the upload is
|
||||
probably still in progress.
|
||||
"""
|
||||
self._conditionally_start_task()
|
||||
uri = self.output_model.update_for_task_and_upload(
|
||||
model_file, self.id, name=name, comment=comment, tags=tags, async_enable=async_enable, cb=cb,
|
||||
iteration=iteration
|
||||
)
|
||||
return uri
|
||||
|
||||
def _conditionally_start_task(self):
|
||||
if self.status == TaskStatusEnum.created:
|
||||
self.started()
|
||||
|
||||
@property
|
||||
def labels_stats(self):
|
||||
""" Get accumulated label stats for the current/last frames iteration """
|
||||
return self._curr_label_stats
|
||||
|
||||
def _accumulate_label_stats(self, roi_stats, reset=False):
|
||||
if reset:
|
||||
self._curr_label_stats = {}
|
||||
for label in roi_stats:
|
||||
if label in self._curr_label_stats:
|
||||
self._curr_label_stats[label] += roi_stats[label]
|
||||
else:
|
||||
self._curr_label_stats[label] = roi_stats[label]
|
||||
|
||||
def set_input_model(self, model_id=None, model_name=None, update_task_design=True, update_task_labels=True):
|
||||
"""
|
||||
Set a new input model for this task. Model must be 'ready' in order to be used as the Task's input model.
|
||||
:param model_id: ID for a model that exists in the backend. Required if model_name is not provided.
|
||||
:param model_name: Model name. Required if model_id is not provided. If provided, this name will be used to
|
||||
locate an existing model in the backend.
|
||||
:param update_task_design: if True, the task's model design will be copied from the input model
|
||||
:param update_task_labels: if True, the task's label enumeration will be copied from the input model
|
||||
"""
|
||||
if model_id is None and not model_name:
|
||||
raise ValueError('Expected one of [model_id, model_name]')
|
||||
|
||||
if model_name:
|
||||
# Try getting the model by name. Limit to 10 results.
|
||||
res = self.send(
|
||||
models.GetAllRequest(
|
||||
name=exact_match_regex(model_name),
|
||||
ready=True,
|
||||
page=0,
|
||||
page_size=10,
|
||||
order_by='-created',
|
||||
only_fields=['id']
|
||||
)
|
||||
)
|
||||
model = get_single_result(entity='model', query=model_name, results=res.response.models, log=self.log)
|
||||
model_id = model.id
|
||||
|
||||
if model_id:
|
||||
res = self.send(models.GetByIdRequest(model=model_id))
|
||||
model = res.response.model
|
||||
if not model.ready:
|
||||
# raise ValueError('Model %s is not published (not ready)' % model_id)
|
||||
self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
|
||||
else:
|
||||
# clear the input model
|
||||
model = None
|
||||
model_id = ''
|
||||
|
||||
# store model id
|
||||
self.data.execution.model = model_id
|
||||
|
||||
# Auto populate input field from model, if they are empty
|
||||
if update_task_design and not self.data.execution.model_desc:
|
||||
self.data.execution.model_desc = model.design if model else ''
|
||||
if update_task_labels and not self.data.execution.model_labels:
|
||||
self.data.execution.model_labels = model.labels if model else {}
|
||||
|
||||
self._edit(execution=self.data.execution)
|
||||
|
||||
def set_parameters(self, *args, **kwargs):
|
||||
"""
|
||||
Set parameters for this task. This allows setting a complete set of key/value parameters, but does not support
|
||||
parameter descriptions (as the input is a dictionary or key/value pairs.
|
||||
:param args: Positional arguments (one or more dictionary or (key, value) iterable). These will be merged into
|
||||
a single key/value dictionary.
|
||||
:param kwargs: Key/value pairs, merged into the parameters dictionary created from `args`.
|
||||
"""
|
||||
if not all(isinstance(x, (dict, collections.Iterable)) for x in args):
|
||||
raise ValueError('only dict or iterable are supported as positional arguments')
|
||||
|
||||
update = kwargs.pop('__update', False)
|
||||
|
||||
parameters = dict() if not update else self.get_parameters()
|
||||
parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
|
||||
parameters.update(kwargs)
|
||||
|
||||
not_allowed = {
|
||||
k: type(v).__name__
|
||||
for k, v in parameters.items()
|
||||
if not isinstance(v, self._parameters_allowed_types)
|
||||
}
|
||||
if not_allowed:
|
||||
raise ValueError(
|
||||
"Only builtin types ({}) are allowed for values (got {})".format(
|
||||
', '.join(t.__name__ for t in self._parameters_allowed_types),
|
||||
', '.join('%s=>%s' % p for p in not_allowed.items())),
|
||||
)
|
||||
|
||||
# force cast all variables to strings (so that we can later edit them in UI)
|
||||
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
|
||||
|
||||
execution = self.data.execution
|
||||
if execution is None:
|
||||
execution = tasks.Execution(parameters=parameters)
|
||||
else:
|
||||
execution.parameters = parameters
|
||||
self._edit(execution=execution)
|
||||
|
||||
def set_parameter(self, name, value, description=None):
|
||||
"""
|
||||
Set a single task parameter. This overrides any previous value for this parameter.
|
||||
:param name: Parameter name
|
||||
:param value: Parameter value
|
||||
:param description: Parameter description (unused for now)
|
||||
"""
|
||||
params = self.get_parameters()
|
||||
params[name] = value
|
||||
self.set_parameters(params)
|
||||
|
||||
def get_parameter(self, name, default=None):
|
||||
"""
|
||||
Get a value for a parameter.
|
||||
:param name: Parameter name
|
||||
:param default: Default value
|
||||
:return: Parameter value (or default value if parameter is not defined)
|
||||
"""
|
||||
params = self.get_parameters()
|
||||
return params.get(name, default)
|
||||
|
||||
def update_parameters(self, *args, **kwargs):
|
||||
"""
|
||||
Update parameters for this task.
|
||||
|
||||
This allows updating a complete set of key/value parameters,but does not support
|
||||
parameter descriptions (as the input is a dictionary or key/value pairs.
|
||||
|
||||
:param args: Positional arguments (one or more dictionary or (key, value) iterable). These will be merged into
|
||||
a single key/value dictionary.
|
||||
:param kwargs: Key/value pairs, merged into the parameters dictionary created from `args`.
|
||||
"""
|
||||
self.set_parameters(__update=True, *args, **kwargs)
|
||||
|
||||
def set_model_label_enumeration(self, enumeration=None):
|
||||
enumeration = enumeration or {}
|
||||
execution = self.data.execution
|
||||
if enumeration is None:
|
||||
return
|
||||
if not (isinstance(enumeration, dict)
|
||||
and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())):
|
||||
raise ValueError('Expected label to be a dict[str => int]')
|
||||
execution.model_labels = enumeration
|
||||
self._edit(execution=execution)
|
||||
|
||||
def _set_model_design(self, design=None):
|
||||
execution = self.data.execution
|
||||
if design is not None:
|
||||
execution.model_desc = Model._wrap_design(design)
|
||||
|
||||
self._edit(execution=execution)
|
||||
|
||||
def get_labels_enumeration(self):
|
||||
"""
|
||||
Return a dictionary of labels (text) to ids (integers) {str(label): integer(id)}
|
||||
:return:
|
||||
"""
|
||||
if not self.data or not self.data.execution:
|
||||
return {}
|
||||
return self.data.execution.model_labels
|
||||
|
||||
def get_model_design(self):
|
||||
"""
|
||||
Returns the model configuration as blob of text
|
||||
:return:
|
||||
"""
|
||||
design = self._get_task_property("execution.model_desc", default={}, raise_on_error=False, log_on_error=False)
|
||||
return Model._unwrap_design(design)
|
||||
|
||||
def set_output_model_id(self, model_id):
|
||||
self.data.output.model = str(model_id)
|
||||
self._edit(output=self.data.output)
|
||||
|
||||
def get_random_seed(self):
|
||||
# fixed seed for the time being
|
||||
return 1337
|
||||
|
||||
def set_random_seed(self, random_seed):
|
||||
# fixed seed for the time being
|
||||
pass
|
||||
|
||||
def set_project(self, project_id):
|
||||
assert isinstance(project_id, six.string_types)
|
||||
self._set_task_property("project", project_id)
|
||||
self._edit(project=project_id)
|
||||
|
||||
def get_project_name(self):
|
||||
if self.project is None:
|
||||
return None
|
||||
|
||||
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
|
||||
return res.response.project.name
|
||||
|
||||
def get_tags(self):
|
||||
return self._get_task_property("tags")
|
||||
|
||||
def set_tags(self, tags):
|
||||
assert isinstance(tags, (list, tuple))
|
||||
self._set_task_property("tags", tags)
|
||||
self._edit(tags=self.data.tags)
|
||||
|
||||
def _get_default_report_storage_uri(self):
|
||||
app_host = self._get_app_server()
|
||||
parsed = urlparse(app_host)
|
||||
if parsed.port:
|
||||
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081'))
|
||||
else:
|
||||
parsed = parsed._replace(netloc=parsed.netloc+':8081')
|
||||
return urlunparse(parsed)
|
||||
|
||||
def _get_app_server(self):
|
||||
host = config_obj.get('api.host')
|
||||
if '://demoapi.' in host:
|
||||
return host.replace('://demoapi.', '://demoapp.')
|
||||
if '://api.' in host:
|
||||
return host.replace('://api.', '://app.')
|
||||
|
||||
parsed = urlparse(host)
|
||||
if parsed.port == 8008:
|
||||
return host.replace(':8008', ':8080')
|
||||
|
||||
def _edit(self, **kwargs):
|
||||
# Since we ae using forced update, make sure he task status is valid
|
||||
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
|
||||
raise ValueError('Task object can only be updated if created or in_progress')
|
||||
|
||||
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def create_new_task(cls, session, task_entry, log=None):
|
||||
"""
|
||||
Create a new task
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:param task_entry: A task entry instance
|
||||
:type task_entry: TaskEntry
|
||||
:param log: Optional log
|
||||
:type log: logging.Logger
|
||||
:return: A new Task instance
|
||||
"""
|
||||
if isinstance(task_entry, dict):
|
||||
task_entry = TaskEntry(**task_entry)
|
||||
|
||||
assert isinstance(task_entry, TaskEntry)
|
||||
res = cls._send(session=session, req=task_entry, log=log)
|
||||
return cls(session, task_id=res.response.id)
|
||||
|
||||
@classmethod
|
||||
def clone_task(cls, cloned_task_id, name, comment=None, execution_overrides=None,
|
||||
tags=None, parent=None, project=None, log=None, session=None):
|
||||
"""
|
||||
Clone a task
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:param cloned_task_id: Task ID for the task to be cloned
|
||||
:type cloned_task_id: str
|
||||
:param name: New for the new task
|
||||
:type name: str
|
||||
:param comment: Optional comment for the new task
|
||||
:type comment: str
|
||||
:param execution_overrides: Task execution overrides. Applied over the cloned task's execution
|
||||
section, useful for overriding values in the cloned task.
|
||||
:type execution_overrides: dict
|
||||
:param tags: Optional updated model tags
|
||||
:type tags: [str]
|
||||
:param parent: Optional parent ID of the new task.
|
||||
:type parent: str
|
||||
:param project: Optional project ID of the new task.
|
||||
If None, the new task will inherit the cloned task's project.
|
||||
:type parent: str
|
||||
:param log: Log object used by the infrastructure.
|
||||
:type log: logging.Logger
|
||||
:return: The new tasks's ID
|
||||
"""
|
||||
|
||||
session = session if session else cls._get_default_session()
|
||||
|
||||
res = cls._send(session=session, log=log, req=tasks.GetByIdRequest(task=cloned_task_id))
|
||||
task = res.response.task
|
||||
output_dest = None
|
||||
if task.output:
|
||||
output_dest = task.output.destination
|
||||
execution = task.execution.to_dict() if task.execution else {}
|
||||
execution = ConfigTree.merge_configs(ConfigFactory.from_dict(execution),
|
||||
ConfigFactory.from_dict(execution_overrides or {}))
|
||||
req = tasks.CreateRequest(
|
||||
name=name,
|
||||
type=task.type,
|
||||
input=task.input,
|
||||
tags=tags if tags is not None else task.tags,
|
||||
comment=comment or task.comment,
|
||||
parent=parent,
|
||||
project=project if project else task.project,
|
||||
output_dest=output_dest,
|
||||
execution=execution.as_plain_ordered_dict(),
|
||||
script=task.script
|
||||
)
|
||||
res = cls._send(session=session, log=log, req=req)
|
||||
return res.response.id
|
||||
|
||||
@classmethod
|
||||
def enqueue_task(cls, task_id, session=None, queue_id=None, log=None):
|
||||
"""
|
||||
Enqueue a task for execution
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:param task_id: ID of the task to be enqueued
|
||||
:type task_id: str
|
||||
:param queue_id: ID of the queue in which to enqueue the task. If not provided, the default queue will be used.
|
||||
:type queue_id: str
|
||||
:param log: Log object
|
||||
:type log: logging.Logger
|
||||
:return: enqueue response
|
||||
"""
|
||||
assert isinstance(task_id, six.string_types)
|
||||
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
|
||||
res = cls._send(session=session, req=req, log=log)
|
||||
resp = res.response
|
||||
return resp
|
||||
|
||||
@classmethod
|
||||
def get_all(cls, session, log=None, **kwargs):
|
||||
"""
|
||||
Get all tasks
|
||||
:param session: Session object used for sending requests to the API
|
||||
:type session: Session
|
||||
:param log: Log object
|
||||
:type log: logging.Logger
|
||||
:param kwargs: Keyword args passed to the GetAllRequest (see .backend_api.services.tasks.GetAllRequest)
|
||||
:type kwargs: dict
|
||||
:return: API response
|
||||
"""
|
||||
req = tasks.GetAllRequest(**kwargs)
|
||||
res = cls._send(session=session, req=req, log=log)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def get_by_name(cls, task_name):
|
||||
res = cls._send(cls._get_default_session(), tasks.GetAllRequest(name=exact_match_regex(task_name)))
|
||||
|
||||
task = get_single_result(entity='task', query=task_name, results=res.response.tasks)
|
||||
return cls(task_id=task.id)
|
||||
|
||||
def _get_all_events(self, max_events=100):
|
||||
"""
|
||||
Get a list of all reported events.
|
||||
|
||||
Warning: Debug only. Do not use outside of testing.
|
||||
|
||||
:param max_events: The maximum events the function will return. Pass None
|
||||
to return all the reported events.
|
||||
:return: A list of events from the task.
|
||||
"""
|
||||
|
||||
log_events = self.send(events.GetTaskEventsRequest(
|
||||
task=self.id,
|
||||
order='asc',
|
||||
batch_size=max_events,
|
||||
))
|
||||
|
||||
events_list = log_events.response.events
|
||||
total_events = log_events.response.total
|
||||
scroll = log_events.response.scroll_id
|
||||
|
||||
while len(events_list) < total_events and (max_events is None or len(events_list) < max_events):
|
||||
log_events = self.send(events.GetTaskEventsRequest(
|
||||
task=self.id,
|
||||
order='asc',
|
||||
batch_size=max_events,
|
||||
scroll_id=scroll,
|
||||
))
|
||||
events_list.extend(log_events.response.events)
|
||||
scroll = log_events.response.scroll_id
|
||||
|
||||
return events_list
|
||||
77
trains/backend_interface/util.py
Normal file
77
trains/backend_interface/util.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import getpass
|
||||
import re
|
||||
from _socket import gethostname
|
||||
from datetime import datetime
|
||||
|
||||
from ..backend_api.services import projects
|
||||
from ..debugging.log import get_logger
|
||||
|
||||
|
||||
def make_message(s, **kwargs):
|
||||
args = dict(
|
||||
user=getpass.getuser(),
|
||||
host=gethostname(),
|
||||
time=datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
|
||||
)
|
||||
args.update(kwargs)
|
||||
return s % args
|
||||
|
||||
|
||||
def get_or_create_project(session, project_name, description=None):
|
||||
res = session.send(projects.GetAllRequest(name=exact_match_regex(project_name)))
|
||||
if res.response.projects:
|
||||
return res.response.projects[0].id
|
||||
res = session.send(projects.CreateRequest(name=project_name, description=description))
|
||||
return res.response.id
|
||||
|
||||
|
||||
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True):
|
||||
if not results:
|
||||
if not raise_on_error:
|
||||
return None
|
||||
|
||||
raise ValueError('No {entity}s found when searching for `{query}`'.format(**locals()))
|
||||
|
||||
if not log:
|
||||
log = get_logger()
|
||||
|
||||
if len(results) > 1:
|
||||
log.warn('More than one {entity} found when searching for `{query}`'
|
||||
' (showing first {show_results} {entity}s follow)'.format(**locals()))
|
||||
for obj in (o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
|
||||
log.warn('Found {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
|
||||
|
||||
if raise_on_error:
|
||||
raise ValueError('More than one {entity}s found when searching for ``{query}`'.format(**locals()))
|
||||
|
||||
return results[0]
|
||||
|
||||
|
||||
def at_least_one(_exception_cls=Exception, **kwargs):
|
||||
actual = [k for k, v in kwargs.items() if v]
|
||||
if len(actual) < 1:
|
||||
raise _exception_cls('At least one of (%s) is required' % ', '.join(kwargs.keys()))
|
||||
|
||||
|
||||
def mutually_exclusive(_exception_cls=Exception, _require_at_least_one=True, **kwargs):
|
||||
""" Helper for checking mutually exclusive options """
|
||||
actual = [k for k, v in kwargs.items() if v]
|
||||
if _require_at_least_one:
|
||||
at_least_one(_exception_cls=_exception_cls, **kwargs)
|
||||
if len(actual) > 1:
|
||||
raise _exception_cls('Only one of (%s) is allowed' % ', '.join(kwargs.keys()))
|
||||
|
||||
|
||||
def validate_dict(obj, key_types, value_types, desc=''):
|
||||
if not isinstance(obj, dict):
|
||||
raise ValueError('%sexpected a dictionary' % ('%s: ' % desc if desc else ''))
|
||||
if not all(isinstance(l, key_types) for l in obj.keys()):
|
||||
raise ValueError('%skeys must all be strings' % ('%s ' % desc if desc else ''))
|
||||
if not all(isinstance(l, value_types) for l in obj.values()):
|
||||
raise ValueError('%svalues must all be integers' % ('%s ' % desc if desc else ''))
|
||||
|
||||
|
||||
def exact_match_regex(name):
|
||||
""" Convert string to a regex representing an exact match """
|
||||
return '^%s$' % re.escape(name)
|
||||
|
||||
64
trains/config/__init__.py
Normal file
64
trains/config/__init__.py
Normal file
@@ -0,0 +1,64 @@
|
||||
""" Configuration module. Uses backend_config to load system configuration. """
|
||||
import logging
|
||||
from os.path import expandvars, expanduser
|
||||
|
||||
from ..backend_api import load_config
|
||||
from ..backend_config.bucket_config import S3BucketConfigurations
|
||||
|
||||
from .defs import *
|
||||
from .remote import running_remotely_task_id as _running_remotely_task_id
|
||||
|
||||
config_obj = load_config(Path(__file__).parent)
|
||||
config_obj.initialize_logging()
|
||||
config = config_obj.get("sdk")
|
||||
""" Configuration object reflecting the merged SDK section of all available configuration files """
|
||||
|
||||
|
||||
def get_cache_dir():
|
||||
cache_base_dir = Path(
|
||||
expandvars(
|
||||
expanduser(
|
||||
config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR
|
||||
)
|
||||
)
|
||||
)
|
||||
return cache_base_dir
|
||||
|
||||
|
||||
def get_config_for_bucket(base_url, extra_configurations=None):
|
||||
config_list = S3BucketConfigurations.from_config(config.get("aws.s3"))
|
||||
|
||||
for configuration in extra_configurations or []:
|
||||
config_list.add_config(configuration)
|
||||
|
||||
return config_list.get_config_by_uri(base_url)
|
||||
|
||||
|
||||
def get_remote_task_id():
|
||||
return None
|
||||
|
||||
|
||||
def running_remotely():
|
||||
return False
|
||||
|
||||
|
||||
def get_log_to_backend(default=None):
|
||||
return LOG_TO_BACKEND_ENV_VAR.get(default=default)
|
||||
|
||||
|
||||
def get_node_id(default=0):
|
||||
return NODE_ID_ENV_VAR.get(default=default)
|
||||
|
||||
|
||||
def get_log_redirect_level():
|
||||
""" Returns which log level (and up) should be redirected to stderr. None means no redirection. """
|
||||
value = LOG_STDERR_REDIRECT_LEVEL.get()
|
||||
try:
|
||||
if value:
|
||||
return logging._checkLevel(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
def dev_worker_name():
|
||||
return DEV_WORKER_NAME.get()
|
||||
40
trains/config/cache.py
Normal file
40
trains/config/cache.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
from . import get_cache_dir
|
||||
from .defs import SESSION_CACHE_FILE
|
||||
|
||||
|
||||
class SessionCache(object):
|
||||
"""
|
||||
Handle SDK session cache.
|
||||
TODO: Improve error handling to something like "except (FileNotFoundError, PermissionError, JSONDecodeError)"
|
||||
TODO: that's both six-compatible and tested
|
||||
"""
|
||||
@classmethod
|
||||
def _load_cache(cls):
|
||||
try:
|
||||
with (get_cache_dir() / SESSION_CACHE_FILE).open("rt") as fp:
|
||||
return json.load(fp)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def _store_cache(cls, cache):
|
||||
try:
|
||||
get_cache_dir().mkdir(parents=True, exist_ok=True)
|
||||
with (get_cache_dir() / SESSION_CACHE_FILE).open("wt") as fp:
|
||||
json.dump(cache, fp)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def store_dict(cls, unique_cache_name, dict_object):
|
||||
# type: (str, dict) -> None
|
||||
cache = cls._load_cache()
|
||||
cache[unique_cache_name] = dict_object
|
||||
cls._store_cache(cache)
|
||||
|
||||
@classmethod
|
||||
def load_dict(cls, unique_cache_name):
|
||||
# type: (str) -> dict
|
||||
cache = cls._load_cache()
|
||||
return cache.get(unique_cache_name, {}) if cache else {}
|
||||
1
trains/config/default/__init__.py
Normal file
1
trains/config/default/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
132
trains/config/default/__main__.py
Normal file
132
trains/config/default/__main__.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from pyhocon import ConfigFactory
|
||||
from pathlib2 import Path
|
||||
from six.moves.urllib.parse import urlparse, urlunparse
|
||||
|
||||
from trains.backend_api.session.defs import ENV_HOST
|
||||
from trains.backend_config.defs import LOCAL_CONFIG_FILES
|
||||
from trains.config import config_obj
|
||||
|
||||
|
||||
description = """
|
||||
Please create new key/secrete credentials using {}/admin
|
||||
|
||||
Copy/Paste credentials here: """
|
||||
|
||||
try:
|
||||
def_host = ENV_HOST.get(default=config_obj.get("api.host"))
|
||||
except Exception:
|
||||
def_host = 'http://localhost:8080'
|
||||
|
||||
host_description = """
|
||||
Editing configuration file: {CONFIG_FILE}
|
||||
Enter your trains-server host [{HOST}]: """.format(
|
||||
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
|
||||
HOST=def_host,
|
||||
)
|
||||
|
||||
|
||||
def main():
|
||||
print('TRAINS SDK setup process')
|
||||
conf_file = Path(LOCAL_CONFIG_FILES[0]).absolute()
|
||||
if conf_file.exists() and conf_file.is_file() and conf_file.stat().st_size > 0:
|
||||
print('Configuration file already exists: {}'.format(str(conf_file)))
|
||||
print('Leaving setup, feel free to edit the configuration file.')
|
||||
return
|
||||
|
||||
print(host_description, end='')
|
||||
parsed_host = None
|
||||
while not parsed_host:
|
||||
parse_input = input()
|
||||
if not parse_input:
|
||||
parse_input = def_host
|
||||
try:
|
||||
parsed_host = urlparse(parse_input)
|
||||
if parsed_host.scheme not in ('http', 'https'):
|
||||
parsed_host = None
|
||||
except Exception:
|
||||
parsed_host = None
|
||||
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
|
||||
|
||||
if parsed_host.port == 8080:
|
||||
# this is a docker 8080 is the web address, we need the api address, it is 8008
|
||||
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapp.'):
|
||||
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
|
||||
parsed_host.netloc))
|
||||
# this is our demo server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('app.'):
|
||||
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
|
||||
parsed_host.netloc))
|
||||
# this is our application server
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
elif parsed_host.port == 8008:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('demoapi.'):
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path
|
||||
elif parsed_host.netloc.startswith('api.'):
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path
|
||||
else:
|
||||
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
|
||||
|
||||
print('Host configured to: {}'.format(api_host))
|
||||
|
||||
print(description.format(web_host), end='')
|
||||
parse_input = input()
|
||||
# check if these are valid credentials
|
||||
credentials = None
|
||||
try:
|
||||
parsed = ConfigFactory.parse_string(parse_input)
|
||||
if parsed:
|
||||
credentials = parsed.get("credentials", None)
|
||||
except Exception:
|
||||
credentials = None
|
||||
|
||||
if not credentials or set(credentials) != {"access_key", "secret_key"}:
|
||||
print('Could not parse user credentials, try again one after the other.')
|
||||
credentials = {}
|
||||
# parse individual
|
||||
print('Enter user access key: ', end='')
|
||||
credentials['access_key'] = input()
|
||||
print('Enter user secret: ', end='')
|
||||
credentials['secret_key'] = input()
|
||||
|
||||
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
|
||||
credentials['secret_key'], ))
|
||||
|
||||
try:
|
||||
default_sdk_conf = Path(__file__).parent.absolute() / 'sdk.conf'
|
||||
with open(str(default_sdk_conf), 'rt') as f:
|
||||
default_sdk = f.read()
|
||||
except Exception:
|
||||
print('Error! Could not read default configuration file')
|
||||
return
|
||||
|
||||
try:
|
||||
with open(str(conf_file), 'wt') as f:
|
||||
header = '# TRAINS SDK configuration file\n' \
|
||||
'api {\n' \
|
||||
' host: %s\n' \
|
||||
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
|
||||
'}\n' \
|
||||
'sdk ' % (api_host, credentials['access_key'], credentials['secret_key'])
|
||||
f.write(header)
|
||||
f.write(default_sdk)
|
||||
except Exception:
|
||||
print('Error! Could not write configuration file at: {}'.format(str(conf_file)))
|
||||
return
|
||||
|
||||
print('\nNew configuration stored in {}'.format(str(conf_file)))
|
||||
print('TRAINS setup completed successfully.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
27
trains/config/default/logging.conf
Normal file
27
trains/config/default/logging.conf
Normal file
@@ -0,0 +1,27 @@
|
||||
{
|
||||
version: 1
|
||||
disable_existing_loggers: 0
|
||||
loggers {
|
||||
trains {
|
||||
level: INFO
|
||||
}
|
||||
boto {
|
||||
level: WARNING
|
||||
}
|
||||
"boto.perf" {
|
||||
level: WARNING
|
||||
}
|
||||
botocore {
|
||||
level: WARNING
|
||||
}
|
||||
boto3 {
|
||||
level: WARNING
|
||||
}
|
||||
google {
|
||||
level: WARNING
|
||||
}
|
||||
urllib3 {
|
||||
level: WARNING
|
||||
}
|
||||
}
|
||||
}
|
||||
126
trains/config/default/sdk.conf
Normal file
126
trains/config/default/sdk.conf
Normal file
@@ -0,0 +1,126 @@
|
||||
{
|
||||
# TRAINS - default SDK configuration
|
||||
|
||||
storage {
|
||||
cache {
|
||||
# Defaults to system temp folder / cache
|
||||
default_base_dir: "~/.trains/cache"
|
||||
}
|
||||
}
|
||||
|
||||
metrics {
|
||||
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
|
||||
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
|
||||
# X files are stored in the upload destination for each metric/variant combination.
|
||||
file_history_size: 100
|
||||
|
||||
# Settings for generated debug images
|
||||
images {
|
||||
format: JPEG
|
||||
quality: 87
|
||||
subsampling: 0
|
||||
}
|
||||
}
|
||||
|
||||
network {
|
||||
metrics {
|
||||
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
|
||||
# a specific iteration
|
||||
file_upload_threads: 4
|
||||
|
||||
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
|
||||
# being sent for upload
|
||||
file_upload_starvation_warning_sec: 120
|
||||
}
|
||||
|
||||
iteration {
|
||||
# Max number of retries when getting frames if the server returned an error (http code 500)
|
||||
max_retries_on_server_error: 5
|
||||
# Backoff factory for consecutive retry attempts.
|
||||
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
|
||||
retry_backoff_factor_sec: 10
|
||||
}
|
||||
}
|
||||
aws {
|
||||
s3 {
|
||||
# S3 credentials, used for read/write access by various SDK elements
|
||||
|
||||
# default, used for any bucket not specified below
|
||||
key: ""
|
||||
secret: ""
|
||||
region: ""
|
||||
|
||||
credentials: [
|
||||
# specifies key/secret credentials to use when handling s3 urls (read or write)
|
||||
# {
|
||||
# bucket: "my-bucket-name"
|
||||
# key: "my-access-key"
|
||||
# secret: "my-secret-key"
|
||||
# },
|
||||
# {
|
||||
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
|
||||
# host: "my-minio-host:9000"
|
||||
# key: "12345678"
|
||||
# secret: "12345678"
|
||||
# multipart: false
|
||||
# secure: false
|
||||
# }
|
||||
]
|
||||
}
|
||||
boto3 {
|
||||
pool_connections: 512
|
||||
max_multipart_concurrency: 16
|
||||
}
|
||||
}
|
||||
google.storage {
|
||||
# # Default project and credentials file
|
||||
# # Will be used when no bucket configuration is found
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
|
||||
# # Specific credentials per bucket and sub directory
|
||||
# credentials = [
|
||||
# {
|
||||
# bucket: "my-bucket"
|
||||
# subdir: "path/in/bucket" # Not required
|
||||
# project: "trains"
|
||||
# credentials_json: "/path/to/credentials.json"
|
||||
# },
|
||||
# ]
|
||||
}
|
||||
|
||||
log {
|
||||
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
|
||||
null_log_propagate: False
|
||||
task_log_buffer_capacity: 66
|
||||
|
||||
# disable urllib info and lower levels
|
||||
disable_urllib3_info: True
|
||||
}
|
||||
|
||||
development {
|
||||
# Development-mode options
|
||||
|
||||
# dev task reuse window
|
||||
task_reuse_time_window_in_hours: 72.0
|
||||
|
||||
# Run VCS repository detection asynchronously
|
||||
vcs_repo_detect_async: False
|
||||
|
||||
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
|
||||
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
|
||||
store_uncommitted_code_diff_on_train: True
|
||||
|
||||
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
|
||||
support_stopping: True
|
||||
|
||||
# Development mode worker
|
||||
worker {
|
||||
# Status report period in seconds
|
||||
report_period_sec: 2
|
||||
|
||||
# Log all stdout & stderr
|
||||
log_stdout: True
|
||||
}
|
||||
}
|
||||
}
|
||||
31
trains/config/defs.py
Normal file
31
trains/config/defs.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import tempfile
|
||||
|
||||
from ..backend_config import EnvEntry
|
||||
from ..backend_config.converters import base64_to_text, or_
|
||||
from pathlib2 import Path
|
||||
|
||||
SESSION_CACHE_FILE = ".session.json"
|
||||
DEFAULT_CACHE_DIR = str(Path(tempfile.gettempdir()) / "trains_cache")
|
||||
|
||||
|
||||
TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID")
|
||||
LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool)
|
||||
NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int)
|
||||
PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=int)
|
||||
LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LOG_STDERR_REDIRECT_LEVEL")
|
||||
DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME")
|
||||
|
||||
LOG_LEVEL_ENV_VAR = EnvEntry("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL", converter=or_(int, str))
|
||||
|
||||
# Repository detection
|
||||
VCS_REPO_TYPE = EnvEntry("TRAINS_VCS_REPO_TYPE", "ALG_VCS_REPO_TYPE", default="git")
|
||||
VCS_REPOSITORY_URL = EnvEntry("TRAINS_VCS_REPO_URL", "ALG_VCS_REPO_URL")
|
||||
VCS_COMMIT_ID = EnvEntry("TRAINS_VCS_COMMIT_ID", "ALG_VCS_COMMIT_ID")
|
||||
VCS_BRANCH = EnvEntry("TRAINS_VCS_BRANCH", "ALG_VCS_BRANCH")
|
||||
VCS_ROOT = EnvEntry("TRAINS_VCS_ROOT", "ALG_VCS_ROOT")
|
||||
VCS_STATUS = EnvEntry("TRAINS_VCS_STATUS", "ALG_VCS_STATUS", converter=base64_to_text)
|
||||
VCS_DIFF = EnvEntry("TRAINS_VCS_DIFF", "ALG_VCS_DIFF", converter=base64_to_text)
|
||||
|
||||
# User credentials
|
||||
API_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY", help="API Access Key")
|
||||
API_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY", help="API Secret Key")
|
||||
17
trains/config/remote.py
Normal file
17
trains/config/remote.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from .defs import TASK_ID_ENV_VAR
|
||||
|
||||
running_remotely_task_id = TASK_ID_ENV_VAR.get()
|
||||
|
||||
|
||||
def override_current_task_id(task_id):
|
||||
"""
|
||||
Overrides the current task id to simulate remote running with a specific task.
|
||||
|
||||
Use for testing and debug only.
|
||||
|
||||
:param task_id: The task's id to use as the remote task.
|
||||
Pass None to simulate local execution.
|
||||
"""
|
||||
|
||||
global running_remotely_task_id
|
||||
running_remotely_task_id = task_id
|
||||
4
trains/debugging/__init__.py
Normal file
4
trains/debugging/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
""" Debugging module """
|
||||
from .timer import Timer
|
||||
from .log import get_logger, get_null_logger, TqdmLog, add_options as add_log_options, \
|
||||
apply_args as parse_log_args, add_rotating_file_handler, add_time_rotating_file_handler
|
||||
181
trains/debugging/log.py
Normal file
181
trains/debugging/log.py
Normal file
@@ -0,0 +1,181 @@
|
||||
""" Logging convenience functions and wrappers """
|
||||
import inspect
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
import sys
|
||||
from platform import system
|
||||
|
||||
import colorama
|
||||
from ..config import config, get_log_redirect_level
|
||||
from coloredlogs import ColoredFormatter
|
||||
from pathlib2 import Path
|
||||
from six import BytesIO
|
||||
from tqdm import tqdm
|
||||
|
||||
default_level = logging.INFO
|
||||
|
||||
|
||||
class _LevelRangeFilter(logging.Filter):
|
||||
|
||||
def __init__(self, min_level, max_level, name=''):
|
||||
super(_LevelRangeFilter, self).__init__(name)
|
||||
self.min_level = min_level
|
||||
self.max_level = max_level
|
||||
|
||||
def filter(self, record):
|
||||
return self.min_level <= record.levelno <= self.max_level
|
||||
|
||||
|
||||
class LoggerRoot(object):
|
||||
__base_logger = None
|
||||
|
||||
@classmethod
|
||||
def _make_stream_handler(cls, level=None, stream=sys.stdout, colored=False):
|
||||
ch = logging.StreamHandler(stream=stream)
|
||||
ch.setLevel(level)
|
||||
if colored:
|
||||
colorama.init()
|
||||
formatter = ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
else:
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
ch.setFormatter(formatter)
|
||||
return ch
|
||||
|
||||
@classmethod
|
||||
def get_base_logger(cls, level=None, stream=sys.stdout, colored=False):
|
||||
if LoggerRoot.__base_logger:
|
||||
return LoggerRoot.__base_logger
|
||||
LoggerRoot.__base_logger = logging.getLogger('trains')
|
||||
level = level if level is not None else default_level
|
||||
LoggerRoot.__base_logger.setLevel(level)
|
||||
|
||||
redirect_level = get_log_redirect_level()
|
||||
|
||||
# Do not redirect to stderr if the target stream is already stderr
|
||||
if redirect_level is not None and stream not in (None, sys.stderr):
|
||||
# Adjust redirect level in case requested level is higher (e.g. logger is requested for CRITICAL
|
||||
# and redirect is set for ERROR, in which case we redirect from CRITICAL)
|
||||
redirect_level = max(level, redirect_level)
|
||||
LoggerRoot.__base_logger.addHandler(
|
||||
cls._make_stream_handler(redirect_level, sys.stderr, colored)
|
||||
)
|
||||
|
||||
if level < redirect_level:
|
||||
# Not all levels were redirected, remaining should be sent to requested stream
|
||||
handler = cls._make_stream_handler(level, stream, colored)
|
||||
handler.addFilter(_LevelRangeFilter(min_level=level, max_level=redirect_level - 1))
|
||||
LoggerRoot.__base_logger.addHandler(handler)
|
||||
else:
|
||||
LoggerRoot.__base_logger.addHandler(
|
||||
cls._make_stream_handler(level, stream, colored)
|
||||
)
|
||||
|
||||
LoggerRoot.__base_logger.propagate = False
|
||||
return LoggerRoot.__base_logger
|
||||
|
||||
@classmethod
|
||||
def flush(cls):
|
||||
if LoggerRoot.__base_logger:
|
||||
for h in LoggerRoot.__base_logger.handlers:
|
||||
h.flush()
|
||||
|
||||
|
||||
def add_options(parser):
|
||||
""" Add logging options to an argparse.ArgumentParser object """
|
||||
level = logging.getLevelName(default_level)
|
||||
parser.add_argument(
|
||||
'--log-level', '-l', default=level, help='Log level (default is %s)' % level)
|
||||
|
||||
|
||||
def apply_args(args):
|
||||
""" Apply logging args from an argparse.ArgumentParser parsed args """
|
||||
global default_level
|
||||
default_level = logging.getLevelName(args.log_level.upper())
|
||||
|
||||
|
||||
def get_logger(path=None, level=None, stream=None, colored=False):
|
||||
""" Get a python logging object named using the provided filename and preconfigured with a color-formatted
|
||||
stream handler
|
||||
"""
|
||||
path = path or os.path.abspath((inspect.stack()[1])[1])
|
||||
root_log = LoggerRoot.get_base_logger(level=default_level, stream=sys.stdout, colored=colored)
|
||||
log = root_log.getChild(Path(path).stem)
|
||||
level = level if level is not None else root_log.level
|
||||
log.setLevel(level)
|
||||
if stream:
|
||||
ch = logging.StreamHandler(stream=stream)
|
||||
ch.setLevel(level)
|
||||
log.propagate = True
|
||||
return log
|
||||
|
||||
|
||||
def _add_file_handler(logger, log_dir, fh, formatter=None):
|
||||
""" Adds a file handler to a logger """
|
||||
Path(log_dir).mkdir(parents=True, exist_ok=True)
|
||||
if not formatter:
|
||||
log_format = '%(asctime)s %(name)s x_x[%(levelname)s] %(message)s'
|
||||
formatter = logging.Formatter(log_format)
|
||||
fh.setFormatter(formatter)
|
||||
logger.addHandler(fh)
|
||||
|
||||
|
||||
def add_rotating_file_handler(logger, log_dir, log_file_prefix, max_bytes=10 * 1024 * 1024, backup_count=20,
|
||||
formatter=None):
|
||||
""" Create and add a rotating file handler to a logger """
|
||||
fh = logging.handlers.RotatingFileHandler(
|
||||
str(Path(log_dir) / ('%s.log' % log_file_prefix)), maxBytes=max_bytes, backupCount=backup_count)
|
||||
_add_file_handler(logger, log_dir, fh, formatter)
|
||||
|
||||
|
||||
def add_time_rotating_file_handler(logger, log_dir, log_file_prefix, when='midnight', formatter=None):
|
||||
"""
|
||||
Create and add a time rotating file handler to a logger.
|
||||
Possible values for when are 'midnight', weekdays ('w0'-'W6', when 0 is Monday), and 's', 'm', 'h' amd 'd' for
|
||||
seconds, minutes, hours and days respectively (case-insensitive)
|
||||
"""
|
||||
fh = logging.handlers.TimedRotatingFileHandler(
|
||||
str(Path(log_dir) / ('%s.log' % log_file_prefix)), when=when)
|
||||
_add_file_handler(logger, log_dir, fh, formatter)
|
||||
|
||||
|
||||
def get_null_logger(name=None):
|
||||
""" Get a logger with a null handler """
|
||||
log = logging.getLogger(name if name else 'null')
|
||||
if not log.handlers:
|
||||
log.addHandler(logging.NullHandler())
|
||||
log.propagate = config.get("log.null_log_propagate", False)
|
||||
return log
|
||||
|
||||
|
||||
class TqdmLog(object):
|
||||
""" Tqdm (progressbar) wrapped logging class """
|
||||
|
||||
class _TqdmIO(BytesIO):
|
||||
""" IO wrapper class for Tqdm """
|
||||
|
||||
def __init__(self, level=20, logger=None, *args, **kwargs):
|
||||
self._log = logger or get_null_logger()
|
||||
self._level = level
|
||||
BytesIO.__init__(self, *args, **kwargs)
|
||||
|
||||
def write(self, buf):
|
||||
self._buf = buf.strip('\r\n\t ')
|
||||
|
||||
def flush(self):
|
||||
self._log.log(self._level, self._buf)
|
||||
|
||||
def __init__(self, total, desc='', log_level=20, ascii=False, logger=None, smoothing=0, mininterval=5, initial=0):
|
||||
self._io = self._TqdmIO(level=log_level, logger=logger)
|
||||
self._tqdm = tqdm(total=total, desc=desc, file=self._io, ascii=ascii if not system() == 'Windows' else True,
|
||||
smoothing=smoothing,
|
||||
mininterval=mininterval, initial=initial)
|
||||
|
||||
def update(self, n=None):
|
||||
if n is not None:
|
||||
self._tqdm.update(n=n)
|
||||
else:
|
||||
self._tqdm.update()
|
||||
|
||||
def close(self):
|
||||
self._tqdm.close()
|
||||
112
trains/debugging/timer.py
Normal file
112
trains/debugging/timer.py
Normal file
@@ -0,0 +1,112 @@
|
||||
""" Timing support """
|
||||
import sys
|
||||
import time
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class Timer(object):
|
||||
"""A class implementing a simple timer, with a reset option """
|
||||
|
||||
def __init__(self):
|
||||
self._start_time = 0.
|
||||
self._diff = 0.
|
||||
self._total_time = 0.
|
||||
self._average_time = 0.
|
||||
self._calls = 0
|
||||
self.tic()
|
||||
|
||||
def reset(self):
|
||||
self._start_time = 0.
|
||||
self._diff = 0.
|
||||
self.reset_average()
|
||||
|
||||
def reset_average(self):
|
||||
""" Reset average counters (does not change current timer) """
|
||||
self._total_time = 0
|
||||
self._average_time = 0
|
||||
self._calls = 0
|
||||
|
||||
def tic(self):
|
||||
try:
|
||||
# using time.time instead of time.clock because time time.clock
|
||||
# does not normalize for multi threading
|
||||
self._start_time = time.time()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def toc(self, average=True):
|
||||
self._diff = time.time() - self._start_time
|
||||
self._total_time += self._diff
|
||||
self._calls += 1
|
||||
self._average_time = self._total_time / self._calls
|
||||
if average:
|
||||
return self._average_time
|
||||
else:
|
||||
return self._diff
|
||||
|
||||
@property
|
||||
def average_time(self):
|
||||
return self._average_time
|
||||
|
||||
@property
|
||||
def total_time(self):
|
||||
return self._total_time
|
||||
|
||||
def toc_with_reset(self, average=True, reset_if_calls=1000):
|
||||
""" Enable toc with reset (slightly inaccurate if reset event occurs) """
|
||||
if self._calls > reset_if_calls:
|
||||
last_diff = time.time() - self._start_time
|
||||
self._start_time = time.time()
|
||||
self._total_time = last_diff
|
||||
self._average_time = 0
|
||||
self._calls = 0
|
||||
|
||||
return self.toc(average=average)
|
||||
|
||||
|
||||
class TimersMixin(object):
|
||||
def __init__(self):
|
||||
self._timers = {}
|
||||
|
||||
def add_timers(self, *names):
|
||||
for name in names:
|
||||
self.add_timer(name)
|
||||
|
||||
def add_timer(self, name, timer=None):
|
||||
if name in self._timers:
|
||||
raise ValueError('timer %s already exists' % name)
|
||||
timer = timer or Timer()
|
||||
self._timers[name] = timer
|
||||
return timer
|
||||
|
||||
def get_timer(self, name, default=None):
|
||||
return self._timers.get(name, default)
|
||||
|
||||
def get_timers(self):
|
||||
return self._timers
|
||||
|
||||
def _call_timer(self, name, callable, silent_fail=False):
|
||||
try:
|
||||
return callable(self._timers[name])
|
||||
except KeyError:
|
||||
if not silent_fail:
|
||||
six.reraise(*sys.exc_info())
|
||||
|
||||
def reset_timers(self, *names):
|
||||
for name in names:
|
||||
self._call_timer(name, lambda t: t.reset())
|
||||
|
||||
def reset_average_timers(self, *names):
|
||||
for name in names:
|
||||
self._call_timer(name, lambda t: t.reset_average())
|
||||
|
||||
def tic_timers(self, *names):
|
||||
for name in names:
|
||||
self._call_timer(name, lambda t: t.tic())
|
||||
|
||||
def toc_timers(self, *names):
|
||||
return [self._call_timer(name, lambda t: t.toc()) for name in names]
|
||||
|
||||
def toc_with_reset_timer(self, name, average=True, reset_if_calls=1000):
|
||||
return self._call_timer(name, lambda t: t.toc_with_reset(average, reset_if_calls))
|
||||
3
trains/errors.py
Normal file
3
trains/errors.py
Normal file
@@ -0,0 +1,3 @@
|
||||
class UsageError(RuntimeError):
|
||||
""" An exception raised for illegal usage of trains objects"""
|
||||
pass
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user