Initial beta version

This commit is contained in:
allegroai
2019-06-10 20:00:28 +03:00
parent 3cb9de58c3
commit f595afe6c8
121 changed files with 34975 additions and 0 deletions

218
README.md Normal file
View File

@@ -0,0 +1,218 @@
# TRAINS - Magic Version Control & Experiment Manager for AI
<p style="font-size:1.2rem; font-weight:700;">"Because its a jungle out there"</p>
Behind every great scientist are great repeatable methods. Sadly, this is easier said than done.
When talented scientists, engineers, or developers work on their own, a mess may be unavoidable. Yet, it may still be
manageable. However, with time and more people joining your project,
managing the clutter takes its toll on productivity.
As your project moves toward production,
visibility and provenance for scaling your deep-learning efforts are a must, but both
suffer as your team grows.
For teams or entire companies, TRAINS logs everything in one central server and takes on the responsibilities for visibility and provenance
so productivity does not suffer.
TRAINS records and manages various deep learning research workloads and does so with unbelievably small integration costs.
TRAINS is an auto-magical experiment manager that you can use productively with minimal integration and while
preserving your existing methods and practices. Use it on a daily basis to boost collaboration and visibility,
or use it to automatically collect your experimentation logs, outputs, and data to one centralized server for provenance.
![nice image here? Maybe collection of all the projects or main screen, this can also be an inspirational image insinuating clutter](img/woman-3441018_1920.jpg)
## Why Should I Use TRAINS?
TRAINS is our solution to a problem we share with countless other researchers and developers in the
machine learning/deep learning universe.
Training production-grade deep learning models is a glorious but messy process.
We built TRAINS to solve that problem. TRAINS tracks and controls the process by associating code version control, research projects, performance metrics, and model provenance.
TRAINS removes the mess but leaves the glory.
Choose TRAINS because...
* Sharing experiments with the team is difficult and gets even more difficult further up the chain.
* Like all of us, you lost a model and are left with no repeatable process.
* You setup up a central location for TensorBoard and it exploded with a gazillion experiments.
* You accidentally threw away important results while trying to manually clean up the clutter.
* You do not associate the train code commit with the model or TensorBoard logs.
* You are storing model parameters in the checkpoint filename.
* You cannot find any other tool for comparing results, hyper-parameters and code commits.
* TRAINS requires **only two-lines of code** for full integration.
* TRAINS is **free**.
## Main Features
* Seamless integration with leading frameworks, including: PyTorch, TensorFlow, Keras, and others coming soon!
* Track everything with two lines of code.
* Model logging that automatically associates models with code and the parameters used to train them, including initial weights logging.
* Multi-user process tracking and collaboration.
* Management capabilities including project management, filter-by-metric, and detailed experiment comparison.
* Centralized server for aggregating logs, records, and general bookkeeping.
* Automatically create a copy of models on centralized storage (TRAINS supports shared folders, S3, GS, and Azure is coming soon!).
* Support for Jupyter notebook (see the [trains-jupyter-plugin]()) and PyCharm remote debugging (see the [trains-pycharm-plugin]()).
* A field-tested, feature-rich SDK for your on-the-fly customization needs.
## TRAINS Magically Logs
TRAINS magically logs the following:
* Git repository, branch and commit id
* Hyper-parameters, including:
* ArgParser for command line parameters with currently used values
* Tensorflow Defines (absl-py)
* Manually passed parameter dictionary
* Initial model weights file
* Model snapshots
* stdout and stderr
* TensorBoard scalars, metrics, histograms, images, and audio coming soon (also tensorboardX)
* Matplotlib
## See for Yourself
We have a demo server up and running [https://demoapp.trainsai.io](https://demoapp.trainsai.io) (it resets every 24 hours and all of the data is deleted).
You can test your code with it:
1. Install TRAINS
pip install trains
1. Add the following to your code:
from trains import Task
Task = Task.init(project_name=”my_projcet”, task_name=”my_task”)
1. Run your code. When TRAINS connects to the server, a link prints.
1. In the Web-App, view your parameters, model and tensorboard metrics.
![GIF screen-shot here. If the Gif looks bad, a few png screen grabs:
Home Page
Projects Page
Experiment Page with experiment open tab execution
Experiment Page with experiment open tab model
Experiment Page with experiment open tab results
Results Page
Comparison Page
Parameters
Graphs
Images
Experiment Models Page]
## How TRAINS Works
TRAINS is composed of the following:
* the [trains-server]()
* the [Web-App]() (web user interface)
* the Python SDK (auto-magically connects your code, see [Using TRAINS (Example)](#using-trains-example)).
The following diagram illustrates the interaction of the TRAINS-server and a GPU machine:
<pre>
TRAINS-server
+--------------------------------------------------------------------+
| |
| Server Docker Elastic Docker Mongo Docker |
| +-------------------------+ +---------------+ +------------+ |
| | Pythonic Server | | | | | |
| | +-----------------+ | | ElasticSearch | | MongoDB | |
| | | WEB server | | | | | | |
| | | Port 8080 | | | | | | |
| | +--------+--------+ | | | | | |
| | | | | | | | |
| | +--------+--------+ | | | | | |
| | | API server +----------------------------+ | |
| | | Port 8008 +---------+ | | | |
| | +-----------------+ | +-------+-------+ +-----+------+ |
| | | | | |
| | +-----------------+ | +---+----------------+------+ |
| | | File Server +-------+ | Host Storage | |
| | | Port 8081 | | +-----+ | |
| | +-----------------+ | +---------------------------+ |
| +------------+------------+ |
+---------------|----------------------------------------------------+
|HTTP
+--------+
GPU Machine |
+------------------------|-------------------------------------------+
| +------------------|--------------+ |
| | Training | | +---------------------+ |
| | Code +---+------------+ | | TRAINS configuration| |
| | | TRAINS - SDK | | | ~/trains.conf | |
| | | +------+ | |
| | +----------------+ | +---------------------+ |
| +---------------------------------+ |
+--------------------------------------------------------------------+
</pre>
## Installing and Configuring TRAINS
1. Install the trains-server docker (see [Installing the TRAINS Server](../trains_server)).
1. Install the TRAINS package:
pip install trains
1. Run the initial configuration wizard to setup the trains-server (ip:port and user credentials):
trains-init
After installing and configuring, your configuration is `~/trains.conf`. View a sample configuration file [here]([link to git]).
## Using TRAINS (Example)
Add these two lines of code to your script:
from trains import Task
task = Task.init(project_name, task_name)
* If no project name is provided, then the repository name is used.
* If no task (experiment) name is provided, then the main filename is used as experiment name
Executing your script prints a direct link to the currently running experiment page, for exampe:
```bash
TRAINS Metrics page:
https://demoapp.trainsai.io/projects/76e5e2d45e914f52880621fe64601e85/experiments/241f06ae0f5c4b27b8ce8b64890ce152/output/log
```
*[Add GIF screenshots here]*
For more examples and use cases, see [examples](link docs/examples/).
## Who Supports TRAINS?
The people behind *allegro.ai*.
We build deep learning pipelines and infrastructure for enterprise companies.
We built TRAINS to track and control the glorious
but messy process of training production-grade deep learning models.
We are committed to vigorously supporting and expanding the capabilities of TRAINS,
because it is not only our beloved creation, we also use it daily.
## Why Are We Releasing TRAINS?
We believe TRAINS is ground-breaking. We wish to establish new standards of experiment management in
machine- and deep-learning.
Only the greater community can help us do that.
We promise to always be backwardly compatible. If you start working with TRAINS today, even though this code is still in the beta stage, your logs and data will always upgrade with you.
## License
Apache License, Version 2.0 (see the [LICENSE](https://www.apache.org/licenses/LICENSE-2.0.html) for more information)
## Guidelines for Contributing
See the TRAINS [Guidelines for Contributing](contributing.md).
## FAQ
See the TRAINS [FAQ](faq.md).
<p style="font-size:0.9rem; font-weight:700; font-style:italic">May the force (and the goddess of learning rates) be with you!</p>

54
docs/contributing.md Normal file
View File

@@ -0,0 +1,54 @@
# Guidelines for Contributing
Firstly, we thank you for taking the time to contribute!
The following is a set of guidelines for contributing to TRAINS.
These are primarily guidelines, not rules.
Use your best judgment and feel free to propose changes to this document in a pull request.
## Reporting Bugs
This section guides you through submitting a bug report for TRAINS.
By following these guidelines, you
help maintainers and the community understand your report, reproduce the behavior, and find related reports.
Before creating bug reports, please check whether the bug you want to report already appears [here](link to issues).
You may discover that you do not need to create a bug report.
When you are creating a bug report, please include as much detail as possible.
**Note**: If you find a **Closed** issue that may be the same issue which you are currently experiencing,
then open a **New** issue and include a link to the original (Closed) issue in the body of your new one.
Explain the problem and include additional details to help maintainers reproduce the problem:
* **Use a clear and descriptive title** for the issue to identify the problem.
* **Describe the exact steps necessary to reproduce the problem** in as much detail as possible. Please do not just summarize what you did. Make sure to explain how you did it.
* **Provide the specific environment setup.** Include the `pip freeze` output, specific environment variables, Python version, and other relevant information.
* **Provide specific examples to demonstrate the steps.** Include links to files or GitHub projects, or copy/paste snippets which you use in those examples.
* **If you are reporting any TRAINS crash,** include a crash report with a stack trace from the operating system. Make sure to add the crash report in the issue and place it in a [code block](https://help.github.com/en/articles/getting-started-with-writing-and-formatting-on-github#multiple-lines),
a [file attachment](https://help.github.com/articles/file-attachments-on-issues-and-pull-requests/), or just put it in a [gist](https://gist.github.com/) (and provide link to that gist).
* **Describe the behavior you observed after following the steps** and the exact problem with that behavior.
* **Explain which behavior you expected to see and why.**
* **For Web-App issues, please include screenshots and animated GIFs** which recreate the described steps and clearly demonstrate the problem. You can use [LICEcap](https://www.cockos.com/licecap/) to record GIFs on macOS and Windows, and [silentcast](https://github.com/colinkeenan/silentcast) or [byzanz](https://github.com/threedaymonk/byzanz) on Linux.
## Suggesting Enhancements
This section guides you through submitting an enhancement suggestion for TRAINS, including
completely new features and minor improvements to existing functionality.
By following these guidelines, you help maintainers and the community understand your suggestion and find related suggestions.
Enhancement suggestions are tracked as GitHub issues. After you determine which repository your enhancement suggestion is related to, create an issue on that repository and provide the following:
* **A clear and descriptive title** for the issue to identify the suggestion.
* **A step-by-step description of the suggested enhancement** in as much detail as possible.
* **Specific examples to demonstrate the steps.** Include copy/pasteable snippets which you use in those examples as [Markdown code blocks](https://help.github.com/articles/markdown-basics/#multiple-lines).
* **Describe the current behavior and explain which behavior you expected to see instead and why.**
* **Include screenshots or animated GIFs** which help you demonstrate the steps or point out the part of TRAINS which the suggestion is related to. You can use [LICEcap](https://www.cockos.com/licecap/) to record GIFs on macOS and Windows, and [silentcast](https://github.com/colinkeenan/silentcast) or [byzanz](https://github.com/threedaymonk/byzanz) on Linux.

160
docs/faq.md Normal file
View File

@@ -0,0 +1,160 @@
# FAQ
**Can I store more information on the models? For example, can I store enumeration of classes?**
YES!
Use the SDK `set_model_label_enumeration` method:
```python
Task.current_task().set_model_label_enumeration( {label: int(0), } )
```
**Can I store the model configuration file as well?**
YES!
Use the SDK `set_model_design` method:
```python
Task.current_task().set_model_design( a very long text of the configuration file content )
```
**I want to add more graphs, not just with Tensorboard. Is this supported?**
YES!
Use an SDK [Logger](link to git) object. An instance can be always be retrieved with `Task.current_task().get_logger()`:
```python
logger = Task.current_task().get_logger()
logger.report_scalar("loss", "classification", iteration=42, value=1.337)
```
TRAINS supports scalars, plots, 2d/3d scatter diagrams, histograms, surface diagrams, confusion matrices, images, and text logging.
An example can be found [here](docs/manual_log.py).
**I noticed that all of my experiments appear as “Training”. Are there other options?**
YES!
When creating experiments and calling `Task.init`, you can pass an experiment type.
The currently supported types are `Task.TaskTypes.training` and `Task.TaskTypes.testing`:
```python
task = Task.init(project_name, task_name, Task.TaskTypes.testing)
```
If you feel we should add a few more, let us know in the [issues]() section.
**I noticed I keep getting a message “warning: uncommitted code”. What does it mean?**
TRAINS not only detects your current repository and git commit,
but it also warns you if you are using uncommitted code. TRAINS does this
because uncommitted code means it will be difficult to reproduce this experiment.
**Is there something you can do about uncommitted code running?**
YES!
TRAINS currently stores the git diff together with the project.
The Web-App will soon present the git diff as well. This is coming very soon!
**I read that there is a feature for centralized model storage. How do I use it?**
Pass the `output_uri` parameter to `Task.init`, for example:
```python
Task.init(project_name, task_name, output_uri=/mnt/shared/folder)
```
All of the stored snapshots are copied into a subfolder whose name contains the task ID, for example:
`/mnt/shared/folder/task_6ea4f0b56d994320a713aeaf13a86d9d/models/`
Other options include:
```python
Task.init(project_name, task_name, output_uri=s3://bucket/folder)
```
```python
Task.init(project_name, task_name, output_uri=gs://bucket/folder)
```
These require configuring the cloud storage credentials in `~/trains.conf` (see an [example](v)).
**I am training multiple models at the same time, but I only see one of them. What happened?**
This will be fixed in a future version. Currently, TRAINS does support multiple models
from the same task/experiment so you can find all the models in the project Models tab.
In the Task view, we only present the last one.
**Can I log input and output models manually?**
YES!
See [InputModel]() and [OutputModel]().
For example:
```python
input_model = InputModel.import_model(link_to_initial_model_file)
Task.current_task().connect(input_model)
OutputModel(Task.current_task()).update_weights(link_to_new_model_file_here)
```
**I am using Jupyter Notebook. Is this supported?**
YES!
Jupyter Notebook is supported.
**I do not use ArgParser for hyper-parameters. Do you have a solution?**
YES!
TRAINS supports using a Python dictionary for hyper-parameter logging.
```python
parameters_dict = Task.current_task().connect(parameters_dict)
```
From this point onward, not only are the dictionary key/value pairs stored, but also any change to the dictionary is automatically stored.
**Git is not well supported in Jupyter. We just gave up on properly committing our code. Do you have a solution?**
YES!
Check our [trains-jupyter-plugin](). It is a Jupyter plugin that allows you to commit your notebook directly from Jupyter. It also saves the Python version of the code and creates an updated `requirements.txt` so you know which packages you were using.
**Can I use TRAINS with scikit-learn?**
YES!
scikit-learn is supported. Everything you do is logged, with the caveat that models are not logged automatically.
Models are not logged automatically because, in most cases, scikit-learn is simply pickling the object to files so there is no underlying frame to connect to.
**I am working with PyCharm and remotely debugging a machine, but the git repo is not detected. Do you have a solution?**
YES!
This is such a common occurrence that we created a PyCharm plugin that allows for a remote debugger to grab your local repository / commit ID. See our [trains-pycharm-plugin]() repository for instructions and [latest release]().
**How do I know a new version came out?**
Unfortunately, TRAINS currently does not support auto-update checks. We hope to add this soon.
**Sometimes I see experiments as running while they are not. What is it?**
When the Python process exits in an orderly fashion, TRAINS closes the experiment.
If a process crashes, then sometimes the stop signal is missed. You can safely right click on the experiment in the Web-App and stop it.
**In the experiment log tab, Im missing the first log lines. Where are they?**
Unfortunately, due to speed/optimization issues, we opted to display only the last several hundreds. The full log can be downloaded from the Web-App.

43
examples/absl_example.py Normal file
View File

@@ -0,0 +1,43 @@
# TRAINS - example code, absl logging
#
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
from absl import app
from absl import flags
from absl import logging
from trains import Task
FLAGS = flags.FLAGS
flags.DEFINE_string('echo', None, 'Text to echo.')
flags.DEFINE_string('another_str', 'My string', 'A string', module_name='test')
task = Task.init(project_name='examples', task_name='absl example')
flags.DEFINE_integer('echo3', 3, 'Text to echo.')
flags.DEFINE_string('echo5', '5', 'Text to echo.', module_name='test')
parameters = {
'list': [1, 2, 3],
'dict': {'a': 1, 'b': 2},
'int': 3,
'float': 2.2,
'string': 'my string',
}
parameters = task.connect(parameters)
def main(_):
print('Running under Python {0[0]}.{0[1]}.{0[2]}'.format(sys.version_info), file=sys.stderr)
logging.info('echo is %s.', FLAGS.echo)
if __name__ == '__main__':
app.run(main)

160
examples/jupyter.ipynb Normal file

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,113 @@
# TRAINS - Keras with Tensorboard example code, automatic logging model and Tensorboard outputs
#
# Train a simple deep NN on the MNIST dataset.
# Gets to 98.40% test accuracy after 20 epochs
# (there is *a lot* of margin for parameter tuning).
# 2 seconds per epoch on a K520 GPU.
from __future__ import print_function
import numpy as np
import tensorflow
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras.datasets import mnist
from keras.models import Sequential, Model
from keras.layers.core import Dense, Dropout, Activation
from keras.optimizers import SGD, Adam, RMSprop
from keras.utils import np_utils
from keras.models import load_model, save_model, model_from_json
from trains import Task
class TensorBoardImage(TensorBoard):
@staticmethod
def make_image(tensor):
import tensorflow as tf
from PIL import Image
tensor = np.stack((tensor, tensor, tensor), axis=2)
height, width, channels = tensor.shape
image = Image.fromarray(tensor)
import io
output = io.BytesIO()
image.save(output, format='PNG')
image_string = output.getvalue()
output.close()
return tf.Summary.Image(height=height,
width=width,
colorspace=channels,
encoded_image_string=image_string)
def on_epoch_end(self, epoch, logs={}):
super().on_epoch_end(epoch, logs)
import tensorflow as tf
images = self.validation_data[0] # 0 - data; 1 - labels
img = (255 * images[0].reshape(28, 28)).astype('uint8')
image = self.make_image(img)
summary = tf.Summary(value=[tf.Summary.Value(tag='image', image=image)])
self.writer.add_summary(summary, epoch)
batch_size = 128
nb_classes = 10
nb_epoch = 6
# the data, shuffled and split between train and test sets
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(60000, 784)
X_test = X_test.reshape(10000, 784)
X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255.
X_test /= 255.
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')
# convert class vectors to binary class matrices
Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)
model = Sequential()
model.add(Dense(512, input_shape=(784,)))
model.add(Activation('relu'))
# model.add(Dropout(0.2))
model.add(Dense(512))
model.add(Activation('relu'))
# model.add(Dropout(0.2))
model.add(Dense(10))
model.add(Activation('softmax'))
model2 = Sequential()
model2.add(Dense(512, input_shape=(784,)))
model2.add(Activation('relu'))
model.summary()
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy'])
# Connecting TRAINS
task = Task.init(project_name='examples', task_name='Keras with TensorBoard example')
# setting model outputs
labels = dict(('digit_%d' % i, i) for i in range(10))
task.set_model_label_enumeration(labels)
board = TensorBoard(histogram_freq=1, log_dir='/tmp/histogram_example', write_images=False)
model_store = ModelCheckpoint(filepath='/tmp/histogram_example/weight.{epoch}.hdf5')
# load previous model, if it is there
try:
model.load_weights('/tmp/histogram_example/weight.1.hdf5')
except:
pass
history = model.fit(X_train, Y_train,
batch_size=batch_size, epochs=nb_epoch,
callbacks=[board, model_store],
verbose=1, validation_data=(X_test, Y_test))
score = model.evaluate(X_test, Y_test, verbose=0)
print('Test score:', score[0])
print('Test accuracy:', score[1])

View File

@@ -0,0 +1,29 @@
# TRAINS - Example of manual model configuration
#
import torch
from trains import Task
task = Task.init(project_name='examples', task_name='Manual model configuration')
# create a model
model = torch.nn.Module
# store dictionary of definition for a specific network design
model_config_dict = {
'value': 13.37,
'dict': {'sub_value': 'string'},
'list_of_ints': [1, 2, 3, 4],
}
task.set_model_config(config_dict=model_config_dict)
# or read form a config file (this will override the previous configuration dictionary)
# task.set_model_config(config_text='this is just a blob\nof text from a configuration file')
# store the label enumeration the model is training for
task.set_model_label_enumeration({'background': 0, 'cat': 1, 'dog': 2})
print('Any model stored from this point onwards, will contain both model_config and label_enumeration')
# storing the model, it will have the task network configuration and label enumeration
torch.save(model, '/tmp/model')
print('Model saved')

View File

@@ -0,0 +1,51 @@
# TRAINS - Example of manual graphs and statistics reporting
#
import numpy as np
import logging
from trains import Task
task = Task.init(project_name='examples', task_name='Manual reporting')
# example python logger
logging.getLogger().setLevel('DEBUG')
logging.debug('This is a debug message')
logging.info('This is an info message')
logging.warning('This is a warning message')
logging.error('This is an error message')
logging.critical('This is a critical message')
# get TRAINS logger object for any metrics / reports
logger = task.get_logger()
# log text
logger.console("hello")
# report scalar values
logger.report_scalar("example_scalar", "series A", iteration=0, value=100)
logger.report_scalar("example_scalar", "series A", iteration=1, value=200)
# report histogram
histogram = np.random.randint(10, size=10)
logger.report_vector("example_histogram", "random histogram", iteration=1, values=histogram)
# report confusion matrix
confusion = np.random.randint(10, size=(10, 10))
logger.report_matrix("example_confusion", "ignored", iteration=1, matrix=confusion)
# report 2d scatter plot
scatter2d = np.hstack((np.atleast_2d(np.arange(0, 10)).T, np.random.randint(10, size=(10, 1))))
logger.report_scatter2d("example_scatter", "series_xy", iteration=1, scatter=scatter2d)
# report 3d scatter plot
scatter3d = np.random.randint(10, size=(10, 3))
logger.report_scatter3d("example_scatter_3d", "series_xyz", iteration=1, scatter=scatter3d)
# report image
m = np.eye(256, 256, dtype=np.uint8)*255
logger.report_image_and_upload("fail cases", "image uint", iteration=1, matrix=m)
m = np.eye(256, 256, dtype=np.float)
logger.report_image_and_upload("fail cases", "image float", iteration=1, matrix=m)
# flush reports (otherwise it will be flushed in the background, every couple of seconds)
logger.flush()

View File

@@ -0,0 +1,36 @@
# TRAINS - Example of Matplotlib integration and reporting
#
import numpy as np
import matplotlib.pyplot as plt
from trains import Task
task = Task.init(project_name='examples', task_name='Matplotlib example')
# create plot
N = 50
x = np.random.rand(N)
y = np.random.rand(N)
colors = np.random.rand(N)
area = (30 * np.random.rand(N))**2 # 0 to 15 point radii
plt.scatter(x, y, s=area, c=colors, alpha=0.5)
plt.show()
# create another plot - with a name
x = np.linspace(0, 10, 30)
y = np.sin(x)
plt.plot(x, y, 'o', color='black')
plt.show()
# create image plot
m = np.eye(256, 256, dtype=np.uint8)
plt.imshow(m)
plt.show()
# create image plot - with a name
m = np.eye(256, 256, dtype=np.uint8)
plt.imshow(m)
plt.title('Image Title')
plt.show()
print('This is a Matplotlib example')

View File

@@ -0,0 +1,479 @@
# TRAINS - Example of Pytorch and matplotlib integration and reporting
#
"""
Neural Transfer Using PyTorch
=============================
**Author**: `Alexis Jacq <https://alexis-jacq.github.io>`_
**Edited by**: `Winston Herring <https://github.com/winston6>`_
Introduction
------------
This tutorial explains how to implement the `Neural-Style algorithm <https://arxiv.org/abs/1508.06576>`__
developed by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge.
Neural-Style, or Neural-Transfer, allows you to take an image and
reproduce it with a new artistic style. The algorithm takes three images,
an input image, a content-image, and a style-image, and changes the input
to resemble the content of the content-image and the artistic style of the style-image.
.. figure:: /_static/img/neural-style/neuralstyle.png
:alt: content1
"""
######################################################################
# Underlying Principle
# --------------------
#
# The principle is simple: we define two distances, one for the content
# (:math:`D_C`) and one for the style (:math:`D_S`). :math:`D_C` measures how different the content
# is between two images while :math:`D_S` measures how different the style is
# between two images. Then, we take a third image, the input, and
# transform it to minimize both its content-distance with the
# content-image and its style-distance with the style-image. Now we can
# import the necessary packages and begin the neural transfer.
#
# Importing Packages and Selecting a Device
# -----------------------------------------
# Below is a list of the packages needed to implement the neural transfer.
#
# - ``torch``, ``torch.nn``, ``numpy`` (indispensables packages for
# neural networks with PyTorch)
# - ``torch.optim`` (efficient gradient descents)
# - ``PIL``, ``PIL.Image``, ``matplotlib.pyplot`` (load and display
# images)
# - ``torchvision.transforms`` (transform PIL images into tensors)
# - ``torchvision.models`` (train or load pre-trained models)
# - ``copy`` (to deep copy the models; system package)
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy
from trains import Task
task = Task.init(project_name='examples', task_name='pytorch with matplotlib example', task_type=Task.TaskTypes.testing)
######################################################################
# Next, we need to choose which device to run the network on and import the
# content and style images. Running the neural transfer algorithm on large
# images takes longer and will go much faster when running on a GPU. We can
# use ``torch.cuda.is_available()`` to detect if there is a GPU available.
# Next, we set the ``torch.device`` for use throughout the tutorial. Also the ``.to(device)``
# method is used to move tensors or modules to a desired device.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
######################################################################
# Loading the Images
# ------------------
#
# Now we will import the style and content images. The original PIL images have values between 0 and 255, but when
# transformed into torch tensors, their values are converted to be between
# 0 and 1. The images also need to be resized to have the same dimensions.
# An important detail to note is that neural networks from the
# torch library are trained with tensor values ranging from 0 to 1. If you
# try to feed the networks with 0 to 255 tensor images, then the activated
# feature maps will be unable sense the intended content and style.
# However, pre-trained networks from the Caffe library are trained with 0
# to 255 tensor images.
#
#
# .. Note::
# Here are links to download the images required to run the tutorial:
# `picasso.jpg <https://pytorch.org/tutorials/_static/img/neural-style/picasso.jpg>`__ and
# `dancing.jpg <https://pytorch.org/tutorials/_static/img/neural-style/dancing.jpg>`__.
# Download these two images and add them to a directory
# with name ``images`` in your current working directory.
# desired size of the output image
imsize = 512 if torch.cuda.is_available() else 128 # use small size if no gpu
loader = transforms.Compose([
transforms.Resize(imsize), # scale imported image
transforms.ToTensor()]) # transform it into a torch tensor
def image_loader(image_name):
image = Image.open(image_name)
# fake batch dimension required to fit network's input dimensions
image = loader(image).unsqueeze(0)
return image.to(device, torch.float)
style_img = image_loader("./samples/picasso.jpg")
content_img = image_loader("./samples/dancing.jpg")
assert style_img.size() == content_img.size(), \
"we need to import style and content images of the same size"
######################################################################
# Now, let's create a function that displays an image by reconverting a
# copy of it to PIL format and displaying the copy using
# ``plt.imshow``. We will try displaying the content and style images
# to ensure they were imported correctly.
unloader = transforms.ToPILImage() # reconvert into PIL image
plt.ion()
def imshow(tensor, title=None):
image = tensor.cpu().clone() # we clone the tensor to not do changes on it
image = image.squeeze(0) # remove the fake batch dimension
image = unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
imshow(style_img, title='Style Image')
plt.figure()
imshow(content_img, title='Content Image')
######################################################################
# Loss Functions
# --------------
# Content Loss
# ~~~~~~~~~~~~
#
# The content loss is a function that represents a weighted version of the
# content distance for an individual layer. The function takes the feature
# maps :math:`F_{XL}` of a layer :math:`L` in a network processing input :math:`X` and returns the
# weighted content distance :math:`w_{CL}.D_C^L(X,C)` between the image :math:`X` and the
# content image :math:`C`. The feature maps of the content image(:math:`F_{CL}`) must be
# known by the function in order to calculate the content distance. We
# implement this function as a torch module with a constructor that takes
# :math:`F_{CL}` as an input. The distance :math:`\|F_{XL} - F_{CL}\|^2` is the mean square error
# between the two sets of feature maps, and can be computed using ``nn.MSELoss``.
#
# We will add this content loss module directly after the convolution
# layer(s) that are being used to compute the content distance. This way
# each time the network is fed an input image the content losses will be
# computed at the desired layers and because of auto grad, all the
# gradients will be computed. Now, in order to make the content loss layer
# transparent we must define a ``forward`` method that computes the content
# loss and then returns the layers input. The computed loss is saved as a
# parameter of the module.
#
class ContentLoss(nn.Module):
def __init__(self, target, ):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.target = target.detach()
def forward(self, input):
self.loss = F.mse_loss(input, self.target)
return input
######################################################################
# .. Note::
# **Important detail**: although this module is named ``ContentLoss``, it
# is not a true PyTorch Loss function. If you want to define your content
# loss as a PyTorch Loss function, you have to create a PyTorch autograd function
# to recompute/implement the gradient manually in the ``backward``
# method.
######################################################################
# Style Loss
# ~~~~~~~~~~
#
# The style loss module is implemented similarly to the content loss
# module. It will act as a transparent layer in a
# network that computes the style loss of that layer. In order to
# calculate the style loss, we need to compute the gram matrix :math:`G_{XL}`. A gram
# matrix is the result of multiplying a given matrix by its transposed
# matrix. In this application the given matrix is a reshaped version of
# the feature maps :math:`F_{XL}` of a layer :math:`L`. :math:`F_{XL}` is reshaped to form :math:`\hat{F}_{XL}`, a :math:`K`\ x\ :math:`N`
# matrix, where :math:`K` is the number of feature maps at layer :math:`L` and :math:`N` is the
# length of any vectorized feature map :math:`F_{XL}^k`. For example, the first line
# of :math:`\hat{F}_{XL}` corresponds to the first vectorized feature map :math:`F_{XL}^1`.
#
# Finally, the gram matrix must be normalized by dividing each element by
# the total number of elements in the matrix. This normalization is to
# counteract the fact that :math:`\hat{F}_{XL}` matrices with a large :math:`N` dimension yield
# larger values in the Gram matrix. These larger values will cause the
# first layers (before pooling layers) to have a larger impact during the
# gradient descent. Style features tend to be in the deeper layers of the
# network so this normalization step is crucial.
#
def gram_matrix(input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
######################################################################
# Now the style loss module looks almost exactly like the content loss
# module. The style distance is also computed using the mean square
# error between :math:`G_{XL}` and :math:`G_{SL}`.
#
class StyleLoss(nn.Module):
def __init__(self, target_feature):
super(StyleLoss, self).__init__()
self.target = gram_matrix(target_feature).detach()
def forward(self, input):
G = gram_matrix(input)
self.loss = F.mse_loss(G, self.target)
return input
######################################################################
# Importing the Model
# -------------------
#
# Now we need to import a pre-trained neural network. We will use a 19
# layer VGG network like the one used in the paper.
#
# PyTorchs implementation of VGG is a module divided into two child
# ``Sequential`` modules: ``features`` (containing convolution and pooling layers),
# and ``classifier`` (containing fully connected layers). We will use the
# ``features`` module because we need the output of the individual
# convolution layers to measure content and style loss. Some layers have
# different behavior during training than evaluation, so we must set the
# network to evaluation mode using ``.eval()``.
#
cnn = models.vgg19(pretrained=True).features.to(device).eval()
######################################################################
# Additionally, VGG networks are trained on images with each channel
# normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225].
# We will use them to normalize the image before sending it into the network.
#
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
# create a module to normalize input image so we can easily put it in a
# nn.Sequential
class Normalization(nn.Module):
def __init__(self, mean, std):
super(Normalization, self).__init__()
# .view the mean and std to make them [C x 1 x 1] so that they can
# directly work with image Tensor of shape [B x C x H x W].
# B is batch size. C is number of channels. H is height and W is width.
self.mean = torch.tensor(mean).view(-1, 1, 1)
self.std = torch.tensor(std).view(-1, 1, 1)
def forward(self, img):
# normalize img
return (img - self.mean) / self.std
######################################################################
# A ``Sequential`` module contains an ordered list of child modules. For
# instance, ``vgg19.features`` contains a sequence (Conv2d, ReLU, MaxPool2d,
# Conv2d, ReLU…) aligned in the right order of depth. We need to add our
# content loss and style loss layers immediately after the convolution
# layer they are detecting. To do this we must create a new ``Sequential``
# module that has content loss and style loss modules correctly inserted.
#
# desired depth layers to compute style/content losses :
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
style_img, content_img,
content_layers=content_layers_default,
style_layers=style_layers_default):
cnn = copy.deepcopy(cnn)
# normalization module
normalization = Normalization(normalization_mean, normalization_std).to(device)
# just in order to have an iterable access to or list of content/syle
# losses
content_losses = []
style_losses = []
# assuming that cnn is a nn.Sequential, so we make a new nn.Sequential
# to put in modules that are supposed to be activated sequentially
model = nn.Sequential(normalization)
i = 0 # increment every time we see a conv
for layer in cnn.children():
if isinstance(layer, nn.Conv2d):
i += 1
name = 'conv_{}'.format(i)
elif isinstance(layer, nn.ReLU):
name = 'relu_{}'.format(i)
# The in-place version doesn't play very nicely with the ContentLoss
# and StyleLoss we insert below. So we replace with out-of-place
# ones here.
layer = nn.ReLU(inplace=False)
elif isinstance(layer, nn.MaxPool2d):
name = 'pool_{}'.format(i)
elif isinstance(layer, nn.BatchNorm2d):
name = 'bn_{}'.format(i)
else:
raise RuntimeError('Unrecognized layer: {}'.format(layer.__class__.__name__))
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_img).detach()
content_loss = ContentLoss(target)
model.add_module("content_loss_{}".format(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_img).detach()
style_loss = StyleLoss(target_feature)
model.add_module("style_loss_{}".format(i), style_loss)
style_losses.append(style_loss)
# now we trim off the layers after the last content and style losses
for i in range(len(model) - 1, -1, -1):
if isinstance(model[i], ContentLoss) or isinstance(model[i], StyleLoss):
break
model = model[:(i + 1)]
return model, style_losses, content_losses
######################################################################
# Next, we select the input image. You can use a copy of the content image
# or white noise.
#
input_img = content_img.clone()
# if you want to use white noise instead uncomment the below line:
# input_img = torch.randn(content_img.data.size(), device=device)
# add the original input image to the figure:
plt.figure()
imshow(input_img, title='Input Image')
######################################################################
# Gradient Descent
# ----------------
#
# As Leon Gatys, the author of the algorithm, suggested `here <https://discuss.pytorch.org/t/pytorch-tutorial-for-neural-transfert-of-artistic-style/336/20?u=alexis-jacq>`__, we will use
# L-BFGS algorithm to run our gradient descent. Unlike training a network,
# we want to train the input image in order to minimise the content/style
# losses. We will create a PyTorch L-BFGS optimizer ``optim.LBFGS`` and pass
# our image to it as the tensor to optimize.
#
def get_input_optimizer(input_img):
# this line to show that input is a parameter that requires a gradient
optimizer = optim.LBFGS([input_img.requires_grad_()])
return optimizer
######################################################################
# Finally, we must define a function that performs the neural transfer. For
# each iteration of the networks, it is fed an updated input and computes
# new losses. We will run the ``backward`` methods of each loss module to
# dynamicaly compute their gradients. The optimizer requires a “closure”
# function, which reevaluates the modul and returns the loss.
#
# We still have one final constraint to address. The network may try to
# optimize the input with values that exceed the 0 to 1 tensor range for
# the image. We can address this by correcting the input values to be
# between 0 to 1 each time the network is run.
#
def run_style_transfer(cnn, normalization_mean, normalization_std,
content_img, style_img, input_img, num_steps=300,
style_weight=1000000, content_weight=1):
"""Run the style transfer."""
print('Building the style transfer model..')
model, style_losses, content_losses = get_style_model_and_losses(cnn,
normalization_mean, normalization_std, style_img,
content_img)
optimizer = get_input_optimizer(input_img)
print('Optimizing..')
run = [0]
while run[0] <= num_steps:
def closure():
# correct the values of updated input image
input_img.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_img)
style_score = 0
content_score = 0
for sl in style_losses:
style_score += sl.loss
for cl in content_losses:
content_score += cl.loss
style_score *= style_weight
content_score *= content_weight
loss = style_score + content_score
loss.backward()
run[0] += 1
if run[0] % 50 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
style_score.item(), content_score.item()))
print()
return style_score + content_score
optimizer.step(closure)
# a last correction...
input_img.data.clamp_(0, 1)
return input_img
######################################################################
# Finally, we can run the algorithm.
#
output = run_style_transfer(cnn, cnn_normalization_mean, cnn_normalization_std,
content_img, style_img, input_img)
plt.figure()
imshow(output, title='Output Image')
# sphinx_gallery_thumbnail_number = 4
plt.ioff()
plt.show()

124
examples/pytorch_mnist.py Normal file
View File

@@ -0,0 +1,124 @@
# TRAINS - Example of Pytorch mnist training integration
#
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from trains import Task
task = Task.init(project_name='examples', task_name='pytorch mnist train')
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4 * 4 * 50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4 * 4 * 50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
def train(args, model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
def test(args, model, device, test_loader):
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
def main():
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--save-model', action='store_true', default=True,
help='For Saving the current Model')
args = parser.parse_args()
use_cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST('../data', train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=args.test_batch_size, shuffle=True, **kwargs)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
for epoch in range(1, args.epochs + 1):
train(args, model, device, train_loader, optimizer, epoch)
test(args, model, device, test_loader)
if (args.save_model):
torch.save(model.state_dict(), "/tmp/mnist_cnn.pt")
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,126 @@
# TRAINS - Example of pytorch with tensorboard>=v1.14
#
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torch.utils.tensorboard import SummaryWriter
from trains import Task
task = Task.init(project_name='examples', task_name='pytroch with tensorboard')
writer = SummaryWriter('runs')
writer.add_text('lstm', 'This is an lstm', 0)
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.batch_size, shuffle=True, **kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
model = Net()
if args.cuda:
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item()))
niter = epoch*len(train_loader)+batch_idx
writer.add_scalar('Train/Loss', loss.data.item(), niter)
def test():
model.eval()
test_loss = 0
correct = 0
for niter, (data, target) in enumerate(test_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).data.item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
pred = pred.eq(target.data).cpu().sum()
writer.add_scalar('Test/Loss', pred, niter)
correct += pred
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
for epoch in range(1, args.epochs + 1):
train(epoch)
torch.save(model, '/tmp/model{}'.format(epoch))
test()

View File

@@ -0,0 +1,126 @@
# TRAINS - Example of pytorch with tensorboardX
#
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from tensorboardX import SummaryWriter
from trains import Task
task = Task.init(project_name='examples', task_name='pytroch with tensorboardX')
writer = SummaryWriter('runs')
writer.add_text('lstm', 'This is an lstm', 0)
# Training settings
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
help='input batch size for training (default: 64)')
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
help='input batch size for testing (default: 1000)')
parser.add_argument('--epochs', type=int, default=2, metavar='N',
help='number of epochs to train (default: 10)')
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
help='learning rate (default: 0.01)')
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
help='SGD momentum (default: 0.5)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
torch.cuda.manual_seed(args.seed)
kwargs = {'num_workers': 4, 'pin_memory': True} if args.cuda else {}
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../data', train=False,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])),
batch_size=args.batch_size, shuffle=True, **kwargs)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
self.conv2_drop = nn.Dropout2d()
self.fc1 = nn.Linear(320, 50)
self.fc2 = nn.Linear(50, 10)
def forward(self, x):
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, 320)
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return F.log_softmax(x)
model = Net()
if args.cuda:
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data), Variable(target)
optimizer.zero_grad()
output = model(data)
loss = F.nll_loss(output, target)
loss.backward()
optimizer.step()
if batch_idx % args.log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item()))
niter = epoch*len(train_loader)+batch_idx
writer.add_scalar('Train/Loss', loss.data.item(), niter)
def test():
model.eval()
test_loss = 0
correct = 0
for niter, (data, target) in enumerate(test_loader):
if args.cuda:
data, target = data.cuda(), target.cuda()
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target, size_average=False).data.item() # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
pred = pred.eq(target.data).cpu().sum()
writer.add_scalar('Test/Loss', pred, niter)
correct += pred
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset)))
for epoch in range(1, args.epochs + 1):
train(epoch)
torch.save(model, '/tmp/model{}'.format(epoch))
test()

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 112 KiB

