diff --git a/examples/k8s_glue_example.py b/examples/k8s_glue_example.py index 6c1b6bf..c528619 100644 --- a/examples/k8s_glue_example.py +++ b/examples/k8s_glue_example.py @@ -21,12 +21,23 @@ def parse_args(): "--num-of-services", type=int, default=20, help="Specify the number of k8s services to be used. Use only with ports-mode." ) + parser.add_argument( + "--base-port", type=int, + help="If using ports-mode, specifies the base port exposed by the services." + "For pod #X, the port will be +X" + ) return parser.parse_args() def main(): args = parse_args() - k8s = K8sIntegration(ports_mode=args.ports_mode, num_of_services=args.num_of_services) + + user_props_cb = None + if args.ports_mode and args.base_port: + def user_props_cb(pod_number): + return {"k8s-pod-port": args.base_port + pod_number} + + k8s = K8sIntegration(ports_mode=args.ports_mode, num_of_services=args.num_of_services, user_props_cb=user_props_cb) k8s.k8s_daemon(args.queue) diff --git a/trains_agent/glue/k8s.py b/trains_agent/glue/k8s.py index 4feda2a..25aefd3 100644 --- a/trains_agent/glue/k8s.py +++ b/trains_agent/glue/k8s.py @@ -54,6 +54,7 @@ class K8sIntegration(Worker): debug=False, ports_mode=False, num_of_services=20, + user_props_cb=None, ): """ Initialize the k8s integration glue layer daemon @@ -68,6 +69,9 @@ class K8sIntegration(Worker): Requires the `num_of_services` parameter. :param int num_of_services: Number of k8s services configured in the cluster. Required if `port_mode` is True. (default: 20) + :param callable user_props_cb: An Optional callable allowing additional user properties to be specified + when scheduling a task to run in a pod. Callable can receive an optional pod number and should return + a dictionary of user properties (name and value). Signature is [[Optional[int]], Dict[str,str]] """ super(K8sIntegration, self).__init__() self.k8s_pending_queue_name = k8s_pending_queue_name or self.K8S_PENDING_QUEUE @@ -82,6 +86,7 @@ class K8sIntegration(Worker): self.ports_mode = ports_mode self.num_of_services = num_of_services self._edit_hyperparams_support = None + self._user_props_cb = user_props_cb def _set_task_user_properties(self, task_id: str, **properties: str): if self._edit_hyperparams_support is not True: @@ -200,9 +205,17 @@ class K8sIntegration(Worker): self.log.error("Running kubectl encountered an error: {}".format( error if isinstance(error, str) else error.decode())) elif self.ports_mode: + user_props = {"k8s-pod-number": pod_number, "k8s-pod-label": labels[0]} + if self._user_props_cb: + # noinspection PyBroadException + try: + custom_props = self._user_props_cb(pod_number) if self.ports_mode else self._user_props_cb() + user_props.update(custom_props) + except Exception: + pass self._set_task_user_properties( task_id=task_id, - **{"k8s-pod-number": pod_number, "k8s-pod-label": labels[0]} + **user_props, ) def run_tasks_loop(self, queues: List[Text], worker_params, **kwargs):