Source code for mars.learn.contrib.tensorflow.run_script

# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import json
from typing import Any, BinaryIO, Dict, List, Optional, TextIO, Union

import numpy as np

from .... import opcodes as OperandDef
from ....core import recursive_tile
from ....core.context import get_context
from ....remote.run_script import RunScript, _extract_inputs
from ....serialization.serializables import (
    BytesField,
    Int32Field,
    DictField,
    StringField,
)
from ....typing import SessionType, TileableType
from ....utils import to_binary
from ...utils import collect_ports
from ..utils import pick_workers


class RunTensorFlow(RunScript):
    _op_type_ = OperandDef.RUN_TENSORFLOW

    _code = BytesField("code")
    _n_workers = Int32Field("n_workers")
    _n_ps = Int32Field("n_ps")
    _tf_config = DictField("tf_config")
    _port = Int32Field("port")
    # used for chunk op
    _tf_task_type = StringField("tf_task_type")
    _tf_task_index = Int32Field("tf_task_index")

    def __init__(
        self,
        n_workers=None,
        n_ps=None,
        tf_config=None,
        port=None,
        tf_task_type=None,
        tf_task_index=None,
        gpu=None,
        **kw,
    ):
        super().__init__(
            _n_workers=n_workers,
            _n_ps=n_ps,
            _tf_config=tf_config,
            _port=port,
            _tf_task_type=tf_task_type,
            _tf_task_index=tf_task_index,
            gpu=gpu,
            **kw,
        )

    @property
    def n_workers(self):
        return self._n_workers

    @property
    def n_ps(self):
        return self._n_ps or 0

    @property
    def n_roles(self):
        return self._n_workers + self._n_ps

    @property
    def tf_config(self):
        return self._tf_config

    @property
    def port(self):
        return self._port

    @property
    def tf_task_type(self):
        return self._tf_task_type

    @property
    def tf_task_index(self):
        return self._tf_task_index

    @classmethod
    def tile(cls, op):
        ctx = get_context()

        cluster_conf = {"worker": []}
        if op.n_ps > 0:
            cluster_conf["ps"] = []
        n_workers = op.n_workers

        out_chunks = []
        worker_addresses = ctx.get_worker_addresses()
        picked_workers = pick_workers(worker_addresses, op.n_roles)
        data, input_chunks = cls._get_chunk_data(op)

        ports = yield from recursive_tile(collect_ports(worker_addresses))
        yield ports.chunks
        ports = ctx.get_chunks_result([ports.chunks[0].key])[0]

        i = 0
        for worker, port in zip(picked_workers, ports):
            worker_addr = worker.rsplit(":", 1)[0]
            chunk_op = op.copy().reset_key()
            chunk_op._data = data
            addr = f"{worker_addr}:{port}"
            # tell graph actor that the chunk should be executed on the exact worker
            chunk_op.expect_worker = worker
            tp = "worker" if i < n_workers else "ps"
            chunk_op._tf_task_type = tp
            idx = i if i < n_workers else i - n_workers
            chunk_op._tf_task_index = idx
            cluster_conf[tp].append(addr)
            chunk_op._tf_config = {
                "cluster": cluster_conf,
                "task": {"type": tp, "index": idx},
            }
            out_chunks.append(chunk_op.new_chunk(input_chunks, index=(i,)))
            i += 1

        new_op = op.copy()
        return new_op.new_tileables(
            op.inputs,
            chunks=out_chunks,
            nsplits=(tuple(np.nan for _ in range(len(out_chunks))),),
        )

    @classmethod
    def _build_envs(cls, ctx, op):
        envs = super()._build_envs(ctx, op)
        envs.update({"TF_CONFIG": json.dumps(op.tf_config)})
        return envs

    @classmethod
    def execute(cls, ctx, op):
        if op.merge:
            return super().execute(ctx, op)

        assert ctx.local_address.split(":")[0] == op.expect_worker.split(":")[0]

        super().execute(ctx, op)

        if op.tf_task_type == "worker" and op.tf_task_index == 0:
            ctx[op.outputs[0].key] = {"status": "ok"}
        else:
            ctx[op.outputs[0].key] = {}


[docs]def run_tensorflow_script( script: Union[bytes, str, BinaryIO, TextIO], n_workers: int, n_ps: int = 0, data: Dict[str, TileableType] = None, gpu: Optional[bool] = None, command_argv: List[str] = None, retry_when_fail: bool = False, session: SessionType = None, run_kwargs: Dict[str, Any] = None, ): """ Run TensorFlow script in Mars cluster. Parameters ---------- script: str or file-like object Script to run n_workers : int Number of TensorFlow workers. n_ps : int Number of TensorFlow PS workers. data : dict Variable name to data. gpu : bool Run PyTorch script on GPU command_argv : list Extra command args for script retry_when_fail : bool If True, retry when function failed. session Mars session, if not provided, will use default one. run_kwargs : dict Extra kwargs for `session.run`. Returns ------- status return {'status': 'ok'} if succeeded, or error raised """ if int(n_workers) <= 0: raise ValueError("n_workers should be at least 1") if int(n_ps) < 0: raise ValueError("n_ps should be at least 0") if hasattr(script, "read"): code = script.read() else: with open(os.path.abspath(script), "rb") as f: code = f.read() inputs = _extract_inputs(data) op = RunTensorFlow( data=data, code=to_binary(code), n_workers=int(n_workers), n_ps=int(n_ps), retry_when_fail=retry_when_fail, gpu=gpu, command_args=command_argv, ) return op(inputs).execute(session=session, **(run_kwargs or {}))