View File

@@ -0,0 +1,237 @@
# TRAINS - Example of new tensorboard pr_curves model
#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Create sample PR curve summary data.
We have 3 classes: R, G, and B. We generate colors within RGB space from 3
normal distributions (1 at each corner of the color triangle: [255, 0, 0],
[0, 255, 0], and [0, 0, 255]).
The true label of each random color is associated with the normal distribution
that generated it.
Using 3 other normal distributions (over the distance each color is from a
corner of the color triangle - RGB), we then compute the probability that each
color belongs to the class. We use those probabilities to generate PR curves.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
from absl import app
from absl import flags
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorboard.plugins.pr_curve import summary
from trains import Task
task = Task.init(project_name='examples', task_name='tensorboard pr_curve')
tf.compat.v1.disable_v2_behavior()
FLAGS = flags.FLAGS
flags.DEFINE_string('logdir', '/tmp/pr_curve_demo', 'Directory into which to write TensorBoard data.')
flags.DEFINE_integer('steps', 10,
'Number of steps to generate for each PR curve.')
def start_runs(
logdir,
steps,
run_name,
thresholds,
mask_every_other_prediction=False):
"""Generate a PR curve with precision and recall evenly weighted.
Arguments:
logdir: The directory into which to store all the runs' data.
steps: The number of steps to run for.
run_name: The name of the run.
thresholds: The number of thresholds to use for PR curves.
mask_every_other_prediction: Whether to mask every other prediction by
alternating weights between 0 and 1.
"""
tf.compat.v1.reset_default_graph()
tf.compat.v1.set_random_seed(42)
# Create a normal distribution layer used to generate true color labels.
distribution = tf.compat.v1.distributions.Normal(loc=0., scale=142.)
# Sample the distribution to generate colors. Lets generate different numbers
# of each color. The first dimension is the count of examples.
# The calls to sample() are given fixed random seed values that are "magic"
# in that they correspond to the default seeds for those ops when the PR
# curve test (which depends on this code) was written. We've pinned these
# instead of continuing to use the defaults since the defaults are based on
# node IDs from the sequence of nodes added to the graph, which can silently
# change when this code or any TF op implementations it uses are modified.
# TODO(nickfelt): redo the PR curve test to avoid reliance on random seeds.
# Generate reds.
number_of_reds = 100
true_reds = tf.clip_by_value(
tf.concat([
255 - tf.abs(distribution.sample([number_of_reds, 1], seed=11)),
tf.abs(distribution.sample([number_of_reds, 2], seed=34))
], axis=1),
0, 255)
# Generate greens.
number_of_greens = 200
true_greens = tf.clip_by_value(
tf.concat([
tf.abs(distribution.sample([number_of_greens, 1], seed=61)),
255 - tf.abs(distribution.sample([number_of_greens, 1], seed=82)),
tf.abs(distribution.sample([number_of_greens, 1], seed=105))
], axis=1),
0, 255)
# Generate blues.
number_of_blues = 150
true_blues = tf.clip_by_value(
tf.concat([
tf.abs(distribution.sample([number_of_blues, 2], seed=132)),
255 - tf.abs(distribution.sample([number_of_blues, 1], seed=153))
], axis=1),
0, 255)
# Assign each color a vector of 3 booleans based on its true label.
labels = tf.concat([
tf.tile(tf.constant([[True, False, False]]), (number_of_reds, 1)),
tf.tile(tf.constant([[False, True, False]]), (number_of_greens, 1)),
tf.tile(tf.constant([[False, False, True]]), (number_of_blues, 1)),
], axis=0)
# We introduce 3 normal distributions. They are used to predict whether a
# color falls under a certain class (based on distances from corners of the
# color triangle). The distributions vary per color. We have the distributions
# narrow over time.
initial_standard_deviations = [v + FLAGS.steps for v in (158, 200, 242)]
iteration = tf.compat.v1.placeholder(tf.int32, shape=[])
red_predictor = tf.compat.v1.distributions.Normal(
loc=0.,
scale=tf.cast(
initial_standard_deviations[0] - iteration,
dtype=tf.float32))
green_predictor = tf.compat.v1.distributions.Normal(
loc=0.,
scale=tf.cast(
initial_standard_deviations[1] - iteration,
dtype=tf.float32))
blue_predictor = tf.compat.v1.distributions.Normal(
loc=0.,
scale=tf.cast(
initial_standard_deviations[2] - iteration,
dtype=tf.float32))
# Make predictions (assign 3 probabilities to each color based on each color's
# distance to each of the 3 corners). We seek double the area in the right
# tail of the normal distribution.
examples = tf.concat([true_reds, true_greens, true_blues], axis=0)
probabilities_colors_are_red = (1 - red_predictor.cdf(
tf.norm(tensor=examples - tf.constant([255., 0, 0]), axis=1))) * 2
probabilities_colors_are_green = (1 - green_predictor.cdf(
tf.norm(tensor=examples - tf.constant([0, 255., 0]), axis=1))) * 2
probabilities_colors_are_blue = (1 - blue_predictor.cdf(
tf.norm(tensor=examples - tf.constant([0, 0, 255.]), axis=1))) * 2
predictions = (
probabilities_colors_are_red,
probabilities_colors_are_green,
probabilities_colors_are_blue
)
# This is the crucial piece. We write data required for generating PR curves.
# We create 1 summary per class because we create 1 PR curve per class.
for i, color in enumerate(('red', 'green', 'blue')):
description = ('The probabilities used to create this PR curve are '
'generated from a normal distribution. Its standard '
'deviation is initially %0.0f and decreases over time.' %
initial_standard_deviations[i])
weights = None
if mask_every_other_prediction:
# Assign a weight of 0 to every even-indexed prediction. Odd-indexed
# predictions are assigned a default weight of 1.
consecutive_indices = tf.reshape(
tf.range(tf.size(input=predictions[i])), tf.shape(input=predictions[i]))
weights = tf.cast(consecutive_indices % 2, dtype=tf.float32)
summary.op(
name=color,
labels=labels[:, i],
predictions=predictions[i],
num_thresholds=thresholds,
weights=weights,
display_name='classifying %s' % color,
description=description)
merged_summary_op = tf.compat.v1.summary.merge_all()
events_directory = os.path.join(logdir, run_name)
sess = tf.compat.v1.Session()
writer = tf.compat.v1.summary.FileWriter(events_directory, sess.graph)
for step in xrange(steps):
feed_dict = {
iteration: step,
}
merged_summary = sess.run(merged_summary_op, feed_dict=feed_dict)
writer.add_summary(merged_summary, step)
writer.close()
def run_all(logdir, steps, thresholds, verbose=False):
"""Generate PR curve summaries.
Arguments:
logdir: The directory into which to store all the runs' data.
steps: The number of steps to run for.
verbose: Whether to print the names of runs into stdout during execution.
thresholds: The number of thresholds to use for PR curves.
"""
# First, we generate data for a PR curve that assigns even weights for
# predictions of all classes.
run_name = 'colors'
if verbose:
print('--- Running: %s' % run_name)
start_runs(
logdir=logdir,
steps=steps,
run_name=run_name,
thresholds=thresholds)
# Next, we generate data for a PR curve that assigns arbitrary weights to
# predictions.
run_name = 'mask_every_other_prediction'
if verbose:
print('--- Running: %s' % run_name)
start_runs(
logdir=logdir,
steps=steps,
run_name=run_name,
thresholds=thresholds,
mask_every_other_prediction=True)
def main(_):
print('Saving output to %s.' % FLAGS.logdir)
run_all(FLAGS.logdir, FLAGS.steps, 50, verbose=True)
print('Done. Output saved to %s.' % FLAGS.logdir)
if __name__ == '__main__':
app.run(main)

View File

@@ -0,0 +1,76 @@
# TRAINS - Example of tensorboard with tensorflow (without any actual training)
#
import tensorflow as tf
import numpy as np
import cv2
from time import sleep
#import tensorflow.compat.v1 as tf
#tf.disable_v2_behavior()
from trains import Task
task = Task.init(project_name='examples', task_name='tensorboard toy example')
k = tf.placeholder(tf.float32)
# Make a normal distribution, with a shifting mean
mean_moving_normal = tf.random_normal(shape=[1000], mean=(5*k), stddev=1)
# Record that distribution into a histogram summary
tf.summary.histogram("normal/moving_mean", mean_moving_normal)
tf.summary.scalar("normal/value", mean_moving_normal[-1])
# Make a normal distribution with shrinking variance
variance_shrinking_normal = tf.random_normal(shape=[1000], mean=0, stddev=1-(k))
# Record that distribution too
tf.summary.histogram("normal/shrinking_variance", variance_shrinking_normal)
tf.summary.scalar("normal/variance_shrinking_normal", variance_shrinking_normal[-1])
# Let's combine both of those distributions into one dataset
normal_combined = tf.concat([mean_moving_normal, variance_shrinking_normal], 0)
# We add another histogram summary to record the combined distribution
tf.summary.histogram("normal/bimodal", normal_combined)
tf.summary.scalar("normal/normal_combined", normal_combined[0])
# Add a gamma distribution
gamma = tf.random_gamma(shape=[1000], alpha=k)
tf.summary.histogram("gamma", gamma)
# And a poisson distribution
poisson = tf.random_poisson(shape=[1000], lam=k)
tf.summary.histogram("poisson", poisson)
# And a uniform distribution
uniform = tf.random_uniform(shape=[1000], maxval=k*10)
tf.summary.histogram("uniform", uniform)
# Finally, combine everything together!
all_distributions = [mean_moving_normal, variance_shrinking_normal, gamma, poisson, uniform]
all_combined = tf.concat(all_distributions, 0)
tf.summary.histogram("all_combined", all_combined)
# convert to 4d [batch, col, row, RGB-channels]
image = cv2.imread('./samples/picasso.jpg')
image = image[:, :, 0][np.newaxis, :, :, np.newaxis]
# image = image[np.newaxis, :, :, :] # test greyscale image
# un-comment to add image reporting
tf.summary.image("test", image, max_outputs=10)
# Setup a session and summary writer
summaries = tf.summary.merge_all()
sess = tf.Session()
logger = task.get_logger()
# Use original FileWriter for comparison , run:
# % tensorboard --logdir=/tmp/histogram_example
writer = tf.summary.FileWriter("/tmp/histogram_example")
# Setup a loop and write the summaries to disk
N = 40
for step in range(N):
k_val = step/float(N)
summ = sess.run(summaries, feed_dict={k: k_val})
writer.add_summary(summ, global_step=step)
print('Done!')

View File

@@ -0,0 +1,358 @@
# TRAINS - Example of tensorflow eager mode, model logging and tensorboard
#
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A deep MNIST classifier using convolutional layers.
Sample usage:
python mnist.py --help
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import os
import sys
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from trains import Task
tf.enable_eager_execution()
task = Task.init(project_name='examples', task_name='Tensorflow eager mode')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('data_num', 100, """Flag of type integer""")
tf.app.flags.DEFINE_string('img_path', './img', """Flag of type string""")
layers = tf.keras.layers
FLAGS = None
class Discriminator(tf.keras.Model):
"""GAN Discriminator.
A network to differentiate between generated and real handwritten digits.
"""
def __init__(self, data_format):
"""Creates a model for discriminating between real and generated digits.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
"""
super(Discriminator, self).__init__(name='')
if data_format == 'channels_first':
self._input_shape = [-1, 1, 28, 28]
else:
assert data_format == 'channels_last'
self._input_shape = [-1, 28, 28, 1]
self.conv1 = layers.Conv2D(
64, 5, padding='SAME', data_format=data_format, activation=tf.tanh)
self.pool1 = layers.AveragePooling2D(2, 2, data_format=data_format)
self.conv2 = layers.Conv2D(
128, 5, data_format=data_format, activation=tf.tanh)
self.pool2 = layers.AveragePooling2D(2, 2, data_format=data_format)
self.flatten = layers.Flatten()
self.fc1 = layers.Dense(1024, activation=tf.tanh)
self.fc2 = layers.Dense(1, activation=None)
def call(self, inputs):
"""Return two logits per image estimating input authenticity.
Users should invoke __call__ to run the network, which delegates to this
method (and not call this method directly).
Args:
inputs: A batch of images as a Tensor with shape [batch_size, 28, 28, 1]
or [batch_size, 1, 28, 28]
Returns:
A Tensor with shape [batch_size] containing logits estimating
the probability that corresponding digit is real.
"""
x = tf.reshape(inputs, self._input_shape)
x = self.conv1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.pool2(x)
x = self.flatten(x)
x = self.fc1(x)
x = self.fc2(x)
return x
class Generator(tf.keras.Model):
"""Generator of handwritten digits similar to the ones in the MNIST dataset.
"""
def __init__(self, data_format):
"""Creates a model for discriminating between real and generated digits.
Args:
data_format: Either 'channels_first' or 'channels_last'.
'channels_first' is typically faster on GPUs while 'channels_last' is
typically faster on CPUs. See
https://www.tensorflow.org/performance/performance_guide#data_formats
"""
super(Generator, self).__init__(name='')
self.data_format = data_format
# We are using 128 6x6 channels as input to the first deconvolution layer
if data_format == 'channels_first':
self._pre_conv_shape = [-1, 128, 6, 6]
else:
assert data_format == 'channels_last'
self._pre_conv_shape = [-1, 6, 6, 128]
self.fc1 = layers.Dense(6 * 6 * 128, activation=tf.tanh)
# In call(), we reshape the output of fc1 to _pre_conv_shape
# Deconvolution layer. Resulting image shape: (batch, 14, 14, 64)
self.conv1 = layers.Conv2DTranspose(
64, 4, strides=2, activation=None, data_format=data_format)
# Deconvolution layer. Resulting image shape: (batch, 28, 28, 1)
self.conv2 = layers.Conv2DTranspose(
1, 2, strides=2, activation=tf.nn.sigmoid, data_format=data_format)
def call(self, inputs):
"""Return a batch of generated images.
Users should invoke __call__ to run the network, which delegates to this
method (and not call this method directly).
Args:
inputs: A batch of noise vectors as a Tensor with shape
[batch_size, length of noise vectors].
Returns:
A Tensor containing generated images. If data_format is 'channels_last',
the shape of returned images is [batch_size, 28, 28, 1], else
[batch_size, 1, 28, 28]
"""
x = self.fc1(inputs)
x = tf.reshape(x, shape=self._pre_conv_shape)
x = self.conv1(x)
x = self.conv2(x)
return x
def discriminator_loss(discriminator_real_outputs, discriminator_gen_outputs):
"""Original discriminator loss for GANs, with label smoothing.
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661) for more
details.
Args:
discriminator_real_outputs: Discriminator output on real data.
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
Returns:
A scalar loss Tensor.
"""
loss_on_real = tf.losses.sigmoid_cross_entropy(
tf.ones_like(discriminator_real_outputs),
discriminator_real_outputs,
label_smoothing=0.25)
loss_on_generated = tf.losses.sigmoid_cross_entropy(
tf.zeros_like(discriminator_gen_outputs), discriminator_gen_outputs)
loss = loss_on_real + loss_on_generated
tf.contrib.summary.scalar('discriminator_loss', loss)
return loss
def generator_loss(discriminator_gen_outputs):
"""Original generator loss for GANs.
L = -log(sigmoid(D(G(z))))
See `Generative Adversarial Nets` (https://arxiv.org/abs/1406.2661)
for more details.
Args:
discriminator_gen_outputs: Discriminator output on generated data. Expected
to be in the range of (-inf, inf).
Returns:
A scalar loss Tensor.
"""
loss = tf.losses.sigmoid_cross_entropy(
tf.ones_like(discriminator_gen_outputs), discriminator_gen_outputs)
tf.contrib.summary.scalar('generator_loss', loss)
return loss
def train_one_epoch(generator, discriminator, generator_optimizer,
discriminator_optimizer, dataset, step_counter,
log_interval, noise_dim):
"""Train `generator` and `discriminator` models on `dataset`.
Args:
generator: Generator model.
discriminator: Discriminator model.
generator_optimizer: Optimizer to use for generator.
discriminator_optimizer: Optimizer to use for discriminator.
dataset: Dataset of images to train on.
step_counter: An integer variable, used to write summaries regularly.
log_interval: How many steps to wait between logging and collecting
summaries.
noise_dim: Dimension of noise vector to use.
"""
total_generator_loss = 0.0
total_discriminator_loss = 0.0
for (batch_index, images) in enumerate(dataset):
with tf.device('/cpu:0'):
tf.assign_add(step_counter, 1)
with tf.contrib.summary.record_summaries_every_n_global_steps(
log_interval, global_step=step_counter):
current_batch_size = images.shape[0]
noise = tf.random_uniform(
shape=[current_batch_size, noise_dim],
minval=-1.,
maxval=1.,
seed=batch_index)
# we can use 2 tapes or a single persistent tape.
# Using two tapes is memory efficient since intermediate tensors can be
# released between the two .gradient() calls below
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise)
tf.contrib.summary.image(
'generated_images',
tf.reshape(generated_images, [-1, 28, 28, 1]),
max_images=10)
discriminator_gen_outputs = discriminator(generated_images)
discriminator_real_outputs = discriminator(images)
discriminator_loss_val = discriminator_loss(discriminator_real_outputs,
discriminator_gen_outputs)
total_discriminator_loss += discriminator_loss_val
generator_loss_val = generator_loss(discriminator_gen_outputs)
total_generator_loss += generator_loss_val
generator_grad = gen_tape.gradient(generator_loss_val,
generator.variables)
discriminator_grad = disc_tape.gradient(discriminator_loss_val,
discriminator.variables)
generator_optimizer.apply_gradients(
zip(generator_grad, generator.variables))
discriminator_optimizer.apply_gradients(
zip(discriminator_grad, discriminator.variables))
if log_interval and batch_index > 0 and batch_index % log_interval == 0:
print('Batch #%d\tAverage Generator Loss: %.6f\t'
'Average Discriminator Loss: %.6f' %
(batch_index, total_generator_loss / batch_index,
total_discriminator_loss / batch_index))
def main(_):
(device, data_format) = ('/gpu:0', 'channels_first')
if FLAGS.no_gpu or tf.contrib.eager.num_gpus() <= 0:
(device, data_format) = ('/cpu:0', 'channels_last')
print('Using device %s, and data format %s.' % (device, data_format))
# Load the datasets
data = input_data.read_data_sets(FLAGS.data_dir)
dataset = (
tf.data.Dataset.from_tensor_slices(data.train.images[:1280]).shuffle(60000)
.batch(FLAGS.batch_size))
# Create the models and optimizers.
model_objects = {
'generator': Generator(data_format),
'discriminator': Discriminator(data_format),
'generator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
'discriminator_optimizer': tf.train.AdamOptimizer(FLAGS.lr),
'step_counter': tf.train.get_or_create_global_step(),
}
# Prepare summary writer and checkpoint info
summary_writer = tf.contrib.summary.create_file_writer(
FLAGS.output_dir, flush_millis=1000)
checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt')
latest_cpkt = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_cpkt:
print('Using latest checkpoint at ' + latest_cpkt)
checkpoint = tf.train.Checkpoint(**model_objects)
# Restore variables on creation if a checkpoint exists.
checkpoint.restore(latest_cpkt)
with tf.device(device):
for _ in range(3):
start = time.time()
with summary_writer.as_default():
train_one_epoch(dataset=dataset, log_interval=FLAGS.log_interval,
noise_dim=FLAGS.noise, **model_objects)
end = time.time()
checkpoint.save(checkpoint_prefix)
print('\nTrain time for epoch #%d (step %d): %f' %
(checkpoint.save_counter.numpy(),
checkpoint.step_counter.numpy(),
end - start))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--data-dir',
type=str,
default='/tmp/tensorflow/mnist/input_data',
help=('Directory for storing input data (default '
'/tmp/tensorflow/mnist/input_data)'))
parser.add_argument(
'--batch-size',
type=int,
default=16,
metavar='N',
help='input batch size for training (default: 128)')
parser.add_argument(
'--log-interval',
type=int,
default=1,
metavar='N',
help=('number of batches between logging and writing summaries '
'(default: 100)'))
parser.add_argument(
'--output_dir',
type=str,
default='/tmp/tensorflow/',
metavar='DIR',
help='Directory to write TensorBoard summaries (defaults to none)')
parser.add_argument(
'--checkpoint_dir',
type=str,
default='/tmp/tensorflow/mnist/checkpoints/',
metavar='DIR',
help=('Directory to save checkpoints in (once per epoch) (default '
'/tmp/tensorflow/mnist/checkpoints/)'))
parser.add_argument(
'--lr',
type=float,
default=0.001,
metavar='LR',
help='learning rate (default: 0.001)')
parser.add_argument(
'--noise',
type=int,
default=100,
metavar='N',
help='Length of noise vector for generator input (default: 100)')
parser.add_argument(
'--no-gpu',
action='store_true',
default=False,
help='disables GPU usage even if a GPU is available')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

View File

@@ -0,0 +1,171 @@
# TRAINS - Example of tensorflow mnist training model logging
#
# Save and Restore a model using TensorFlow.
# This example is using the MNIST database of handwritten digits
# (http://yann.lecun.com/exdb/mnist/)
#
# Author: Aymeric Damien
# Project: https://github.com/aymericdamien/TensorFlow-Examples/
from __future__ import print_function
from os.path import exists
import numpy as np
import tensorflow as tf
from trains import Task
MODEL_PATH = "/tmp/module_no_signatures"
task = Task.init(project_name='examples', task_name='Tensorflow mnist example')
## block
X_train = np.random.rand(100, 3)
y_train = np.random.rand(100, 1)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
model.compile(loss='categorical_crossentropy',
optimizer=tf.keras.optimizers.SGD(),
metrics=['accuracy'])
model.fit(X_train, y_train, steps_per_epoch=1, nb_epoch=1)
with tf.Session(graph=tf.Graph()) as sess:
if exists(MODEL_PATH):
try:
tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], MODEL_PATH)
m2 = tf.saved_model.load(sess, [tf.saved_model.tag_constants.SERVING], MODEL_PATH)
except Exception:
pass
tf.train.Checkpoint
## block end
# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# Parameters
parameters = {
'learning_rate': 0.001,
'batch_size': 100,
'display_step': 1,
'model_path': "/tmp/model.ckpt",
# Network Parameters
'n_hidden_1': 256, # 1st layer number of features
'n_hidden_2': 256, # 2nd layer number of features
'n_input': 784, # MNIST data input (img shape: 28*28)
'n_classes': 10, # MNIST total classes (0-9 digits)
}
# TRAINS: connect parameters with the experiment/task for logging
parameters = task.connect(parameters)
# tf Graph input
x = tf.placeholder("float", [None, parameters['n_input']])
y = tf.placeholder("float", [None, parameters['n_classes']])
# Create model
def multilayer_perceptron(x, weights, biases):
# Hidden layer with RELU activation
layer_1 = tf.add(tf.matmul(x, weights['h1']), biases['b1'])
layer_1 = tf.nn.relu(layer_1)
# Hidden layer with RELU activation
layer_2 = tf.add(tf.matmul(layer_1, weights['h2']), biases['b2'])
layer_2 = tf.nn.relu(layer_2)
# Output layer with linear activation
out_layer = tf.matmul(layer_2, weights['out']) + biases['out']
return out_layer
# Store layers weight & bias
weights = {
'h1': tf.Variable(tf.random_normal([parameters['n_input'], parameters['n_hidden_1']])),
'h2': tf.Variable(tf.random_normal([parameters['n_hidden_1'], parameters['n_hidden_2']])),
'out': tf.Variable(tf.random_normal([parameters['n_hidden_2'], parameters['n_classes']]))
}
biases = {
'b1': tf.Variable(tf.random_normal([parameters['n_hidden_1']])),
'b2': tf.Variable(tf.random_normal([parameters['n_hidden_2']])),
'out': tf.Variable(tf.random_normal([parameters['n_classes']]))
}
# Construct model
pred = multilayer_perceptron(x, weights, biases)
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optimizer = tf.train.AdamOptimizer(learning_rate=parameters['learning_rate']).minimize(cost)
# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()
# 'Saver' op to save and restore all the variables
saver = tf.train.Saver()
# Running first session
print("Starting 1st session...")
with tf.Session() as sess:
# Run the initializer
sess.run(init)
# Training cycle
for epoch in range(3):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/parameters['batch_size'])
# Loop over all batches
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(parameters['batch_size'])
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={x: batch_x,
y: batch_y})
# Compute average loss
avg_cost += c / total_batch
# Display logs per epoch step
if epoch % parameters['display_step'] == 0:
print("Epoch:", '%04d' % (epoch+1), "cost=", \
"{:.9f}".format(avg_cost))
save_path = saver.save(sess, parameters['model_path'])
print("First Optimization Finished!")
# Test model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))
# Save model weights to disk
save_path = saver.save(sess, parameters['model_path'])
print("Model saved in file: %s" % save_path)
# Running a new session
print("Starting 2nd session...")
with tf.Session() as sess:
# Initialize variables
sess.run(init)
# Restore model weights from previously saved model
saver.restore(sess, parameters['model_path'])
print("Model restored from file: %s" % save_path)
# Resume training
for epoch in range(7):
avg_cost = 0.
total_batch = int(mnist.train.num_examples / parameters['batch_size'])
# Loop over all batches
for i in range(total_batch):
batch_x, batch_y = mnist.train.next_batch(parameters['batch_size'])
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={x: batch_x,
y: batch_y})
# Compute average loss
avg_cost += c / total_batch
# Display logs per epoch step
if epoch % parameters['display_step'] == 0:
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(avg_cost))
print("Second Optimization Finished!")
# Test model
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print("Accuracy:", accuracy.eval(
{x: mnist.test.images, y: mnist.test.labels}))

131
examples/trains.conf Normal file
View File

@@ -0,0 +1,131 @@
# TRAINS SDK configuration file
api {
host: http://localhost:8008
credentials {"access_key": "EGRTCO8JMSIGI6S39GTP43NFWXDQOW", "secret_key": "x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8"}
}
sdk {
# TRAINS - default SDK configuration
storage {
cache {
# Defaults to system temp folder / cache
default_base_dir: "~/.trains/cache"
}
}
metrics {
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
# X files are stored in the upload destination for each metric/variant combination.
file_history_size: 100
# Settings for generated debug images
images {
format: JPEG
quality: 87
subsampling: 0
}
}
network {
metrics {
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
# a specific iteration
file_upload_threads: 4
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
# being sent for upload
file_upload_starvation_warning_sec: 120
}
iteration {
# Max number of retries when getting frames if the server returned an error (http code 500)
max_retries_on_server_error: 5
# Backoff factory for consecutive retry attempts.
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
retry_backoff_factor_sec: 10
}
}
aws {
s3 {
# S3 credentials, used for read/write access by various SDK elements
# default, used for any bucket not specified below
key: ""
secret: ""
region: ""
credentials: [
# specifies key/secret credentials to use when handling s3 urls (read or write)
# {
# bucket: "my-bucket-name"
# key: "my-access-key"
# secret: "my-secret-key"
# },
# {
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
# host: "my-minio-host:9000"
# key: "12345678"
# secret: "12345678"
# multipart: false
# secure: false
# }
]
}
boto3 {
pool_connections: 512
max_multipart_concurrency: 16
}
}
google.storage {
# # Default project and credentials file
# # Will be used when no bucket configuration is found
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# # Specific credentials per bucket and sub directory
# credentials = [
# {
# bucket: "my-bucket"
# subdir: "path/in/bucket" # Not required
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# },
# ]
}
log {
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
null_log_propagate: False
task_log_buffer_capacity: 66
# disable urllib info and lower levels
disable_urllib3_info: True
}
development {
# Development-mode options
# dev task reuse window
task_reuse_time_window_in_hours: 72.0
# Run VCS repository detection asynchronously
vcs_repo_detect_async: False
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff_on_train: True
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True
# Development mode worker
worker {
# Status report period in seconds
report_period_sec: 2
# Log all stdout & stderr
log_stdout: True
}
}
}

31
requirements.txt Normal file
View File

