diff --git a/trains/backend_api/session/client/client.py b/trains/backend_api/session/client/client.py index 6c394388..1eefa199 100644 --- a/trains/backend_api/session/client/client.py +++ b/trains/backend_api/session/client/client.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import abc import os +import types from argparse import Namespace from collections import OrderedDict from enum import Enum @@ -520,9 +521,11 @@ class APIClient(object): models = None # type: Any projects = None # type: Any - def __init__(self, session=None, api_version=None): + def __init__(self, session=None, api_version=None, **kwargs): self.session = session or StrictSession() + _api_services = kwargs.pop("api_services", api_services) + def import_(*args, **kwargs): try: return import_module(*args, **kwargs) @@ -536,16 +539,20 @@ class APIClient(object): for name, mod in ( ( name, - import_(".".join((api_services.__name__, api_version, name))), + import_(".".join((_api_services.__name__, api_version, name))), ) - for name in api_services.__all__ + for name in _api_services.__all__ ) if mod ) else: services = OrderedDict( - (name, getattr(api_services, name)) for name in api_services.__all__ + (name, getattr(_api_services, name)) for name in _api_services.__all__ ) + self._update_services(services) + + def _update_services(self, services): + # type: (Dict[str, types.ModuleType]) -> () self.__dict__.update( dict( { @@ -553,4 +560,4 @@ class APIClient(object): for name, module in services.items() }, ) - ) + ) \ No newline at end of file