# Copyright 1999-2020 Alibaba Group Holding Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import itertools
import operator
import tempfile
from import Iterable

import numpy as np

from ... import opcodes as OperandDef
from ...serialize import AnyField, BoolField, StringField, TupleField, SliceField
from ..array_utils import device, as_same_device
from ..utils import validate_axis, unify_chunks
from ..datasource import tensor as astensor
from ..operands import TensorOperand, TensorOperandMixin
from ..indexing.slice import TensorSlice

def _get_index(chunk):
        return chunk.index
    except AttributeError:
        if isinstance(chunk.op, TensorSlice):
            return chunk.inputs[0].index

def _norm_axis(axis):
    if isinstance(axis, int):
        return axis, True
    if isinstance(axis, Iterable):
        axis = sorted(tuple(axis))
        if len(axis) == 1:
            return axis[0], True
        return axis, False

    assert axis is None
    return None, False

class TensorConcatenate(TensorOperand, TensorOperandMixin):
    _op_type_ = OperandDef.CONCATENATE

    _axis = AnyField('axis')

    # for mmap
    _mmap = BoolField('mmap')
    _file_prefix = StringField('file_prefix')
    _create_mmap_file = BoolField('create_mmap_file')
    _partition_slice = SliceField('partition_slice')
    _total_shape = TupleField('total_shape')

    def __init__(self, axis=None, dtype=None, mmap=None, file_prefix=None, create_mmap_file=None,
                 partition_slice=None, total_shape=None, sparse=False, gpu=None, **kw):
        super().__init__(_axis=axis, _dtype=dtype, _mmap=mmap,
                         _gpu=gpu, _sparse=sparse, **kw)

    def axis(self):
        return getattr(self, '_axis', None)

    def mmap(self):
        return self._mmap

    def file_prefix(self):
        return self._file_prefix

    def create_mmap_file(self):
        return self._create_mmap_file

    def partition_slice(self):
        return self._partition_slice

    def total_shape(self):
        return self._total_shape

    def __call__(self, tensors):
        if len(set(t.ndim for t in tensors)) != 1:
            raise ValueError('all the input tensors must have same number of dimensions')

        axis = self._axis
        shapes = [t.shape[:axis] + t.shape[axis + 1:] for t in tensors]
        if len(set(shapes)) != 1:
            raise ValueError('all the input tensor dimensions '
                             'except for the concatenation axis must match exactly')

        shape = [0 if i == axis else tensors[0].shape[i] for i in range(tensors[0].ndim)]
        shape[axis] = sum(t.shape[axis] for t in tensors)

        if any(np.isnan(s) for i, s in enumerate(shape) if i != axis):
            raise ValueError('cannot concatenate tensor with unknown shape')

        return self.new_tensor(tensors, shape=tuple(shape))

    def tile(cls, op):
        from ..indexing.slice import TensorSlice

        inputs = op.inputs
        output = op.outputs[0]
        axis = op.axis

        c = itertools.count(inputs[0].ndim)
        tensor_axes = [(t, tuple(i if i != axis else next(c) for i in range(t.ndim)))
                       for t in inputs]
        inputs = unify_chunks(*tensor_axes)

        out_chunk_shape = [0 if i == axis else inputs[0].chunk_shape[i]
                           for i in range(inputs[0].ndim)]
        out_chunk_shape[axis] = sum(t.chunk_shape[axis] for t in inputs)
        out_nsplits = [None if i == axis else inputs[0].nsplits[i]
                       for i in range(inputs[0].ndim)]
        out_nsplits[axis] = tuple(itertools.chain(*[t.nsplits[axis] for t in inputs]))

        out_chunks = []
        axis_cum_chunk_shape = np.cumsum([t.chunk_shape[axis] for t in inputs])
        for out_idx in itertools.product(*[range(s) for s in out_chunk_shape]):
            axis_index = np.searchsorted(axis_cum_chunk_shape, out_idx[axis], side='right')
            t = inputs[axis_index]
            axis_inner_index = out_idx[axis] - \
                (0 if axis_index < 1 else axis_cum_chunk_shape[axis_index - 1])
            idx = out_idx[:axis] + (axis_inner_index,) + out_idx[axis + 1:]
            in_chunk = t.cix[idx]
            if idx == out_idx:
                # if index is the same, just use the input chunk
                chunk_op = TensorSlice(slices=[slice(None) for _ in range(in_chunk.ndim)],
                                       dtype=in_chunk.dtype, sparse=in_chunk.op.sparse)
                out_chunk = chunk_op.new_chunk([in_chunk], shape=in_chunk.shape,
                                               index=out_idx, order=output.order)


        new_op = op.copy()
        return new_op.new_tensors(op.inputs, output.shape, order=output.order,
                                  nsplits=out_nsplits, chunks=out_chunks)

    def _ensure_order(result, order):
        return result.astype(result.dtype, order=order.value, copy=False)

    def execute(cls, ctx, op):
        if op.mmap:  # pragma: no cover
            cls._execute_with_mmap(ctx, op)
            cls._execute(ctx, op)

    def _execute(cls, ctx, op):
        def _base_concatenate(chunk, inputs):
            inputs, device_id, xp = as_same_device(inputs, device=chunk.op.device, ret_extra=True)

            axis, single_axis = _norm_axis(chunk.op.axis)
            if single_axis:
                with device(device_id):
                    res = xp.concatenate(tuple(inputs), axis=axis)
                axes = axis or list(range(chunk.ndim))
                chunks = [(_get_index(input), data) for input, data in zip(chunk.inputs, inputs)]
                with device(device_id):
                    for i in range(len(axes) - 1):
                        new_chunks = []
                        for idx, cs in itertools.groupby(chunks, key=lambda t: t[0][:-1]):
                            cs = list(map(operator.itemgetter(1), cs))
                            new_chunks.append((idx, xp.concatenate(cs, axis=len(axes) - i - 1)))
                        chunks = new_chunks
                    res = xp.concatenate(list(map(operator.itemgetter(1), chunks)), axis=axes[0])
            return res

        chunk = op.outputs[0]
        inputs = [ctx[input.key] for input in op.inputs]

        if isinstance(inputs[0], tuple):
            ctx[chunk.key] = \
                tuple(cls._ensure_order(_base_concatenate(chunk, [input[i] for input in inputs]), chunk.order)
                      for i in range(len(inputs[0])))
            ctx[chunk.key] = cls._ensure_order(_base_concatenate(chunk, inputs), chunk.order)

    def _execute_with_mmap(cls, ctx, op):  # pragma: no cover
        if op.create_mmap_file:
            path = tempfile.mkstemp(prefix=op.file_prefix, suffix='.dat')[1]
            np.memmap(path, dtype=op.dtype, mode='w+', shape=op.total_shape)
            ctx[op.outputs[0].key] = path
            path = ctx[op.inputs[0].key]
            array = ctx[op.inputs[1].key]
            fp = np.memmap(path, dtype=op.dtype, mode='r+', shape=op.total_shape)
            fp[op.partition_slice] = array
            ctx[op.outputs[0].key] = path