@@ -0,0 +1,31 @@
apache-libcloud>=2.2.1
attrs>=18.0
backports.functools-lru-cache>=1.0.2 ; python_version < '3'
boto3>=1.9
botocore>=1.12
colorama>=0.4.1
coloredlogs>=10.0
enum34>=0.9
funcsigs>=1.0
furl>=2.0.0
future>=0.16.0
futures>=3.0.5 ; python_version < '3'
google-cloud-storage>=1.13.2
humanfriendly>=2.1
jsonmodels>=2.2
jsonschema>=2.6.0
numpy>=1.10
opencv-python>=3.2.0.8
pathlib2>=2.3.0
psutil>=3.4.2
pyhocon>=0.3.38
python-dateutil>=2.6.1
PyYAML>=3.12
requests-file>=1.4.2
requests>=2.18.4
six>=1.11.0
tqdm>=4.19.5
urllib3>=1.22
watchdog>=0.8.0
pyjwt>=1.6.4
plotly>=3.9.0

4
setup.cfg Normal file
View File

@@ -0,0 +1,4 @@
[bdist_wheel]
# Currently supports Python2 only,
# Python 3 is coming...
universal=1

77
setup.py Normal file
View File

@@ -0,0 +1,77 @@
"""
TRAINS - Artificial Intelligence Version Control
https://github.com/allegroai/trains
"""
# Always prefer setuptools over distutils
from setuptools import setup, find_packages
from six import exec_
from pathlib2 import Path
here = Path(__file__).resolve().parent
# Get the long description from the README file
long_description = (here / 'README.md').read_text()
def read_version_string():
result = {}
exec_((here / 'trains/version.py').read_text(), result)
return result['__version__']
version = read_version_string()
requirements = (here / 'requirements.txt').read_text().splitlines()
setup(
name='trains',
version=version,
description='TRAINS - Magic Version Control & Experiment Manager for AI',
long_description=long_description,
long_description_content_type='text/markdown',
# The project's main homepage.
url='https://github.com/allegroai/trains',
author='Allegroai',
author_email='trains@allegro.ai',
license='Apache License 2.0',
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Intended Audience :: Science/Research',
'Operating System :: POSIX :: Linux',
'Operating System :: MacOS :: MacOS X',
'Operating System :: Microsoft',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'Topic :: Software Development',
'Topic :: Software Development :: Version Control',
'Topic :: System :: Logging',
'Topic :: System :: Monitoring',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'License :: OSI Approved :: Apache Software License',
],
keywords='trains development machine deep learning version control machine-learning machinelearning '
'deeplearning deep-learning experiment-manager experimentmanager',
packages=find_packages(exclude=['contrib', 'docs', 'data', 'examples', 'tests']),
install_requires=requirements,
package_data={
'trains': ['config/default/*.conf', 'backend_api/config/default/*.conf']
},
include_package_data=True,
# To provide executable scripts, use entry points in preference to the
# "scripts" keyword. Entry points provide cross-platform support and allow
# pip to create the appropriate form of executable for the target platform.
entry_points={
'console_scripts': [
'trains-init = trains.config.default.__main__:main',
],
},
)

7
trains/__init__.py Normal file
View File

@@ -0,0 +1,7 @@
""" TRAINS open SDK """
from .version import __version__
from .task import Task
from .model import InputModel, OutputModel
from .logger import Logger
from .errors import UsageError

View File

@@ -0,0 +1,3 @@
from .version import __version__
from .session import Session, CallResult, TimeoutExpiredError, ResultNotReadyError
from .config import load as load_config

View File

@@ -0,0 +1,16 @@
from ...backend_config import Config
from pathlib2 import Path
def load(*additional_module_paths):
# type: (str) -> Config
"""
Load configuration with the API defaults, using the additional module path provided
:param additional_module_paths: Additional config paths for modules who'se default
configurations should be loaded as well
:return: Config object
"""
config = Config(verbose=False)
this_module_path = Path(__file__).parent
config.load_relative_to(this_module_path, *additional_module_paths)
return config

View File

@@ -0,0 +1,41 @@
{
version: 1.5
host: https://demoapi.trainsai.io
# default version assigned to requests with no specific version. this is not expected to change
# as it keeps us backwards compatible.
default_version: 1.5
http {
max_req_size = 15728640 # request size limit (smaller than that configured in api server)
retries {
# retry values (int, 0 means fail on first retry)
total: 240 # Total number of retries to allow. Takes precedence over other counts.
connect: 240 # How many times to retry on connection-related errors (never reached server)
read: 240 # How many times to retry on read errors (waiting for server)
redirect: 240 # How many redirects to perform (HTTP response with a status code 301, 302, 303, 307 or 308)
status: 240 # How many times to retry on bad status codes
# backoff parameters
# timeout between retries is min({backoff_max}, {backoff factor} * (2 ^ ({number of total retries} - 1))
backoff_factor: 1.0
backoff_max: 300.0
}
wait_on_maintenance_forever: true
pool_maxsize: 512
pool_connections: 512
}
credentials {
access_key: ""
secret_key: ""
}
auth {
# When creating a request, if token will expire in less than this value, try to refresh the token
token_expiration_threshold_sec = 360
}
}

View File

@@ -0,0 +1,9 @@
{
version: 1
loggers {
urllib3 {
level: ERROR
}
}
}

View File

View File

@@ -0,0 +1,38 @@
import re
from functools import partial
import attr
from attr.converters import optional as optional_converter
from attr.validators import instance_of, optional, and_
from six import string_types
# noinspection PyTypeChecker
sequence = instance_of((list, tuple))
def sequence_of(types):
def validator(_, attrib, value):
assert all(isinstance(x, types) for x in value), attrib.name
return and_(sequence, validator)
@attr.s
class Action(object):
name = attr.ib()
version = attr.ib()
service = attr.ib()
definitions_keys = attr.ib(validator=sequence)
authorize = attr.ib(validator=instance_of(bool), default=True)
log_data = attr.ib(validator=instance_of(bool), default=True)
log_result_data = attr.ib(validator=instance_of(bool), default=True)
internal = attr.ib(default=False)
allow_roles = attr.ib(default=None, validator=optional(sequence_of(string_types)))
request = attr.ib(validator=optional(instance_of(dict)), default=None)
batch_request = attr.ib(validator=optional(instance_of(dict)), default=None)
response = attr.ib(validator=optional(instance_of(dict)), default=None)
method = attr.ib(default=None)
description = attr.ib(
default=None,
validator=optional(instance_of(string_types)),
)

View File

@@ -0,0 +1,201 @@
import itertools
import re
import attr
import six
import pyhocon
from .action import Action
class Service(object):
""" Service schema handler """
__jsonschema_ref_ex = re.compile("^#/definitions/(.*)$")
@property
def default(self):
return self._default
@property
def actions(self):
return self._actions
@property
def definitions(self):
""" Raw service definitions (each might be dependant on some of its siblings) """
return self._definitions
@property
def definitions_refs(self):
return self._definitions_refs
@property
def name(self):
return self._name
@property
def doc(self):
return self._doc
def __init__(self, name, service_config):
self._name = name
self._default = None
self._actions = []
self._definitions = None
self._definitions_refs = None
self._doc = None
self.parse(service_config)
@classmethod
def get_ref_name(cls, ref_string):
m = cls.__jsonschema_ref_ex.match(ref_string)
if m:
return m.group(1)
def parse(self, service_config):
self._default = service_config.get(
"_default", pyhocon.ConfigTree()
).as_plain_ordered_dict()
self._doc = '{} service'.format(self.name)
description = service_config.get('_description', '')
if description:
self._doc += '\n\n{}'.format(description)
self._definitions = service_config.get(
"_definitions", pyhocon.ConfigTree()
).as_plain_ordered_dict()
self._definitions_refs = {
k: self._get_schema_references(v) for k, v in self._definitions.items()
}
all_refs = set(itertools.chain(*self.definitions_refs.values()))
if not all_refs.issubset(self.definitions):
raise ValueError(
"Unresolved references (%s) in %s/definitions"
% (", ".join(all_refs.difference(self.definitions)), self.name)
)
actions = {
k: v.as_plain_ordered_dict()
for k, v in service_config.items()
if not k.startswith("_")
}
self._actions = {
action_name: action
for action_name, action in (
(action_name, self._parse_action_versions(action_name, action_versions))
for action_name, action_versions in actions.items()
)
if action
}
def _parse_action_versions(self, action_name, action_versions):
def parse_version(action_version):
try:
return float(action_version)
except (ValueError, TypeError) as ex:
raise ValueError(
"Failed parsing version number {} ({}) in {}/{}".format(
action_version, ex.args[0], self.name, action_name
)
)
def add_internal(cfg):
if "internal" in action_versions:
cfg.setdefault("internal", action_versions["internal"])
return cfg
return {
parsed_version: action
for parsed_version, action in (
(parsed_version, self._parse_action(action_name, parsed_version, add_internal(cfg)))
for parsed_version, cfg in (
(parse_version(version), cfg)
for version, cfg in action_versions.items()
if version not in ["internal", "allow_roles", "authorize"]
)
)
if action
}
def _get_schema_references(self, s):
refs = set()
if isinstance(s, dict):
for k, v in s.items():
if isinstance(v, six.string_types):
m = self.__jsonschema_ref_ex.match(v)
if m:
refs.add(m.group(1))
continue
elif k in ("oneOf", "anyOf") and isinstance(v, list):
refs.update(*map(self._get_schema_references, v))
refs.update(self._get_schema_references(v))
return refs
def _expand_schema_references_with_definitions(self, schema, refs=None):
definitions = schema.get("definitions", {})
refs = refs if refs is not None else self._get_schema_references(schema)
required_refs = set(refs).difference(definitions)
if not required_refs:
return required_refs
if not required_refs.issubset(self.definitions):
raise ValueError(
"Unresolved references (%s)"
% ", ".join(required_refs.difference(self.definitions))
)
# update required refs with all sub requirements
last_required_refs = None
while last_required_refs != required_refs:
last_required_refs = required_refs.copy()
additional_refs = set(
itertools.chain(
*(self.definitions_refs.get(ref, []) for ref in required_refs)
)
)
required_refs.update(additional_refs)
return required_refs
def _resolve_schema_references(self, schema, refs=None):
definitions = schema.get("definitions", {})
definitions.update({k: v for k, v in self.definitions.items() if k in refs})
schema["definitions"] = definitions
def _parse_action(self, action_name, action_version, action_config):
data = self.default.copy()
data.update(action_config)
if not action_config.get("generate", True):
return None
definitions_keys = set()
for schema_key in ("request", "response"):
if schema_key in action_config:
try:
schema = action_config[schema_key]
refs = self._expand_schema_references_with_definitions(schema)
self._resolve_schema_references(schema, refs=refs)
definitions_keys.update(refs)
except ValueError as ex:
name = "%s.%s/%.1f/%s" % (
self.name,
action_name,
action_version,
schema_key,
)
raise ValueError("%s in %s" % (str(ex), name))
return Action(
name=action_name,
version=action_version,
definitions_keys=list(definitions_keys),
service=self.name,
**(
{
key: value
for key, value in data.items()
if key in attr.fields_dict(Action)
}
)
)

View File

@@ -0,0 +1,22 @@
from .v2_1 import async_request
from .v2_1 import auth
from .v2_1 import debug
from .v2_1 import events
from .v2_1 import models
from .v2_1 import news
from .v2_1 import projects
from .v2_1 import storage
from .v2_1 import tasks
__all__ = [
'async_request',
'auth',
'debug',
'events',
'models',
'news',
'projects',
'storage',
'tasks',
]

View File

@@ -0,0 +1,414 @@
"""
async service
This service provides support for asynchronous API calls.
"""
import six
import types
from datetime import datetime
import enum
from dateutil.parser import parse as parse_datetime
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
class Call(NonStrictDataModel):
"""
:param id: The job ID associated with this call.
:type id: str
:param status: The job's status.
:type status: str
:param created: Job creation time.
:type created: str
:param ended: Job end time.
:type ended: str
:param enqueued: Job enqueue time.
:type enqueued: str
:param meta: Metadata for this job, includes endpoint and additional relevant
call data.
:type meta: dict
:param company: The Company this job belongs to.
:type company: str
:param exec_info: Job execution information.
:type exec_info: str
"""
_schema = {
'properties': {
'company': {
'description': 'The Company this job belongs to.',
'type': ['string', 'null'],
},
'created': {
'description': 'Job creation time.',
'type': ['string', 'null'],
},
'ended': {'description': 'Job end time.', 'type': ['string', 'null']},
'enqueued': {
'description': 'Job enqueue time.',
'type': ['string', 'null'],
},
'exec_info': {
'description': 'Job execution information.',
'type': ['string', 'null'],
},
'id': {
'description': 'The job ID associated with this call.',
'type': ['string', 'null'],
},
'meta': {
'additionalProperties': True,
'description': 'Metadata for this job, includes endpoint and additional relevant call data.',
'type': ['object', 'null'],
},
'status': {
'description': "The job's status.",
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, id=None, status=None, created=None, ended=None, enqueued=None, meta=None, company=None, exec_info=None, **kwargs):
super(Call, self).__init__(**kwargs)
self.id = id
self.status = status
self.created = created
self.ended = ended
self.enqueued = enqueued
self.meta = meta
self.company = company
self.exec_info = exec_info
@schema_property('id')
def id(self):
return self._property_id
@id.setter
def id(self, value):
if value is None:
self._property_id = None
return
self.assert_isinstance(value, "id", six.string_types)
self._property_id = value
@schema_property('status')
def status(self):
return self._property_status
@status.setter
def status(self, value):
if value is None:
self._property_status = None
return
self.assert_isinstance(value, "status", six.string_types)
self._property_status = value
@schema_property('created')
def created(self):
return self._property_created
@created.setter
def created(self, value):
if value is None:
self._property_created = None
return
self.assert_isinstance(value, "created", six.string_types)
self._property_created = value
@schema_property('ended')
def ended(self):
return self._property_ended
@ended.setter
def ended(self, value):
if value is None:
self._property_ended = None
return
self.assert_isinstance(value, "ended", six.string_types)
self._property_ended = value
@schema_property('enqueued')
def enqueued(self):
return self._property_enqueued
@enqueued.setter
def enqueued(self, value):
if value is None:
self._property_enqueued = None
return
self.assert_isinstance(value, "enqueued", six.string_types)
self._property_enqueued = value
@schema_property('meta')
def meta(self):
return self._property_meta
@meta.setter
def meta(self, value):
if value is None:
self._property_meta = None
return
self.assert_isinstance(value, "meta", (dict,))
self._property_meta = value
@schema_property('company')
def company(self):
return self._property_company
@company.setter
def company(self, value):
if value is None:
self._property_company = None
return
self.assert_isinstance(value, "company", six.string_types)
self._property_company = value
@schema_property('exec_info')
def exec_info(self):
return self._property_exec_info
@exec_info.setter
def exec_info(self, value):
if value is None:
self._property_exec_info = None
return
self.assert_isinstance(value, "exec_info", six.string_types)
self._property_exec_info = value
class CallsRequest(Request):
"""
Get a list of all asynchronous API calls handled by the system.
This includes both previously handled calls, calls being executed and calls waiting in queue.
:param status: Return only calls who's status is in this list.
:type status: Sequence[str]
:param endpoint: Return only calls handling this endpoint. Supports wildcards.
:type endpoint: str
:param task: Return only calls associated with this task ID. Supports
wildcards.
:type task: str
"""
_service = "async"
_action = "calls"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'endpoint': {
'description': 'Return only calls handling this endpoint. Supports wildcards.',
'type': ['string', 'null'],
},
'status': {
'description': "Return only calls who's status is in this list.",
'items': {'enum': ['queued', 'in_progress', 'completed'], 'type': 'string'},
'type': ['array', 'null'],
},
'task': {
'description': 'Return only calls associated with this task ID. Supports wildcards.',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, status=None, endpoint=None, task=None, **kwargs):
super(CallsRequest, self).__init__(**kwargs)
self.status = status
self.endpoint = endpoint
self.task = task
@schema_property('status')
def status(self):
return self._property_status
@status.setter
def status(self, value):
if value is None:
self._property_status = None
return
self.assert_isinstance(value, "status", (list, tuple))
self.assert_isinstance(value, "status", six.string_types, is_array=True)
self._property_status = value
@schema_property('endpoint')
def endpoint(self):
return self._property_endpoint
@endpoint.setter
def endpoint(self, value):
if value is None:
self._property_endpoint = None
return
self.assert_isinstance(value, "endpoint", six.string_types)
self._property_endpoint = value
@schema_property('task')
def task(self):
return self._property_task
@task.setter
def task(self, value):
if value is None:
self._property_task = None
return
self.assert_isinstance(value, "task", six.string_types)
self._property_task = value
class CallsResponse(Response):
"""
Response of async.calls endpoint.
:param calls: A list of the current asynchronous calls handled by the system.
:type calls: Sequence[Call]
"""
_service = "async"
_action = "calls"
_version = "1.5"
_schema = {
'definitions': {
'call': {
'properties': {
'company': {
'description': 'The Company this job belongs to.',
'type': ['string', 'null'],
},
'created': {
'description': 'Job creation time.',
'type': ['string', 'null'],
},
'ended': {
'description': 'Job end time.',
'type': ['string', 'null'],
},
'enqueued': {
'description': 'Job enqueue time.',
'type': ['string', 'null'],
},
'exec_info': {
'description': 'Job execution information.',
'type': ['string', 'null'],
},
'id': {
'description': 'The job ID associated with this call.',
'type': ['string', 'null'],
},
'meta': {
'additionalProperties': True,
'description': 'Metadata for this job, includes endpoint and additional relevant call data.',
'type': ['object', 'null'],
},
'status': {
'description': "The job's status.",
'type': ['string', 'null'],
},
},
'type': 'object',
},
},
'properties': {
'calls': {
'description': 'A list of the current asynchronous calls handled by the system.',
'items': {'$ref': '#/definitions/call'},
'type': ['array', 'null'],
},
},
'type': 'object',
}
def __init__(
self, calls=None, **kwargs):
super(CallsResponse, self).__init__(**kwargs)
self.calls = calls
@schema_property('calls')
def calls(self):
return self._property_calls
@calls.setter
def calls(self, value):
if value is None:
self._property_calls = None
return
self.assert_isinstance(value, "calls", (list, tuple))
if any(isinstance(v, dict) for v in value):
value = [Call.from_dict(v) if isinstance(v, dict) else v for v in value]
else:
self.assert_isinstance(value, "calls", Call, is_array=True)
self._property_calls = value
class ResultRequest(Request):
"""
Try getting the result of a previously accepted asynchronous API call.
If execution for the asynchronous call has completed, the complete call response data will be returned.
Otherwise, a 202 code will be returned with no data
:param id: The id returned by the accepted asynchronous API call.
:type id: str
"""
_service = "async"
_action = "result"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'id': {
'description': 'The id returned by the accepted asynchronous API call.',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, id=None, **kwargs):
super(ResultRequest, self).__init__(**kwargs)
self.id = id
@schema_property('id')
def id(self):
return self._property_id
@id.setter
def id(self, value):
if value is None:
self._property_id = None
return
self.assert_isinstance(value, "id", six.string_types)
self._property_id = value
class ResultResponse(Response):
"""
Response of async.result endpoint.
"""
_service = "async"
_action = "result"
_version = "1.5"
_schema = {'additionalProperties': True, 'definitions': {}, 'type': 'object'}
response_mapping = {
ResultRequest: ResultResponse,
CallsRequest: CallsResponse,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,194 @@
"""
debug service
Debugging utilities
"""
import six
import types
from datetime import datetime
import enum
from dateutil.parser import parse as parse_datetime
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
class ApiexRequest(Request):
"""
"""
_service = "debug"
_action = "apiex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
class ApiexResponse(Response):
"""
Response of debug.apiex endpoint.
"""
_service = "debug"
_action = "apiex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class EchoRequest(Request):
"""
Return request data
"""
_service = "debug"
_action = "echo"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class EchoResponse(Response):
"""
Response of debug.echo endpoint.
"""
_service = "debug"
_action = "echo"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class ExRequest(Request):
"""
"""
_service = "debug"
_action = "ex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'required': [], 'type': 'object'}
class ExResponse(Response):
"""
Response of debug.ex endpoint.
"""
_service = "debug"
_action = "ex"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class PingRequest(Request):
"""
Return a message. Does not require authorization.
"""
_service = "debug"
_action = "ping"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class PingResponse(Response):
"""
Response of debug.ping endpoint.
:param msg: A friendly message
:type msg: str
"""
_service = "debug"
_action = "ping"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'msg': {
'description': 'A friendly message',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, msg=None, **kwargs):
super(PingResponse, self).__init__(**kwargs)
self.msg = msg
@schema_property('msg')
def msg(self):
return self._property_msg
@msg.setter
def msg(self, value):
if value is None:
self._property_msg = None
return
self.assert_isinstance(value, "msg", six.string_types)
self._property_msg = value
class PingAuthRequest(Request):
"""
Return a message. Requires authorization.
"""
_service = "debug"
_action = "ping_auth"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class PingAuthResponse(Response):
"""
Response of debug.ping_auth endpoint.
:param msg: A friendly message
:type msg: str
"""
_service = "debug"
_action = "ping_auth"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'msg': {
'description': 'A friendly message',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, msg=None, **kwargs):
super(PingAuthResponse, self).__init__(**kwargs)
self.msg = msg
@schema_property('msg')
def msg(self):
return self._property_msg
@msg.setter
def msg(self, value):
if value is None:
self._property_msg = None
return
self.assert_isinstance(value, "msg", six.string_types)
self._property_msg = value
response_mapping = {
EchoRequest: EchoResponse,
PingRequest: PingResponse,
PingAuthRequest: PingAuthResponse,
ApiexRequest: ApiexResponse,
ExRequest: ExResponse,
}

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,70 @@
"""
news service
This service provides platform news.
"""
import six
import types
from datetime import datetime
import enum
from dateutil.parser import parse as parse_datetime
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
class GetRequest(Request):
"""
Gets latest news link
"""
_service = "news"
_action = "get"
_version = "1.5"
_schema = {'definitions': {}, 'properties': {}, 'type': 'object'}
class GetResponse(Response):
"""
Response of news.get endpoint.
:param url: URL to news html file
:type url: str
"""
_service = "news"
_action = "get"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'url': {
'description': 'URL to news html file',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, url=None, **kwargs):
super(GetResponse, self).__init__(**kwargs)
self.url = url
@schema_property('url')
def url(self):
return self._property_url
@url.setter
def url(self, value):
if value is None:
self._property_url = None
return
self.assert_isinstance(value, "url", six.string_types)
self._property_url = value
response_mapping = {
GetRequest: GetResponse,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,681 @@
"""
storage service
Provides a management API for customer-associated storage locations
"""
import six
import types
from datetime import datetime
import enum
from dateutil.parser import parse as parse_datetime
from ....backend_api.session import Request, BatchRequest, Response, DataModel, NonStrictDataModel, CompoundRequest, schema_property, StringEnum
class Credentials(NonStrictDataModel):
"""
:param access_key: Credentials access key
:type access_key: str
:param secret_key: Credentials secret key
:type secret_key: str
"""
_schema = {
'properties': {
'access_key': {
'description': 'Credentials access key',
'type': ['string', 'null'],
},
'secret_key': {
'description': 'Credentials secret key',
'type': ['string', 'null'],
},
},
'type': 'object',
}
def __init__(
self, access_key=None, secret_key=None, **kwargs):
super(Credentials, self).__init__(**kwargs)
self.access_key = access_key
self.secret_key = secret_key
@schema_property('access_key')
def access_key(self):
return self._property_access_key
@access_key.setter
def access_key(self, value):
if value is None:
self._property_access_key = None
return
self.assert_isinstance(value, "access_key", six.string_types)
self._property_access_key = value
@schema_property('secret_key')
def secret_key(self):
return self._property_secret_key
@secret_key.setter
def secret_key(self, value):
if value is None:
self._property_secret_key = None
return
self.assert_isinstance(value, "secret_key", six.string_types)
self._property_secret_key = value
class Storage(NonStrictDataModel):
"""
:param id: Entry ID
:type id: str
:param name: Entry name
:type name: str
:param company: Company ID
:type company: str
:param created: Entry creation time
:type created: datetime.datetime
:param uri: Storage URI
:type uri: str
:param credentials: Credentials required for accessing the storage
:type credentials: Credentials
"""
_schema = {
'properties': {
'company': {'description': 'Company ID', 'type': ['string', 'null']},
'created': {
'description': 'Entry creation time',
'format': 'date-time',
'type': ['string', 'null'],
},
'credentials': {
'description': 'Credentials required for accessing the storage',
'oneOf': [{'$ref': '#/definitions/credentials'}, {'type': 'null'}],
},
'id': {'description': 'Entry ID', 'type': ['string', 'null']},
'name': {'description': 'Entry name', 'type': ['string', 'null']},
'uri': {'description': 'Storage URI', 'type': ['string', 'null']},
},
'type': 'object',
}
def __init__(
self, id=None, name=None, company=None, created=None, uri=None, credentials=None, **kwargs):
super(Storage, self).__init__(**kwargs)
self.id = id
self.name = name
self.company = company
self.created = created
self.uri = uri
self.credentials = credentials
@schema_property('id')
def id(self):
return self._property_id
@id.setter
def id(self, value):
if value is None:
self._property_id = None
return
self.assert_isinstance(value, "id", six.string_types)
self._property_id = value
@schema_property('name')
def name(self):
return self._property_name
@name.setter
def name(self, value):
if value is None:
self._property_name = None
return
self.assert_isinstance(value, "name", six.string_types)
self._property_name = value
@schema_property('company')
def company(self):
return self._property_company
@company.setter
def company(self, value):
if value is None:
self._property_company = None
return
self.assert_isinstance(value, "company", six.string_types)
self._property_company = value
@schema_property('created')
def created(self):
return self._property_created
@created.setter
def created(self, value):
if value is None:
self._property_created = None
return
self.assert_isinstance(value, "created", six.string_types + (datetime,))
if not isinstance(value, datetime):
value = parse_datetime(value)
self._property_created = value
@schema_property('uri')
def uri(self):
return self._property_uri
@uri.setter
def uri(self, value):
if value is None:
self._property_uri = None
return
self.assert_isinstance(value, "uri", six.string_types)
self._property_uri = value
@schema_property('credentials')
def credentials(self):
return self._property_credentials
@credentials.setter
def credentials(self, value):
if value is None:
self._property_credentials = None
return
if isinstance(value, dict):
value = Credentials.from_dict(value)
else:
self.assert_isinstance(value, "credentials", Credentials)
self._property_credentials = value
class CreateRequest(Request):
"""
Create a new storage entry
:param name: Storage name
:type name: str
:param uri: Storage URI
:type uri: str
:param credentials: Credentials required for accessing the storage
:type credentials: Credentials
:param company: Company under which to add this storage. Only valid for users
with the root or system role, otherwise the calling user's company will be
used.
:type company: str
"""
_service = "storage"
_action = "create"
_version = "1.5"
_schema = {
'definitions': {
'credentials': {
'properties': {
'access_key': {
'description': 'Credentials access key',
'type': ['string', 'null'],
},
'secret_key': {
'description': 'Credentials secret key',
'type': ['string', 'null'],
},
},
'type': 'object',
},
},
'properties': {
'company': {
'description': "Company under which to add this storage. Only valid for users with the root or system role, otherwise the calling user's company will be used.",
'type': 'string',
},
'credentials': {
'$ref': '#/definitions/credentials',
'description': 'Credentials required for accessing the storage',
},
'name': {'description': 'Storage name', 'type': ['string', 'null']},
'uri': {'description': 'Storage URI', 'type': 'string'},
},
'required': ['uri'],
'type': 'object',
}
def __init__(
self, uri, name=None, credentials=None, company=None, **kwargs):
super(CreateRequest, self).__init__(**kwargs)
self.name = name
self.uri = uri
self.credentials = credentials
self.company = company
@schema_property('name')
def name(self):
return self._property_name
@name.setter
def name(self, value):
if value is None:
self._property_name = None
return
self.assert_isinstance(value, "name", six.string_types)
self._property_name = value
@schema_property('uri')
def uri(self):
return self._property_uri
@uri.setter
def uri(self, value):
if value is None:
self._property_uri = None
return
self.assert_isinstance(value, "uri", six.string_types)
self._property_uri = value
@schema_property('credentials')
def credentials(self):
return self._property_credentials
@credentials.setter
def credentials(self, value):
if value is None:
self._property_credentials = None
return
if isinstance(value, dict):
value = Credentials.from_dict(value)
else:
self.assert_isinstance(value, "credentials", Credentials)
self._property_credentials = value
@schema_property('company')
def company(self):
return self._property_company
@company.setter
def company(self, value):
if value is None:
self._property_company = None
return
self.assert_isinstance(value, "company", six.string_types)
self._property_company = value
class CreateResponse(Response):
"""
Response of storage.create endpoint.
:param id: New storage ID
:type id: str
"""
_service = "storage"
_action = "create"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'id': {'description': 'New storage ID', 'type': ['string', 'null']},
},
'type': 'object',
}
def __init__(
self, id=None, **kwargs):
super(CreateResponse, self).__init__(**kwargs)
self.id = id
@schema_property('id')
def id(self):
return self._property_id
@id.setter
def id(self, value):
if value is None:
self._property_id = None
return
self.assert_isinstance(value, "id", six.string_types)
self._property_id = value
class DeleteRequest(Request):
"""
Deletes a storage entry
:param storage: Storage entry ID
:type storage: str
"""
_service = "storage"
_action = "delete"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'storage': {'description': 'Storage entry ID', 'type': 'string'},
},
'required': ['storage'],
'type': 'object',
}
def __init__(
self, storage, **kwargs):
super(DeleteRequest, self).__init__(**kwargs)
self.storage = storage
@schema_property('storage')
def storage(self):
return self._property_storage
@storage.setter
def storage(self, value):
if value is None:
self._property_storage = None
return
self.assert_isinstance(value, "storage", six.string_types)
self._property_storage = value
class DeleteResponse(Response):
"""
Response of storage.delete endpoint.
:param deleted: Number of storage entries deleted (0 or 1)
:type deleted: int
"""
_service = "storage"
_action = "delete"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'deleted': {
'description': 'Number of storage entries deleted (0 or 1)',
'type': ['integer', 'null'],
},
},
'type': 'object',
}
def __init__(
self, deleted=None, **kwargs):
super(DeleteResponse, self).__init__(**kwargs)
self.deleted = deleted
@schema_property('deleted')
def deleted(self):
return self._property_deleted
@deleted.setter
def deleted(self, value):
if value is None:
self._property_deleted = None
return
if isinstance(value, float) and value.is_integer():
value = int(value)
self.assert_isinstance(value, "deleted", six.integer_types)
self._property_deleted = value
class GetAllRequest(Request):
"""
Get all storage entries
:param name: Get only storage entries whose name matches this pattern (python
regular expression syntax)
:type name: str
:param id: List of Storage IDs used to filter results
:type id: Sequence[str]
:param page: Page number, returns a specific page out of the result list of
results.
:type page: int
:param page_size: Page size, specifies the number of results returned in each
page (last page may contain fewer results)
:type page_size: int
:param order_by: List of field names to order by. When search_text is used,
'@text_score' can be used as a field representing the text score of returned
documents. Use '-' prefix to specify descending order. Optional, recommended
when using page
:type order_by: Sequence[str]
:param only_fields: List of document field names (nesting is supported using
'.', e.g. execution.model_labels). If provided, this list defines the query's
projection (only these fields will be returned for each result entry)
:type only_fields: Sequence[str]
"""
_service = "storage"
_action = "get_all"
_version = "1.5"
_schema = {
'definitions': {},
'properties': {
'id': {
'description': 'List of Storage IDs used to filter results',
'items': {'type': 'string'},
'type': ['array', 'null'],
},
'name': {
'description': 'Get only storage entries whose name matches this pattern (python regular expression syntax)',
'type': ['string', 'null'],
},
'only_fields': {
'description': "List of document field names (nesting is supported using '.', e.g. execution.model_labels). If provided, this list defines the query's projection (only these fields will be returned for each result entry)",
'items': {'type': 'string'},
'type': ['array', 'null'],
},
'order_by': {
'description': "List of field names to order by. When search_text is used, '@text_score' can be used as a field representing the text score of returned documents. Use '-' prefix to specify descending order. Optional, recommended when using page",
'items': {'type': 'string'},
'type': ['array', 'null'],
},
'page': {
'description': 'Page number, returns a specific page out of the result list of results.',
'minimum': 0,
'type': ['integer', 'null'],
},
'page_size': {
'description': 'Page size, specifies the number of results returned in each page (last page may contain fewer results)',
'minimum': 1,
'type': ['integer', 'null'],
},
},
'type': 'object',
}
def __init__(
self, name=None, id=None, page=None, page_size=None, order_by=None, only_fields=None, **kwargs):
super(GetAllRequest, self).__init__(**kwargs)
self.name = name
self.id = id
self.page = page
self.page_size = page_size
self.order_by = order_by
self.only_fields = only_fields
@schema_property('name')
def name(self):
return self._property_name
@name.setter
def name(self, value):
if value is None:
self._property_name = None
return
self.assert_isinstance(value, "name", six.string_types)
self._property_name = value
@schema_property('id')
def id(self):
return self._property_id
@id.setter
def id(self, value):
if value is None:
self._property_id = None
return
self.assert_isinstance(value, "id", (list, tuple))
self.assert_isinstance(value, "id", six.string_types, is_array=True)
self._property_id = value
@schema_property('page')
def page(self):
return self._property_page
@page.setter
def page(self, value):
if value is None:
self._property_page = None
return
if isinstance(value, float) and value.is_integer():
value = int(value)
self.assert_isinstance(value, "page", six.integer_types)
self._property_page = value
@schema_property('page_size')
def page_size(self):
return self._property_page_size
@page_size.setter
def page_size(self, value):
if value is None:
self._property_page_size = None
return
if isinstance(value, float) and value.is_integer():
value = int(value)
self.assert_isinstance(value, "page_size", six.integer_types)
self._property_page_size = value
@schema_property('order_by')
def order_by(self):
return self._property_order_by
@order_by.setter
def order_by(self, value):
if value is None:
self._property_order_by = None
return
self.assert_isinstance(value, "order_by", (list, tuple))
self.assert_isinstance(value, "order_by", six.string_types, is_array=True)
self._property_order_by = value
@schema_property('only_fields')
def only_fields(self):
return self._property_only_fields
@only_fields.setter
def only_fields(self, value):
if value is None:
self._property_only_fields = None
return
self.assert_isinstance(value, "only_fields", (list, tuple))
self.assert_isinstance(value, "only_fields", six.string_types, is_array=True)
self._property_only_fields = value
class GetAllResponse(Response):
"""
Response of storage.get_all endpoint.
:param results: Storage entries list
:type results: Sequence[Storage]
"""
_service = "storage"
_action = "get_all"
_version = "1.5"
_schema = {
'definitions': {
'credentials': {
'properties': {
'access_key': {
'description': 'Credentials access key',
'type': ['string', 'null'],
},
'secret_key': {
'description': 'Credentials secret key',
'type': ['string', 'null'],
},
},
'type': 'object',
},
'storage': {
'properties': {
'company': {
'description': 'Company ID',
'type': ['string', 'null'],
},
'created': {
'description': 'Entry creation time',
'format': 'date-time',
'type': ['string', 'null'],
},
'credentials': {
'description': 'Credentials required for accessing the storage',
'oneOf': [
{'$ref': '#/definitions/credentials'},
{'type': 'null'},
],
},
'id': {'description': 'Entry ID', 'type': ['string', 'null']},
'name': {
'description': 'Entry name',
'type': ['string', 'null'],
},
'uri': {
'description': 'Storage URI',
'type': ['string', 'null'],
},
},
'type': 'object',
},
},
'properties': {
'results': {
'description': 'Storage entries list',
'items': {'$ref': '#/definitions/storage'},
'type': ['array', 'null'],
},
},
'type': 'object',
}
def __init__(
self, results=None, **kwargs):
super(GetAllResponse, self).__init__(**kwargs)
self.results = results
@schema_property('results')
def results(self):
return self._property_results
@results.setter
def results(self, value):
if value is None:
self._property_results = None
return
self.assert_isinstance(value, "results", (list, tuple))
if any(isinstance(v, dict) for v in value):
value = [Storage.from_dict(v) if isinstance(v, dict) else v for v in value]
else:
self.assert_isinstance(value, "results", Storage, is_array=True)
self._property_results = value
response_mapping = {
GetAllRequest: GetAllResponse,
CreateRequest: CreateResponse,
DeleteRequest: DeleteResponse,
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,7 @@
from .session import Session
from .datamodel import DataModel, NonStrictDataModel, schema_property, StringEnum
from .request import Request, BatchRequest, CompoundRequest
from .response import Response
from .token_manager import TokenManager
from .errors import TimeoutExpiredError, ResultNotReadyError
from .callresult import CallResult

View File

@@ -0,0 +1,8 @@
from .datamodel import DataModel
class ApiModel(DataModel):
""" API-related data model """
_service = None
_action = None
_version = None

View File

@@ -0,0 +1,131 @@
import sys
import time
from ...backend_api.utils import get_response_cls
from .response import ResponseMeta, Response
from .errors import ResultNotReadyError, TimeoutExpiredError
class CallResult(object):
@property
def meta(self):
return self.__meta
@property
def response(self):
return self.__response
@property
def response_data(self):
return self.__response_data
@property
def async_accepted(self):
return self.meta.result_code == 202
@property
def request_cls(self):
return self.__request_cls
def __init__(self, meta, response=None, response_data=None, request_cls=None, session=None):
assert isinstance(meta, ResponseMeta)
if response and not isinstance(response, Response):
raise ValueError('response should be an instance of %s' % Response.__name__)
elif response_data and not isinstance(response_data, dict):
raise TypeError('data should be an instance of {}'.format(dict.__name__))
self.__meta = meta
self.__response = response
self.__request_cls = request_cls
self.__session = session
self.__async_result = None
if response_data is not None:
self.__response_data = response_data
elif response is not None:
try:
self.__response_data = response.to_dict()
except AttributeError:
raise TypeError('response should be an instance of {}'.format(Response.__name__))
else:
self.__response_data = None
@classmethod
def from_result(cls, res, request_cls=None, logger=None, service=None, action=None, session=None):
""" From requests result """
response_cls = get_response_cls(request_cls)
try:
data = res.json()
except ValueError:
service = service or (request_cls._service if request_cls else 'unknown')
action = action or (request_cls._action if request_cls else 'unknown')
return cls(request_cls=request_cls, meta=ResponseMeta.from_raw_data(
status_code=res.status_code, text=res.text, endpoint='%(service)s.%(action)s' % locals()))
if 'meta' not in data:
raise ValueError('Missing meta section in response payload')
try:
meta = ResponseMeta(**data['meta'])
# TODO: validate meta?
# meta.validate()
except Exception as ex:
raise ValueError('Failed parsing meta section in response payload (data=%s, error=%s)' % (data, ex))
response = None
response_data = None
try:
response_data = data.get('data', {})
if response_cls:
response = response_cls(**response_data)
# TODO: validate response?
# response.validate()
except Exception as e:
if logger:
logger.warn('Failed parsing response: %s' % str(e))
return cls(meta=meta, response=response, response_data=response_data, request_cls=request_cls, session=session)
def ok(self):
return self.meta.result_code == 200
def ready(self):
if not self.async_accepted:
return True
session = self.__session
res = session.send_request(service='async', action='result', json=dict(id=self.meta.id), async_enable=False)
if res.status_code != session._async_status_code:
self.__async_result = CallResult.from_result(res=res, request_cls=self.request_cls, logger=session._logger)
return True
def result(self):
if not self.async_accepted:
return self
if self.__async_result is None:
raise ResultNotReadyError(self._format_msg('Timeout expired'), call_id=self.meta.id)
return self.__async_result
def wait(self, timeout=None, poll_interval=5, verbose=False):
if not self.async_accepted:
return self
session = self.__session
poll_interval = max(1, poll_interval)
remaining = max(0, timeout) if timeout else sys.maxsize
while remaining > 0:
if not self.ready():
# Still pending, log and continue
if verbose and session._logger:
progress = ('waiting forever'
if timeout is False
else '%.1f/%.1f seconds remaining' % (remaining, float(timeout or 0)))
session._logger.info('Waiting for asynchronous call %s (%s)'
% (self.request_cls.__name__, progress))
time.sleep(poll_interval)
remaining -= poll_interval
continue
# We've got something (good or bad, we don't know), create a call result and return
return self.result()
# Timeout expired, return the asynchronous call's result (we've got nothing better to report)
raise TimeoutExpiredError(self._format_msg('Timeout expired'), call_id=self.meta.id)
def _format_msg(self, msg):
return msg + ' for call %s (%s)' % (self.request_cls.__name__, self.meta.id)

View File

@@ -0,0 +1,145 @@
import keyword
import enum
import json
import warnings
from datetime import datetime
import jsonschema
from enum import Enum
import six
def format_date(obj):
if isinstance(obj, datetime):
return str(obj)
class SchemaProperty(property):
def __init__(self, name=None, *args, **kwargs):
super(SchemaProperty, self).__init__(*args, **kwargs)
self.name = name
def setter(self, fset):
return type(self)(self.name, self.fget, fset, self.fdel, self.__doc__)
def schema_property(name):
def init(*args, **kwargs):
return SchemaProperty(name, *args, **kwargs)
return init
class DataModel(object):
""" Data Model"""
_schema = None
_data_props_list = None
@classmethod
def _get_data_props(cls):
props = cls._data_props_list
if props is None:
props = {}
for c in cls.__mro__:
props.update({k: getattr(v, 'name', k) for k, v in vars(c).items()
if isinstance(v, property)})
cls._data_props_list = props
return props.copy()
@classmethod
def _to_base_type(cls, value):
if isinstance(value, DataModel):
return value.to_dict()
elif isinstance(value, enum.Enum):
return value.value
elif isinstance(value, list):
return [cls._to_base_type(model) for model in value]
return value
def to_dict(self, only=None, except_=None):
prop_values = {v: getattr(self, k) for k, v in self._get_data_props().items()}
return {
k: self._to_base_type(v)
for k, v in prop_values.items()
if v is not None and (not only or k in only) and (not except_ or k not in except_)
}
def validate(self, schema=None):
jsonschema.validate(
self.to_dict(),
schema or self._schema,
types=dict(array=(list, tuple), integer=six.integer_types),
)
def __repr__(self):
return '<{}.{}: {}>'.format(
self.__module__.split('.')[-1],
type(self).__name__,
json.dumps(
self.to_dict(),
indent=4,
default=format_date,
)
)
@staticmethod
def assert_isinstance(value, field_name, expected, is_array=False):
if not is_array:
if not isinstance(value, expected):
raise TypeError("Expected %s of type %s, got %s" % (field_name, expected, type(value).__name__))
return
if not all(isinstance(x, expected) for x in value):
raise TypeError(
"Expected %s of type list[%s], got %s" % (
field_name,
expected,
", ".join(set(type(x).__name__ for x in value)),
)
)
@staticmethod
def normalize_key(prop_key):
if keyword.iskeyword(prop_key):
prop_key += '_'
return prop_key.replace('.', '__')
@classmethod
def from_dict(cls, dct, strict=False):
"""
Create an instance from a dictionary while ignoring unnecessary keys
"""
allowed_keys = cls._get_data_props().values()
invalid_keys = set(dct).difference(allowed_keys)
if strict and invalid_keys:
raise ValueError("Invalid keys %s" % tuple(invalid_keys))
return cls(**{cls.normalize_key(key): value for key, value in dct.items() if key not in invalid_keys})
class UnusedKwargsWarning(UserWarning):
pass
class NonStrictDataModelMixin(object):
"""
NonStrictDataModelMixin
:summary: supplies an __init__ method that warns about unused keywords
"""
def __init__(self, **kwargs):
unexpected = [key for key in kwargs if not key.startswith('_')]
if unexpected:
message = '{}: unused keyword argument(s) {}' \
.format(type(self).__name__, unexpected)
warnings.warn(message, UnusedKwargsWarning)
class NonStrictDataModel(DataModel, NonStrictDataModelMixin):
pass
class StringEnum(Enum):
def __str__(self):
return self.value

View File

@@ -0,0 +1,7 @@
from ...backend_config import EnvEntry
ENV_HOST = EnvEntry("TRAINS_API_HOST", "ALG_API_HOST")
ENV_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY")
ENV_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY")
ENV_VERBOSE = EnvEntry("TRAINS_API_VERBOSE", "ALG_API_VERBOSE", type=bool, default=False)

View File

@@ -0,0 +1,17 @@
class SessionError(Exception):
pass
class AsyncError(SessionError):
def __init__(self, msg, *args, **kwargs):
super(AsyncError, self).__init__(msg, *args)
for k, v in kwargs.items():
setattr(self, k, v)
class TimeoutExpiredError(SessionError):
pass
class ResultNotReadyError(SessionError):
pass

View File

@@ -0,0 +1,76 @@
import abc
import jsonschema
import six
from .apimodel import ApiModel
from .datamodel import DataModel
class Request(ApiModel):
_method = 'get'
def __init__(self, **kwargs):
if kwargs:
raise ValueError('Unsupported keyword arguments: %s' % ', '.join(kwargs.keys()))
@six.add_metaclass(abc.ABCMeta)
class BatchRequest(Request):
_batched_request_cls = abc.abstractproperty()
_schema_errors = (jsonschema.SchemaError, jsonschema.ValidationError, jsonschema.FormatError,
jsonschema.RefResolutionError)
def __init__(self, requests, validate_requests=False, allow_raw_requests=True, **kwargs):
super(BatchRequest, self).__init__(**kwargs)
self._validate_requests = validate_requests
self._allow_raw_requests = allow_raw_requests
self._property_requests = None
self.requests = requests
@property
def requests(self):
return self._property_requests
@requests.setter
def requests(self, value):
assert issubclass(self._batched_request_cls, Request)
assert isinstance(value, (list, tuple))
if not self._allow_raw_requests:
if any(isinstance(x, dict) for x in value):
value = [self._batched_request_cls(**x) if isinstance(x, dict) else x for x in value]
assert all(isinstance(x, self._batched_request_cls) for x in value)
self._property_requests = value
def validate(self):
if not self._validate_requests or self._allow_raw_requests:
return
for i, req in enumerate(self.requests):
try:
req.validate()
except (jsonschema.SchemaError, jsonschema.ValidationError,
jsonschema.FormatError, jsonschema.RefResolutionError) as e:
raise Exception('Validation error in batch item #%d: %s' % (i, str(e)))
def get_json(self):
return [r if isinstance(r, dict) else r.to_dict() for r in self.requests]
class CompoundRequest(Request):
_item_prop_name = 'item'
def _get_item(self):
item = getattr(self, self._item_prop_name, None)
if item is None:
raise ValueError('Item property is empty or missing')
assert isinstance(item, DataModel)
return item
def to_dict(self):
return self._get_item().to_dict()
def validate(self):
return self._get_item().validate(self._schema)

View File

@@ -0,0 +1,49 @@
import requests
import jsonmodels.models
import jsonmodels.fields
import jsonmodels.errors
from .apimodel import ApiModel
from .datamodel import NonStrictDataModelMixin
class Response(ApiModel, NonStrictDataModelMixin):
pass
class _ResponseEndpoint(jsonmodels.models.Base):
name = jsonmodels.fields.StringField()
requested_version = jsonmodels.fields.FloatField()
actual_version = jsonmodels.fields.FloatField()
class ResponseMeta(jsonmodels.models.Base):
@property
def is_valid(self):
return self._is_valid
@classmethod
def from_raw_data(cls, status_code, text, endpoint=None):
return cls(is_valid=False, result_code=status_code, result_subcode=0, result_msg=text,
endpoint=_ResponseEndpoint(name=(endpoint or 'unknown')))
def __init__(self, is_valid=True, **kwargs):
super(ResponseMeta, self).__init__(**kwargs)
self._is_valid = is_valid
id = jsonmodels.fields.StringField(required=True)
trx = jsonmodels.fields.StringField(required=True)
endpoint = jsonmodels.fields.EmbeddedField([_ResponseEndpoint], required=True)
result_code = jsonmodels.fields.IntField(required=True)
result_subcode = jsonmodels.fields.IntField()
result_msg = jsonmodels.fields.StringField(required=True)
error_stack = jsonmodels.fields.StringField()
def __str__(self):
if self.result_code == requests.codes.ok:
return "<%d: %s/v%.1f>" % (self.result_code, self.endpoint.name, self.endpoint.actual_version)
elif self._is_valid:
return "<%d/%d: %s/v%.1f (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name,
self.endpoint.actual_version, self.result_msg)
return "<%d/%d: %s (%s)>" % (self.result_code, self.result_subcode, self.endpoint.name, self.result_msg)

