# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import operator
import tempfile
from collections.abc import Iterable
import numpy as np
from ... import opcodes as OperandDef
from ...serialize import AnyField, BoolField, StringField, TupleField, SliceField
from ..array_utils import device, as_same_device
from ..utils import validate_axis, unify_chunks
from ..datasource import tensor as astensor
from ..operands import TensorOperand, TensorOperandMixin
from ..indexing.slice import TensorSlice
def _get_index(chunk):
try:
return chunk.index
except AttributeError:
if isinstance(chunk.op, TensorSlice):
return chunk.inputs[0].index
raise
def _norm_axis(axis):
if isinstance(axis, int):
return axis, True
if isinstance(axis, Iterable):
axis = sorted(tuple(axis))
if len(axis) == 1:
return axis[0], True
return axis, False
assert axis is None
return None, False
class TensorConcatenate(TensorOperand, TensorOperandMixin):
_op_type_ = OperandDef.CONCATENATE
_axis = AnyField('axis')
# for mmap
_mmap = BoolField('mmap')
_file_prefix = StringField('file_prefix')
_create_mmap_file = BoolField('create_mmap_file')
_partition_slice = SliceField('partition_slice')
_total_shape = TupleField('total_shape')
def __init__(self, axis=None, dtype=None, mmap=None, file_prefix=None, create_mmap_file=None,
partition_slice=None, total_shape=None, sparse=False, gpu=None, **kw):
super().__init__(_axis=axis, _dtype=dtype, _mmap=mmap,
_file_prefix=file_prefix,
_create_mmap_file=create_mmap_file,
_partition_slice=partition_slice,
_total_shape=total_shape,
_gpu=gpu, _sparse=sparse, **kw)
@property
def axis(self):
return getattr(self, '_axis', None)
@property
def mmap(self):
return self._mmap
@property
def file_prefix(self):
return self._file_prefix
@property
def create_mmap_file(self):
return self._create_mmap_file
@property
def partition_slice(self):
return self._partition_slice
@property
def total_shape(self):
return self._total_shape
def __call__(self, tensors):
if len(set(t.ndim for t in tensors)) != 1:
raise ValueError('all the input tensors must have same number of dimensions')
axis = self._axis
shapes = [t.shape[:axis] + t.shape[axis + 1:] for t in tensors]
if len(set(shapes)) != 1:
raise ValueError('all the input tensor dimensions '
'except for the concatenation axis must match exactly')
shape = [0 if i == axis else tensors[0].shape[i] for i in range(tensors[0].ndim)]
shape[axis] = sum(t.shape[axis] for t in tensors)
if any(np.isnan(s) for i, s in enumerate(shape) if i != axis):
raise ValueError('cannot concatenate tensor with unknown shape')
return self.new_tensor(tensors, shape=tuple(shape))
@classmethod
def tile(cls, op):
from ..indexing.slice import TensorSlice
inputs = op.inputs
output = op.outputs[0]
axis = op.axis
c = itertools.count(inputs[0].ndim)
tensor_axes = [(t, tuple(i if i != axis else next(c) for i in range(t.ndim)))
for t in inputs]
inputs = unify_chunks(*tensor_axes)
out_chunk_shape = [0 if i == axis else inputs[0].chunk_shape[i]
for i in range(inputs[0].ndim)]
out_chunk_shape[axis] = sum(t.chunk_shape[axis] for t in inputs)
out_nsplits = [None if i == axis else inputs[0].nsplits[i]
for i in range(inputs[0].ndim)]
out_nsplits[axis] = tuple(itertools.chain(*[t.nsplits[axis] for t in inputs]))
out_chunks = []
axis_cum_chunk_shape = np.cumsum([t.chunk_shape[axis] for t in inputs])
for out_idx in itertools.product(*[range(s) for s in out_chunk_shape]):
axis_index = np.searchsorted(axis_cum_chunk_shape, out_idx[axis], side='right')
t = inputs[axis_index]
axis_inner_index = out_idx[axis] - \
(0 if axis_index < 1 else axis_cum_chunk_shape[axis_index - 1])
idx = out_idx[:axis] + (axis_inner_index,) + out_idx[axis + 1:]
in_chunk = t.cix[idx]
if idx == out_idx:
# if index is the same, just use the input chunk
out_chunks.append(in_chunk)
else:
chunk_op = TensorSlice(slices=[slice(None) for _ in range(in_chunk.ndim)],
dtype=in_chunk.dtype, sparse=in_chunk.op.sparse)
out_chunk = chunk_op.new_chunk([in_chunk], shape=in_chunk.shape,
index=out_idx, order=output.order)
out_chunks.append(out_chunk)
new_op = op.copy()
return new_op.new_tensors(op.inputs, output.shape, order=output.order,
nsplits=out_nsplits, chunks=out_chunks)
@staticmethod
def _ensure_order(result, order):
return result.astype(result.dtype, order=order.value, copy=False)
@classmethod
def execute(cls, ctx, op):
if op.mmap: # pragma: no cover
cls._execute_with_mmap(ctx, op)
else:
cls._execute(ctx, op)
@classmethod
def _execute(cls, ctx, op):
def _base_concatenate(chunk, inputs):
inputs, device_id, xp = as_same_device(inputs, device=chunk.op.device, ret_extra=True)
axis, single_axis = _norm_axis(chunk.op.axis)
if single_axis:
with device(device_id):
res = xp.concatenate(tuple(inputs), axis=axis)
else:
axes = axis or list(range(chunk.ndim))
chunks = [(_get_index(input), data) for input, data in zip(chunk.inputs, inputs)]
with device(device_id):
for i in range(len(axes) - 1):
new_chunks = []
for idx, cs in itertools.groupby(chunks, key=lambda t: t[0][:-1]):
cs = list(map(operator.itemgetter(1), cs))
new_chunks.append((idx, xp.concatenate(cs, axis=len(axes) - i - 1)))
chunks = new_chunks
res = xp.concatenate(list(map(operator.itemgetter(1), chunks)), axis=axes[0])
return res
chunk = op.outputs[0]
inputs = [ctx[input.key] for input in op.inputs]
if isinstance(inputs[0], tuple):
ctx[chunk.key] = \
tuple(cls._ensure_order(_base_concatenate(chunk, [input[i] for input in inputs]), chunk.order)
for i in range(len(inputs[0])))
else:
ctx[chunk.key] = cls._ensure_order(_base_concatenate(chunk, inputs), chunk.order)
@classmethod
def _execute_with_mmap(cls, ctx, op): # pragma: no cover
if op.create_mmap_file:
path = tempfile.mkstemp(prefix=op.file_prefix, suffix='.dat')[1]
np.memmap(path, dtype=op.dtype, mode='w+', shape=op.total_shape)
ctx[op.outputs[0].key] = path
else:
path = ctx[op.inputs[0].key]
array = ctx[op.inputs[1].key]
fp = np.memmap(path, dtype=op.dtype, mode='r+', shape=op.total_shape)
fp[op.partition_slice] = array
ctx[op.outputs[0].key] = path
[docs]def concatenate(tensors, axis=0):
"""
Join a sequence of arrays along an existing axis.
Parameters
----------
a1, a2, ... : sequence of array_like
The tensors must have the same shape, except in the dimension
corresponding to `axis` (the first, by default).
axis : int, optional
The axis along which the tensors will be joined. Default is 0.
Returns
-------
res : Tensor
The concatenated tensor.
See Also
--------
array_split : Split a tensor into multiple sub-arrays of equal or
near-equal size.
split : Split tensor into a list of multiple sub-tensors of equal size.
hsplit : Split tensor into multiple sub-tensors horizontally (column wise)
vsplit : Split tensor into multiple sub-tensors vertically (row wise)
dsplit : Split tensor into multiple sub-tensors along the 3rd axis (depth).
stack : Stack a sequence of tensors along a new axis.
hstack : Stack tensors in sequence horizontally (column wise)
vstack : Stack tensors in sequence vertically (row wise)
dstack : Stack tensors in sequence depth wise (along third dimension)
Examples
--------
>>> import mars.tensor as mt
>>> a = mt.array([[1, 2], [3, 4]])
>>> b = mt.array([[5, 6]])
>>> mt.concatenate((a, b), axis=0).execute()
array([[1, 2],
[3, 4],
[5, 6]])
>>> mt.concatenate((a, b.T), axis=1).execute()
array([[1, 2, 5],
[3, 4, 6]])
"""
if axis is None:
axis = 0
tensors = [astensor(t) for t in tensors]
axis = validate_axis(tensors[0].ndim, axis)
dtype = np.result_type(*(t.dtype for t in tensors))
sparse = all(t.issparse() for t in tensors)
op = TensorConcatenate(axis=axis, dtype=dtype, sparse=sparse)
return op(tensors)