mirror of
https://github.com/clearml/clearml
synced 2025-05-03 04:21:00 +00:00
Add scikit-learn support (joblib) and xgboost support
This commit is contained in:
parent
1bb06c0190
commit
19c5f05912
25
examples/joblib_example.py
Normal file
25
examples/joblib_example.py
Normal file
@ -0,0 +1,25 @@
|
||||
import joblib
|
||||
|
||||
from sklearn import datasets
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
from trains import Task
|
||||
|
||||
task = Task.init(project_name="examples", task_name="joblib test")
|
||||
|
||||
iris = datasets.load_iris()
|
||||
X = iris.data
|
||||
y = iris.target
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
|
||||
model = LogisticRegression() # sklearn LogisticRegression class
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
joblib.dump(model, 'model.pkl', compress=True)
|
||||
|
||||
loaded_model = joblib.load('model.pkl')
|
||||
result = loaded_model.score(X_test, y_test)
|
||||
print(result)
|
@ -6,29 +6,28 @@
|
||||
# 2 seconds per epoch on a K520 GPU.
|
||||
from __future__ import print_function
|
||||
|
||||
import io
|
||||
import numpy as np
|
||||
import tensorflow
|
||||
|
||||
from keras.callbacks import TensorBoard, ModelCheckpoint
|
||||
from keras.datasets import mnist
|
||||
from keras.models import Sequential, Model
|
||||
from keras.layers.core import Dense, Dropout, Activation
|
||||
from keras.optimizers import SGD, Adam, RMSprop
|
||||
from keras.models import Sequential
|
||||
from keras.layers.core import Dense, Activation
|
||||
from keras.optimizers import RMSprop
|
||||
from keras.utils import np_utils
|
||||
# TODO: test these methods binding
|
||||
from keras.models import load_model, save_model, model_from_json
|
||||
|
||||
import tensorflow as tf
|
||||
from trains import Task
|
||||
|
||||
|
||||
class TensorBoardImage(TensorBoard):
|
||||
@staticmethod
|
||||
def make_image(tensor):
|
||||
import tensorflow as tf
|
||||
from PIL import Image
|
||||
tensor = np.stack((tensor, tensor, tensor), axis=2)
|
||||
height, width, channels = tensor.shape
|
||||
image = Image.fromarray(tensor)
|
||||
import io
|
||||
output = io.BytesIO()
|
||||
image.save(output, format='PNG')
|
||||
image_string = output.getvalue()
|
||||
@ -38,9 +37,10 @@ class TensorBoardImage(TensorBoard):
|
||||
colorspace=channels,
|
||||
encoded_image_string=image_string)
|
||||
|
||||
def on_epoch_end(self, epoch, logs={}):
|
||||
def on_epoch_end(self, epoch, logs=None):
|
||||
if logs is None:
|
||||
logs = {}
|
||||
super(TensorBoardImage, self).on_epoch_end(epoch, logs)
|
||||
import tensorflow as tf
|
||||
images = self.validation_data[0] # 0 - data; 1 - labels
|
||||
img = (255 * images[0].reshape(28, 28)).astype('uint8')
|
||||
|
||||
|
15
examples/requirements.txt
Normal file
15
examples/requirements.txt
Normal file
@ -0,0 +1,15 @@
|
||||
absl-py>=0.7.1
|
||||
Keras>=2.2.4
|
||||
joblib>=0.13.2
|
||||
matplotlib>=3.1.1
|
||||
seaborn>=0.9.0
|
||||
sklearn>=0.0
|
||||
tensorboard>=1.14.0
|
||||
tensorboardX>=1.8
|
||||
tensorflow>=1.14.0
|
||||
torch>=1.1.0
|
||||
torchvision>=0.3.0
|
||||
xgboost>=0.90
|
||||
|
||||
# sudo apt-get install graphviz
|
||||
graphviz>=0.8
|
@ -3,9 +3,6 @@
|
||||
import tensorflow as tf
|
||||
import numpy as np
|
||||
import cv2
|
||||
from time import sleep
|
||||
#import tensorflow.compat.v1 as tf
|
||||
#tf.disable_v2_behavior()
|
||||
|
||||
from trains import Task
|
||||
task = Task.init(project_name='examples', task_name='tensorboard toy example')
|
||||
|
59
examples/xgboost_sample.py
Normal file
59
examples/xgboost_sample.py
Normal file
@ -0,0 +1,59 @@
|
||||
import matplotlib.pyplot as plt
|
||||
import xgboost as xgb
|
||||
from sklearn import datasets
|
||||
from sklearn.metrics import accuracy_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from xgboost import plot_tree
|
||||
|
||||
from trains import Task
|
||||
|
||||
task = Task.init(project_name='examples', task_name='XGBoost simple example')
|
||||
iris = datasets.load_iris()
|
||||
X = iris.data
|
||||
y = iris.target
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
dtrain = xgb.DMatrix(X_train, label=y_train)
|
||||
dtest = xgb.DMatrix(X_test, label=y_test)
|
||||
param = {
|
||||
'max_depth': 3, # the maximum depth of each tree
|
||||
'eta': 0.3, # the training step for each iteration
|
||||
'silent': 1, # logging mode - quiet
|
||||
'objective': 'multi:softprob', # error evaluation for multiclass training
|
||||
'num_class': 3} # the number of classes that exist in this datset
|
||||
num_round = 20 # the number of training iterations
|
||||
|
||||
try:
|
||||
# try to load a model
|
||||
bst = xgb.Booster(params=param, model_file='xgb.01.model')
|
||||
bst.load_model('xgb.01.model')
|
||||
except:
|
||||
bst = None
|
||||
|
||||
# if we dont have one train a model
|
||||
if bst is None:
|
||||
bst = xgb.train(param, dtrain, num_round)
|
||||
|
||||
# store trained model model v1
|
||||
bst.save_model('xgb.01.model')
|
||||
bst.dump_model('xgb.01.raw.txt')
|
||||
|
||||
# build classifier
|
||||
model = xgb.XGBClassifier()
|
||||
model.fit(X_train, y_train)
|
||||
|
||||
# store trained classifier model
|
||||
model.save_model('xgb.02.model')
|
||||
|
||||
# make predictions for test data
|
||||
y_pred = model.predict(X_test)
|
||||
predictions = [round(value) for value in y_pred]
|
||||
|
||||
# evaluate predictions
|
||||
accuracy = accuracy_score(y_test, predictions)
|
||||
print("Accuracy: %.2f%%" % (accuracy * 100.0))
|
||||
labels = dtest.get_label()
|
||||
|
||||
# plot results
|
||||
xgb.plot_importance(model)
|
||||
plot_tree(model)
|
||||
plt.show()
|
@ -144,7 +144,11 @@ class Session(TokenManager):
|
||||
|
||||
# update api version from server response
|
||||
try:
|
||||
api_version = jwt.decode(self.token, verify=False).get('api_version', Session.api_version)
|
||||
token_dict = jwt.decode(self.token, verify=False)
|
||||
api_version = token_dict.get('api_version')
|
||||
if not api_version:
|
||||
api_version = '2.2' if token_dict.get('env', '') == 'prod' else Session.api_version
|
||||
|
||||
Session.api_version = str(api_version)
|
||||
except (jwt.DecodeError, ValueError):
|
||||
pass
|
||||
|
@ -36,7 +36,7 @@ def or_(*converters, **kwargs):
|
||||
"""
|
||||
Wrapper that implements an "optional converter" pattern. Allows specifying a converter
|
||||
for which a set of exceptions is ignored (and the original value is returned)
|
||||
:param converter: A converter callable
|
||||
:param converters: A converter callable
|
||||
:param exceptions: A tuple of exception types to ignore
|
||||
"""
|
||||
# noinspection PyUnresolvedReferences
|
||||
|
@ -1,9 +1,8 @@
|
||||
import os
|
||||
import weakref
|
||||
|
||||
import numpy as np
|
||||
import hashlib
|
||||
from tempfile import mkstemp, mkdtemp
|
||||
from tempfile import mkdtemp
|
||||
from threading import Thread, Event
|
||||
from multiprocessing.pool import ThreadPool
|
||||
|
||||
|
52
trains/binding/frameworks/base_bind.py
Normal file
52
trains/binding/frameworks/base_bind.py
Normal file
@ -0,0 +1,52 @@
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import six
|
||||
|
||||
|
||||
@six.add_metaclass(ABCMeta)
|
||||
class PatchBaseModelIO(object):
|
||||
"""
|
||||
Base class for patched models
|
||||
|
||||
:param __main_task: Task to run (Experiment)
|
||||
:type __main_task: Task
|
||||
:param __patched: True if the model is patched
|
||||
:type __patched: bool
|
||||
"""
|
||||
@property
|
||||
@abstractmethod
|
||||
def __main_task(self):
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def __patched(self):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
"""
|
||||
Update the model task to run
|
||||
:param task: the experiment to do
|
||||
:type task: Task
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _patch_model_io():
|
||||
"""
|
||||
Patching the load and save functions
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
pass
|
@ -3,13 +3,14 @@ import sys
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains.binding.frameworks.base_bind import PatchBaseModelIO
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...config import running_remotely
|
||||
from ...model import Framework
|
||||
|
||||
|
||||
class PatchPyTorchModelIO(object):
|
||||
class PatchPyTorchModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
|
||||
|
@ -9,8 +9,6 @@ from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty, TrainsFrameworkAdapter
|
||||
from ..import_bind import PostImportHookPatching
|
||||
|
101
trains/binding/frameworks/xgboost_bind.py
Normal file
101
trains/binding/frameworks/xgboost_bind.py
Normal file
@ -0,0 +1,101 @@
|
||||
import sys
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains.binding.frameworks.base_bind import PatchBaseModelIO
|
||||
from ..frameworks import _patched_call, WeightsFileHandler, _Empty
|
||||
from ..import_bind import PostImportHookPatching
|
||||
from ...config import running_remotely
|
||||
from ...model import Framework
|
||||
|
||||
|
||||
class PatchXGBoostModelIO(PatchBaseModelIO):
|
||||
__main_task = None
|
||||
__patched = None
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task, **kwargs):
|
||||
PatchXGBoostModelIO.__main_task = task
|
||||
PatchXGBoostModelIO._patch_model_io()
|
||||
PostImportHookPatching.add_on_import('xgboost', PatchXGBoostModelIO._patch_model_io)
|
||||
|
||||
@staticmethod
|
||||
def _patch_model_io():
|
||||
if PatchXGBoostModelIO.__patched:
|
||||
return
|
||||
|
||||
if 'xgboost' not in sys.modules:
|
||||
return
|
||||
PatchXGBoostModelIO.__patched = True
|
||||
try:
|
||||
import xgboost as xgb
|
||||
bst = xgb.Booster
|
||||
bst.save_model = _patched_call(bst.save_model, PatchXGBoostModelIO._save)
|
||||
bst.load_model = _patched_call(bst.load_model, PatchXGBoostModelIO._load)
|
||||
except ImportError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _save(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchXGBoostModelIO.__main_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
filename = None
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
WeightsFileHandler.create_output_model(obj, filename, Framework.xgboost, PatchXGBoostModelIO.__main_task,
|
||||
singlefile=True, model_name=model_name)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
elif len(args) == 1 and isinstance(args[0], six.string_types):
|
||||
filename = args[0]
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchXGBoostModelIO.__main_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||
PatchXGBoostModelIO.__main_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, Framework.xgboost,
|
||||
PatchXGBoostModelIO.__main_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
110
trains/binding/joblib_bind.py
Normal file
110
trains/binding/joblib_bind.py
Normal file
@ -0,0 +1,110 @@
|
||||
try:
|
||||
import joblib
|
||||
except ImportError as e:
|
||||
joblib = None
|
||||
|
||||
import six
|
||||
from pathlib2 import Path
|
||||
|
||||
from trains.binding.frameworks import _patched_call, _Empty, WeightsFileHandler
|
||||
from trains.config import running_remotely
|
||||
from trains.debugging.log import LoggerRoot
|
||||
|
||||
|
||||
class PatchedJoblib(object):
|
||||
_patched_original_dump = None
|
||||
_patched_original_load = None
|
||||
_current_task = None
|
||||
_current_framework = None
|
||||
|
||||
@staticmethod
|
||||
def patch_joblib():
|
||||
if PatchedJoblib._patched_original_dump is not None and PatchedJoblib._patched_original_load is not None:
|
||||
# We don't need to patch anything else, so we are done
|
||||
return True
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
joblib.dump = _patched_call(joblib.dump, PatchedJoblib._dump)
|
||||
joblib.load = _patched_call(joblib.load, PatchedJoblib._load)
|
||||
|
||||
except Exception:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def update_current_task(task):
|
||||
if PatchedJoblib.patch_joblib():
|
||||
PatchedJoblib._current_task = task
|
||||
|
||||
@staticmethod
|
||||
def _dump(original_fn, obj, f, *args, **kwargs):
|
||||
ret = original_fn(obj, f, *args, **kwargs)
|
||||
if not PatchedJoblib._current_task:
|
||||
return ret
|
||||
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
f.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
filename = None
|
||||
|
||||
# give the model a descriptive name based on the file name
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model_name = Path(filename).stem
|
||||
except Exception:
|
||||
model_name = None
|
||||
PatchedJoblib._current_framework = PatchedJoblib.get_model_framework(obj)
|
||||
WeightsFileHandler.create_output_model(obj, filename, PatchedJoblib._current_framework,
|
||||
PatchedJoblib._current_task, singlefile=True, model_name=model_name)
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _load(original_fn, f, *args, **kwargs):
|
||||
if isinstance(f, six.string_types):
|
||||
filename = f
|
||||
elif hasattr(f, 'name'):
|
||||
filename = f.name
|
||||
else:
|
||||
filename = None
|
||||
|
||||
if not PatchedJoblib._current_task:
|
||||
return original_fn(f, *args, **kwargs)
|
||||
|
||||
# register input model
|
||||
empty = _Empty()
|
||||
if running_remotely():
|
||||
filename = WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
||||
PatchedJoblib._current_task)
|
||||
model = original_fn(filename or f, *args, **kwargs)
|
||||
else:
|
||||
# try to load model before registering, in case we fail
|
||||
model = original_fn(f, *args, **kwargs)
|
||||
WeightsFileHandler.restore_weights_file(empty, filename, PatchedJoblib._current_framework,
|
||||
PatchedJoblib._current_task)
|
||||
|
||||
if empty.trains_in_model:
|
||||
# noinspection PyBroadException
|
||||
try:
|
||||
model.trains_in_model = empty.trains_in_model
|
||||
except Exception:
|
||||
pass
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def get_model_framework(obj):
|
||||
object_orig_module = obj.__module__
|
||||
framework = object_orig_module
|
||||
try:
|
||||
framework = object_orig_module.partition(".")[0]
|
||||
except Exception as _:
|
||||
LoggerRoot.get_base_logger().warning(
|
||||
"Can't get model framework, model framework will be: {} ".format(object_orig_module))
|
||||
finally:
|
||||
return framework
|
@ -1,6 +1,5 @@
|
||||
import abc
|
||||
import os
|
||||
import re
|
||||
import tarfile
|
||||
import zipfile
|
||||
from tempfile import mkdtemp, mkstemp
|
||||
@ -40,6 +39,7 @@ class Framework(Options):
|
||||
darknet = 'Darknet'
|
||||
paddlepaddle = 'PaddlePaddle'
|
||||
scikitlearn = 'ScikitLearn'
|
||||
xgboost = 'XGBoost'
|
||||
|
||||
__file_extensions_mapping = {
|
||||
'.pb': (tensorflow, tensorflowjs, onnx, ),
|
||||
@ -59,13 +59,13 @@ class Framework(Options):
|
||||
'.h5': (keras, ),
|
||||
'.hdf5': (keras, ),
|
||||
'.keras': (keras, ),
|
||||
'.model': (mknet, cntk, ),
|
||||
'.model': (mknet, cntk, xgboost),
|
||||
'-symbol.json': (mknet, ),
|
||||
'.cntk': (cntk, ),
|
||||
'.t7': (torch, ),
|
||||
'.cfg': (darknet, ),
|
||||
'__model__': (paddlepaddle, ),
|
||||
'.pkl': (scikitlearn, keras, ),
|
||||
'.pkl': (scikitlearn, keras, xgboost),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
|
@ -10,6 +10,7 @@ from collections import OrderedDict, Callable
|
||||
import psutil
|
||||
import six
|
||||
|
||||
from trains.binding.joblib_bind import PatchedJoblib
|
||||
from .backend_api.services import tasks, projects
|
||||
from .backend_api.session.session import Session
|
||||
from .backend_interface.model import Model as BackendModel
|
||||
@ -34,8 +35,9 @@ from .utilities.args import argparser_parseargs_called, get_argparser_last_args,
|
||||
from .binding.frameworks.pytorch_bind import PatchPyTorchModelIO
|
||||
from .binding.frameworks.tensorflow_bind import PatchSummaryToEventTransformer, PatchTensorFlowEager, \
|
||||
PatchKerasModelIO, PatchTensorflowModelIO
|
||||
from .utilities.resource_monitor import ResourceMonitor
|
||||
from .binding.frameworks.xgboost_bind import PatchXGBoostModelIO
|
||||
from .binding.matplotlib_bind import PatchedMatplotlib
|
||||
from .utilities.resource_monitor import ResourceMonitor
|
||||
from .utilities.seed import make_deterministic
|
||||
|
||||
NotSet = object()
|
||||
@ -118,15 +120,15 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def init(
|
||||
cls,
|
||||
project_name=None,
|
||||
task_name=None,
|
||||
task_type=TaskTypes.training,
|
||||
reuse_last_task_id=True,
|
||||
output_uri=None,
|
||||
auto_connect_arg_parser=True,
|
||||
auto_connect_frameworks=True,
|
||||
auto_resource_monitoring=True,
|
||||
cls,
|
||||
project_name=None,
|
||||
task_name=None,
|
||||
task_type=TaskTypes.training,
|
||||
reuse_last_task_id=True,
|
||||
output_uri=None,
|
||||
auto_connect_arg_parser=True,
|
||||
auto_connect_frameworks=True,
|
||||
auto_resource_monitoring=True,
|
||||
):
|
||||
"""
|
||||
Return the Task object for the main execution task (task context).
|
||||
@ -239,14 +241,15 @@ class Task(_Task):
|
||||
# patch OS forking
|
||||
PatchOsFork.patch_fork()
|
||||
if auto_connect_frameworks:
|
||||
PatchedJoblib.update_current_task(task)
|
||||
PatchedMatplotlib.update_current_task(Task.__main_task)
|
||||
PatchAbsl.update_current_task(Task.__main_task)
|
||||
PatchSummaryToEventTransformer.update_current_task(task)
|
||||
# PatchModelCheckPointCallback.update_current_task(task)
|
||||
PatchTensorFlowEager.update_current_task(task)
|
||||
PatchKerasModelIO.update_current_task(task)
|
||||
PatchTensorflowModelIO.update_current_task(task)
|
||||
PatchPyTorchModelIO.update_current_task(task)
|
||||
PatchXGBoostModelIO.update_current_task(task)
|
||||
if auto_resource_monitoring:
|
||||
task._resource_monitor = ResourceMonitor(task)
|
||||
task._resource_monitor.start()
|
||||
@ -277,10 +280,10 @@ class Task(_Task):
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
cls,
|
||||
task_name=None,
|
||||
project_name=None,
|
||||
task_type=TaskTypes.training,
|
||||
cls,
|
||||
task_name=None,
|
||||
project_name=None,
|
||||
task_type=TaskTypes.training,
|
||||
):
|
||||
"""
|
||||
Create a new Task object, regardless of the main execution task (Task.init).
|
||||
@ -345,7 +348,7 @@ class Task(_Task):
|
||||
pass
|
||||
|
||||
# if we force no task reuse from os environment
|
||||
if DEV_TASK_NO_REUSE.get():
|
||||
if DEV_TASK_NO_REUSE.get() or reuse_last_task_id:
|
||||
default_task = None
|
||||
else:
|
||||
# if we have a previous session to use, get the task id from it
|
||||
@ -364,7 +367,6 @@ class Task(_Task):
|
||||
default_task_id = reuse_last_task_id
|
||||
elif not reuse_last_task_id or not cls.__task_is_relevant(default_task):
|
||||
default_task_id = None
|
||||
closed_old_task = cls.__close_timed_out_task(default_task)
|
||||
else:
|
||||
default_task_id = default_task.get('id') if default_task else None
|
||||
|
||||
@ -693,7 +695,7 @@ class Task(_Task):
|
||||
If `config_text` is not None, `config_dict` must not be provided.
|
||||
"""
|
||||
config_text = self.get_model_config_text()
|
||||
return OutputModel._text_to_config_dict(config_text)
|
||||
return OutputModel._text_to_config_dict(config_text)
|
||||
|
||||
def set_model_label_enumeration(self, enumeration=None):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user