View File

@@ -0,0 +1,425 @@
import json as json_lib
import sys
import types
from socket import gethostname
import requests
import six
from pyhocon import ConfigTree
from requests.auth import HTTPBasicAuth
from .callresult import CallResult
from .defs import ENV_VERBOSE, ENV_HOST, ENV_ACCESS_KEY, ENV_SECRET_KEY
from .request import Request, BatchRequest
from .token_manager import TokenManager
from ..config import load
from ..utils import get_http_session_with_retry
from ..version import __version__
class LoginError(Exception):
pass
class Session(TokenManager):
""" TRAINS API Session class. """
_AUTHORIZATION_HEADER = "Authorization"
_WORKER_HEADER = "X-Trains-Worker"
_ASYNC_HEADER = "X-Trains-Async"
_CLIENT_HEADER = "X-Trains-Client"
_async_status_code = 202
_session_requests = 0
_session_initial_timeout = (1.0, 10)
_session_timeout = (5.0, None)
# TODO: add requests.codes.gateway_timeout once we support async commits
_retry_codes = [
requests.codes.bad_gateway,
requests.codes.service_unavailable,
requests.codes.bandwidth_limit_exceeded,
requests.codes.too_many_requests,
]
@property
def access_key(self):
return self.__access_key
@property
def secret_key(self):
return self.__secret_key
@property
def host(self):
return self.__host
@property
def worker(self):
return self.__worker
def __init__(
self,
worker=None,
api_key=None,
secret_key=None,
host=None,
logger=None,
verbose=None,
initialize_logging=True,
client=None,
config=None,
**kwargs
):
if config is not None:
self.config = config
else:
self.config = load()
if initialize_logging:
self.config.initialize_logging()
token_expiration_threshold_sec = self.config.get(
"auth.token_expiration_threshold_sec", 60
)
super(Session, self).__init__(
token_expiration_threshold_sec=token_expiration_threshold_sec, **kwargs
)
self._verbose = verbose if verbose is not None else ENV_VERBOSE.get()
self._logger = logger
self.__access_key = api_key or ENV_ACCESS_KEY.get(
default=(self.config.get("api.credentials.access_key", None) or
"EGRTCO8JMSIGI6S39GTP43NFWXDQOW")
)
if not self.access_key:
raise ValueError(
"Missing access_key. Please set in configuration file or pass in session init."
)
self.__secret_key = secret_key or ENV_SECRET_KEY.get(
default=(self.config.get("api.credentials.secret_key", None) or
"x!XTov_G-#vspE*Y(h$Anm&DIc5Ou-F)jsl$PdOyj5wG1&E!Z8")
)
if not self.secret_key:
raise ValueError(
"Missing secret_key. Please set in configuration file or pass in session init."
)
host = host or ENV_HOST.get(default=self.config.get("api.host"))
if not host:
raise ValueError("host is required in init or config")
self.__host = host.strip("/")
http_retries_config = self.config.get(
"api.http.retries", ConfigTree()
).as_plain_ordered_dict()
http_retries_config["status_forcelist"] = self._retry_codes
self.__http_session = get_http_session_with_retry(**http_retries_config)
self.__worker = worker or gethostname()
self.__max_req_size = self.config.get("api.http.max_req_size")
if not self.__max_req_size:
raise ValueError("missing max request size")
self.client = client or "api-{}".format(__version__)
self.refresh_token()
def _send_request(
self,
service,
action,
version=None,
method="get",
headers=None,
auth=None,
data=None,
json=None,
refresh_token_if_unauthorized=True,
):
""" Internal implementation for making a raw API request.
- Constructs the api endpoint name
- Injects the worker id into the headers
- Allows custom authorization using a requests auth object
- Intercepts `Unauthorized` responses and automatically attempts to refresh the session token once in this
case (only once). This is done since permissions are embedded in the token, and addresses a case where
server-side permissions have changed but are not reflected in the current token. Refreshing the token will
generate a token with the updated permissions.
"""
host = self.host
headers = headers.copy() if headers else {}
headers[self._WORKER_HEADER] = self.worker
headers[self._CLIENT_HEADER] = self.client
token_refreshed_on_error = False
url = (
"{host}/v{version}/{service}.{action}"
if version
else "{host}/{service}.{action}"
).format(**locals())
while True:
res = self.__http_session.request(
method, url, headers=headers, auth=auth, data=data, json=json,
timeout=self._session_initial_timeout if self._session_requests < 1 else self._session_timeout,
)
if (
refresh_token_if_unauthorized
and res.status_code == requests.codes.unauthorized
and not token_refreshed_on_error
):
# it seems we're unauthorized, so we'll try to refresh our token once in case permissions changed since
# the last time we got the token, and try again
self.refresh_token()
token_refreshed_on_error = True
# try again
continue
if (
res.status_code == requests.codes.service_unavailable
and self.config.get("api.http.wait_on_maintenance_forever", True)
):
self._logger.warn(
"Service unavailable: {} is undergoing maintenance, retrying...".format(
host
)
)
continue
break
self._session_requests += 1
return res
def send_request(
self,
service,
action,
version=None,
method="get",
headers=None,
data=None,
json=None,
async_enable=False,
):
"""
Send a raw API request.
:param service: service name
:param action: action name
:param version: version number (default is the preconfigured api version)
:param method: method type (default is 'get')
:param headers: request headers (authorization and content type headers will be automatically added)
:param json: json to send in the request body (jsonable object or builtin types construct. if used,
content type will be application/json)
:param data: Dictionary, bytes, or file-like object to send in the request body
:param async_enable: whether request is asynchronous
:return: requests Response instance
"""
headers = headers.copy() if headers else {}
headers[self._AUTHORIZATION_HEADER] = "Bearer {}".format(self.token)
if async_enable:
headers[self._ASYNC_HEADER] = "1"
return self._send_request(
service=service,
action=action,
version=version,
method=method,
headers=headers,
data=data,
json=json,
)
def send_request_batch(
self,
service,
action,
version=None,
headers=None,
data=None,
json=None,
method="get",
):
"""
Send a raw batch API request. Batch requests always use application/json-lines content type.
:param service: service name
:param action: action name
:param version: version number (default is the preconfigured api version)
:param headers: request headers (authorization and content type headers will be automatically added)
:param json: iterable of json items (batched items, jsonable objects or builtin types constructs). These will
be sent as a multi-line payload in the request body.
:param data: iterable of bytes objects (batched items). These will be sent as a multi-line payload in the
request body.
:param method: HTTP method
:return: requests Response instance
"""
if not all(
isinstance(x, (list, tuple, type(None), types.GeneratorType))
for x in (data, json)
):
raise ValueError("Expecting list, tuple or generator in 'data' or 'json'")
if not data and not json:
raise ValueError(
"Missing data (data or json), batch requests are meaningless without it."
)
headers = headers.copy() if headers else {}
headers["Content-Type"] = "application/json-lines"
if data:
req_data = "\n".join(data)
else:
req_data = "\n".join(json_lib.dumps(x) for x in json)
cur = 0
results = []
while True:
size = self.__max_req_size
slice = req_data[cur : cur + size]
if not slice:
break
if len(slice) < size:
# this is the remainder, no need to search for newline
pass
elif slice[-1] != "\n":
# search for the last newline in order to send a coherent request
size = slice.rfind("\n") + 1
# readjust the slice
slice = req_data[cur : cur + size]
res = self.send_request(
method=method,
service=service,
action=action,
data=slice,
headers=headers,
version=version,
)
results.append(res)
if res.status_code != requests.codes.ok:
break
cur += size
return results
def validate_request(self, req_obj):
""" Validate an API request against the current version and the request's schema """
try:
# make sure we're using a compatible version for this request
# validate the request (checks required fields and specific field version restrictions)
validate = req_obj.validate
except AttributeError:
raise TypeError(
'"req_obj" parameter must be an backend_api.session.Request object'
)
validate()
def send_async(self, req_obj):
"""
Asynchronously sends an API request using a request object.
:param req_obj: The request object
:type req_obj: Request
:return: CallResult object containing the raw response, response metadata and parsed response object.
"""
return self.send(req_obj=req_obj, async_enable=True)
def send(self, req_obj, async_enable=False, headers=None):
"""
Sends an API request using a request object.
:param req_obj: The request object
:type req_obj: Request
:param async_enable: Request this method be executed in an asynchronous manner
:param headers: Additional headers to send with request
:return: CallResult object containing the raw response, response metadata and parsed response object.
"""
self.validate_request(req_obj)
if isinstance(req_obj, BatchRequest):
# TODO: support async for batch requests as well
if async_enable:
raise NotImplementedError(
"Async behavior is currently not implemented for batch requests"
)
json_data = req_obj.get_json()
res = self.send_request_batch(
service=req_obj._service,
action=req_obj._action,
version=req_obj._version,
json=json_data,
method=req_obj._method,
headers=headers,
)
# TODO: handle multiple results in this case
try:
res = next(r for r in res if r.status_code != 200)
except StopIteration:
# all are 200
res = res[0]
else:
res = self.send_request(
service=req_obj._service,
action=req_obj._action,
version=req_obj._version,
json=req_obj.to_dict(),
method=req_obj._method,
async_enable=async_enable,
headers=headers,
)
call_result = CallResult.from_result(
res=res,
request_cls=req_obj.__class__,
logger=self._logger,
service=req_obj._service,
action=req_obj._action,
session=self,
)
return call_result
def _do_refresh_token(self, old_token, exp=None):
""" TokenManager abstract method implementation.
Here we ignore the old token and simply obtain a new token.
"""
verbose = self._verbose and self._logger
if verbose:
self._logger.info(
"Refreshing token from {} (access_key={}, exp={})".format(
self.host, self.access_key, exp
)
)
auth = HTTPBasicAuth(self.access_key, self.secret_key)
try:
data = {"expiration_sec": exp} if exp else {}
res = self._send_request(
service="auth",
action="login",
auth=auth,
json=data,
refresh_token_if_unauthorized=False,
)
try:
resp = res.json()
except ValueError:
resp = {}
if res.status_code != 200:
msg = resp.get("meta", {}).get("result_msg", res.reason)
raise LoginError(
"Failed getting token (error {} from {}): {}".format(
res.status_code, self.host, msg
)
)
if verbose:
self._logger.info("Received new token")
return resp["data"]["token"]
except LoginError:
six.reraise(*sys.exc_info())
except Exception as ex:
raise LoginError(str(ex))
def __str__(self):
return "{self.__class__.__name__}[{self.host}, {self.access_key}/{secret_key}]".format(
self=self, secret_key=self.secret_key[:5] + "*" * (len(self.secret_key) - 5)
)

View File

@@ -0,0 +1,95 @@
import sys
from abc import ABCMeta, abstractmethod
from time import time
import jwt
import six
@six.add_metaclass(ABCMeta)
class TokenManager(object):
@property
def token_expiration_threshold_sec(self):
return self.__token_expiration_threshold_sec
@token_expiration_threshold_sec.setter
def token_expiration_threshold_sec(self, value):
self.__token_expiration_threshold_sec = value
@property
def req_token_expiration_sec(self):
""" Token expiration sec requested when refreshing token """
return self.__req_token_expiration_sec
@req_token_expiration_sec.setter
def req_token_expiration_sec(self, value):
assert isinstance(value, (type(None), int))
self.__req_token_expiration_sec = value
@property
def token_expiration_sec(self):
return self.__token_expiration_sec
@property
def token(self):
return self._get_token()
@property
def raw_token(self):
return self.__token
def __init__(
self,
token=None,
req_token_expiration_sec=None,
token_history=None,
token_expiration_threshold_sec=60,
**kwargs
):
super(TokenManager, self).__init__()
assert isinstance(token_history, (type(None), dict))
self.token_expiration_threshold_sec = token_expiration_threshold_sec
self.req_token_expiration_sec = req_token_expiration_sec
self._set_token(token)
def _calc_token_valid_period_sec(self, token, exp=None, at_least_sec=None):
if token:
try:
exp = exp or self._get_token_exp(token)
if at_least_sec:
at_least_sec = max(at_least_sec, self.token_expiration_threshold_sec)
else:
at_least_sec = self.token_expiration_threshold_sec
return max(0, (exp - time() - at_least_sec))
except Exception:
pass
return 0
@classmethod
def _get_token_exp(cls, token):
""" Get token expiration time. If not present, assume forever """
return jwt.decode(token, verify=False).get('exp', sys.maxsize)
def _set_token(self, token):
if token:
self.__token = token
self.__token_expiration_sec = self._get_token_exp(token)
else:
self.__token = None
self.__token_expiration_sec = 0
def get_token_valid_period_sec(self):
return self._calc_token_valid_period_sec(self.__token, self.token_expiration_sec)
def _get_token(self):
if self.get_token_valid_period_sec() <= 0:
self.refresh_token()
return self.__token
@abstractmethod
def _do_refresh_token(self, old_token, exp=None):
pass
def refresh_token(self):
self._set_token(self._do_refresh_token(self.__token, exp=self.req_token_expiration_sec))

View File

@@ -0,0 +1,86 @@
import ssl
import sys
import requests
from requests.adapters import HTTPAdapter
## from requests.packages.urllib3.util.retry import Retry
from urllib3.util import Retry
from urllib3 import PoolManager
import six
if six.PY3:
from functools import lru_cache
elif six.PY2:
# python 2 support
from backports.functools_lru_cache import lru_cache
@lru_cache()
def get_config():
from ..backend_config import Config
config = Config(verbose=False)
config.reload()
return config
class TLSv1HTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs):
self.poolmanager = PoolManager(num_pools=connections,
maxsize=maxsize,
block=block,
ssl_version=ssl.PROTOCOL_TLSv1_2)
def get_http_session_with_retry(
total=0,
connect=None,
read=None,
redirect=None,
status=None,
status_forcelist=None,
backoff_factor=0,
backoff_max=None,
pool_connections=None,
pool_maxsize=None):
if not all(isinstance(x, (int, type(None))) for x in (total, connect, read, redirect, status)):
raise ValueError('Bad configuration. All retry count values must be null or int')
if status_forcelist and not all(isinstance(x, int) for x in status_forcelist):
raise ValueError('Bad configuration. Retry status_forcelist must be null or list of ints')
pool_maxsize = (
pool_maxsize
if pool_maxsize is not None
else get_config().get('api.http.pool_maxsize', 512)
)
pool_connections = (
pool_connections
if pool_connections is not None
else get_config().get('api.http.pool_connections', 512)
)
session = requests.Session()
if backoff_max is not None:
Retry.BACKOFF_MAX = backoff_max
retry = Retry(
total=total, connect=connect, read=read, redirect=redirect, status=status,
status_forcelist=status_forcelist, backoff_factor=backoff_factor)
adapter = TLSv1HTTPAdapter(max_retries=retry, pool_connections=pool_connections, pool_maxsize=pool_maxsize)
session.mount('http://', adapter)
session.mount('https://', adapter)
return session
def get_response_cls(request_cls):
""" Extract a request's response class using the mapping found in the module defining the request's service """
for req_cls in request_cls.mro():
module = sys.modules[req_cls.__module__]
if hasattr(module, 'action_mapping'):
return module.action_mapping[(request_cls._action, request_cls._version)][1]
elif hasattr(module, 'response_mapping'):
return module.response_mapping[req_cls]
raise TypeError('no response class!')

View File

@@ -0,0 +1 @@
__version__ = '2.0.0'

View File

@@ -0,0 +1,4 @@
from .defs import Environment
from .config import Config, ConfigEntry
from .errors import ConfigurationError
from .environment import EnvEntry

View File

@@ -0,0 +1,291 @@
import abc
import warnings
from operator import itemgetter
import furl
import six
from attr import attrib, attrs
def _none_to_empty_string(maybe_string):
return maybe_string if maybe_string is not None else ""
def _url_stripper(bucket):
bucket = _none_to_empty_string(bucket)
bucket = bucket.strip("\"'").rstrip("/")
return bucket
@attrs
class S3BucketConfig(object):
bucket = attrib(type=str, converter=_url_stripper, default="")
host = attrib(type=str, converter=_none_to_empty_string, default="")
key = attrib(type=str, converter=_none_to_empty_string, default="")
secret = attrib(type=str, converter=_none_to_empty_string, default="")
multipart = attrib(type=bool, default=True)
acl = attrib(type=str, converter=_none_to_empty_string, default="")
secure = attrib(type=bool, default=True)
region = attrib(type=str, converter=_none_to_empty_string, default="")
def update(self, key, secret, multipart=True, region=None):
self.key = key
self.secret = secret
self.multipart = multipart
self.region = region
def is_valid(self):
return self.key and self.secret
def get_bucket_host(self):
return self.bucket, self.host
@classmethod
def from_list(cls, dict_list, log=None):
if not isinstance(dict_list, (tuple, list)) or not all(
isinstance(x, dict) for x in dict_list
):
raise ValueError("Expecting a list of configurations dictionaries")
configs = [cls(**entry) for entry in dict_list]
valid_configs = [conf for conf in configs if conf.is_valid()]
if log and len(valid_configs) < len(configs):
log.warn(
"Invalid bucket configurations detected for {}".format(
", ".join(
"/".join((config.host, config.bucket))
for config in configs
if config not in valid_configs
)
)
)
return configs
BucketConfig = S3BucketConfig
@six.add_metaclass(abc.ABCMeta)
class BaseBucketConfigurations(object):
def __init__(self, buckets=None, *_, **__):
self._buckets = buckets or []
self._prefixes = None
def _update_prefixes(self, refresh=True):
if self._prefixes and not refresh:
return
prefixes = (
(config, self._get_prefix_from_bucket_config(config))
for config in self._buckets
)
self._prefixes = sorted(prefixes, key=itemgetter(1), reverse=True)
@abc.abstractmethod
def _get_prefix_from_bucket_config(self, config):
pass
class S3BucketConfigurations(BaseBucketConfigurations):
def __init__(
self, buckets=None, default_key="", default_secret="", default_region=""
):
super(S3BucketConfigurations, self).__init__()
self._buckets = buckets if buckets else list()
self._default_key = default_key
self._default_secret = default_secret
self._default_region = default_region
self._default_multipart = True
@classmethod
def from_config(cls, s3_configuration):
config_list = S3BucketConfig.from_list(
s3_configuration.get("credentials", default=None)
)
default_key = s3_configuration.get("key", default="")
default_secret = s3_configuration.get("secret", default="")
default_region = s3_configuration.get("region", default="")
default_key = _none_to_empty_string(default_key)
default_secret = _none_to_empty_string(default_secret)
default_region = _none_to_empty_string(default_region)
return cls(config_list, default_key, default_secret, default_region)
def add_config(self, bucket_config):
self._buckets.insert(0, bucket_config)
self._prefixes = None
def remove_config(self, bucket_config):
self._buckets.remove(bucket_config)
self._prefixes = None
def get_config_by_bucket(self, bucket, host=None):
try:
return next(
bucket_config
for bucket_config in self._buckets
if (bucket, host) == bucket_config.get_bucket_host()
)
except StopIteration:
pass
return None
def update_config_with_defaults(self, bucket_config):
bucket_config.update(
key=self._default_key,
secret=self._default_secret,
region=bucket_config.region or self._default_region,
multipart=bucket_config.multipart or self._default_multipart,
)
def _get_prefix_from_bucket_config(self, config):
scheme = "s3"
prefix = furl.furl()
if config.host:
prefix.set(
scheme=scheme,
netloc=config.host.lower(),
path=config.bucket.lower() if config.bucket else "",
)
else:
prefix.set(scheme=scheme, path=config.bucket.lower())
bucket = prefix.path.segments[0]
prefix.path.segments.pop(0)
prefix.set(netloc=bucket)
return str(prefix)
def get_config_by_uri(self, uri):
"""
Get the credentials for an AWS S3 bucket from the config
:param uri: URI of bucket, directory or file
:return: bucket config
:rtype: S3BucketConfig
"""
def find_match(uri):
self._update_prefixes(refresh=False)
uri = uri.lower()
res = (
config
for config, prefix in self._prefixes
if prefix is not None and uri.startswith(prefix)
)
try:
return next(res)
except StopIteration:
return None
match = find_match(uri)
if match:
return match
parsed = furl.furl(uri)
if parsed.port:
host = parsed.netloc
parts = parsed.path.segments
bucket = parts[0] if parts else None
else:
host = None
bucket = parsed.netloc
return S3BucketConfig(
key=self._default_key,
secret=self._default_secret,
region=self._default_region,
multipart=True,
bucket=bucket,
host=host,
)
BucketConfigurations = S3BucketConfigurations
@attrs
class GSBucketConfig(object):
bucket = attrib(type=str)
subdir = attrib(type=str, converter=_url_stripper, default="")
project = attrib(type=str, default=None)
credentials_json = attrib(type=str, default=None)
def update(self, **kwargs):
for item in kwargs:
if not hasattr(self, item):
warnings.warn("Unexpected argument {} for update. Ignored".format(item))
else:
setattr(self, item, kwargs[item])
class GSBucketConfigurations(BaseBucketConfigurations):
def __init__(self, buckets=None, default_project=None, default_credentials=None):
super(GSBucketConfigurations, self).__init__(buckets)
self._default_project = default_project
self._default_credentials = default_credentials
self._update_prefixes()
@classmethod
def from_config(cls, gs_configuration):
if gs_configuration is None:
return cls()
config_list = gs_configuration.get("credentials", default=list())
buckets_configs = [GSBucketConfig(**entry) for entry in config_list]
default_project = gs_configuration.get("project", default=None)
default_credentials = gs_configuration.get("credentials_json", default=None)
return cls(buckets_configs, default_project, default_credentials)
def add_config(self, bucket_config):
self._buckets.insert(0, bucket_config)
self._update_prefixes()
def remove_config(self, bucket_config):
self._buckets.remove(bucket_config)
self._update_prefixes()
def update_config_with_defaults(self, bucket_config):
bucket_config.update(
project=bucket_config.project or self._default_project,
credentials_json=bucket_config.credentials_json
or self._default_credentials,
)
def get_config_by_uri(self, uri):
"""
Get the credentials for a Google Storage bucket from the config
:param uri: URI of bucket, directory or file
:return: bucket config
:rtype: GSBucketConfig
"""
res = (
config
for config, prefix in self._prefixes
if prefix is not None and uri.lower().startswith(prefix)
)
try:
return next(res)
except StopIteration:
pass
parsed = furl.furl(uri)
return GSBucketConfig(
bucket=parsed.netloc,
subdir=str(parsed.path),
project=self._default_project,
credentials_json=self._default_credentials,
)
def _get_prefix_from_bucket_config(self, config):
prefix = furl.furl(scheme="gs", netloc=config.bucket, path=config.subdir)
return str(prefix)

