#!/usr/bin/env python # -*- coding: utf-8 -*- # Copyright 1999-2020 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import numpy as np from ... import opcodes as OperandDef from ...serialize import KeyField, StringField from ...utils import check_chunks_unknown_shape from ...tiles import TilesError from ..utils import unify_chunks, broadcast_shape from ..operands import TensorOperand, TensorOperandMixin from ..datasource import tensor as astensor from .broadcast_to import broadcast_to from ..array_utils import as_same_device, device class TensorCopyTo(TensorOperand, TensorOperandMixin): _op_type_ = OperandDef.COPYTO _src = KeyField('src') _dst = KeyField('dest') _casting = StringField('casting') _where = KeyField('where') def __init__(self, casting=None, dtype=None, gpu=None, sparse=None, **kw): super().__init__(_casting=casting, _dtype=dtype, _gpu=gpu, _sparse=sparse, **kw) @property def src(self): return self._src @property def dst(self): return self._dst @property def casting(self): return self._casting @property def where(self): return self._where def check_inputs(self, inputs): if not 2 <= len(inputs) <= 3: raise ValueError("inputs' length must be 2 or 3") def _set_inputs(self, inputs): super()._set_inputs(inputs) self._src = self._inputs[0] self._dst = self._inputs[1] if len(self._inputs) > 2: self._where = self._inputs[2] @staticmethod def _extract_inputs(inputs): if len(inputs) == 2: (src, dst), where = inputs, None else: src, dst, where = inputs if where is True: where = None else: where = astensor(where) return src, dst, where def __call__(self, *inputs): from ..core import Tensor src, dst, where = self._extract_inputs(inputs) if not isinstance(dst, Tensor): raise TypeError('dst has to be a Tensor') self._dtype = dst.dtype self._gpu = dst.op.gpu self._sparse = dst.issparse() if not np.can_cast(src.dtype, dst.dtype, casting=self.casting): raise TypeError(f'Cannot cast array from {src.dtype!r} to {dst.dtype!r} ' f'according to the rule {self.casting!s}') try: broadcast_to(src, dst.shape) except ValueError: raise ValueError('could not broadcast input array ' f'from shape {src.shape!r} into shape {dst.shape!r}') if where: try: broadcast_to(where, dst.shape) except ValueError: raise ValueError('could not broadcast where mask ' f'from shape {src.shape!r} into shape {dst.shape!r}') inps = [src, dst] if where is not None: inps.append(where) ret = self.new_tensor(inps, dst.shape, order=dst.order) dst.data = ret.data @classmethod def tile(cls, op): check_chunks_unknown_shape(op.inputs, TilesError) inputs = unify_chunks(*[(input, list(range(input.ndim))[::-1]) for input in op.inputs]) output = op.outputs[0] chunk_shapes = [t.chunk_shape if hasattr(t, 'chunk_shape') else t for t in inputs] out_chunk_shape = broadcast_shape(*chunk_shapes) out_chunks = [] nsplits = [[None] * shape for shape in out_chunk_shape] get_index = lambda idx, t: tuple(0 if t.nsplits[i] == (1,) else ix for i, ix in enumerate(idx)) for out_idx in itertools.product(*(map(range, out_chunk_shape))): in_chunks = [t.cix[get_index(out_idx[-t.ndim:], t)] if t.ndim != 0 else t.chunks[0] for t in inputs] out_chunk = op.copy().reset_key().new_chunk( in_chunks, shape=in_chunks[1].shape, order=output.order, index=out_idx) out_chunks.append(out_chunk) for i, idx, s in zip(itertools.count(0), out_idx, out_chunk.shape): nsplits[i][idx] = s new_op = op.copy() return new_op.new_tensors(op.inputs, output.shape, order=output.order, chunks=out_chunks, nsplits=nsplits) @classmethod def execute(cls, ctx, op): inputs, device_id, xp = as_same_device( [ctx[c.key] for c in op.inputs], device=op.device, ret_extra=True) with device(device_id): dst = inputs[1].copy() src = inputs[0] where = inputs[2] if len(inputs) > 2 else True xp.copyto(dst, src, casting=op.casting, where=where) ctx[op.outputs[0].key] = dst [docs]def copyto(dst, src, casting='same_kind', where=True): """ Copies values from one array to another, broadcasting as necessary. Raises a TypeError if the `casting` rule is violated, and if `where` is provided, it selects which elements to copy. Parameters ---------- dst : Tensor The tensor into which values are copied. src : array_like The tensor from which values are copied. casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional Controls what kind of data casting may occur when copying. * 'no' means the data types should not be cast at all. * 'equiv' means only byte-order changes are allowed. * 'safe' means only casts which can preserve values are allowed. * 'same_kind' means only safe casts or casts within a kind, like float64 to float32, are allowed. * 'unsafe' means any data conversions may be done. where : array_like of bool, optional A boolean tensor which is broadcasted to match the dimensions of `dst`, and selects elements to copy from `src` to `dst` wherever it contains the value True. """ op = TensorCopyTo(casting=casting) return op(src, dst, where)