[docs]def concatenate(tensors, axis=0): """ Join a sequence of arrays along an existing axis. Parameters ---------- a1, a2, ... : sequence of array_like The tensors must have the same shape, except in the dimension corresponding to `axis` (the first, by default). axis : int, optional The axis along which the tensors will be joined. Default is 0. Returns ------- res : Tensor The concatenated tensor. See Also -------- array_split : Split a tensor into multiple sub-arrays of equal or near-equal size. split : Split tensor into a list of multiple sub-tensors of equal size. hsplit : Split tensor into multiple sub-tensors horizontally (column wise) vsplit : Split tensor into multiple sub-tensors vertically (row wise) dsplit : Split tensor into multiple sub-tensors along the 3rd axis (depth). stack : Stack a sequence of tensors along a new axis. hstack : Stack tensors in sequence horizontally (column wise) vstack : Stack tensors in sequence vertically (row wise) dstack : Stack tensors in sequence depth wise (along third dimension) Examples -------- >>> import mars.tensor as mt >>> a = mt.array([[1, 2], [3, 4]]) >>> b = mt.array([[5, 6]]) >>> mt.concatenate((a, b), axis=0).execute() array([[1, 2], [3, 4], [5, 6]]) >>> mt.concatenate((a, b.T), axis=1).execute() array([[1, 2, 5], [3, 4, 6]]) """ if axis is None: axis = 0 tensors = [astensor(t) for t in tensors] axis = validate_axis(tensors[0].ndim, axis) dtype = np.result_type(*(t.dtype for t in tensors)) sparse = all(t.issparse() for t in tensors) op = TensorConcatenate(axis=axis, dtype=dtype, sparse=sparse) return op(tensors)