View File

@@ -0,0 +1,412 @@
from __future__ import print_function
import functools
import json
import os
import sys
import warnings
from fnmatch import fnmatch
from logging import Logger
from os.path import expanduser
from typing import Any, Text
import pyhocon
import six
from pathlib2 import Path
from pyhocon import ConfigTree
from pyparsing import (
ParseFatalException,
ParseException,
RecursiveGrammarException,
ParseSyntaxException,
)
from six.moves.urllib.parse import urlparse
from watchdog.observers import Observer
from .bucket_config import S3BucketConfig
from .defs import (
Environment,
DEFAULT_CONFIG_FOLDER,
LOCAL_CONFIG_PATHS,
ENV_CONFIG_PATHS,
LOCAL_CONFIG_FILES,
LOCAL_CONFIG_FILE_OVERRIDE_VAR,
ENV_CONFIG_PATH_OVERRIDE_VAR,
)
from .defs import is_config_file
from .entry import Entry, NotSet
from .errors import ConfigurationError
from .log import initialize as initialize_log, logger
from .reloader import ConfigReloader
from .utils import get_options
log = logger(__file__)
class ConfigEntry(Entry):
logger = None
def __init__(self, config, *keys, **kwargs):
# type: (Config, Text, Any) -> None
super(ConfigEntry, self).__init__(*keys, **kwargs)
self.config = config
def _get(self, key):
# type: (Text) -> Any
return self.config.get(key, NotSet)
def error(self, message):
# type: (Text) -> None
log.error(message.capitalize())
class Config(object):
"""
Represents a server configuration.
If watch=True, will watch configuration folders for changes and reload itself.
NOTE: will not watch folders that were created after initialization.
"""
# used in place of None in Config.get as default value because None is a valid value
_MISSING = object()
def __init__(
self,
config_folder=None,
env=None,
verbose=True,
relative_to=None,
app=None,
watch=False,
is_server=False,
**_
):
self._app = app
self._verbose = verbose
self._folder_name = config_folder or DEFAULT_CONFIG_FOLDER
self._roots = []
self._config = ConfigTree()
self._env = env or os.environ.get("TRAINS_ENV", Environment.default)
self.config_paths = set()
self.watch = watch
self.is_server = is_server
if watch:
self.observer = Observer()
self.observer.start()
self.handler = ConfigReloader(self)
if self._verbose:
print("Config env:%s" % str(self._env))
if not self._env:
raise ValueError(
"Missing environment in either init of environment variable"
)
if self._env not in get_options(Environment):
raise ValueError("Invalid environment %s" % env)
if relative_to is not None:
self.load_relative_to(relative_to)
@property
def root(self):
return self.roots[0] if self.roots else None
@property
def roots(self):
return self._roots
@roots.setter
def roots(self, value):
self._roots = value
@property
def env(self):
return self._env
def logger(self, path=None):
return logger(path)
def load_relative_to(self, *module_paths):
def normalize(p):
return Path(os.path.abspath(str(p))).with_name(self._folder_name)
self.roots = list(map(normalize, module_paths))
self.reload()
if self.watch:
for path in self.config_paths:
self.observer.schedule(self.handler, str(path), recursive=True)
def _reload(self):
env = self._env
config = self._config.copy()
if self.is_server:
env_config_paths = ENV_CONFIG_PATHS
else:
env_config_paths = []
env_config_path_override = os.environ.get(ENV_CONFIG_PATH_OVERRIDE_VAR)
if env_config_path_override:
env_config_paths = [expanduser(env_config_path_override)]
# merge configuration from root and other environment config paths
config = functools.reduce(
lambda cfg, path: ConfigTree.merge_configs(
cfg,
self._read_recursive_for_env(path, env, verbose=self._verbose),
copy_trees=True,
),
self.roots + env_config_paths,
config,
)
# merge configuration from local configuration paths
config = functools.reduce(
lambda cfg, path: ConfigTree.merge_configs(
cfg, self._read_recursive(path, verbose=self._verbose), copy_trees=True
),
LOCAL_CONFIG_PATHS,
config,
)
local_config_files = LOCAL_CONFIG_FILES
local_config_override = os.environ.get(LOCAL_CONFIG_FILE_OVERRIDE_VAR)
if local_config_override:
local_config_files = [expanduser(local_config_override)]
# merge configuration from local configuration files
config = functools.reduce(
lambda cfg, file_path: ConfigTree.merge_configs(
cfg,
self._read_single_file(file_path, verbose=self._verbose),
copy_trees=True,
),
local_config_files,
config,
)
config["env"] = env
return config
def replace(self, config):
self._config = config
def reload(self):
self.replace(self._reload())
def initialize_logging(self):
logging_config = self._config.get("logging", None)
if not logging_config:
return False
# handle incomplete file handlers
deleted = []
handlers = logging_config.get("handlers", {})
for name, handler in list(handlers.items()):
cls = handler.get("class", None)
is_file = cls and "FileHandler" in cls
if cls is None or (is_file and "filename" not in handler):
deleted.append(name)
del handlers[name]
elif is_file:
file = Path(handler.get("filename"))
if not file.is_file():
file.parent.mkdir(parents=True, exist_ok=True)
file.touch()
# remove dependency in deleted handlers
root_logger = logging_config.get("root", None)
loggers = list(logging_config.get("loggers", {}).values()) + (
[root_logger] if root_logger else []
)
for logger in loggers:
handlers = logger.get("handlers", None)
if not handlers:
continue
logger["handlers"] = [h for h in handlers if h not in deleted]
extra = None
if self._app:
extra = {"app": self._app}
initialize_log(logging_config, extra=extra)
return True
def __getitem__(self, key):
return self._config[key]
def get(self, key, default=_MISSING):
value = self._config.get(key, default)
if value is self._MISSING and not default:
raise KeyError(
"Unable to find value for key '{}' and default value was not provided.".format(
key
)
)
return value
def to_dict(self):
return self._config.as_plain_ordered_dict()
def as_json(self):
return json.dumps(self.to_dict(), indent=2)
def _read_recursive_for_env(self, root_path_str, env, verbose=True):
root_path = Path(root_path_str)
if root_path.exists():
default_config = self._read_recursive(
root_path / Environment.default, verbose=verbose
)
env_config = self._read_recursive(
root_path / env, verbose=verbose
) # None is ok, will return empty config
config = ConfigTree.merge_configs(default_config, env_config, True)
else:
config = ConfigTree()
return config
def _read_recursive(self, conf_root, verbose=True):
conf = ConfigTree()
if not conf_root:
return conf
conf_root = Path(conf_root)
if not conf_root.exists():
if verbose:
print("No config in %s" % str(conf_root))
return conf
if self.watch:
self.config_paths.add(conf_root)
if verbose:
print("Loading config from %s" % str(conf_root))
for root, dirs, files in os.walk(str(conf_root)):
rel_dir = str(Path(root).relative_to(conf_root))
if rel_dir == ".":
rel_dir = ""
prefix = rel_dir.replace("/", ".")
for filename in files:
if not is_config_file(filename):
continue
if prefix != "":
key = prefix + "." + Path(filename).stem
else:
key = Path(filename).stem
file_path = str(Path(root) / filename)
conf.put(key, self._read_single_file(file_path, verbose=verbose))
return conf
@staticmethod
def _read_single_file(file_path, verbose=True):
if not file_path or not Path(file_path).is_file():
return ConfigTree()
if verbose:
print("Loading config from file %s" % file_path)
try:
return pyhocon.ConfigFactory.parse_file(file_path)
except ParseSyntaxException as ex:
msg = "Failed parsing {0} ({1.__class__.__name__}): (at char {1.loc}, line:{1.lineno}, col:{1.column})".format(
file_path, ex
)
six.reraise(
ConfigurationError,
ConfigurationError(msg, file_path=file_path),
sys.exc_info()[2],
)
except (ParseException, ParseFatalException, RecursiveGrammarException) as ex:
msg = "Failed parsing {0} ({1.__class__.__name__}): {1}".format(
file_path, ex
)
six.reraise(ConfigurationError, ConfigurationError(msg), sys.exc_info()[2])
except Exception as ex:
print("Failed loading %s: %s" % (file_path, ex))
raise
def get_config_for_bucket(self, base_url, extra_configurations=None):
"""
Get the credentials for an AWS S3 bucket from the config
:param base_url: URL of bucket
:param extra_configurations:
:return: bucket config
:rtype: bucket config
"""
warnings.warn(
"Use backend_config.bucket_config.BucketList.get_config_for_uri",
DeprecationWarning,
)
configs = S3BucketConfig.from_list(self.get("sdk.aws.s3.credentials", []))
if extra_configurations:
configs.extend(extra_configurations)
def find_match(host=None, bucket=None):
if not host and not bucket:
raise ValueError("host or bucket required")
try:
if host:
res = {
config
for config in configs
if (config.host and fnmatch(host, config.host))
and (
not bucket
or not config.bucket
or fnmatch(bucket.lower(), config.bucket.lower())
)
}
else:
res = {
config
for config in configs
if config.bucket
and fnmatch(bucket.lower(), config.bucket.lower())
}
return next(iter(res))
except StopIteration:
pass
parsed = urlparse(base_url)
parts = Path(parsed.path.strip("/")).parts
if parsed.netloc:
# We have a netloc (either an actual hostname or an AWS bucket name).
# First, we'll try with the netloc as host, but if we don't find anything, we'll try without a host and
# with the netloc as the bucket name
match = None
if parts:
# try host/bucket only if path parts contain any element
match = find_match(host=parsed.netloc, bucket=parts[0])
if not match:
# no path parts or no config found for host/bucket, try netloc as bucket
match = find_match(bucket=parsed.netloc)
else:
# No netloc, so we'll simply search by bucket
match = find_match(bucket=parts[0])
if match:
return match
non_aws_s3_host_suffix = ":9000"
if parsed.netloc.endswith(non_aws_s3_host_suffix):
host = parsed.netloc
bucket = parts[0] if parts else None
else:
host = None
bucket = parsed.netloc
return S3BucketConfig(
key=self.get("sdk.aws.s3.key", None),
secret=self.get("sdk.aws.s3.secret", None),
region=self.get("sdk.aws.s3.region", None),
multipart=True,
bucket=bucket,
host=host,
)

View File

@@ -0,0 +1,46 @@
import base64
from distutils.util import strtobool
from typing import Union, Optional, Text, Any, TypeVar, Callable, Tuple
import six
ConverterType = TypeVar("ConverterType", bound=Callable[[Any], Any])
def base64_to_text(value):
# type: (Any) -> Text
return base64.b64decode(value).decode("utf-8")
def text_to_bool(value):
# type: (Text) -> bool
return bool(strtobool(value))
def any_to_bool(value):
# type: (Optional[Union[int, float, Text]]) -> bool
if isinstance(value, six.text_type):
return text_to_bool(value)
return bool(value)
def or_(*converters, **kwargs):
# type: (ConverterType, Tuple[Exception, ...]) -> ConverterType
"""
Wrapper that implements an "optional converter" pattern. Allows specifying a converter
for which a set of exceptions is ignored (and the original value is returned)
:param converter: A converter callable
:param exceptions: A tuple of exception types to ignore
"""
# noinspection PyUnresolvedReferences
exceptions = kwargs.get("exceptions", (ValueError, TypeError))
def wrapper(value):
for converter in converters:
try:
return converter(value)
except exceptions:
pass
return value
return wrapper

View File

@@ -0,0 +1,53 @@
from os.path import expanduser
from pathlib2 import Path
ENV_VAR = 'TRAINS_ENV'
""" Name of system environment variable that can be used to specify the config environment name """
DEFAULT_CONFIG_FOLDER = 'config'
""" Default config folder to search for when loading relative to a given path """
ENV_CONFIG_PATHS = [
]
""" Environment-related config paths """
LOCAL_CONFIG_PATHS = [
'/etc/opt/trains', # used by servers for docker-generated configuration
expanduser('~/.trains/config'),
]
""" Local config paths, not related to environment """
LOCAL_CONFIG_FILES = [
expanduser('~/trains.conf'), # used for workstation configuration (end-users, workers)
]
""" Local config files (not paths) """
LOCAL_CONFIG_FILE_OVERRIDE_VAR = 'TRAINS_CONFIG_FILE'
""" Local config file override environment variable. If this is set, no other local config files will be used. """
ENV_CONFIG_PATH_OVERRIDE_VAR = 'TRAINS_CONFIG_PATH'
"""
Environment-related config path override environment variable. If this is set, no other env config path will be used.
"""
class Environment(object):
""" Supported environment names """
default = 'default'
demo = 'demo'
local = 'local'
CONFIG_FILE_EXTENSION = '.conf'
def is_config_file(path):
return Path(path).suffix == CONFIG_FILE_EXTENSION

View File

@@ -0,0 +1,96 @@
import abc
from typing import Optional, Any, Tuple, Text, Callable, Dict
import six
from .converters import any_to_bool
NotSet = object()
Converter = Callable[[Any], Any]
@six.add_metaclass(abc.ABCMeta)
class Entry(object):
"""
Configuration entry definition
"""
@classmethod
def default_conversions(cls):
# type: () -> Dict[Any, Converter]
return {
bool: any_to_bool,
six.text_type: lambda s: six.text_type(s).strip(),
}
def __init__(self, key, *more_keys, **kwargs):
# type: (Text, Text, Any) -> None
"""
:param key: Entry's key (at least one).
:param more_keys: More alternate keys for this entry.
:param type: Value type. If provided, will be used choosing a default conversion or
(if none exists) for casting the environment value.
:param converter: Value converter. If provided, will be used to convert the environment value.
:param default: Default value. If provided, will be used as the default value on calls to get() and get_pair()
in case no value is found for any key and no specific default value was provided in the call.
Default value is None.
:param help: Help text describing this entry
"""
self.keys = (key,) + more_keys
self.type = kwargs.pop("type", six.text_type)
self.converter = kwargs.pop("converter", None)
self.default = kwargs.pop("default", None)
self.help = kwargs.pop("help", None)
def __str__(self):
return str(self.key)
@property
def key(self):
return self.keys[0]
def convert(self, value, converter=None):
# type: (Any, Converter) -> Optional[Any]
converter = converter or self.converter
if not converter:
converter = self.default_conversions().get(self.type, self.type)
return converter(value)
def get_pair(self, default=NotSet, converter=None):
# type: (Any, Converter) -> Optional[Tuple[Text, Any]]
for key in self.keys:
value = self._get(key)
if value is NotSet:
continue
try:
value = self.convert(value, converter)
except Exception as ex:
self.error("invalid value {key}={value}: {ex}".format(**locals()))
break
return key, value
result = self.default if default is NotSet else default
return self.key, result
def get(self, default=NotSet, converter=None):
# type: (Any, Converter) -> Optional[Any]
return self.get_pair(default=default, converter=converter)[1]
def set(self, value):
# type: (Any, Any) -> (Text, Any)
key, _ = self.get_pair(default=None, converter=None)
self._set(key, str(value))
def _set(self, key, value):
# type: (Text, Text) -> None
pass
@abc.abstractmethod
def _get(self, key):
# type: (Text) -> Any
pass
@abc.abstractmethod
def error(self, message):
# type: (Text) -> None
pass

View File

@@ -0,0 +1,25 @@
from os import getenv, environ
from .converters import text_to_bool
from .entry import Entry, NotSet
class EnvEntry(Entry):
@classmethod
def default_conversions(cls):
conversions = super(EnvEntry, cls).default_conversions().copy()
conversions[bool] = text_to_bool
return conversions
def _get(self, key):
value = getenv(key, "").strip()
return value or NotSet
def _set(self, key, value):
environ[key] = value
def __str__(self):
return "env:{}".format(super(EnvEntry, self).__str__())
def error(self, message):
print("Environment configuration: {}".format(message))

View File

@@ -0,0 +1,5 @@
class ConfigurationError(Exception):
def __init__(self, msg, file_path=None, *args):
super(ConfigurationError, self).__init__(msg, *args)
self.file_path = file_path

View File

@@ -0,0 +1,30 @@
import logging.config
from pathlib2 import Path
def logger(path=None):
name = "trains"
if path:
p = Path(path)
module = (p.parent if p.stem.startswith('_') else p).stem
name = "trains.%s" % module
return logging.getLogger(name)
def initialize(logging_config=None, extra=None):
if extra is not None:
from logging import Logger
class _Logger(Logger):
__extra = extra.copy()
def _log(self, level, msg, args, exc_info=None, extra=None, **kwargs):
extra = extra or {}
extra.update(self.__extra)
super(_Logger, self)._log(level, msg, args, exc_info=exc_info, extra=extra, **kwargs)
Logger.manager.loggerClass = _Logger
if logging_config is not None:
logging.config.dictConfig(dict(logging_config))

View File

@@ -0,0 +1,32 @@
import logging
from watchdog.events import FileSystemEventHandler, FileCreatedEvent, FileDeletedEvent, FileModifiedEvent, \
FileMovedEvent
from .defs import is_config_file
from .log import logger
log = logger(__file__)
log.setLevel(logging.DEBUG)
class ConfigReloader(FileSystemEventHandler):
def __init__(self, config):
self.config = config
def reload(self):
try:
self.config.reload()
except Exception as ex:
log.warning('failed loading configuration: %s: %s', type(ex), ex)
def on_any_event(self, event):
if not (
is_config_file(event.src_path) and
isinstance(event, (FileCreatedEvent, FileDeletedEvent, FileModifiedEvent, FileMovedEvent))
):
return
log.debug('reloading configuration - triggered by %s', event)
self.reload()

View File

@@ -0,0 +1,9 @@
def get_items(cls):
""" get key/value items from an enum-like class (members represent enumeration key/value) """
return {k: v for k, v in vars(cls).items() if not k.startswith('_')}
def get_options(cls):
""" get options from an enum-like class (members represent enumeration key/value) """
return get_items(cls).values()

View File

@@ -0,0 +1,2 @@
""" High-level abstractions for backend API """
from .task import Task, TaskStatusEnum, TaskEntry

View File

@@ -0,0 +1,147 @@
import abc
import requests.exceptions
import six
from ..backend_api import Session
from ..backend_api.session import BatchRequest
from ..config import config_obj
from ..config.defs import LOG_LEVEL_ENV_VAR, API_ACCESS_KEY, API_SECRET_KEY
from ..debugging import get_logger
from ..backend_api.version import __version__
from .session import SendError, SessionInterface
class InterfaceBase(SessionInterface):
""" Base class for a backend manager class """
_default_session = None
@property
def session(self):
return self._session
@property
def log(self):
return self._log
def __init__(self, session=None, log=None, **kwargs):
super(InterfaceBase, self).__init__()
self._session = session or self._get_default_session()
self._log = log or self._create_log()
def _create_log(self):
log = get_logger(str(self.__class__.__name__))
try:
log.setLevel(LOG_LEVEL_ENV_VAR.get(default=log.level))
except TypeError as ex:
raise ValueError('Invalid log level defined in environment variable `%s`: %s' % (LOG_LEVEL_ENV_VAR, ex))
return log
@classmethod
def _send(cls, session, req, ignore_errors=False, raise_on_errors=True, log=None, async_enable=False):
""" Convenience send() method providing a standardized error reporting """
while True:
try:
res = session.send(req, async_enable=async_enable)
if res.meta.result_code in (200, 202) or ignore_errors:
return res
if isinstance(req, BatchRequest):
error_msg = 'Action failed %s' % res.meta
else:
error_msg = 'Action failed %s (%s)' \
% (res.meta, ', '.join('%s=%s' % p for p in req.to_dict().items()))
if log:
log.error(error_msg)
if res.meta.result_code <= 500:
# Proper backend error/bad status code - raise or return
if raise_on_errors:
raise SendError(res, error_msg)
return res
except requests.exceptions.BaseHTTPError as e:
log.error('failed sending %s: %s' % (str(req), str(e)))
# Infrastructure error
if log:
log.info('retrying request %s' % str(req))
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
return self._send(session=self.session, req=req, ignore_errors=ignore_errors, raise_on_errors=raise_on_errors,
log=self.log, async_enable=async_enable)
@classmethod
def _get_default_session(cls):
if not InterfaceBase._default_session:
InterfaceBase._default_session = Session(
initialize_logging=False,
client='sdk-%s' % __version__,
config=config_obj,
api_key=API_ACCESS_KEY.get(),
secret_key=API_SECRET_KEY.get(),
)
return InterfaceBase._default_session
@classmethod
def _set_default_session(cls, session):
"""
Set a new default session to the system
Warning: Use only for debug and testing
:param session: The new default session
"""
InterfaceBase._default_session = session
@property
def default_session(self):
if hasattr(self, '_session'):
return self._session
return self._get_default_session()
@six.add_metaclass(abc.ABCMeta)
class IdObjectBase(InterfaceBase):
def __init__(self, id, session=None, log=None, **kwargs):
super(IdObjectBase, self).__init__(session, log, **kwargs)
self._data = None
self._id = None
self.id = self.normalize_id(id)
@property
def id(self):
return self._id
@id.setter
def id(self, value):
should_reload = value is not None and value != self._id
self._id = value
if should_reload:
self.reload()
@property
def data(self):
if self._data is None:
self.reload()
return self._data
@abc.abstractmethod
def _reload(self):
pass
def reload(self):
if not self.id:
raise ValueError('Failed reloading %s: missing id' % type(self).__name__)
self._data = self._reload()
@classmethod
def normalize_id(cls, id):
return id.strip() if id else None
@classmethod
def resolve_id(cls, obj):
if isinstance(obj, cls):
return obj.id
return obj

View File

@@ -0,0 +1,4 @@
""" Metrics management and batching support """
from .interface import Metrics
from .reporter import Reporter
from .events import ScalarEvent, VectorEvent, PlotEvent, ImageEvent

View File

@@ -0,0 +1,258 @@
import abc
import time
from threading import Lock
import attr
import cv2
import numpy as np
import pathlib2
import six
from ...backend_api.services import events
from six.moves.urllib.parse import urlparse, urlunparse
from ...config import config
@six.add_metaclass(abc.ABCMeta)
class MetricsEventAdapter(object):
"""
Adapter providing all the base attributes required by a metrics event and defining an interface used by the
metrics manager when batching and writing events.
"""
_default_nan_value = 0.
""" Default value used when a np.nan value is encountered """
@attr.attrs(cmp=False, slots=True)
class FileEntry(object):
""" File entry used to report on file data that needs to be uploaded prior to sending the event """
event = attr.attrib()
name = attr.attrib()
""" File name """
stream = attr.attrib()
""" File-like object containing the file's data """
url_prop = attr.attrib()
""" Property name that should be updated with the uploaded url """
key_prop = attr.attrib()
upload_uri = attr.attrib()
url = attr.attrib(default=None)
exception = attr.attrib(default=None)
def set_exception(self, exp):
self.exception = exp
self.event.upload_exception = exp
@property
def metric(self):
return self._metric
@metric.setter
def metric(self, value):
self._metric = value
@property
def variant(self):
return self._variant
def __init__(self, metric, variant, iter=None, timestamp=None, task=None, gen_timestamp_if_none=True):
if not timestamp and gen_timestamp_if_none:
timestamp = int(time.time() * 1000)
self._metric = metric
self._variant = variant
self._iter = iter
self._timestamp = timestamp
self._task = task
# Try creating an event just to trigger validation
_ = self.get_api_event()
self.upload_exception = None
@abc.abstractmethod
def get_api_event(self):
""" Get an API event instance """
pass
def get_file_entry(self):
""" Get information for a file that should be uploaded before this event is sent """
pass
def update(self, task=None, **kwargs):
""" Update event properties """
if task:
self._task = task
def _get_base_dict(self):
""" Get a dict with the base attributes """
res = dict(
task=self._task,
timestamp=self._timestamp,
metric=self._metric,
variant=self._variant
)
if self._iter is not None:
res.update(iter=self._iter)
return res
@classmethod
def _convert_np_nan(cls, val):
if np.isnan(val) or np.isinf(val):
return cls._default_nan_value
return val
class ScalarEvent(MetricsEventAdapter):
""" Scalar event adapter """
def __init__(self, metric, variant, value, iter, **kwargs):
self._value = self._convert_np_nan(value)
super(ScalarEvent, self).__init__(metric=metric, variant=variant, iter=iter, **kwargs)
def get_api_event(self):
return events.MetricsScalarEvent(
value=self._value,
**self._get_base_dict())
class VectorEvent(MetricsEventAdapter):
""" Vector event adapter """
def __init__(self, metric, variant, values, iter, **kwargs):
self._values = [self._convert_np_nan(v) for v in values]
super(VectorEvent, self).__init__(metric=metric, variant=variant, iter=iter, **kwargs)
def get_api_event(self):
return events.MetricsVectorEvent(
values=self._values,
**self._get_base_dict())
class PlotEvent(MetricsEventAdapter):
""" Plot event adapter """
def __init__(self, metric, variant, plot_str, iter=None, **kwargs):
self._plot_str = plot_str
super(PlotEvent, self).__init__(metric=metric, variant=variant, iter=iter, **kwargs)
def get_api_event(self):
return events.MetricsPlotEvent(
plot_str=self._plot_str,
**self._get_base_dict())
class ImageEventNoUpload(MetricsEventAdapter):
def __init__(self, metric, variant, src, iter=0, **kwargs):
parts = urlparse(src)
self._url = urlunparse((parts.scheme, parts.netloc, '', '', '', ''))
self._key = urlunparse(('', '', parts.path, parts.params, parts.query, parts.fragment))
super(ImageEventNoUpload, self).__init__(metric, variant, iter=iter, **kwargs)
def get_api_event(self):
return events.MetricsImageEvent(
url=self._url,
key=self._key,
**self._get_base_dict())
class ImageEvent(MetricsEventAdapter):
""" Image event adapter """
_format = '.' + str(config.get('metrics.images.format', 'JPEG')).upper().lstrip('.')
_quality = int(config.get('metrics.images.quality', 87))
_subsampling = int(config.get('metrics.images.subsampling', 0))
_metric_counters = {}
_metric_counters_lock = Lock()
_image_file_history_size = int(config.get('metrics.file_history_size', 5))
def __init__(self, metric, variant, image_data, iter=0, upload_uri=None,
image_file_history_size=None, **kwargs):
if not hasattr(image_data, 'shape'):
raise ValueError('Image must have a shape attribute')
self._image_data = image_data
self._url = None
self._key = None
self._count = self._get_metric_count(metric, variant)
if not image_file_history_size:
image_file_history_size = self._image_file_history_size
if image_file_history_size < 1:
self._filename = '%s_%s_%08d' % (metric, variant, self._count)
else:
self._filename = '%s_%s_%08d' % (metric, variant, self._count % image_file_history_size)
self._upload_uri = upload_uri
super(ImageEvent, self).__init__(metric, variant, iter=iter, **kwargs)
@classmethod
def _get_metric_count(cls, metric, variant, next=True):
""" Returns the next count number for the given metric/variant (rotates every few calls) """
counters = cls._metric_counters
key = '%s_%s' % (metric, variant)
try:
cls._metric_counters_lock.acquire()
value = counters.get(key, -1)
if next:
value = counters[key] = value + 1
return value
finally:
cls._metric_counters_lock.release()
def get_api_event(self):
return events.MetricsImageEvent(
url=self._url,
key=self._key,
**self._get_base_dict())
def update(self, url=None, key=None, **kwargs):
super(ImageEvent, self).update(**kwargs)
if url is not None:
self._url = url
if key is not None:
self._key = key
def get_file_entry(self):
# don't provide file in case this event is out of the history window
last_count = self._get_metric_count(self.metric, self.variant, next=False)
if abs(self._count - last_count) > self._image_file_history_size:
output = None
else:
image_data = self._image_data
if not isinstance(image_data, np.ndarray):
# try conversion, if it fails we'll leave it to the user.
image_data = np.ndarray(image_data, dtype=np.uint8)
image_data = np.atleast_3d(image_data)
if image_data.dtype != np.uint8:
if np.issubdtype(image_data.dtype, np.floating) and image_data.max() <= 1.0:
image_data = (image_data*255).astype(np.uint8)
else:
image_data = image_data.astype(np.uint8)
shape = image_data.shape
height, width, channel = shape[:3]
if channel == 1:
image_data = np.reshape(image_data, (height, width))
# serialize image
_, img_bytes = cv2.imencode(
self._format, image_data,
params=(cv2.IMWRITE_JPEG_QUALITY, self._quality),
)
output = six.BytesIO(img_bytes.tostring())
output.seek(0)
filename = str(pathlib2.Path(self._filename).with_suffix(self._format.lower()))
return self.FileEntry(
event=self,
name=filename,
stream=output,
url_prop='url',
key_prop='key',
upload_uri=self._upload_uri
)

View File

@@ -0,0 +1,192 @@
from functools import partial
from multiprocessing.pool import ThreadPool
from threading import Lock
from time import time
from humanfriendly import format_timespan
from ...backend_api.services import events as api_events
from ..base import InterfaceBase
from ...config import config
from ...debugging import get_logger
from ...storage import StorageHelper
from .events import MetricsEventAdapter
upload_pool = ThreadPool(processes=1)
file_upload_pool = ThreadPool(processes=config.get('network.metrics.file_upload_threads', 4))
log = get_logger('metrics')
class Metrics(InterfaceBase):
""" Metrics manager and batch writer """
_storage_lock = Lock()
_file_upload_starvation_warning_sec = config.get('network.metrics.file_upload_starvation_warning_sec', None)
@property
def storage_key_prefix(self):
return self._storage_key_prefix
def _get_storage(self, storage_uri=None):
""" Storage helper used to upload files """
try:
# use a lock since this storage object will be requested by thread pool threads, so we need to make sure
# any singleton initialization will occur only once
self._storage_lock.acquire()
storage_uri = storage_uri or self._storage_uri
return StorageHelper.get(storage_uri)
except Exception as e:
log.error('Failed getting storage helper for %s: %s' % (storage_uri, str(e)))
finally:
self._storage_lock.release()
def __init__(self, session, task_id, storage_uri, storage_uri_suffix='metrics', log=None):
super(Metrics, self).__init__(session, log=log)
self._task_id = task_id
self._storage_uri = storage_uri.rstrip('/') if storage_uri else None
self._storage_key_prefix = storage_uri_suffix.strip('/') if storage_uri_suffix else None
self._file_related_event_time = None
self._file_upload_time = None
def write_events(self, events, async_enable=True, callback=None, **kwargs):
"""
Write events to the backend, uploading any required files.
:param events: A list of event objects
:param async_enable: If True, upload is performed asynchronously and an AsyncResult object is returned, otherwise a
blocking call is made and the upload result is returned.
:param callback: A optional callback called when upload was completed in case async is True
:return: .backend_api.session.CallResult if async is False otherwise AsyncResult. Note that if no events were
sent, None will be returned.
"""
if not events:
return
storage_uri = kwargs.pop('storage_uri', self._storage_uri)
if not async_enable:
return self._do_write_events(events, storage_uri)
def safe_call(*args, **kwargs):
try:
return self._do_write_events(*args, **kwargs)
except Exception as e:
return e
return upload_pool.apply_async(
safe_call,
args=(events, storage_uri),
callback=partial(self._callback_wrapper, callback))
def _callback_wrapper(self, callback, res):
""" A wrapper for the async callback for handling common errors """
if not res:
# no result yet
return
elif isinstance(res, Exception):
# error
self.log.error('Error trying to send metrics: %s' % str(res))
elif not res.ok():
# bad result, log error
self.log.error('Failed reporting metrics: %s' % str(res.meta))
# call callback, even if we received an error
if callback:
callback(res)
def _do_write_events(self, events, storage_uri=None):
""" Sends an iterable of events as a series of batch operations. note: metric send does not raise on error"""
assert isinstance(events, (list, tuple))
assert all(isinstance(x, MetricsEventAdapter) for x in events)
# def event_key(ev):
# return (ev.metric, ev.variant)
#
# events = sorted(events, key=event_key)
# multiple_events_for = [k for k, v in groupby(events, key=event_key) if len(list(v)) > 1]
# if multiple_events_for:
# log.warning(
# 'More than one metrics event sent for these metric/variant combinations in a report: %s' %
# ', '.join('%s/%s' % k for k in multiple_events_for))
storage_uri = storage_uri or self._storage_uri
now = time()
def update_and_get_file_entry(ev):
entry = ev.get_file_entry()
kwargs = {}
if entry:
e_storage_uri = entry.upload_uri or storage_uri
self._file_related_event_time = now
# if we have an entry (with or without a stream), we'll generate the URL and store it in the event
filename = entry.name
key = '/'.join(x for x in (self._storage_key_prefix, ev.metric, ev.variant, filename.strip('/')) if x)
url = '/'.join(x.strip('/') for x in (e_storage_uri, key))
kwargs[entry.key_prop] = key
kwargs[entry.url_prop] = url
if not entry.stream:
# if entry has no stream, we won't upload it
entry = None
else:
if not hasattr(entry.stream, 'read'):
raise ValueError('Invalid file object %s' % entry.stream)
entry.url = url
ev.update(task=self._task_id, **kwargs)
return entry
# prepare event needing file upload
entries = []
for ev in events:
try:
e = update_and_get_file_entry(ev)
if e:
entries.append(e)
except Exception as ex:
log.warning(str(ex))
# upload the needed files
if entries:
# upload files
def upload(e):
upload_uri = e.upload_uri or storage_uri
try:
storage = self._get_storage(upload_uri)
url = storage.upload_from_stream(e.stream, e.url)
e.event.update(url=url)
except Exception as exp:
log.debug("Failed uploading to {} ({})".format(
upload_uri if upload_uri else "(Could not calculate upload uri)",
exp,
))
e.set_exception(exp)
res = file_upload_pool.map_async(upload, entries)
res.wait()
# remember the last time we uploaded a file
self._file_upload_time = time()
t_f, t_u, t_ref = \
(self._file_related_event_time, self._file_upload_time, self._file_upload_starvation_warning_sec)
if t_f and t_u and t_ref and (t_f - t_u) > t_ref:
log.warning('Possible metrics file upload starvation: files were not uploaded for %s' %
format_timespan(t_ref))
# send the events in a batched request
good_events = [ev for ev in events if ev.upload_exception is None]
error_events = [ev for ev in events if ev.upload_exception is not None]
if error_events:
log.error("Not uploading {}/{} events because the data upload failed".format(
len(error_events),
len(events),
))
if good_events:
batched_requests = [api_events.AddRequest(event=ev.get_api_event()) for ev in good_events]
req = api_events.AddBatchRequest(requests=batched_requests)
return self.send(req, raise_on_errors=False)
return None

View File

@@ -0,0 +1,457 @@
import collections
import json
import cv2
import six
from ..base import InterfaceBase
from ..setupuploadmixin import SetupUploadMixin
from ...utilities.async_manager import AsyncManagerMixin
from ...utilities.plotly import create_2d_histogram_plot, create_value_matrix, create_3d_surface, \
create_2d_scatter_series, create_3d_scatter_series, create_line_plot, plotly_scatter3d_layout_dict
from ...utilities.py3_interop import AbstractContextManager
from .events import ScalarEvent, VectorEvent, ImageEvent, PlotEvent, ImageEventNoUpload
class Reporter(InterfaceBase, AbstractContextManager, SetupUploadMixin, AsyncManagerMixin):
"""
A simple metrics reporter class.
This class caches reports and supports both a explicit flushing and context-based flushing. To ensure reports are
sent to the backend, please use (assuming an instance of Reporter named 'reporter'):
- use the context manager feature (which will automatically flush when exiting the context):
with reporter:
reporter.report...
...
- explicitly call flush:
reporter.report...
...
reporter.flush()
"""
def __init__(self, metrics, flush_threshold=10, async_enable=False):
"""
Create a reporter
:param metrics: A Metrics manager instance that handles actual reporting, uploads etc.
:type metrics: .backend_interface.metrics.Metrics
:param flush_threshold: Events flush threshold. This determines the threshold over which cached reported events
are flushed and sent to the backend.
:type flush_threshold: int
"""
log = metrics.log.getChild('reporter')
log.setLevel(log.level)
super(Reporter, self).__init__(session=metrics.session, log=log)
self._metrics = metrics
self._flush_threshold = flush_threshold
self._events = []
self._bucket_config = None
self._storage_uri = None
self._async_enable = async_enable
def _set_storage_uri(self, value):
value = '/'.join(x for x in (value.rstrip('/'), self._metrics.storage_key_prefix) if x)
self._storage_uri = value
storage_uri = property(None, _set_storage_uri)
@property
def flush_threshold(self):
return self._flush_threshold
@flush_threshold.setter
def flush_threshold(self, value):
self._flush_threshold = max(0, value)
@property
def async_enable(self):
return self._async_enable
@async_enable.setter
def async_enable(self, value):
self._async_enable = bool(value)
def _report(self, ev):
self._events.append(ev)
if len(self._events) >= self._flush_threshold:
self._write()
def _write(self):
if not self._events:
return
# print('reporting %d events' % len(self._events))
res = self._metrics.write_events(self._events, async_enable=self._async_enable, storage_uri=self._storage_uri)
if self._async_enable:
self._add_async_result(res)
self._events = []
def flush(self):
"""
Flush cached reports to backend.
"""
self._write()
# wait for all reports
if self.get_num_results() > 0:
self.wait_for_results()
def report_scalar(self, title, series, value, iter):
"""
Report a scalar value
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param value: Reported value
:type value: float
:param iter: Iteration number
:type value: int
"""
ev = ScalarEvent(metric=self._normalize_name(title),
variant=self._normalize_name(series), value=value, iter=iter)
self._report(ev)
def report_vector(self, title, series, values, iter):
"""
Report a vector of values
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param values: Reported values
:type value: [float]
:param iter: Iteration number
:type value: int
"""
if not isinstance(values, collections.Iterable):
raise ValueError('values: expected an iterable')
ev = VectorEvent(metric=self._normalize_name(title),
variant=self._normalize_name(series), values=values, iter=iter)
self._report(ev)
def report_plot(self, title, series, plot, iter):
"""
Report a Plotly chart
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param plot: A JSON describing a plotly chart (see https://help.plot.ly/json-chart-schema/)
:type plot: str or dict
:param iter: Iteration number
:type value: int
"""
if isinstance(plot, dict):
plot = json.dumps(plot)
elif not isinstance(plot, six.string_types):
raise ValueError('Plot should be a string or a dict')
ev = PlotEvent(metric=self._normalize_name(title),
variant=self._normalize_name(series), plot_str=plot, iter=iter)
self._report(ev)
def report_image(self, title, series, src, iter):
"""
Report an image.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param src: Image source URI. This URI will be used by the webapp and workers when trying to obtain the image
for presentation of processing. Currently only http(s), file and s3 schemes are supported.
:type src: str
:param iter: Iteration number
:type value: int
"""
ev = ImageEventNoUpload(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, src=src)
self._report(ev)
def report_image_and_upload(self, title, series, iter, path=None, matrix=None, upload_uri=None,
max_image_history=None):
"""
Report an image and upload its contents. Image is uploaded to a preconfigured bucket (see setup_upload()) with
a key (filename) describing the task ID, title, series and iteration.
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param iter: Iteration number
:type value: int
:param path: A path to an image file. Required unless matrix is provided.
:type path: str
:param matrix: A 3D numpy.ndarray object containing image data (BGR). Required unless filename is provided.
:type matrix: str
:param max_image_history: maximum number of image to store per metric/variant combination
use negative value for unlimited. default is set in global configuration (default=5)
"""
if not self._storage_uri and not upload_uri:
raise ValueError('Upload configuration is required (use setup_upload())')
if len([x for x in (path, matrix) if x is not None]) != 1:
raise ValueError('Expected only one of [filename, matrix]')
kwargs = dict(metric=self._normalize_name(title),
variant=self._normalize_name(series), iter=iter, image_file_history_size=max_image_history)
if matrix is None:
matrix = cv2.imread(path)
ev = ImageEvent(image_data=matrix, upload_uri=upload_uri, **kwargs)
self._report(ev)
def report_histogram(self, title, series, histogram, iter, labels=None, xlabels=None, comment=None):
"""
Report an histogram bar plot
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param histogram: The histogram data.
A row for each dataset(bar in a bar group). A column for each bucket.
:type histogram: numpy array
:param iter: Iteration number
:type value: int
:param labels: The labels for each bar group.
:type labels: list of strings.
:param xlabels: The labels of the x axis.
:type xlabels: List of strings.
:param comment: comment underneath the title
:type comment: str
"""
plotly_dict = create_2d_histogram_plot(
np_row_wise=histogram,
title=title,
labels=labels,
series=series,
xlabels=xlabels,
comment=comment,
)
return self.report_plot(
title=self._normalize_name(title),
series=self._normalize_name(series),
plot=plotly_dict,
iter=iter,
)
def report_line_plot(self, title, series, iter, xtitle, ytitle, mode='lines', reverse_xaxis=False, comment=None):
"""
Report a (possibly multiple) line plot.
:param title: Title (AKA metric)
:type title: str
:param series: All the series' data, one for each line in the plot.
:type series: An iterable of LineSeriesInfo.
:param iter: Iteration number
:type iter: int
:param xtitle: x-axis title
:type xtitle: str
:param ytitle: y-axis title
:type ytitle: str
:param mode: 'lines' / 'markers' / 'lines+markers'
:type mode: str
:param reverse_xaxis: If true X axis will be displayed from high to low (reversed)
:type reverse_xaxis: bool
:param comment: comment underneath the title
:type comment: str
"""
plotly_dict = create_line_plot(
title=title,
series=series,
xtitle=xtitle,
ytitle=ytitle,
mode=mode,
reverse_xaxis=reverse_xaxis,
comment=comment,
)
return self.report_plot(
title=self._normalize_name(title),
series='',
plot=plotly_dict,
iter=iter,
)
def report_2d_scatter(self, title, series, data, iter, mode='lines', xtitle=None, ytitle=None, labels=None,
comment=None):
"""
Report a 2d scatter graph (with lines)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param data: A scattered data: pairs of x,y as rows in a numpy array
:type scatter: ndarray
:param iter: Iteration number
:type iter: int
:param mode: (type str) 'lines'/'markers'/'lines+markers'
:param xtitle: optional x-axis title
:param ytitle: optional y-axis title
:param labels: label (text) per point in the scatter (in the same order)
:param comment: comment underneath the title
:type comment: str
"""
plotly_dict = create_2d_scatter_series(
np_row_wise=data,
title=title,
series_name=series,
mode=mode,
xtitle=xtitle,
ytitle=ytitle,
labels=labels,
comment=comment,
)
return self.report_plot(
title=self._normalize_name(title),
series=self._normalize_name(series),
plot=plotly_dict,
iter=iter,
)
def report_3d_scatter(self, title, series, data, iter, labels=None, mode='lines', color=((217, 217, 217, 0.14),),
marker_size=5, line_width=0.8, xtitle=None, ytitle=None, ztitle=None, fill=None,
comment=None):
"""
Report a 3d scatter graph (with markers)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param data: A scattered data: pairs of x,y,z as rows in a numpy array. or list of numpy arrays
:type data: ndarray.
:param iter: Iteration number
:type iter: int
:param labels: label (text) per point in the scatter (in the same order)
:type labels: str
:param mode: (type str) 'lines'/'markers'/'lines+markers'
:param color: list of RGBA colors [(217, 217, 217, 0.14),]
:param marker_size: marker size in px
:param line_width: line width in px
:param xtitle: optional x-axis title
:param ytitle: optional y-axis title
:param ztitle: optional z-axis title
:param comment: comment underneath the title
"""
data_series = data if isinstance(data, list) else [data]
def get_labels(i):
if labels and isinstance(labels, list):
try:
item = labels[i]
except IndexError:
item = labels[-1]
if isinstance(item, list):
return item
return labels
plotly_obj = plotly_scatter3d_layout_dict(
title=title,
xaxis_title=xtitle,
yaxis_title=ytitle,
zaxis_title=ztitle,
comment=comment,
)
for i, values in enumerate(data_series):
plotly_obj = create_3d_scatter_series(
np_row_wise=values,
title=title,
series_name=series[i] if isinstance(series, list) else None,
labels=get_labels(i),
plotly_obj=plotly_obj,
mode=mode,
line_width=line_width,
marker_size=marker_size,
color=color,
fill_axis=fill,
)
return self.report_plot(
title=self._normalize_name(title),
series=self._normalize_name(series) if not isinstance(series, list) else None,
plot=plotly_obj,
iter=iter,
)
def report_value_matrix(self, title, series, data, iter, xlabels=None, ylabels=None, comment=None):
"""
Report a heat-map matrix
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param data: A heat-map matrix (example: confusion matrix)
:type data: ndarray
:param iter: Iteration number
:type iter: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
:param comment: comment underneath the title
"""
plotly_dict = create_value_matrix(
np_value_matrix=data,
title=title,
xlabels=xlabels,
ylabels=ylabels,
series=series,
comment=comment,
)
return self.report_plot(
title=self._normalize_name(title),
series=self._normalize_name(series),
plot=plotly_dict,
iter=iter,
)
def report_value_surface(self, title, series, data, iter, xlabels=None, ylabels=None,
xtitle=None, ytitle=None, ztitle=None, camera=None, comment=None):
"""
Report a 3d surface (same data as heat-map matrix, only presented differently)
:param title: Title (AKA metric)
:type title: str
:param series: Series (AKA variant)
:type series: str
:param data: A heat-map matrix (example: confusion matrix)
:type data: ndarray
:param iter: Iteration number
:type iter: int
:param xlabels: optional label per column of the matrix
:param ylabels: optional label per row of the matrix
:param xtitle: optional x-axis title
:param ytitle: optional y-axis title
:param ztitle: optional z-axis title
:param camera: X,Y,Z camera position. def: (1,1,1)
:param comment: comment underneath the title
"""
plotly_dict = create_3d_surface(
np_value_matrix=data,
title=title + '/' + series,
xlabels=xlabels,
ylabels=ylabels,
series=series,
xtitle=xtitle,
ytitle=ytitle,
ztitle=ztitle,
camera=camera,
comment=comment,
)
return self.report_plot(
title=self._normalize_name(title),
series=self._normalize_name(series),
plot=plotly_dict,
iter=iter,
)
@classmethod
def _normalize_name(cls, name):
if not name:
return name
return name.replace('$', '/').replace('.', '/')
def __exit__(self, exc_type, exc_val, exc_tb):
# don't flush in case an exception was raised
if not exc_type:
self.flush()

View File

@@ -0,0 +1,408 @@
from collections import namedtuple
from functools import partial
import six
from pathlib2 import Path
from ..backend_api.services import models
from .base import IdObjectBase
from .util import make_message
from ..storage import StorageHelper
from ..utilities.async_manager import AsyncManagerMixin
ModelPackage = namedtuple('ModelPackage', 'weights design')
class ModelDoesNotExistError(Exception):
pass
class _StorageUriMixin(object):
@property
def upload_storage_uri(self):
""" A URI into which models are uploaded """
return self._upload_storage_uri
@upload_storage_uri.setter
def upload_storage_uri(self, value):
self._upload_storage_uri = value.rstrip('/') if value else None
class DummyModel(models.Model, _StorageUriMixin):
def __init__(self, upload_storage_uri=None, *args, **kwargs):
super(DummyModel, self).__init__(*args, **kwargs)
self.upload_storage_uri = upload_storage_uri
def update(self, **kwargs):
for k, v in kwargs.items():
setattr(self, k, v)
class Model(IdObjectBase, AsyncManagerMixin, _StorageUriMixin):
""" Manager for backend model objects """
_EMPTY_MODEL_ID = 'empty'
@property
def model_id(self):
return self.id
@property
def storage(self):
return StorageHelper.get(self.upload_storage_uri)
def __init__(self, upload_storage_uri, cache_dir, model_id=None,
upload_storage_suffix='models', session=None, log=None):
super(Model, self).__init__(id=model_id, session=session, log=log)
self._upload_storage_suffix = upload_storage_suffix
if model_id == self._EMPTY_MODEL_ID:
# Set an empty data object
self._data = models.Model()
else:
self._data = None
self._cache_dir = cache_dir
self.upload_storage_uri = upload_storage_uri
def publish(self):
self.send(models.SetReadyRequest(model=self.id, publish_task=False))
self.reload()
def _reload(self):
""" Reload the model object """
if self.id == self._EMPTY_MODEL_ID:
return
res = self.send(models.GetByIdRequest(model=self.id))
return res.response.model
def _upload_model(self, model_file, async_enable=False, target_filename=None, cb=None):
if not self.upload_storage_uri:
raise ValueError('Model has no storage URI defined (nowhere to upload to)')
helper = self.storage
target_filename = target_filename or Path(model_file).name
dest_path = '/'.join((self.upload_storage_uri, self._upload_storage_suffix or '.', target_filename))
result = helper.upload(
src_path=model_file,
dest_path=dest_path,
async_enable=async_enable,
cb=partial(self._upload_callback, cb=cb),
)
if async_enable:
def msg(num_results):
self.log.info("Waiting for previous model to upload (%d pending, %s)" % (num_results, dest_path))
self._add_async_result(result, wait_on_max_results=2, wait_cb=msg)
return dest_path
def _upload_callback(self, res, cb=None):
if res is None:
self.log.debug('Starting model upload')
elif res is False:
self.log.info('Failed model upload')
else:
self.log.info('Completed model upload to %s' % res)
if cb:
cb(res)
@staticmethod
def _wrap_design(design):
"""
Wrap design text with a dictionary.
In the backend, the design is a dictionary with a 'design' key in it.
For the client, it is a text. This function wraps a design string with
the proper dictionary.
:param design: If it is a dictionary, it mast have a 'design' key in it.
In that case, return design as-is.
If it is a string, return the dictionary {'design': design}.
If it is None (or any False value), return the dictionary {'design': ''}
:return: A proper design dictionary according to design parameter.
"""
if isinstance(design, dict):
if 'design' not in design:
raise ValueError('design dictionary must have \'design\' key in it')
return design
return {'design': design if design else ''}
@staticmethod
def _unwrap_design(design):
"""
Unwrap design text from a dictionary.
In the backend, the design is a dictionary with a 'design' key in it.
For the client, it is a text. This function unwraps a design string from
the dictionary.
:param design: If it is a dictionary with a 'design' key in it, return
design['design'].
If it is a dictionary without 'design' key, return the first value
in it's values list.
If it is an empty dictionary, None, or any other False value,
return an empty string.
If it is a string, return design as-is.
:return: The design string according to design parameter.
"""
if not design:
return ''
if isinstance(design, six.string_types):
return design
if isinstance(design, dict):
if 'design' in design:
return design['design']
return list(design.values())[0]
raise ValueError('design must be a string or a dictionary with at least one value')
def update(self, model_file=None, design=None, labels=None, name=None, comment=None, tags=None,
task_id=None, project_id=None, parent_id=None, uri=None, framework=None,
upload_storage_uri=None, target_filename=None, iteration=None):
""" Update model weights file and various model properties """
if self.id is None:
if upload_storage_uri:
self.upload_storage_uri = upload_storage_uri
self._create_empty_model(self.upload_storage_uri)
# upload model file if needed and get uri
uri = uri or (self._upload_model(model_file, target_filename=target_filename) if model_file else self.data.uri)
# update fields
design = self._wrap_design(design) if design else self.data.design
name = name or self.data.name
comment = comment or self.data.comment
tags = tags or self.data.tags
labels = labels or self.data.labels
task = task_id or self.data.task
project = project_id or self.data.project
parent = parent_id or self.data.parent
self.send(models.EditRequest(
model=self.id,
uri=uri,
name=name,
comment=comment,
tags=tags,
labels=labels,
design=design,
task=task,
project=project,
parent=parent,
framework=framework or self.data.framework,
iteration=iteration,
))
self.reload()
def update_and_upload(self, model_file, design=None, labels=None, name=None, comment=None,
tags=None, task_id=None, project_id=None, parent_id=None, framework=None, async_enable=False,
target_filename=None, cb=None, iteration=None):
""" Update the given model for a given task ID """
if async_enable:
def callback(uploaded_uri):
if uploaded_uri is None:
return
# If not successful, mark model as failed_uploading
if uploaded_uri is False:
uploaded_uri = '{}/failed_uploading'.format(self._upload_storage_uri)
self.update(
uri=uploaded_uri,
task_id=task_id,
name=name,
comment=comment,
tags=tags,
design=design,
labels=labels,
project_id=project_id,
parent_id=parent_id,
framework=framework,
iteration=iteration,
)
if cb:
cb(model_file)
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename, cb=callback)
return uri
else:
uri = self._upload_model(model_file, async_enable=async_enable, target_filename=target_filename)
self.update(
uri=uri,
task_id=task_id,
name=name,
comment=comment,
tags=tags,
design=design,
labels=labels,
project_id=project_id,
parent_id=parent_id,
framework=framework,
)
return uri
def _complete_update_for_task(self, uri, task_id=None, name=None, comment=None, tags=None, override_model_id=None,
cb=None):
if self._data:
name = name or self.data.name
comment = comment or self.data.comment
tags = tags or self.data.tags
uri = (uri or self.data.uri) if not override_model_id else None
res = self.send(
models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
override_model_id=override_model_id))
if self.id is None:
# update the model id. in case it was just created, this will trigger a reload of the model object
self.id = res.response.id
else:
self.reload()
try:
if cb:
cb(uri)
except Exception as ex:
self.log.warning('Failed calling callback on complete_update_for_task: %s' % str(ex))
pass
def update_for_task_and_upload(
self, model_file, task_id, name=None, comment=None, tags=None, override_model_id=None, target_filename=None,
async_enable=False, cb=None, iteration=None):
""" Update the given model for a given task ID """
if async_enable:
callback = partial(
self._complete_update_for_task, task_id=task_id, name=name, comment=comment, tags=tags,
override_model_id=override_model_id, cb=cb)
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable, cb=callback)
return uri
else:
uri = self._upload_model(model_file, target_filename=target_filename, async_enable=async_enable)
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
_ = self.send(models.UpdateForTaskRequest(task=task_id, uri=uri, name=name, comment=comment, tags=tags,
override_model_id=override_model_id, iteration=iteration))
return uri
def update_for_task(self, task_id, uri=None, name=None, comment=None, tags=None, override_model_id=None):
self._complete_update_for_task(uri, task_id, name, comment, tags, override_model_id)
@property
def model_design(self):
""" Get the model design. For now, this is stored as a single key in the design dict. """
try:
return self._unwrap_design(self.data.design)
except ValueError:
# no design is yet specified
return None
@property
def labels(self):
try:
return self.data.labels
except ValueError:
# no labels is yet specified
return None
@property
def name(self):
try:
return self.data.name
except ValueError:
# no name is yet specified
return None
@property
def comment(self):
try:
return self.data.comment
except ValueError:
# no comment is yet specified
return None
@property
def tags(self):
return self.data.tags
@property
def locked(self):
if self.id is None:
return False
return bool(self.data.ready)
def download_model_weights(self):
""" Download the model weights into a local file in our cache """
uri = self.data.uri
helper = StorageHelper.get(uri, logger=self._log, verbose=True)
return helper.download_to_file(uri, force_cache=True)
@property
def cache_dir(self):
return self._cache_dir
def save_model_design_file(self):
""" Download model description file into a local file in our cache_dir """
design = self.model_design
filename = self.data.name + '.txt'
p = Path(self.cache_dir) / filename
# we always write the original model design to file, to prevent any mishaps
# if p.is_file():
# return str(p)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(six.text_type(design))
return str(p)
def get_model_package(self):
""" Get a named tuple containing the model's weights and design """
return ModelPackage(weights=self.download_model_weights(), design=self.save_model_design_file())
def get_model_design(self):
""" Get model description (text) """
return self.model_design
@classmethod
def get_all(cls, session, log=None, **kwargs):
req = models.GetAllRequest(**kwargs)
res = cls._send(session=session, req=req, log=log)
return res
def clone(self, name, comment=None, child=True, tags=None, task=None, ready=True):
"""
Clone this model into a new model.
:param name: Name for the new model
:param comment: Optional comment for the new model
:param child: Should the new model be a child of this model? (default True)
:return: The new model's ID
"""
data = self.data
assert isinstance(data, models.Model)
parent = self.id if child else None
req = models.CreateRequest(
uri=data.uri,
name=name,
labels=data.labels,
comment=comment or data.comment,
tags=tags or data.tags,
framework=data.framework,
design=data.design,
ready=ready,
project=data.project,
parent=parent,
task=task,
)
res = self.send(req)
return res.response.id
def _create_empty_model(self, upload_storage_uri=None):
upload_storage_uri = upload_storage_uri or self.upload_storage_uri
name = make_message('Anonymous model %(time)s')
uri = '{}/uploading_file'.format(upload_storage_uri or 'file://')
req = models.CreateRequest(uri=uri, name=name, labels={})
res = self.send(req)
if not res:
return False
self.id = res.response.id
return True

View File

@@ -0,0 +1,28 @@
from abc import ABCMeta, abstractmethod
import six
class SendError(Exception):
""" A session send() error class """
@property
def result(self):
return self._result
def __init__(self, result, *args, **kwargs):
super(SendError, self).__init__(*args, **kwargs)
self._result = result
@six.add_metaclass(ABCMeta)
class SessionInterface(object):
""" Session wrapper interface providing a session property and a send convenience method """
@property
@abstractmethod
def session(self):
pass
@abstractmethod
def send(self, req, ignore_errors=False, raise_on_errors=True, async_enable=False):
pass

View File

@@ -0,0 +1,43 @@
from abc import abstractproperty
from ..backend_config.bucket_config import S3BucketConfig
from ..storage import StorageHelper
class SetupUploadMixin(object):
log = abstractproperty()
storage_uri = abstractproperty()
def setup_upload(
self, bucket_name, host=None, access_key=None, secret_key=None, region=None, multipart=True, https=True):
"""
Setup upload options (currently only S3 is supported)
:param bucket_name: AWS bucket name
:type bucket_name: str
:param host: Hostname. Only required in case a Non-AWS S3 solution such as a local Minio server is used)
:type host: str
:param access_key: AWS access key. If not provided, we'll attempt to obtain the key from the
configuration file (bucket-specific, than global)
:type access_key: str
:param secret_key: AWS secret key. If not provided, we'll attempt to obtain the secret from the
configuration file (bucket-specific, than global)
:type secret_key: str
:param multipart: Server supports multipart. Only required when using a Non-AWS S3 solution that doesn't support
multipart.
:type multipart: bool
:param https: Server supports HTTPS. Only required when using a Non-AWS S3 solution that only supports HTTPS.
:type https: bool
:param region: Bucket region. Required if the bucket doesn't reside in the default region (us-east-1)
:type region: str
"""
self._bucket_config = S3BucketConfig(
bucket=bucket_name,
host=host,
key=access_key,
secret=secret_key,
multipart=multipart,
secure=https,
region=region
)
self.storage_uri = ('s3://%(host)s/%(bucket_name)s' if host else 's3://%(bucket_name)s') % locals()
StorageHelper.add_configuration(self._bucket_config, log=self.log)

View File

@@ -0,0 +1 @@
from .task import Task, TaskEntry, TaskStatusEnum

View File

@@ -0,0 +1,85 @@
import itertools
import operator
from abc import abstractproperty
import six
from pathlib2 import Path
class AccessMixin(object):
""" A mixin providing task fields access functionality """
session = abstractproperty()
data = abstractproperty()
cache_dir = abstractproperty()
log = abstractproperty()
def _get_task_property(self, prop_path, raise_on_error=True, log_on_error=True, default=None):
obj = self.data
props = prop_path.split('.')
for i in range(len(props)):
obj = getattr(obj, props[i], None)
if obj is None:
msg = 'Task has no %s section defined' % '.'.join(props[:i + 1])
if log_on_error:
self.log.info(msg)
if raise_on_error:
raise ValueError(msg)
return default
return obj
def _set_task_property(self, prop_path, value, raise_on_error=True, log_on_error=True):
props = prop_path.split('.')
if len(props) > 1:
obj = self._get_task_property('.'.join(props[:-1]), raise_on_error=raise_on_error,
log_on_error=log_on_error)
else:
obj = self.data
setattr(obj, props[-1], value)
def save_exec_model_design_file(self, filename='model_design.txt', use_cache=False):
""" Save execution model design to file """
p = Path(self.cache_dir) / filename
if use_cache and p.is_file():
return str(p)
desc = self._get_task_property('execution.model_desc')
try:
design = six.next(six.itervalues(desc))
except StopIteration:
design = None
if not design:
raise ValueError('Task has no design in execution.model_desc')
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text('%s' % design)
return str(p)
def get_parameters(self):
return self._get_task_property('execution.parameters')
def get_label_num_description(self):
""" Get a dict of label number to a string representing all labels associated with this number on the
model labels
"""
model_labels = self._get_task_property('execution.model_labels')
label_getter = operator.itemgetter(0)
num_getter = operator.itemgetter(1)
groups = list(itertools.groupby(sorted(model_labels.items(), key=num_getter), key=num_getter))
if any(len(set(label_getter(x) for x in group)) > 1 for _, group in groups):
raise ValueError("Multiple labels mapped to same model index not supported")
return {key: ','.join(label_getter(x) for x in group) for key, group in groups}
def get_output_destination(self, extra_path=None, **kwargs):
""" Get the task's output destination, with an optional suffix """
return self._get_task_property('output.destination', **kwargs)
def get_num_of_classes(self):
""" number of classes based on the task's labels """
model_labels = self.data.execution.model_labels
expected_num_of_classes = 0
for labels, index in model_labels.items():
expected_num_of_classes += 1 if int(index) > 0 else 0
num_of_classes = int(max(model_labels.values()))
if num_of_classes != expected_num_of_classes:
self.log.warn('The highest label index is %d, while there are %d non-bg labels' %
(num_of_classes, expected_num_of_classes))
return num_of_classes + 1 # +1 is meant for bg!

View File

@@ -0,0 +1,314 @@
import yaml
from six import PY2
from argparse import _StoreAction, ArgumentError, _StoreConstAction, _SubParsersAction, SUPPRESS
from copy import copy
from ...utilities.args import call_original_argparser
class _Arguments(object):
_prefix_sep = '/'
# TODO: separate dict and argparse after we add UI support
_prefix_dict = 'dict' + _prefix_sep
_prefix_args = 'argparse' + _prefix_sep
_prefix_tf_defines = 'TF_DEFINE' + _prefix_sep
class _ProxyDictWrite(dict):
""" Dictionary wrapper that updates an arguments instance on any item set in the dictionary """
def __init__(self, arguments, *args, **kwargs):
super(_Arguments._ProxyDictWrite, self).__init__(*args, **kwargs)
self._arguments = arguments
def __setitem__(self, key, value):
super(_Arguments._ProxyDictWrite, self).__setitem__(key, value)
if self._arguments:
self._arguments.copy_from_dict(self)
class _ProxyDictReadOnly(dict):
""" Dictionary wrapper that prevents modifications to the dictionary """
def __init__(self, *args, **kwargs):
super(_Arguments._ProxyDictReadOnly, self).__init__(*args, **kwargs)
def __setitem__(self, key, value):
pass
def __init__(self, task):
super(_Arguments, self).__init__()
self._task = task
def set_defaults(self, *dicts, **kwargs):
self._task.set_parameters(*dicts, **kwargs)
def add_argument(self, option_strings, type=None, default=None, help=None):
if not option_strings:
raise Exception('Expected at least one argument name (option string)')
name = option_strings[0].strip('- \t') if isinstance(option_strings, list) else option_strings.strip('- \t')
# TODO: add argparse prefix
# name = self._prefix_args + name
self._task.set_parameter(name=name, value=default, description=help)
def connect(self, parser):
self._task.connect_argparse(parser)
@classmethod
def _add_to_defaults(cls, a_parser, defaults, a_args=None, a_namespace=None, a_parsed_args=None):
actions = [
a for a in a_parser._actions
if isinstance(a, _StoreAction) or isinstance(a, _StoreConstAction)
]
args_dict = {}
try:
if isinstance(a_parsed_args, dict):
args_dict = a_parsed_args
else:
if a_parsed_args:
args_dict = a_parsed_args.__dict__
else:
args_dict = call_original_argparser(a_parser, args=a_args, namespace=a_namespace).__dict__
defaults_ = {
a.dest: args_dict.get(a.dest) if (args_dict.get(a.dest) is not None) else ''
for a in actions
}
except Exception:
# don't crash us if we failed parsing the inputs
defaults_ = {
a.dest: a.default if a.default is not None else ''
for a in actions
}
full_args_dict = copy(defaults)
full_args_dict.update(args_dict)
defaults.update(defaults_)
# deal with sub parsers
sub_parsers = [
a for a in a_parser._actions
if isinstance(a, _SubParsersAction)
]
for sub_parser in sub_parsers:
if sub_parser.dest and sub_parser.dest != SUPPRESS:
defaults[sub_parser.dest] = full_args_dict.get(sub_parser.dest)
for choice in sub_parser.choices.values():
# recursively parse
defaults = cls._add_to_defaults(
a_parser=choice,
defaults=defaults,
a_parsed_args=a_parsed_args or full_args_dict
)
return defaults
def copy_defaults_from_argparse(self, parser, args=None, namespace=None, parsed_args=None):
task_defaults = {}
self._add_to_defaults(parser, task_defaults, args, namespace, parsed_args)
# Make sure we didn't miss anything
if parsed_args:
for k, v in parsed_args.__dict__.items():
if k not in task_defaults:
if type(v) == None:
task_defaults[k] = ''
elif type(v) in (str, int, float, bool, list):
task_defaults[k] = v
# Verify arguments
for k, v in task_defaults.items():
try:
if type(v) is list:
task_defaults[k] = '[' + ', '.join(map("{0}".format, v)) + ']'
elif type(v) not in (str, int, float, bool):
task_defaults[k] = str(v)
except Exception:
del task_defaults[k]
# Add prefix, TODO: add argparse prefix
# task_defaults = dict([(self._prefix_args + k, v) for k, v in task_defaults.items()])
task_defaults = dict([(k, v) for k, v in task_defaults.items()])
# Store to task
self._task.update_parameters(task_defaults)
@classmethod
def _find_parser_action(cls, a_parser, name):
# find by name
_actions = [(a_parser, a) for a in a_parser._actions if a.dest == name]
if _actions:
return _actions
# iterate over subparsers
_actions = []
sub_parsers = [a for a in a_parser._actions if isinstance(a, _SubParsersAction)]
for sub_parser in sub_parsers:
for choice in sub_parser.choices.values():
# recursively parse
_action = cls._find_parser_action(choice, name)
if _action:
_actions.extend(_action)
return _actions
def copy_to_parser(self, parser, parsed_args):
# todo: change to argparse prefix only
# task_arguments = dict([(k[len(self._prefix_args):], v) for k, v in self._task.get_parameters().items()
# if k.startswith(self._prefix_args)])
task_arguments = dict([(k, v) for k, v in self._task.get_parameters().items()
if not k.startswith(self._prefix_tf_defines)])
for k, v in task_arguments.items():
# if we have a StoreTrueAction and the value is either False or Empty or 0 change the default to False
# with the rest we have to make sure the type is correct
matched_actions = self._find_parser_action(parser, k)
for parent_parser, current_action in matched_actions:
if current_action and isinstance(current_action, _StoreConstAction):
# make the default value boolean
# first check if False value
const_value = current_action.const if current_action.const is not None else (
current_action.default if current_action.default is not None else True)
const_type = type(const_value)
strip_v = str(v).lower().strip()
if const_type == bool:
if strip_v == 'false' or not strip_v:
const_value = False
elif strip_v == 'true':
const_value = True
else:
# first try to cast to integer
try:
const_value = int(strip_v)
except ValueError:
pass
else:
const_value = strip_v
# then cast to const type (might be boolean)
try:
const_value = const_type(const_value)
current_action.const = const_value
except ValueError:
pass
task_arguments[k] = const_value
elif current_action and current_action.nargs == '+':
try:
v = yaml.load(v.strip())
if current_action.type:
v = [current_action.type(a) for a in v]
elif current_action.default:
v_type = type(current_action.default[0])
v = [v_type(a) for a in v]
task_arguments[k] = v
except Exception:
pass
elif current_action and not current_action.type:
# cast manually if there is no type
var_type = type(current_action.default)
# if we have an int, we should cast to float, because it is more generic
if var_type == int:
var_type = float
elif var_type == type(None):
var_type = str
# now we should try and cast the value if we can
try:
v = var_type(v)
task_arguments[k] = v
except Exception:
pass
# add as default
try:
if current_action and isinstance(current_action, _SubParsersAction):
current_action.default = v
current_action.required = False
elif current_action and isinstance(current_action, _StoreAction):
current_action.default = v
current_action.required = False
# python2 doesn't support defaults for positional arguments, unless used with nargs=?
if PY2 and not current_action.nargs:
current_action.nargs = '?'
else:
parent_parser.add_argument(
'--%s' % k,
default=v,
type=type(v),
required=False,
help='Task parameter %s (default %s)' % (k, v),
)
except ArgumentError:
pass
except Exception:
pass
# if we already have an instance of parsed args, we should update its values
if parsed_args:
for k, v in task_arguments.items():
setattr(parsed_args, k, v)
parser.set_defaults(**task_arguments)
def copy_from_dict(self, dictionary, prefix=None):
# TODO: add dict prefix
prefix = prefix or '' # self._prefix_dict
if prefix:
prefix_dictionary = dict([(prefix + k, v) for k, v in dictionary.items()])
cur_params = dict([(k, v) for k, v in self._task.get_parameters().items() if not k.startswith(prefix)])
cur_params.update(prefix_dictionary)
self._task.set_parameters(cur_params)
else:
self._task.update_parameters(dictionary)
if not isinstance(dictionary, self._ProxyDictWrite):
return self._ProxyDictWrite(self, **dictionary)
return dictionary
def copy_to_dict(self, dictionary, prefix=None):
# iterate over keys and merge values according to parameter type in dictionary
# TODO: add dict prefix
prefix = prefix or '' # self._prefix_dict
if prefix:
parameters = dict([(k[len(prefix):], v) for k, v in self._task.get_parameters().items()
if k.startswith(prefix)])
else:
parameters = dict([(k, v) for k, v in self._task.get_parameters().items()
if not k.startswith(self._prefix_tf_defines)])
for k, v in dictionary.items():
param = parameters.get(k, None)
if param is None:
continue
v_type = type(v)
# assume more general purpose type int -> float
if v_type == int:
v_type = float
elif v_type == bool:
# cast based on string or int
try:
param = bool(float(param))
except ValueError:
try:
param = str(param).lower().strip() == 'true'
except ValueError:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v)))
continue
elif v_type == list:
try:
p = str(param).strip()
param = yaml.load(p)
except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v)))
continue
elif v_type == dict:
try:
p = str(param).strip()
param = yaml.load(p)
except Exception:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v)))
elif v_type == type(None):
v_type = str
try:
dictionary[k] = v_type(param)
except ValueError:
self._task.log.warning('Failed parsing task parameter %s=%s keeping default %s=%s' %
(str(k), str(param), str(k), str(v)))
continue
# add missing parameters to dictionary
for k, v in parameters.items():
if k not in dictionary:
dictionary[k] = v
if not isinstance(dictionary, self._ProxyDictReadOnly):
return self._ProxyDictReadOnly(**dictionary)
return dictionary

View File

@@ -0,0 +1,48 @@
from ....config import config
from ....backend_interface import Task, TaskStatusEnum
class TaskStopReason(object):
stopped = "stopped"
reset = "reset"
status_changed = "status_changed"
class TaskStopSignal(object):
enabled = bool(config.get('development.support_stopping', False))
_number_of_consecutive_reset_tests = 4
_unexpected_statuses = (
TaskStatusEnum.closed,
TaskStatusEnum.stopped,
TaskStatusEnum.failed,
TaskStatusEnum.published,
)
def __init__(self, task):
assert isinstance(task, Task)
self.task = task
self._task_reset_state_counter = 0
def test(self):
status = self.task.status
message = self.task.data.status_message
if status == TaskStatusEnum.in_progress and "stopping" in message:
return TaskStopReason.stopped
if status in self._unexpected_statuses and "worker" not in message:
return TaskStopReason.status_changed
if status == TaskStatusEnum.created:
self._task_reset_state_counter += 1
if self._task_reset_state_counter >= self._number_of_consecutive_reset_tests:
return TaskStopReason.reset
self.task.get_logger().warning(
"Task {} was reset! if state is consistent we shall terminate.".format(self.task.id),
)
else:
self._task_reset_state_counter = 0

View File

@@ -0,0 +1,26 @@
from socket import gethostname
import attr
from ....config import config, running_remotely, dev_worker_name
@attr.s
class DevWorker(object):
prefix = attr.ib(type=str, default="MANUAL:")
report_period = float(config.get('development.worker.report_period_sec', 30.))
report_stdout = bool(config.get('development.worker.log_stdout', True))
@classmethod
def is_enabled(cls, model_updated=False):
return False
def status_report(self, timestamp=None):
return True
def register(self):
return True
def unregister(self):
return True

View File

@@ -0,0 +1,110 @@
import time
from logging import LogRecord, getLogger, basicConfig
from logging.handlers import BufferingHandler
from ...backend_api.services import events
from ...config import config
buffer_capacity = config.get('log.task_log_buffer_capacity', 100)
class TaskHandler(BufferingHandler):
__flush_max_history_seconds = 30.
__once = False
@property
def task_id(self):
return self._task_id
@task_id.setter
def task_id(self, value):
self._task_id = value
def __init__(self, session, task_id, capacity=buffer_capacity):
super(TaskHandler, self).__init__(capacity)
self.task_id = task_id
self.session = session
self.last_timestamp = 0
self.counter = 1
self._last_event = None
def shouldFlush(self, record):
"""
Should the handler flush its buffer?
Returns true if the buffer is up to capacity. This method can be
overridden to implement custom flushing strategies.
"""
# Notice! protect against infinite loops, i.e. flush while sending previous records
# if self.lock._is_owned():
# return False
# if we need to add handlers to the base_logger,
# it will not automatically create stream one when first used, so we must manually configure it.
if not TaskHandler.__once:
base_logger = getLogger()
if len(base_logger.handlers) == 1 and isinstance(base_logger.handlers[0], TaskHandler):
if record.name != 'console' and not record.name.startswith('trains.'):
base_logger.removeHandler(self)
basicConfig()
base_logger.addHandler(self)
TaskHandler.__once = True
else:
TaskHandler.__once = True
# if we passed the max buffer
if len(self.buffer) >= self.capacity:
return True
# if the first entry in the log was too long ago.
if len(self.buffer) and (time.time() - self.buffer[0].created) > self.__flush_max_history_seconds:
return True
return False
def _record_to_event(self, record):
# type: (LogRecord) -> events.TaskLogEvent
timestamp = int(record.created * 1000)
if timestamp == self.last_timestamp:
timestamp += self.counter
self.counter += 1
else:
self.last_timestamp = timestamp
self.counter = 1
# unite all records in a single second
if self._last_event and timestamp - self._last_event.timestamp < 1000 and \
record.levelname.lower() == str(self._last_event.level):
# ignore backspaces (they are often used)
self._last_event.msg += '\n' + record.getMessage().replace('\x08', '')
return None
self._last_event = events.TaskLogEvent(
task=self.task_id,
timestamp=timestamp,
level=record.levelname.lower(),
worker=self.session.worker,
msg=record.getMessage().replace('\x08', '') # ignore backspaces (they are often used)
)
return self._last_event
def flush(self):
if not self.buffer:
return
self.acquire()
buffer = self.buffer
try:
if not buffer:
return
self.buffer = []
record_events = [self._record_to_event(record) for record in buffer]
self._last_event = None
requests = [events.AddRequest(e) for e in record_events if e]
res = self.session.send(events.AddBatchRequest(requests=requests))
if not res.ok():
print("Failed logging task to backend ({:d} lines, {})".format(len(buffer), str(res.meta)))
except Exception:
print("Failed logging task to backend ({:d} lines)".format(len(buffer)))
finally:
self.release()

View File

@@ -0,0 +1,2 @@
from .scriptinfo import ScriptInfo
from .freeze import pip_freeze

View File

@@ -0,0 +1,248 @@
import abc
import os
from subprocess import call, CalledProcessError
import attr
import six
from pathlib2 import Path
from ....config.defs import (
VCS_REPO_TYPE,
VCS_DIFF,
VCS_STATUS,
VCS_ROOT,
VCS_BRANCH,
VCS_COMMIT_ID,
VCS_REPOSITORY_URL,
)
from ....debugging import get_logger
from .util import get_command_output
_logger = get_logger("Repository Detection")
class DetectionError(Exception):
pass
@attr.s
class Result(object):
"""" Repository information as queried by a detector """
url = attr.ib(default="")
branch = attr.ib(default="")
commit = attr.ib(default="")
root = attr.ib(default="")
status = attr.ib(default="")
diff = attr.ib(default="")
modified = attr.ib(default=False, type=bool, converter=bool)
def is_empty(self):
return not any(attr.asdict(self).values())
@six.add_metaclass(abc.ABCMeta)
class Detector(object):
""" Base class for repository detection """
"""
Commands are represented using the result class, where each attribute contains
the command used to obtain the value of the same attribute in the actual result.
"""
@attr.s
class Commands(object):
"""" Repository information as queried by a detector """
url = attr.ib(default=None, type=list)
branch = attr.ib(default=None, type=list)
commit = attr.ib(default=None, type=list)
root = attr.ib(default=None, type=list)
status = attr.ib(default=None, type=list)
diff = attr.ib(default=None, type=list)
modified = attr.ib(default=None, type=list)
def __init__(self, type_name, name=None):
self.type_name = type_name
self.name = name or type_name
def _get_commands(self):
""" Returns a RepoInfo instance containing a command for each info attribute """
return self.Commands()
def _get_command_output(self, path, name, command):
""" Run a command and return its output """
try:
return get_command_output(command, path)
except (CalledProcessError, UnicodeDecodeError) as ex:
_logger.warning(
"Can't get {} information for {} repo in {}: {}".format(
name, self.type_name, path, str(ex)
)
)
return ""
def _get_info(self, path, include_diff=False):
"""
Get repository information.
:param path: Path to repository
:param include_diff: Whether to include the diff command's output (if available)
:return: RepoInfo instance
"""
path = str(path)
commands = self._get_commands()
if not include_diff:
commands.diff = None
info = Result(
**{
name: self._get_command_output(path, name, command)
for name, command in attr.asdict(commands).items()
if command
}
)
return info
def _post_process_info(self, info):
# check if there are uncommitted changes in the current repository
return info
def get_info(self, path, include_diff=False):
"""
Get repository information.
:param path: Path to repository
:param include_diff: Whether to include the diff command's output (if available)
:return: RepoInfo instance
"""
info = self._get_info(path, include_diff)
return self._post_process_info(info)
def _is_repo_type(self, script_path):
try:
with open(os.devnull, "wb") as devnull:
return (
call(
[self.type_name, "status"],
stderr=devnull,
stdout=devnull,
cwd=str(script_path),
)
== 0
)
except CalledProcessError:
_logger.warning("Can't get {} status".format(self.type_name))
except (OSError, EnvironmentError, IOError):
# File not found or can't be executed
pass
return False
def exists(self, script_path):
"""
Test whether the given script resides in
a repository type represented by this plugin.
"""
return self._is_repo_type(script_path)
class HgDetector(Detector):
def __init__(self):
super(HgDetector, self).__init__("hg")
def _get_commands(self):
return self.Commands(
url=["hg", "paths", "--verbose"],
branch=["hg", "--debug", "id", "-b"],
commit=["hg", "--debug", "id", "-i"],
root=["hg", "root"],
status=["hg", "status"],
diff=["hg", "diff"],
modified=["hg", "status", "-m"],
)
def _post_process_info(self, info):
if info.url:
info.url = info.url.split(" = ")[1]
if info.commit:
info.commit = info.commit.rstrip("+")
return info
class GitDetector(Detector):
def __init__(self):
super(GitDetector, self).__init__("git")
def _get_commands(self):
return self.Commands(
url=["git", "remote", "get-url", "origin"],
branch=["git", "rev-parse", "--abbrev-ref", "--symbolic-full-name", "@{u}"],
commit=["git", "rev-parse", "HEAD"],
root=["git", "rev-parse", "--show-toplevel"],
status=["git", "status", "-s"],
diff=["git", "diff"],
modified=["git", "ls-files", "-m"],
)
def _post_process_info(self, info):
if info.url and not info.url.endswith(".git"):
info.url += ".git"
if (info.branch or "").startswith("origin/"):
info.branch = info.branch[len("origin/") :]
return info
class EnvDetector(Detector):
def __init__(self, type_name):
super(EnvDetector, self).__init__(type_name, "{} environment".format(type_name))
def _is_repo_type(self, script_path):
return VCS_REPO_TYPE.get(default="").lower() == self.type_name and bool(
VCS_REPOSITORY_URL.get()
)
@staticmethod
def _normalize_root(root):
"""
Get the absolute location of the parent folder (where .git resides)
"""
root_parts = list(reversed(Path(root).parts))
cwd_abs = list(reversed(Path.cwd().parts))
count = len(cwd_abs)
for i, p in enumerate(cwd_abs):
if i >= len(root_parts):
break
if p == root_parts[i]:
count -= 1
cwd_abs.reverse()
root_abs_path = Path().joinpath(*cwd_abs[:count])
return str(root_abs_path)
def _get_info(self, _, include_diff=False):
repository_url = VCS_REPOSITORY_URL.get()
if not repository_url:
raise DetectionError("No VCS environment data")
return Result(
url=repository_url,
branch=VCS_BRANCH.get(),
commit=VCS_COMMIT_ID.get(),
root=VCS_ROOT.get(converter=self._normalize_root),
status=VCS_STATUS.get(),
diff=VCS_DIFF.get(),
)
class GitEnvDetector(EnvDetector):
def __init__(self):
super(GitEnvDetector, self).__init__("git")
class HgEnvDetector(EnvDetector):
def __init__(self):
super(HgEnvDetector, self).__init__("hg")

View File

@@ -0,0 +1,11 @@
import sys
from .util import get_command_output
def pip_freeze():
try:
return get_command_output([sys.executable, "-m", "pip", "freeze"]).splitlines()
except Exception as ex:
print('Failed calling "pip freeze": {}'.format(str(ex)))
return []

View File

@@ -0,0 +1,162 @@
import os
import sys
import attr
from furl import furl
from pathlib2 import Path
from ....debugging import get_logger
from .detectors import GitEnvDetector, GitDetector, HgEnvDetector, HgDetector, Result as DetectionResult
_logger = get_logger("Repository Detection")
class ScriptInfoError(Exception):
pass
class ScriptInfo(object):
plugins = [GitEnvDetector(), HgEnvDetector(), HgDetector(), GitDetector()]
""" Script info detection plugins, in order of priority """
@classmethod
def _get_jupyter_notebook_filename(cls):
if not sys.argv[0].endswith('/ipykernel_launcher.py') or len(sys.argv) < 3 or not sys.argv[2].endswith('.json'):
return None
# we can safely assume that we can import the notebook package here
try:
from notebook.notebookapp import list_running_servers
import requests
current_kernel = sys.argv[2].split('/')[-1].replace('kernel-', '').replace('.json', '')
server_info = next(list_running_servers())
r = requests.get(
url=server_info['url'] + 'api/sessions',
headers={'Authorization': 'token {}'.format(server_info.get('token', '')), })
r.raise_for_status()
notebooks = r.json()
cur_notebook = None
for n in notebooks:
if n['kernel']['id'] == current_kernel:
cur_notebook = n
break
notebook_path = cur_notebook['notebook']['path']
entry_point_filename = notebook_path.split('/')[-1]
# now we should try to find the actual file
entry_point = (Path.cwd() / entry_point_filename).absolute()
if not entry_point.is_file():
entry_point = (Path.cwd() / notebook_path).absolute()
# now replace the .ipynb with .py
# we assume we will have that file available with the Jupyter notebook plugin
entry_point = entry_point.with_suffix('.py')
return entry_point.as_posix()
except Exception:
return None
@classmethod
def _get_entry_point(cls, repo_root, script_path):
repo_root = Path(repo_root).absolute()
try:
# Use os.path.relpath as it calculates up dir movements (../)
entry_point = os.path.relpath(str(script_path), str(Path.cwd()))
except ValueError:
# Working directory not under repository root
entry_point = script_path.relative_to(repo_root)
return Path(entry_point).as_posix()
@classmethod
def _get_working_dir(cls, repo_root):
repo_root = Path(repo_root).absolute()
try:
return Path.cwd().relative_to(repo_root).as_posix()
except ValueError:
# Working directory not under repository root
return os.path.curdir
@classmethod
def _get_script_info(cls, filepath, check_uncommitted=False, log=None):
jupyter_filepath = cls._get_jupyter_notebook_filename()
if jupyter_filepath:
script_path = Path(os.path.normpath(jupyter_filepath)).absolute()
else:
script_path = Path(os.path.normpath(filepath)).absolute()
if not script_path.is_file():
raise ScriptInfoError(
"Script file [{}] could not be found".format(filepath)
)
script_dir = script_path.parent
def _log(msg, *args, **kwargs):
if not log:
return
log.warning(
"Failed auto-detecting task repository: {}".format(
msg.format(*args, **kwargs)
)
)
plugin = next((p for p in cls.plugins if p.exists(script_dir)), None)
repo_info = DetectionResult()
if not plugin:
_log("expected one of: {}", ", ".join((p.name for p in cls.plugins)))
else:
try:
repo_info = plugin.get_info(str(script_dir), include_diff=check_uncommitted)
except Exception as ex:
_log("no info for {} ({})", script_dir, ex)
else:
if repo_info.is_empty():
_log("no info for {}", script_dir)
repo_root = repo_info.root or script_dir
working_dir = cls._get_working_dir(repo_root)
entry_point = cls._get_entry_point(repo_root, script_path)
script_info = dict(
repository=furl(repo_info.url).remove(username=True, password=True).tostr(),
branch=repo_info.branch,
version_num=repo_info.commit,
entry_point=entry_point,
working_dir=working_dir,
diff=repo_info.diff,
)
messages = []
if repo_info.modified:
messages.append(
"======> WARNING! UNCOMMITTED CHANGES IN REPOSITORY {} <======".format(
script_info.get("repository", "")
)
)
if not any(script_info.values()):
script_info = None
return ScriptInfoResult(script=script_info, warning_messages=messages)
@classmethod
def get(cls, filepath=sys.argv[0], check_uncommitted=False, log=None):
try:
return cls._get_script_info(
filepath=filepath, check_uncommitted=check_uncommitted, log=log
)
except Exception as ex:
if log:
log.warning("Failed auto-detecting task repository: {}".format(ex))
return ScriptInfoResult()
@attr.s
class ScriptInfoResult(object):
script = attr.ib(default=None)
warning_messages = attr.ib(factory=list)

View File

@@ -0,0 +1,12 @@
import os
from subprocess import check_output
def get_command_output(command, path=None):
"""
Run a command and return its output
:raises CalledProcessError: when command execution fails
:raises UnicodeDecodeError: when output decoding fails
"""
with open(os.devnull, "wb") as devnull:
return check_output(command, cwd=path, stderr=devnull).decode().strip()

View File

@@ -0,0 +1,811 @@
""" Backend task management support """
import collections
import itertools
import logging
from copy import copy
from six.moves.urllib.parse import urlparse, urlunparse
import six
from ...backend_interface.task.development.worker import DevWorker
from ...backend_api import Session
from ...backend_api.services import tasks, models, events, projects
from pathlib2 import Path
from pyhocon import ConfigTree, ConfigFactory
from ..base import IdObjectBase
from ..metrics import Metrics, Reporter
from ..model import Model
from ..setupuploadmixin import SetupUploadMixin
from ..util import make_message, get_or_create_project, get_single_result, \
exact_match_regex
from ...config import get_config_for_bucket, get_remote_task_id, TASK_ID_ENV_VAR, get_log_to_backend, \
running_remotely, get_cache_dir, config_obj
from ...debugging import get_logger
from ...debugging.log import LoggerRoot
from ...storage import StorageHelper
from ...storage.helper import StorageError
from .access import AccessMixin
from .log import TaskHandler
from .repo import ScriptInfo
from ...config import config
TaskStatusEnum = tasks.TaskStatusEnum
class TaskEntry(tasks.CreateRequest):
pass
class Task(IdObjectBase, AccessMixin, SetupUploadMixin):
""" Task manager providing task object access and management. Includes read/write access to task-associated
frames and models.
"""
_anonymous_dataview_id = '__anonymous__'
def __init__(self, session=None, task_id=None, log=None, project_name=None,
task_name=None, task_type=tasks.TaskTypeEnum.training, log_to_backend=True,
raise_on_validation_errors=True, force_create=False):
"""
Create a new task instance.
:param session: Optional API Session instance. If not provided, a default session based on the system's
configuration will be used.
:type session: Session
:param task_id: Optional task ID. If not provided, a new task will be created using the API
and its information reflected in the resulting instance.
:type task_id: string
:param log: Optional log to be used. If not provided, and internal log shared with all backend objects will be
used instead.
:type log: logging.Logger
:param project_name: Optional project name, used only if a new task is created. The new task will be associated
with a project by this name. If no such project exists, a new project will be created using the API.
:type project_name: str
:param task_name: Optional task name, used only if a new task is created.
:type project_name: str
:param task_type: Optional task type, used only if a new task is created. Default is custom task.
:type project_name: str (see tasks.TaskTypeEnum)
:param log_to_backend: If True, all calls to the task's log will be logged to the backend using the API.
This value can be overridden using the environment variable TRAINS_LOG_TASK_TO_BACKEND.
:type log_to_backend: bool
:param force_create: If True a new task will always be created (task_id, if provided, will be ignored)
:type force_create: bool
"""
task_id = self._resolve_task_id(task_id, log=log) if not force_create else None
super(Task, self).__init__(id=task_id, session=session, log=log)
self._storage_uri = None
self._input_model = None
self._output_model = None
self._metrics_manager = None
self._reporter = None
self._curr_label_stats = {}
self._raise_on_validation_errors = raise_on_validation_errors
self._parameters_allowed_types = (
six.string_types + six.integer_types + (six.text_type, float, list, dict, type(None))
)
if not task_id:
# generate a new task
self.id = self._auto_generate(project_name=project_name, task_name=task_name, task_type=task_type)
else:
# this is an existing task, let's try to verify stuff
self._validate()
if running_remotely() or DevWorker.report_stdout:
log_to_backend = False
self._log_to_backend = log_to_backend
self._setup_log(default_log_to_backend=log_to_backend)
def _setup_log(self, default_log_to_backend=None, replace_existing=False):
"""
Setup logging facilities for this task.
:param default_log_to_backend: Should this task log to the backend. If not specified, value for this option
will be obtained from the environment, with this value acting as a default in case configuration for this is
missing.
If the value for this option is false, we won't touch the current logger configuration regarding TaskHandler(s)
:param replace_existing: If True and another task is already logging to the backend, replace the handler with
a handler for this task.
"""
# Make sure urllib is never in debug/info,
disable_urllib3_info = config.get('log.disable_urllib3_info', True)
if disable_urllib3_info and logging.getLogger('urllib3').isEnabledFor(logging.INFO):
logging.getLogger('urllib3').setLevel(logging.WARNING)
log_to_backend = get_log_to_backend(default=default_log_to_backend) or self._log_to_backend
if not log_to_backend:
return
# Handle the root logger and our own logger. We use set() to make sure we create no duplicates
# in case these are the same logger...
loggers = {logging.getLogger(), LoggerRoot.get_base_logger()}
# Find all TaskHandler handlers for these loggers
handlers = {logger: h for logger in loggers for h in logger.handlers if isinstance(h, TaskHandler)}
if handlers and not replace_existing:
# Handlers exist and we shouldn't replace them
return
# Remove all handlers, we'll add new ones
for logger, handler in handlers.items():
logger.removeHandler(handler)
# Create a handler that will be used in all loggers. Since our handler is a buffering handler, using more
# than one instance to report to the same task will result in out-of-order log reports (grouped by whichever
# handler instance handled them)
backend_handler = TaskHandler(self.session, self.task_id)
# Add backend handler to both loggers:
# 1. to root logger root logger
# 2. to our own logger as well, since our logger is not propagated to the root logger
# (if we propagate our logger will be caught be the root handlers as well, and
# we do not want that)
for logger in loggers:
logger.addHandler(backend_handler)
def _validate(self, check_output_dest_credentials=True):
raise_errors = self._raise_on_validation_errors
output_dest = self.get_output_destination(raise_on_error=False, log_on_error=False)
if output_dest and check_output_dest_credentials:
try:
self.log.info('Validating output destination')
conf = get_config_for_bucket(base_url=output_dest)
if not conf:
msg = 'Failed resolving output destination (no credentials found for %s)' % output_dest
self.log.warn(msg)
if raise_errors:
raise Exception(msg)
else:
StorageHelper._test_bucket_config(conf=conf, log=self.log, raise_on_error=raise_errors)
except StorageError:
raise
except Exception as ex:
self.log.error('Failed trying to verify output destination: %s' % ex)
@classmethod
def _resolve_task_id(cls, task_id, log=None):
if not task_id:
task_id = cls.normalize_id(get_remote_task_id())
if task_id:
log = log or get_logger('task')
log.info('Using task ID from env %s=%s' % (TASK_ID_ENV_VAR[0], task_id))
return task_id
def _update_repository(self):
result = ScriptInfo.get(log=self.log)
for msg in result.warning_messages:
self.get_logger().console(msg)
self.data.script = result.script
# Since we might run asynchronously, don't use self.data (lest someone else
# overwrite it before we have a chance to call edit)
self._edit(script=result.script)
def _auto_generate(self, project_name=None, task_name=None, task_type=tasks.TaskTypeEnum.training):
created_msg = make_message('Auto-generated at %(time)s by %(user)s@%(host)s')
project_id = None
if project_name:
project_id = get_or_create_project(self, project_name, created_msg)
tags = ['development'] if not running_remotely() else []
req = tasks.CreateRequest(
name=task_name or make_message('Anonymous task (%(user)s@%(host)s %(time)s)'),
type=task_type,
comment=created_msg,
project=project_id,
input={'view': {}},
tags=tags,
)
res = self.send(req)
return res.response.id
def _set_storage_uri(self, value):
value = value.rstrip('/')
self._storage_uri = StorageHelper.conform_url(value)
self.data.output.destination = self._storage_uri
self._edit(output_dest=self._storage_uri)
self.output_model.upload_storage_uri = self._storage_uri
@property
def storage_uri(self):
if self._storage_uri:
return self._storage_uri
if running_remotely():
return self.data.output.destination
else:
return None
@storage_uri.setter
def storage_uri(self, value):
self._set_storage_uri(value)
@property
def task_id(self):
return self.id
@property
def name(self):
return self.data.name
@property
def task_type(self):
return self.data.type
@property
def project(self):
return self.data.project
@property
def input_model_id(self):
return self.data.execution.model
@property
def output_model_id(self):
return self.data.output.model
@property
def comment(self):
return self.data.comment
@property
def cache_dir(self):
""" Cache dir used to store task related files """
return Path(get_cache_dir()) / self.id
@property
def status(self):
""" The task's status. In order to stay updated, we always reload the task info when this value is accessed. """
self.reload()
return self._status
@property
def _status(self):
""" Return the task's cached status (don't reload if we don't have to) """
return self.data.status
@property
def input_model(self):
""" A model manager used to handle the input model object """
model_id = self._get_task_property('execution.model', raise_on_error=False)
if not model_id:
return None
if self._input_model is None:
self._input_model = Model(
session=self.session,
model_id=model_id,
cache_dir=self.cache_dir,
log=self.log,
upload_storage_uri=None)
return self._input_model
@property
def output_model(self):
""" A model manager used to manage the output model object """
if self._output_model is None:
self._output_model = self._get_output_model(upload_required=True)
return self._output_model
def create_output_model(self):
return self._get_output_model(upload_required=False, force=True)
def _get_output_model(self, upload_required=True, force=False):
return Model(
session=self.session,
model_id=None if force else self._get_task_property(
'output.model', raise_on_error=False, log_on_error=False),
cache_dir=self.cache_dir,
upload_storage_uri=self.storage_uri or self.get_output_destination(
raise_on_error=upload_required, log_on_error=upload_required),
upload_storage_suffix=self._get_output_destination_suffix('models'),
log=self.log)
@property
def metrics_manager(self):
""" A metrics manager used to manage the metrics related to this task """
return self._get_metrics_manager(self.get_output_destination())
@property
def reporter(self):
"""
Returns a simple metrics reporter instance
"""
if self._reporter is None:
try:
storage_uri = self.get_output_destination(log_on_error=False)
except ValueError:
storage_uri = None
self._reporter = Reporter(self._get_metrics_manager(storage_uri=storage_uri))
return self._reporter
def _get_metrics_manager(self, storage_uri):
if self._metrics_manager is None:
self._metrics_manager = Metrics(
session=self.session,
task_id=self.id,
storage_uri=storage_uri,
storage_uri_suffix=self._get_output_destination_suffix('metrics')
)
return self._metrics_manager
def _get_output_destination_suffix(self, extra_path=None):
return '/'.join(x for x in ('task_%s' % self.data.id, extra_path) if x)
def _reload(self):
""" Reload the task object from the backend """
res = self.send(tasks.GetByIdRequest(task=self.id))
return res.response.task
def reset(self, set_started_on_success=True):
""" Reset the task. Task will be reloaded following a successful reset. """
self.send(tasks.ResetRequest(task=self.id))
if set_started_on_success:
self.started()
self.reload()
def started(self, ignore_errors=True):
""" Signal that this task has started """
return self.send(tasks.StartedRequest(self.id), ignore_errors=ignore_errors)
def stopped(self, ignore_errors=True):
""" Signal that this task has stopped """
return self.send(tasks.StoppedRequest(self.id), ignore_errors=ignore_errors)
def mark_failed(self, ignore_errors=True, status_reason=None, status_message=None):
""" Signal that this task has stopped """
return self.send(tasks.FailedRequest(self.id, status_reason=status_reason, status_message=status_message),
ignore_errors=ignore_errors)
def publish(self, ignore_errors=True):
""" Signal that this task will be published """
if self.status != tasks.TaskStatusEnum.stopped:
raise ValueError("Can't publish, Task is not stopped")
resp = self.send(tasks.PublishRequest(self.id), ignore_errors=ignore_errors)
assert isinstance(resp.response, tasks.PublishResponse)
return resp
def update_model_desc(self, new_model_desc_file=None):
""" Change the task's model_desc """
execution = self._get_task_property('execution')
p = Path(new_model_desc_file)
if not p.is_file():
raise IOError('mode_desc file %s cannot be found' % new_model_desc_file)
new_model_desc = p.read_text()
model_desc_key = list(execution.model_desc.keys())[0] if execution.model_desc else 'design'
execution.model_desc[model_desc_key] = new_model_desc
res = self._edit(execution=execution)
return res.response
def update_output_model(self, model_uri, name=None, comment=None, tags=None):
"""
Update the task's output model.
Note that this method only updates the model's metadata using the API and does not upload any data. Use this
method to update the output model when you have a local model URI (e.g. storing the weights file locally and
providing a file://path/to/file URI)
:param model_uri: URI for the updated model weights file
:type model_uri: str
:param name: Optional updated model name
:type name: str
:param comment: Optional updated model description
:type comment: str
:param tags: Optional updated model tags
:type tags: [str]
"""
self._conditionally_start_task()
self._get_output_model(upload_required=False).update_for_task(model_uri, self.id, name, comment, tags)
def update_output_model_and_upload(
self, model_file, name=None, comment=None, tags=None, async_enable=False, cb=None, iteration=None):
"""
Update the task's output model weights file. File is first uploaded to the preconfigured output destination (see
task's output.destination property or call setup_upload()), than the model object associated with the task is
updated using an API call with the URI of the uploaded file (and other values provided by additional arguments)
:param model_file: Path to the updated model weights file
:type model_file: str
:param name: Optional updated model name
:type name: str
:param comment: Optional updated model description
:type comment: str
:param tags: Optional updated model tags
:type tags: [str]
:param async_enable: Request asynchronous upload. If False, the call blocks until upload is completed and the
API call updating the model returns. If True, the call returns immediately, while upload and update are
scheduled in another thread. Default is False.
:type async_enable: bool
:param cb: Asynchronous callback. If async=True, this callback will be invoked once the asynchronous upload and
update have completed.
:return: The URI of the uploaded weights file. If async=True, this is the expected URI as the upload is
probably still in progress.
"""
self._conditionally_start_task()
uri = self.output_model.update_for_task_and_upload(
model_file, self.id, name=name, comment=comment, tags=tags, async_enable=async_enable, cb=cb,
iteration=iteration
)
return uri
def _conditionally_start_task(self):
if self.status == TaskStatusEnum.created:
self.started()
@property
def labels_stats(self):
""" Get accumulated label stats for the current/last frames iteration """
return self._curr_label_stats
def _accumulate_label_stats(self, roi_stats, reset=False):
if reset:
self._curr_label_stats = {}
for label in roi_stats:
if label in self._curr_label_stats:
self._curr_label_stats[label] += roi_stats[label]
else:
self._curr_label_stats[label] = roi_stats[label]
def set_input_model(self, model_id=None, model_name=None, update_task_design=True, update_task_labels=True):
"""
Set a new input model for this task. Model must be 'ready' in order to be used as the Task's input model.
:param model_id: ID for a model that exists in the backend. Required if model_name is not provided.
:param model_name: Model name. Required if model_id is not provided. If provided, this name will be used to
locate an existing model in the backend.
:param update_task_design: if True, the task's model design will be copied from the input model
:param update_task_labels: if True, the task's label enumeration will be copied from the input model
"""
if model_id is None and not model_name:
raise ValueError('Expected one of [model_id, model_name]')
if model_name:
# Try getting the model by name. Limit to 10 results.
res = self.send(
models.GetAllRequest(
name=exact_match_regex(model_name),
ready=True,
page=0,
page_size=10,
order_by='-created',
only_fields=['id']
)
)
model = get_single_result(entity='model', query=model_name, results=res.response.models, log=self.log)
model_id = model.id
if model_id:
res = self.send(models.GetByIdRequest(model=model_id))
model = res.response.model
if not model.ready:
# raise ValueError('Model %s is not published (not ready)' % model_id)
self.log.debug('Model %s [%s] is not published yet (not ready)' % (model_id, model.uri))
else:
# clear the input model
model = None
model_id = ''
# store model id
self.data.execution.model = model_id
# Auto populate input field from model, if they are empty
if update_task_design and not self.data.execution.model_desc:
self.data.execution.model_desc = model.design if model else ''
if update_task_labels and not self.data.execution.model_labels:
self.data.execution.model_labels = model.labels if model else {}
self._edit(execution=self.data.execution)
def set_parameters(self, *args, **kwargs):
"""
Set parameters for this task. This allows setting a complete set of key/value parameters, but does not support
parameter descriptions (as the input is a dictionary or key/value pairs.
:param args: Positional arguments (one or more dictionary or (key, value) iterable). These will be merged into
a single key/value dictionary.
:param kwargs: Key/value pairs, merged into the parameters dictionary created from `args`.
"""
if not all(isinstance(x, (dict, collections.Iterable)) for x in args):
raise ValueError('only dict or iterable are supported as positional arguments')
update = kwargs.pop('__update', False)
parameters = dict() if not update else self.get_parameters()
parameters.update(itertools.chain.from_iterable(x.items() if isinstance(x, dict) else x for x in args))
parameters.update(kwargs)
not_allowed = {
k: type(v).__name__
for k, v in parameters.items()
if not isinstance(v, self._parameters_allowed_types)
}
if not_allowed:
raise ValueError(
"Only builtin types ({}) are allowed for values (got {})".format(
', '.join(t.__name__ for t in self._parameters_allowed_types),
', '.join('%s=>%s' % p for p in not_allowed.items())),
)
# force cast all variables to strings (so that we can later edit them in UI)
parameters = {k: str(v) if v is not None else "" for k, v in parameters.items()}
execution = self.data.execution
if execution is None:
execution = tasks.Execution(parameters=parameters)
else:
execution.parameters = parameters
self._edit(execution=execution)
def set_parameter(self, name, value, description=None):
"""
Set a single task parameter. This overrides any previous value for this parameter.
:param name: Parameter name
:param value: Parameter value
:param description: Parameter description (unused for now)
"""
params = self.get_parameters()
params[name] = value
self.set_parameters(params)
def get_parameter(self, name, default=None):
"""
Get a value for a parameter.
:param name: Parameter name
:param default: Default value
:return: Parameter value (or default value if parameter is not defined)
"""
params = self.get_parameters()
return params.get(name, default)
def update_parameters(self, *args, **kwargs):
"""
Update parameters for this task.
This allows updating a complete set of key/value parameters,but does not support
parameter descriptions (as the input is a dictionary or key/value pairs.
:param args: Positional arguments (one or more dictionary or (key, value) iterable). These will be merged into
a single key/value dictionary.
:param kwargs: Key/value pairs, merged into the parameters dictionary created from `args`.
"""
self.set_parameters(__update=True, *args, **kwargs)
def set_model_label_enumeration(self, enumeration=None):
enumeration = enumeration or {}
execution = self.data.execution
if enumeration is None:
return
if not (isinstance(enumeration, dict)
and all(isinstance(k, six.string_types) and isinstance(v, int) for k, v in enumeration.items())):
raise ValueError('Expected label to be a dict[str => int]')
execution.model_labels = enumeration
self._edit(execution=execution)
def _set_model_design(self, design=None):
execution = self.data.execution
if design is not None:
execution.model_desc = Model._wrap_design(design)
self._edit(execution=execution)
def get_labels_enumeration(self):
"""
Return a dictionary of labels (text) to ids (integers) {str(label): integer(id)}
:return:
"""
if not self.data or not self.data.execution:
return {}
return self.data.execution.model_labels
def get_model_design(self):
"""
Returns the model configuration as blob of text
:return:
"""
design = self._get_task_property("execution.model_desc", default={}, raise_on_error=False, log_on_error=False)
return Model._unwrap_design(design)
def set_output_model_id(self, model_id):
self.data.output.model = str(model_id)
self._edit(output=self.data.output)
def get_random_seed(self):
# fixed seed for the time being
return 1337
def set_random_seed(self, random_seed):
# fixed seed for the time being
pass
def set_project(self, project_id):
assert isinstance(project_id, six.string_types)
self._set_task_property("project", project_id)
self._edit(project=project_id)
def get_project_name(self):
if self.project is None:
return None
res = self.send(projects.GetByIdRequest(project=self.project), raise_on_errors=False)
return res.response.project.name
def get_tags(self):
return self._get_task_property("tags")
def set_tags(self, tags):
assert isinstance(tags, (list, tuple))
self._set_task_property("tags", tags)
self._edit(tags=self.data.tags)
def _get_default_report_storage_uri(self):
app_host = self._get_app_server()
parsed = urlparse(app_host)
if parsed.port:
parsed = parsed._replace(netloc=parsed.netloc.replace(':%d' % parsed.port, ':8081'))
else:
parsed = parsed._replace(netloc=parsed.netloc+':8081')
return urlunparse(parsed)
def _get_app_server(self):
host = config_obj.get('api.host')
if '://demoapi.' in host:
return host.replace('://demoapi.', '://demoapp.')
if '://api.' in host:
return host.replace('://api.', '://app.')
parsed = urlparse(host)
if parsed.port == 8008:
return host.replace(':8008', ':8080')
def _edit(self, **kwargs):
# Since we ae using forced update, make sure he task status is valid
if not self._data or (self.data.status not in (TaskStatusEnum.created, TaskStatusEnum.in_progress)):
raise ValueError('Task object can only be updated if created or in_progress')
res = self.send(tasks.EditRequest(task=self.id, force=True, **kwargs), raise_on_errors=False)
return res
@classmethod
def create_new_task(cls, session, task_entry, log=None):
"""
Create a new task
:param session: Session object used for sending requests to the API
:type session: Session
:param task_entry: A task entry instance
:type task_entry: TaskEntry
:param log: Optional log
:type log: logging.Logger
:return: A new Task instance
"""
if isinstance(task_entry, dict):
task_entry = TaskEntry(**task_entry)
assert isinstance(task_entry, TaskEntry)
res = cls._send(session=session, req=task_entry, log=log)
return cls(session, task_id=res.response.id)
@classmethod
def clone_task(cls, cloned_task_id, name, comment=None, execution_overrides=None,
tags=None, parent=None, project=None, log=None, session=None):
"""
Clone a task
:param session: Session object used for sending requests to the API
:type session: Session
:param cloned_task_id: Task ID for the task to be cloned
:type cloned_task_id: str
:param name: New for the new task
:type name: str
:param comment: Optional comment for the new task
:type comment: str
:param execution_overrides: Task execution overrides. Applied over the cloned task's execution
section, useful for overriding values in the cloned task.
:type execution_overrides: dict
:param tags: Optional updated model tags
:type tags: [str]
:param parent: Optional parent ID of the new task.
:type parent: str
:param project: Optional project ID of the new task.
If None, the new task will inherit the cloned task's project.
:type parent: str
:param log: Log object used by the infrastructure.
:type log: logging.Logger
:return: The new tasks's ID
"""
session = session if session else cls._get_default_session()
res = cls._send(session=session, log=log, req=tasks.GetByIdRequest(task=cloned_task_id))
task = res.response.task
output_dest = None
if task.output:
output_dest = task.output.destination
execution = task.execution.to_dict() if task.execution else {}
execution = ConfigTree.merge_configs(ConfigFactory.from_dict(execution),
ConfigFactory.from_dict(execution_overrides or {}))
req = tasks.CreateRequest(
name=name,
type=task.type,
input=task.input,
tags=tags if tags is not None else task.tags,
comment=comment or task.comment,
parent=parent,
project=project if project else task.project,
output_dest=output_dest,
execution=execution.as_plain_ordered_dict(),
script=task.script
)
res = cls._send(session=session, log=log, req=req)
return res.response.id
@classmethod
def enqueue_task(cls, task_id, session=None, queue_id=None, log=None):
"""
Enqueue a task for execution
:param session: Session object used for sending requests to the API
:type session: Session
:param task_id: ID of the task to be enqueued
:type task_id: str
:param queue_id: ID of the queue in which to enqueue the task. If not provided, the default queue will be used.
:type queue_id: str
:param log: Log object
:type log: logging.Logger
:return: enqueue response
"""
assert isinstance(task_id, six.string_types)
req = tasks.EnqueueRequest(task=task_id, queue=queue_id)
res = cls._send(session=session, req=req, log=log)
resp = res.response
return resp
@classmethod
def get_all(cls, session, log=None, **kwargs):
"""
Get all tasks
:param session: Session object used for sending requests to the API
:type session: Session
:param log: Log object
:type log: logging.Logger
:param kwargs: Keyword args passed to the GetAllRequest (see .backend_api.services.tasks.GetAllRequest)
:type kwargs: dict
:return: API response
"""
req = tasks.GetAllRequest(**kwargs)
res = cls._send(session=session, req=req, log=log)
return res
@classmethod
def get_by_name(cls, task_name):
res = cls._send(cls._get_default_session(), tasks.GetAllRequest(name=exact_match_regex(task_name)))
task = get_single_result(entity='task', query=task_name, results=res.response.tasks)
return cls(task_id=task.id)
def _get_all_events(self, max_events=100):
"""
Get a list of all reported events.
Warning: Debug only. Do not use outside of testing.
:param max_events: The maximum events the function will return. Pass None
to return all the reported events.
:return: A list of events from the task.
"""
log_events = self.send(events.GetTaskEventsRequest(
task=self.id,
order='asc',
batch_size=max_events,
))
events_list = log_events.response.events
total_events = log_events.response.total
scroll = log_events.response.scroll_id
while len(events_list) < total_events and (max_events is None or len(events_list) < max_events):
log_events = self.send(events.GetTaskEventsRequest(
task=self.id,
order='asc',
batch_size=max_events,
scroll_id=scroll,
))
events_list.extend(log_events.response.events)
scroll = log_events.response.scroll_id
return events_list

View File

@@ -0,0 +1,77 @@
import getpass
import re
from _socket import gethostname
from datetime import datetime
from ..backend_api.services import projects
from ..debugging.log import get_logger
def make_message(s, **kwargs):
args = dict(
user=getpass.getuser(),
host=gethostname(),
time=datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')
)
args.update(kwargs)
return s % args
def get_or_create_project(session, project_name, description=None):
res = session.send(projects.GetAllRequest(name=exact_match_regex(project_name)))
if res.response.projects:
return res.response.projects[0].id
res = session.send(projects.CreateRequest(name=project_name, description=description))
return res.response.id
def get_single_result(entity, query, results, log=None, show_results=10, raise_on_error=True):
if not results:
if not raise_on_error:
return None
raise ValueError('No {entity}s found when searching for `{query}`'.format(**locals()))
if not log:
log = get_logger()
if len(results) > 1:
log.warn('More than one {entity} found when searching for `{query}`'
' (showing first {show_results} {entity}s follow)'.format(**locals()))
for obj in (o if isinstance(o, dict) else o.to_dict() for o in results[:show_results]):
log.warn('Found {entity} `{obj[name]}` (id={obj[id]})'.format(**locals()))
if raise_on_error:
raise ValueError('More than one {entity}s found when searching for ``{query}`'.format(**locals()))
return results[0]
def at_least_one(_exception_cls=Exception, **kwargs):
actual = [k for k, v in kwargs.items() if v]
if len(actual) < 1:
raise _exception_cls('At least one of (%s) is required' % ', '.join(kwargs.keys()))
def mutually_exclusive(_exception_cls=Exception, _require_at_least_one=True, **kwargs):
""" Helper for checking mutually exclusive options """
actual = [k for k, v in kwargs.items() if v]
if _require_at_least_one:
at_least_one(_exception_cls=_exception_cls, **kwargs)
if len(actual) > 1:
raise _exception_cls('Only one of (%s) is allowed' % ', '.join(kwargs.keys()))
def validate_dict(obj, key_types, value_types, desc=''):
if not isinstance(obj, dict):
raise ValueError('%sexpected a dictionary' % ('%s: ' % desc if desc else ''))
if not all(isinstance(l, key_types) for l in obj.keys()):
raise ValueError('%skeys must all be strings' % ('%s ' % desc if desc else ''))
if not all(isinstance(l, value_types) for l in obj.values()):
raise ValueError('%svalues must all be integers' % ('%s ' % desc if desc else ''))
def exact_match_regex(name):
""" Convert string to a regex representing an exact match """
return '^%s$' % re.escape(name)

64
trains/config/__init__.py Normal file
View File

@@ -0,0 +1,64 @@
""" Configuration module. Uses backend_config to load system configuration. """
import logging
from os.path import expandvars, expanduser
from ..backend_api import load_config
from ..backend_config.bucket_config import S3BucketConfigurations
from .defs import *
from .remote import running_remotely_task_id as _running_remotely_task_id
config_obj = load_config(Path(__file__).parent)
config_obj.initialize_logging()
config = config_obj.get("sdk")
""" Configuration object reflecting the merged SDK section of all available configuration files """
def get_cache_dir():
cache_base_dir = Path(
expandvars(
expanduser(
config.get("storage.cache.default_base_dir") or DEFAULT_CACHE_DIR
)
)
)
return cache_base_dir
def get_config_for_bucket(base_url, extra_configurations=None):
config_list = S3BucketConfigurations.from_config(config.get("aws.s3"))
for configuration in extra_configurations or []:
config_list.add_config(configuration)
return config_list.get_config_by_uri(base_url)
def get_remote_task_id():
return None
def running_remotely():
return False
def get_log_to_backend(default=None):
return LOG_TO_BACKEND_ENV_VAR.get(default=default)
def get_node_id(default=0):
return NODE_ID_ENV_VAR.get(default=default)
def get_log_redirect_level():
""" Returns which log level (and up) should be redirected to stderr. None means no redirection. """
value = LOG_STDERR_REDIRECT_LEVEL.get()
try:
if value:
return logging._checkLevel(value)
except (ValueError, TypeError):
pass
def dev_worker_name():
return DEV_WORKER_NAME.get()

40
trains/config/cache.py Normal file
View File

@@ -0,0 +1,40 @@
import json
from . import get_cache_dir
from .defs import SESSION_CACHE_FILE
class SessionCache(object):
"""
Handle SDK session cache.
TODO: Improve error handling to something like "except (FileNotFoundError, PermissionError, JSONDecodeError)"
TODO: that's both six-compatible and tested
"""
@classmethod
def _load_cache(cls):
try:
with (get_cache_dir() / SESSION_CACHE_FILE).open("rt") as fp:
return json.load(fp)
except Exception:
return {}
@classmethod
def _store_cache(cls, cache):
try:
get_cache_dir().mkdir(parents=True, exist_ok=True)
with (get_cache_dir() / SESSION_CACHE_FILE).open("wt") as fp:
json.dump(cache, fp)
except Exception:
pass
@classmethod
def store_dict(cls, unique_cache_name, dict_object):
# type: (str, dict) -> None
cache = cls._load_cache()
cache[unique_cache_name] = dict_object
cls._store_cache(cache)
@classmethod
def load_dict(cls, unique_cache_name):
# type: (str) -> dict
cache = cls._load_cache()
return cache.get(unique_cache_name, {}) if cache else {}

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,132 @@
from pyhocon import ConfigFactory
from pathlib2 import Path
from six.moves.urllib.parse import urlparse, urlunparse
from trains.backend_api.session.defs import ENV_HOST
from trains.backend_config.defs import LOCAL_CONFIG_FILES
from trains.config import config_obj
description = """
Please create new key/secrete credentials using {}/admin
Copy/Paste credentials here: """
try:
def_host = ENV_HOST.get(default=config_obj.get("api.host"))
except Exception:
def_host = 'http://localhost:8080'
host_description = """
Editing configuration file: {CONFIG_FILE}
Enter your trains-server host [{HOST}]: """.format(
CONFIG_FILE=LOCAL_CONFIG_FILES[0],
HOST=def_host,
)
def main():
print('TRAINS SDK setup process')
conf_file = Path(LOCAL_CONFIG_FILES[0]).absolute()
if conf_file.exists() and conf_file.is_file() and conf_file.stat().st_size > 0:
print('Configuration file already exists: {}'.format(str(conf_file)))
print('Leaving setup, feel free to edit the configuration file.')
return
print(host_description, end='')
parsed_host = None
while not parsed_host:
parse_input = input()
if not parse_input:
parse_input = def_host
try:
parsed_host = urlparse(parse_input)
if parsed_host.scheme not in ('http', 'https'):
parsed_host = None
except Exception:
parsed_host = None
print('Could not parse url {}\nEnter your trains-server host: '.format(parse_input), end='')
if parsed_host.port == 8080:
# this is a docker 8080 is the web address, we need the api address, it is 8008
print('Port 8080 is the web port, we need the api port. Replacing 8080 with 8008')
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8080', ':8008') + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.netloc.startswith('demoapp.'):
print('{} is the web server, we need the api server. Replacing \'demoapp.\' with \'demoapi.\''.format(
parsed_host.netloc))
# this is our demo server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapp.', 'demoapi.') + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.netloc.startswith('app.'):
print('{} is the web server, we need the api server. Replacing \'app.\' with \'api.\''.format(
parsed_host.netloc))
# this is our application server
api_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('app.', 'api.') + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
elif parsed_host.port == 8008:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace(':8008', ':8080') + parsed_host.path
elif parsed_host.netloc.startswith('demoapi.'):
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('demoapi.', 'demoapp.') + parsed_host.path
elif parsed_host.netloc.startswith('api.'):
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc.replace('api.', 'app.') + parsed_host.path
else:
api_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
web_host = parsed_host.scheme + "://" + parsed_host.netloc + parsed_host.path
print('Host configured to: {}'.format(api_host))
print(description.format(web_host), end='')
parse_input = input()
# check if these are valid credentials
credentials = None
try:
parsed = ConfigFactory.parse_string(parse_input)
if parsed:
credentials = parsed.get("credentials", None)
except Exception:
credentials = None
if not credentials or set(credentials) != {"access_key", "secret_key"}:
print('Could not parse user credentials, try again one after the other.')
credentials = {}
# parse individual
print('Enter user access key: ', end='')
credentials['access_key'] = input()
print('Enter user secret: ', end='')
credentials['secret_key'] = input()
print('Detected credentials key=\"{}\" secret=\"{}\"'.format(credentials['access_key'],
credentials['secret_key'], ))
try:
default_sdk_conf = Path(__file__).parent.absolute() / 'sdk.conf'
with open(str(default_sdk_conf), 'rt') as f:
default_sdk = f.read()
except Exception:
print('Error! Could not read default configuration file')
return
try:
with open(str(conf_file), 'wt') as f:
header = '# TRAINS SDK configuration file\n' \
'api {\n' \
' host: %s\n' \
' credentials {"access_key": "%s", "secret_key": "%s"}\n' \
'}\n' \
'sdk ' % (api_host, credentials['access_key'], credentials['secret_key'])
f.write(header)
f.write(default_sdk)
except Exception:
print('Error! Could not write configuration file at: {}'.format(str(conf_file)))
return
print('\nNew configuration stored in {}'.format(str(conf_file)))
print('TRAINS setup completed successfully.')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,27 @@
{
version: 1
disable_existing_loggers: 0
loggers {
trains {
level: INFO
}
boto {
level: WARNING
}
"boto.perf" {
level: WARNING
}
botocore {
level: WARNING
}
boto3 {
level: WARNING
}
google {
level: WARNING
}
urllib3 {
level: WARNING
}
}
}

View File

@@ -0,0 +1,126 @@
{
# TRAINS - default SDK configuration
storage {
cache {
# Defaults to system temp folder / cache
default_base_dir: "~/.trains/cache"
}
}
metrics {
# History size for debug files per metric/variant. For each metric/variant combination with an attached file
# (e.g. debug image event), file names for the uploaded files will be recycled in such a way that no more than
# X files are stored in the upload destination for each metric/variant combination.
file_history_size: 100
# Settings for generated debug images
images {
format: JPEG
quality: 87
subsampling: 0
}
}
network {
metrics {
# Number of threads allocated to uploading files (typically debug images) when transmitting metrics for
# a specific iteration
file_upload_threads: 4
# Warn about upload starvation if no uploads were made in specified period while file-bearing events keep
# being sent for upload
file_upload_starvation_warning_sec: 120
}
iteration {
# Max number of retries when getting frames if the server returned an error (http code 500)
max_retries_on_server_error: 5
# Backoff factory for consecutive retry attempts.
# SDK will wait for {backoff factor} * (2 ^ ({number of total retries} - 1)) between retries.
retry_backoff_factor_sec: 10
}
}
aws {
s3 {
# S3 credentials, used for read/write access by various SDK elements
# default, used for any bucket not specified below
key: ""
secret: ""
region: ""
credentials: [
# specifies key/secret credentials to use when handling s3 urls (read or write)
# {
# bucket: "my-bucket-name"
# key: "my-access-key"
# secret: "my-secret-key"
# },
# {
# # This will apply to all buckets in this host (unless key/value is specifically provided for a given bucket)
# host: "my-minio-host:9000"
# key: "12345678"
# secret: "12345678"
# multipart: false
# secure: false
# }
]
}
boto3 {
pool_connections: 512
max_multipart_concurrency: 16
}
}
google.storage {
# # Default project and credentials file
# # Will be used when no bucket configuration is found
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# # Specific credentials per bucket and sub directory
# credentials = [
# {
# bucket: "my-bucket"
# subdir: "path/in/bucket" # Not required
# project: "trains"
# credentials_json: "/path/to/credentials.json"
# },
# ]
}
log {
# debugging feature: set this to true to make null log propagate messages to root logger (so they appear in stdout)
null_log_propagate: False
task_log_buffer_capacity: 66
# disable urllib info and lower levels
disable_urllib3_info: True
}
development {
# Development-mode options
# dev task reuse window
task_reuse_time_window_in_hours: 72.0
# Run VCS repository detection asynchronously
vcs_repo_detect_async: False
# Store uncommitted git/hg source code diff in experiment manifest when training in development mode
# This stores "git diff" or "hg diff" into the experiment's "script.requirements.diff" section
store_uncommitted_code_diff_on_train: True
# Support stopping an experiment in case it was externally stopped, status was changed or task was reset
support_stopping: True
# Development mode worker
worker {
# Status report period in seconds
report_period_sec: 2
# Log all stdout & stderr
log_stdout: True
}
}
}

31
trains/config/defs.py Normal file
View File

@@ -0,0 +1,31 @@
import tempfile
from ..backend_config import EnvEntry
from ..backend_config.converters import base64_to_text, or_
from pathlib2 import Path
SESSION_CACHE_FILE = ".session.json"
DEFAULT_CACHE_DIR = str(Path(tempfile.gettempdir()) / "trains_cache")
TASK_ID_ENV_VAR = EnvEntry("TRAINS_TASK_ID", "ALG_TASK_ID")
LOG_TO_BACKEND_ENV_VAR = EnvEntry("TRAINS_LOG_TASK_TO_BACKEND", "ALG_LOG_TASK_TO_BACKEND", type=bool)
NODE_ID_ENV_VAR = EnvEntry("TRAINS_NODE_ID", "ALG_NODE_ID", type=int)
PROC_MASTER_ID_ENV_VAR = EnvEntry("TRAINS_PROC_MASTER_ID", "ALG_PROC_MASTER_ID", type=int)
LOG_STDERR_REDIRECT_LEVEL = EnvEntry("TRAINS_LOG_STDERR_REDIRECT_LEVEL", "ALG_LOG_STDERR_REDIRECT_LEVEL")
DEV_WORKER_NAME = EnvEntry("TRAINS_WORKER_NAME", "ALG_WORKER_NAME")
LOG_LEVEL_ENV_VAR = EnvEntry("TRAINS_LOG_LEVEL", "ALG_LOG_LEVEL", converter=or_(int, str))
# Repository detection
VCS_REPO_TYPE = EnvEntry("TRAINS_VCS_REPO_TYPE", "ALG_VCS_REPO_TYPE", default="git")
VCS_REPOSITORY_URL = EnvEntry("TRAINS_VCS_REPO_URL", "ALG_VCS_REPO_URL")
VCS_COMMIT_ID = EnvEntry("TRAINS_VCS_COMMIT_ID", "ALG_VCS_COMMIT_ID")
VCS_BRANCH = EnvEntry("TRAINS_VCS_BRANCH", "ALG_VCS_BRANCH")
VCS_ROOT = EnvEntry("TRAINS_VCS_ROOT", "ALG_VCS_ROOT")
VCS_STATUS = EnvEntry("TRAINS_VCS_STATUS", "ALG_VCS_STATUS", converter=base64_to_text)
VCS_DIFF = EnvEntry("TRAINS_VCS_DIFF", "ALG_VCS_DIFF", converter=base64_to_text)
# User credentials
API_ACCESS_KEY = EnvEntry("TRAINS_API_ACCESS_KEY", "ALG_API_ACCESS_KEY", help="API Access Key")
API_SECRET_KEY = EnvEntry("TRAINS_API_SECRET_KEY", "ALG_API_SECRET_KEY", help="API Secret Key")

17
trains/config/remote.py Normal file
View File

@@ -0,0 +1,17 @@
from .defs import TASK_ID_ENV_VAR
running_remotely_task_id = TASK_ID_ENV_VAR.get()
def override_current_task_id(task_id):
"""
Overrides the current task id to simulate remote running with a specific task.
Use for testing and debug only.
:param task_id: The task's id to use as the remote task.
Pass None to simulate local execution.
"""
global running_remotely_task_id
running_remotely_task_id = task_id

View File

@@ -0,0 +1,4 @@
""" Debugging module """
from .timer import Timer
from .log import get_logger, get_null_logger, TqdmLog, add_options as add_log_options, \
apply_args as parse_log_args, add_rotating_file_handler, add_time_rotating_file_handler

181
trains/debugging/log.py Normal file
View File

@@ -0,0 +1,181 @@
""" Logging convenience functions and wrappers """
import inspect
import logging
import logging.handlers
import os
import sys
from platform import system
import colorama
from ..config import config, get_log_redirect_level
from coloredlogs import ColoredFormatter
from pathlib2 import Path
from six import BytesIO
from tqdm import tqdm
default_level = logging.INFO
class _LevelRangeFilter(logging.Filter):
def __init__(self, min_level, max_level, name=''):
super(_LevelRangeFilter, self).__init__(name)
self.min_level = min_level
self.max_level = max_level
def filter(self, record):
return self.min_level <= record.levelno <= self.max_level
class LoggerRoot(object):
__base_logger = None
@classmethod
def _make_stream_handler(cls, level=None, stream=sys.stdout, colored=False):
ch = logging.StreamHandler(stream=stream)
ch.setLevel(level)
if colored:
colorama.init()
formatter = ColoredFormatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
else:
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
ch.setFormatter(formatter)
return ch
@classmethod
def get_base_logger(cls, level=None, stream=sys.stdout, colored=False):
if LoggerRoot.__base_logger:
return LoggerRoot.__base_logger
LoggerRoot.__base_logger = logging.getLogger('trains')
level = level if level is not None else default_level
LoggerRoot.__base_logger.setLevel(level)
redirect_level = get_log_redirect_level()
# Do not redirect to stderr if the target stream is already stderr
if redirect_level is not None and stream not in (None, sys.stderr):
# Adjust redirect level in case requested level is higher (e.g. logger is requested for CRITICAL
# and redirect is set for ERROR, in which case we redirect from CRITICAL)
redirect_level = max(level, redirect_level)
LoggerRoot.__base_logger.addHandler(
cls._make_stream_handler(redirect_level, sys.stderr, colored)
)
if level < redirect_level:
# Not all levels were redirected, remaining should be sent to requested stream
handler = cls._make_stream_handler(level, stream, colored)
handler.addFilter(_LevelRangeFilter(min_level=level, max_level=redirect_level - 1))
LoggerRoot.__base_logger.addHandler(handler)
else:
LoggerRoot.__base_logger.addHandler(
cls._make_stream_handler(level, stream, colored)
)
LoggerRoot.__base_logger.propagate = False
return LoggerRoot.__base_logger
@classmethod
def flush(cls):
if LoggerRoot.__base_logger:
for h in LoggerRoot.__base_logger.handlers:
h.flush()
def add_options(parser):
""" Add logging options to an argparse.ArgumentParser object """
level = logging.getLevelName(default_level)
parser.add_argument(
'--log-level', '-l', default=level, help='Log level (default is %s)' % level)
def apply_args(args):
""" Apply logging args from an argparse.ArgumentParser parsed args """
global default_level
default_level = logging.getLevelName(args.log_level.upper())
def get_logger(path=None, level=None, stream=None, colored=False):
""" Get a python logging object named using the provided filename and preconfigured with a color-formatted
stream handler
"""
path = path or os.path.abspath((inspect.stack()[1])[1])
root_log = LoggerRoot.get_base_logger(level=default_level, stream=sys.stdout, colored=colored)
log = root_log.getChild(Path(path).stem)
level = level if level is not None else root_log.level
log.setLevel(level)
if stream:
ch = logging.StreamHandler(stream=stream)
ch.setLevel(level)
log.propagate = True
return log
def _add_file_handler(logger, log_dir, fh, formatter=None):
""" Adds a file handler to a logger """
Path(log_dir).mkdir(parents=True, exist_ok=True)
if not formatter:
log_format = '%(asctime)s %(name)s x_x[%(levelname)s] %(message)s'
formatter = logging.Formatter(log_format)
fh.setFormatter(formatter)
logger.addHandler(fh)
def add_rotating_file_handler(logger, log_dir, log_file_prefix, max_bytes=10 * 1024 * 1024, backup_count=20,
formatter=None):
""" Create and add a rotating file handler to a logger """
fh = logging.handlers.RotatingFileHandler(
str(Path(log_dir) / ('%s.log' % log_file_prefix)), maxBytes=max_bytes, backupCount=backup_count)
_add_file_handler(logger, log_dir, fh, formatter)
def add_time_rotating_file_handler(logger, log_dir, log_file_prefix, when='midnight', formatter=None):
"""
Create and add a time rotating file handler to a logger.
Possible values for when are 'midnight', weekdays ('w0'-'W6', when 0 is Monday), and 's', 'm', 'h' amd 'd' for
seconds, minutes, hours and days respectively (case-insensitive)
"""
fh = logging.handlers.TimedRotatingFileHandler(
str(Path(log_dir) / ('%s.log' % log_file_prefix)), when=when)
_add_file_handler(logger, log_dir, fh, formatter)
def get_null_logger(name=None):
""" Get a logger with a null handler """
log = logging.getLogger(name if name else 'null')
if not log.handlers:
log.addHandler(logging.NullHandler())
log.propagate = config.get("log.null_log_propagate", False)
return log
class TqdmLog(object):
""" Tqdm (progressbar) wrapped logging class """
class _TqdmIO(BytesIO):
""" IO wrapper class for Tqdm """
def __init__(self, level=20, logger=None, *args, **kwargs):
self._log = logger or get_null_logger()
self._level = level
BytesIO.__init__(self, *args, **kwargs)
def write(self, buf):
self._buf = buf.strip('\r\n\t ')
def flush(self):
self._log.log(self._level, self._buf)
def __init__(self, total, desc='', log_level=20, ascii=False, logger=None, smoothing=0, mininterval=5, initial=0):
self._io = self._TqdmIO(level=log_level, logger=logger)
self._tqdm = tqdm(total=total, desc=desc, file=self._io, ascii=ascii if not system() == 'Windows' else True,
smoothing=smoothing,
mininterval=mininterval, initial=initial)
def update(self, n=None):
if n is not None:
self._tqdm.update(n=n)
else:
self._tqdm.update()
def close(self):
self._tqdm.close()

112
trains/debugging/timer.py Normal file
View File

@@ -0,0 +1,112 @@
""" Timing support """
import sys
import time
import six
class Timer(object):
"""A class implementing a simple timer, with a reset option """
def __init__(self):
self._start_time = 0.
self._diff = 0.
self._total_time = 0.
self._average_time = 0.
self._calls = 0
self.tic()
def reset(self):
self._start_time = 0.
self._diff = 0.
self.reset_average()
def reset_average(self):
""" Reset average counters (does not change current timer) """
self._total_time = 0
self._average_time = 0
self._calls = 0
def tic(self):
try:
# using time.time instead of time.clock because time time.clock
# does not normalize for multi threading
self._start_time = time.time()
except Exception:
pass
def toc(self, average=True):
self._diff = time.time() - self._start_time
self._total_time += self._diff
self._calls += 1
self._average_time = self._total_time / self._calls
if average:
return self._average_time
else:
return self._diff
@property
def average_time(self):
return self._average_time
@property
def total_time(self):
return self._total_time
def toc_with_reset(self, average=True, reset_if_calls=1000):
""" Enable toc with reset (slightly inaccurate if reset event occurs) """
if self._calls > reset_if_calls:
last_diff = time.time() - self._start_time
self._start_time = time.time()
self._total_time = last_diff
self._average_time = 0
self._calls = 0
return self.toc(average=average)
class TimersMixin(object):
def __init__(self):
self._timers = {}
def add_timers(self, *names):
for name in names:
self.add_timer(name)
def add_timer(self, name, timer=None):
if name in self._timers:
raise ValueError('timer %s already exists' % name)
timer = timer or Timer()
self._timers[name] = timer
return timer
def get_timer(self, name, default=None):
return self._timers.get(name, default)
def get_timers(self):
return self._timers
def _call_timer(self, name, callable, silent_fail=False):
try:
return callable(self._timers[name])
except KeyError:
if not silent_fail:
six.reraise(*sys.exc_info())
def reset_timers(self, *names):
for name in names:
self._call_timer(name, lambda t: t.reset())
def reset_average_timers(self, *names):
for name in names:
self._call_timer(name, lambda t: t.reset_average())
def tic_timers(self, *names):
for name in names:
self._call_timer(name, lambda t: t.tic())
def toc_timers(self, *names):
return [self._call_timer(name, lambda t: t.toc()) for name in names]
def toc_with_reset_timer(self, name, average=True, reset_if_calls=1000):
return self._call_timer(name, lambda t: t.toc_with_reset(average, reset_if_calls))

3
trains/errors.py Normal file
View File

@@ -0,0 +1,3 @@
class UsageError(RuntimeError):
""" An exception raised for illegal usage of trains objects"""
pass

Some files were not shown because too many files have changed in this diff Show More