Source code for mars.core

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 1999-2020 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import builtins
import enum
import functools
import itertools
from concurrent.futures import ThreadPoolExecutor
from operator import attrgetter
from typing import List
from weakref import WeakKeyDictionary, WeakSet, ref

import numpy as np

from .serialize import HasKey, HasData, ValueType, ProviderType, Serializable, AttributeAsDict, \
    TupleField, ListField, DictField, KeyField, BoolField, StringField
from .tiles import Tileable, handler
from .utils import tokenize, AttributeDict, on_serialize_shape, \
    on_deserialize_shape, on_serialize_nsplits, enter_mode, is_build_mode


class Base(HasKey):
    __slots__ = ()
    _no_copy_attrs_ = {'_id'}
    _init_update_key_ = True

    def __init__(self, *args, **kwargs):
        for slot, arg in zip(self.__slots__, args):
            object.__setattr__(self, slot, arg)

        for key, val in kwargs.items():
            object.__setattr__(self, key, val)

        if self._init_update_key_ and (not hasattr(self, '_key') or not self._key):
            self._update_key()
        if not hasattr(self, '_id') or not self._id:
            self._id = str(id(self))

    @property
    def _keys_(self):
        cls = type(self)
        member = '__keys_' + cls.__name__
        try:
            return getattr(cls, member)
        except AttributeError:
            slots = sorted(self.__slots__)
            setattr(cls, member, slots)
            return slots

    @property
    def _values_(self):
        return [getattr(self, k, None) for k in self._keys_
                if k not in self._no_copy_attrs_]

    def __mars_tokenize__(self):
        if hasattr(self, '_key'):
            return self._key
        else:
            return (type(self), *self._values_)

    def _obj_set(self, k, v):
        object.__setattr__(self, k, v)

    def _update_key(self):
        self._obj_set('_key', tokenize(type(self).__name__, *self._values_))
        return self

    def reset_key(self):
        self._obj_set('_key', None)
        return self

    def __copy__(self):
        return self.copy()

    def copy(self):
        return self.copy_to(type(self)(_key=self.key))

    def copy_to(self, target):
        for attr in self.__slots__:
            if (attr.startswith('__') and attr.endswith('__')) or attr in self._no_copy_attrs_:
                # we don't copy id to identify that the copied one is new
                continue
            try:
                attr_val = getattr(self, attr)
            except AttributeError:
                continue
            setattr(target, attr, attr_val)

        return target

    def copy_from(self, obj):
        obj.copy_to(self)

    @property
    def key(self):
        return self._key

    @property
    def id(self):
        return self._id


class Entity(HasData):
    __slots__ = ()
    _allow_data_type_ = ()

    def __init__(self, data):
        self._check_data(data)
        self._data = data

    def __dir__(self):
        obj_dir = object.__dir__(self)
        if self._data is not None:
            obj_dir = sorted(set(dir(self._data) + obj_dir))
        return obj_dir

    def __str__(self):
        return self._data.__str__()

    def __repr__(self):
        return self._data.__repr__()

    def _check_data(self, data):
        if data is not None and not isinstance(data, self._allow_data_type_):
            raise TypeError(f'Expect {self._allow_data_type_}, got {type(data)}')

    @property
    def data(self):
        return self._data

    @data.setter
    def data(self, new_data):
        self._check_data(new_data)
        self._data = new_data

    def __copy__(self):
        return self.copy()

    def copy(self):
        return self.copy_to(type(self)(None))

    def copy_to(self, target):
        target.data = self._data
        return target

    def copy_from(self, obj):
        self.data = obj.data

    def tiles(self):
        new_entity = self.copy()
        new_entity.data = handler.tiles(self.data)
        return new_entity

    def _inplace_tile(self):
        return handler.inplace_tile(self)

    def __getattr__(self, attr):
        return getattr(self._data, attr)

    def __setattr__(self, key, value):
        try:
            object.__setattr__(self, key, value)
        except AttributeError:
            return setattr(self._data, key, value)


class SerializableWithKey(Base, Serializable):
    _key = StringField('key')
    _id = StringField('id')


class AttributeAsDictKey(Base, AttributeAsDict):
    _key = StringField('key')
    _id = StringField('id')


class EntityData(SerializableWithKey):
    __slots__ = '__weakref__', '_siblings'

    # required fields
    _op = KeyField('op')  # store key of operand here
    # optional fields
    _extra_params = DictField('extra_params', key_type=ValueType.string, on_deserialize=AttributeDict)

    def __init__(self, *args, **kwargs):
        extras = AttributeDict((k, kwargs.pop(k)) for k in set(kwargs) - set(self.__slots__))
        kwargs['_extra_params'] = kwargs.pop('_extra_params', extras)
        super().__init__(*args, **kwargs)

    @property
    def op(self):
        return getattr(self, '_op', None)

    @property
    def inputs(self):
        return self.op.inputs or []

    @inputs.setter
    def inputs(self, new_inputs):
        self.op.inputs = new_inputs

    def is_sparse(self):
        return self.op.is_sparse()

    issparse = is_sparse

    @property
    def extra_params(self):
        return self._extra_params


ENTITY_TYPE = (EntityData, Entity)


class ChunkData(EntityData):
    __slots__ = ()

    # optional fields
    _index = TupleField('index', ValueType.uint32)
    _cached = BoolField('cached')

    def __repr__(self):
        if self.op.stage is None:
            return f'Chunk <op={type(self.op).__name__}, key={self.key}>'
        else:
            return f'Chunk <op={type(self.op).__name__}, stage={self.op.stage.name}, key={self.key}>'

    @classmethod
    def cls(cls, provider):
        if provider.type == ProviderType.protobuf:
            from .serialize.protos.chunk_pb2 import ChunkDef
            return ChunkDef
        return super().cls(provider)

    @property
    def index(self):
        return getattr(self, '_index', None)

    @property
    def cached(self):
        return getattr(self, '_cached', None)

    @property
    def device(self):
        return self.op.device

    def _update_key(self):
        object.__setattr__(self, '_key', tokenize(
            type(self).__name__, *(getattr(self, k, None) for k in self._keys_ if k != '_index')))


class Chunk(Entity):
    __slots__ = ()
    _allow_data_type_ = (ChunkData,)


CHUNK_TYPE = (ChunkData, Chunk)


class ObjectChunkData(ChunkData):
    # chunk whose data could be any serializable
    __slots__ = ()

    def __init__(self, op=None, index=None, **kw):
        super().__init__(_op=op, _index=index, **kw)

    @classmethod
    def cls(cls, provider):
        if provider.type == ProviderType.protobuf:
            from .serialize.protos.object_pb2 import ObjectChunkDef
            return ObjectChunkDef
        return super().cls(provider)

    @property
    def params(self):
        # params return the properties which useful to rebuild a new chunk
        return {
            'index': self.index,
        }


class ObjectChunk(Chunk):
    __slots__ = ()
    _allow_data_type_ = (ObjectChunkData,)


class FuseChunkData(ChunkData):
    __slots__ = '_inited',

    _chunk = KeyField('chunk',
                      on_serialize=lambda x: x.data if hasattr(x, 'data') else x)

    def __init__(self, *args, **kwargs):
        self._inited = False
        super().__init__(*args, **kwargs)
        self._extra_params = {}
        self._inited = True

    @property
    def chunk(self):
        return self._chunk

    @property
    def composed(self):
        # for compatibility, just return the topological ordering,
        # once we apply optimization on the subgraph,
        # `composed` is not needed any more and should be removed then.
        assert getattr(self._op, 'fuse_graph', None) is not None
        fuse_graph = self._op.fuse_graph
        return list(fuse_graph.topological_iter())

    @classmethod
    def cls(cls, provider):
        if provider.type == ProviderType.protobuf:
            from .serialize.protos.fusechunk_pb2 import FuseChunkDef
            return FuseChunkDef
        return super().cls(provider)

    def __getattr__(self, attr):
        if not self._inited:
            return object.__getattribute__(self, attr)
        if attr in self._extra_params:
            return self._extra_params[attr]
        try:
            return getattr(self._chunk, attr)
        except AttributeError:
            return object.__getattribute__(self, attr)

    @property
    def nbytes(self):
        return np.prod(self.shape) * self.dtype.itemsize


class FuseChunk(Chunk):
    __slots__ = ()
    _allow_data_type_ = (FuseChunkData,)


FUSE_CHUNK_TYPE = (FuseChunkData, FuseChunk)


class _ExecutableMixin:
    __slots__ = ()

    def execute(self, session=None, **kw):
        from .session import Session

        if 'fetch' in kw and kw['fetch']:
            raise ValueError('Does not support fetch=True for `.execute()`,'
                             'please use `.fetch()` instead')
        else:
            kw['fetch'] = False

        wait = kw.pop('wait', True)

        if session is None:
            session = Session.default_or_local()

        def run():
            # no more fetch, thus just fire run
            session.run(self, **kw)
            # return Tileable or ExecutableTuple itself
            return self

        if wait:
            return run()
        else:
            # leverage ThreadPoolExecutor to submit task,
            # return a concurrent.future.Future
            thread_executor = ThreadPoolExecutor(1)
            return thread_executor.submit(run)

    def _get_session(self, session=None):
        from .session import Session

        if session is None and len(self._executed_sessions) > 0:
            session = self._executed_sessions[-1]
        if session is None:
            session = Session.default

        return session

    def _check_session(self, session, action):
        if session is None:
            if isinstance(self, tuple):
                key = self[0].key
            else:
                key = self.key
            raise ValueError(
                f'Tileable object {key} must be executed first before {action}')

    def fetch(self, session=None, **kw):
        session = self._get_session(session)
        self._check_session(session, 'fetch')
        return session.fetch(self, **kw)

    def fetch_log(self, session=None, offsets=None, sizes=None):
        session = self._get_session(session)
        self._check_session(session, 'fetch_log')
        return session.fetch_log([self], offsets=offsets, sizes=sizes)[0]

    def _attach_session(self, session):
        _cleaner.register(self, session)
        self._executed_sessions.append(session)


class _ExecuteAndFetchMixin:
    __slots__ = ()

    def _execute_and_fetch(self, session=None, **kw):
        if session is None and len(self._executed_sessions) > 0:
            session = self._executed_sessions[-1]
        try:
            # fetch first, to reduce the potential cost of submitting a graph
            return self.fetch(session=session)
        except ValueError:
            # not execute before
            wait = kw.pop('wait', True)

            def run():
                return self.execute(session=session, **kw).fetch(session=session)

            if wait:
                return run()
            else:
                thread_executor = ThreadPoolExecutor(1)
                return thread_executor.submit(run)


class _ToObjectMixin(_ExecuteAndFetchMixin):
    __slots__ = ()

    def to_object(self, session=None, **kw):
        return self._execute_and_fetch(session=session, **kw)


class TileableData(EntityData, Tileable, _ExecutableMixin):
    __slots__ = '_cix', '_entities', '_executed_sessions'
    _no_copy_attrs_ = SerializableWithKey._no_copy_attrs_ | {'_cix'}

    # optional fields
    # `nsplits` means the sizes of chunks for each dimension
    _nsplits = TupleField('nsplits', ValueType.tuple(ValueType.uint64),
                          on_serialize=on_serialize_nsplits)

    def __init__(self, *args, **kwargs):
        if kwargs.get('_nsplits', None) is not None:
            kwargs['_nsplits'] = tuple(tuple(s) for s in kwargs['_nsplits'])

        super().__init__(*args, **kwargs)

        if hasattr(self, '_chunks') and self._chunks:
            self._chunks = sorted(self._chunks, key=attrgetter('index'))

        self._entities = WeakSet()
        self._executed_sessions = []

    @property
    def chunk_shape(self):
        if hasattr(self, '_nsplits') and self._nsplits is not None:
            return tuple(map(len, self._nsplits))

    @property
    def chunks(self) -> List["Chunk"]:
        return getattr(self, '_chunks', None)

    @property
    def nsplits(self):
        return getattr(self, '_nsplits', None)

    @nsplits.setter
    def nsplits(self, new_nsplits):
        self._nsplits = new_nsplits

    @property
    def params(self) -> dict:
        # params return the properties which useful to rebuild a new tileable object
        return dict()

    @property
    def cix(self):
        if self.ndim == 0:
            return ChunksIndexer(self)

        try:
            if getattr(self, '_cix', None) is None:
                self._cix = ChunksIndexer(self)
            return self._cix
        except (TypeError, ValueError):
            return ChunksIndexer(self)

    @property
    def entities(self):
        return self._entities

    def is_coarse(self):
        return not hasattr(self, '_chunks') or self._chunks is None or len(self._chunks) == 0

    @enter_mode(build=True)
    def attach(self, entity):
        self._entities.add(entity)

    @enter_mode(build=True)
    def detach(self, entity):
        self._entities.discard(entity)


class TileableEntity(Entity):
    __slots__ = '__weakref__',

    def __init__(self, data):
        super().__init__(data)
        if self._data is not None:
            self._data.attach(self)
            if self._data.op.create_view:
                entity_view_handler.add_observer(self._data.inputs[0], self)

    def __copy__(self):
        return self._view()

    def _view(self):
        return super().copy()

    def copy(self):
        new_op = self.op.copy()
        if new_op.create_view:
            # if the operand is a view, make it a copy
            new_op._create_view = False
        params = []
        for o in self.op.outputs:
            param = o.params
            param['_key'] = o.key
            param.update(o.extra_params)
            params.append(param)
        new_outs = new_op.new_tileables(self.op.inputs, kws=params,
                                        output_limit=len(params))
        pos = -1
        for i, out in enumerate(self.op.outputs):
            # create a ref to copied one
            new_out = new_outs[i]
            if not hasattr(new_out.data, '_siblings'):
                new_out.data._siblings = []
            new_out.data._siblings.append(out)

            if self._data is out:
                pos = i
                break
        assert pos >= 0
        return new_outs[pos]

    @Entity.data.setter
    def data(self, new_data):
        self._check_data(new_data)
        if self._data is None:
            self._data = new_data
            self._data.attach(self)
        else:
            entity_view_handler.data_changed(self._data, new_data)


TILEABLE_TYPE = (TileableEntity, TileableData)


class HasShapeTileableData(TileableData):
    # required fields
    _shape = TupleField('shape', ValueType.int64,
                        on_serialize=on_serialize_shape, on_deserialize=on_deserialize_shape)

    @property
    def ndim(self):
        return len(self.shape)

    def __len__(self):
        try:
            return self.shape[0]
        except IndexError:
            if is_build_mode():
                return 0
            raise TypeError('len() of unsized object')

    @property
    def shape(self):
        if hasattr(self, '_shape') and self._shape is not None:
            return self._shape
        if hasattr(self, '_nsplits') and self._nsplits is not None:
            self._shape = tuple(builtins.sum(nsplit) for nsplit in self._nsplits)
            return self._shape

    def _update_shape(self, new_shape):
        self._shape = new_shape

    @property
    def size(self):
        return np.prod(self.shape).item()

    @property
    def params(self):
        # params return the properties which useful to rebuild a new tileable object
        return {
            'shape': self.shape
        }

    def _equals(self, o):
        return self is o


class HasShapeTileableEnity(TileableEntity):
    __slots__ = ()

    @property
    def shape(self):
        return self._data.shape

    @property
    def ndim(self):
        return self._data.ndim

    @property
    def size(self):
        return self._data.size

    def execute(self, session=None, **kw):
        wait = kw.pop('wait', True)

        def run():
            self.data.execute(session, **kw)
            return self

        if wait:
            return run()
        else:
            thread_executor = ThreadPoolExecutor(1)
            return thread_executor.submit(run)


class ObjectData(TileableData, _ToObjectMixin):
    __slots__ = ()

    # optional fields
    _chunks = ListField('chunks', ValueType.reference(ObjectChunkData),
                        on_serialize=lambda x: [it.data for it in x] if x is not None else x,
                        on_deserialize=lambda x: [ObjectChunk(it) for it in x] if x is not None else x)

    def __init__(self, op=None, nsplits=None, chunks=None, **kw):
        super().__init__(_op=op, _nsplits=nsplits, _chunks=chunks, **kw)

    def __repr__(self):
        return f'Object <op={type(self.op).__name__}, key={self.key}>'

    @classmethod
    def cls(cls, provider):
        if provider.type == ProviderType.protobuf:
            from .serialize.protos.object_pb2 import ObjectDef
            return ObjectDef
        return super().cls(provider)

    @property
    def params(self):
        # params return the properties which useful to rebuild a new tileable object
        return {
        }


class Object(Entity, _ToObjectMixin):
    __slots__ = ()
    _allow_data_type_ = (ObjectData,)


OBJECT_TYPE = (Object, ObjectData)
OBJECT_CHUNK_TYPE = (ObjectChunk, ObjectChunkData)


class ChunksIndexer(object):
    __slots__ = '_tileable',

    def __init__(self, tileable):
        self._tileable = tileable

    def __getitem__(self, item):
        """
        The indices for `cix` can be [x, y] or [x, :]. For the former the result will be
        a single chunk, and for the later the result will be a list of chunks (flattened).

        The length of indices must be the same with `chunk_shape` of tileable.
        """
        if isinstance(item, tuple):
            if len(item) == 0 and self._tileable.is_scalar():
                return self._tileable.chunks[0]
            if len(item) != self._tileable.ndim:
                raise ValueError(f'Cannot get chunk by {item}, expect length {self._tileable.ndim}')
            slices, singleton = [], True
            for it, dim in zip(item, self._tileable.chunk_shape):
                if isinstance(it, slice):
                    slices.append(range(dim)[it])
                    singleton = False
                elif np.issubdtype(type(it), np.integer):
                    slices.append([it if it >= 0 else dim + it])
                else:
                    raise TypeError(f'Cannot get chunk by {it}, invalid value has type {type(it)}')

            indexes = tuple(zip(*itertools.product(*slices)))

            flat_index = np.ravel_multi_index(indexes, self._tileable.chunk_shape)
            if singleton:
                return self._tileable._chunks[flat_index[0]]
            else:
                return [self._tileable._chunks[idx] for idx in flat_index]

        raise ValueError(f'Cannot get {type(self._tileable).__name__} chunk by {item}')


[docs]class ExecutableTuple(tuple, _ExecutableMixin, _ToObjectMixin):
[docs] def __init__(self, *_): super().__init__() self._executed_sessions = []
def execute(self, session=None, **kw): if len(self) == 0: return self return super().execute(session=session, **kw) def fetch(self, session=None, **kw): if len(self) == 0: return tuple() return super().fetch(session=session, **kw) def fetch_log(self, session=None, offsets=None, sizes=None): if len(self) == 0: return [] session = self._get_session(session=session) return session.fetch_log(self, offsets=offsets, sizes=sizes) def _get_session(self, session=None): session = super()._get_session(session=session) if session is None: for item in self: session = item._get_session() if session is not None: return session return session
class _TileableSession: def __init__(self, tensor, session): key = tensor.key, tensor.id def cb(_, sess=ref(session)): s = sess() if s: s.decref(key) self._tensor = ref(tensor, cb) class _TileableDataCleaner: def __init__(self): self._tileable_to_sessions = WeakKeyDictionary() @enter_mode(build=True) def register(self, tensor, session): if tensor in self._tileable_to_sessions: self._tileable_to_sessions[tensor].append(_TileableSession(tensor, session)) else: self._tileable_to_sessions[tensor] = [_TileableSession(tensor, session)] # we don't use __del__ to avoid potential Circular reference _cleaner = _TileableDataCleaner() class EntityDataModificationHandler: def __init__(self): self._data_to_entities = WeakKeyDictionary() def _add_observer(self, data, entity): # only tileable data should be considered assert isinstance(data, TileableData) assert isinstance(entity, TileableEntity) if data not in self._data_to_entities: self._data_to_entities[data] = WeakSet() self._data_to_entities[data].add(entity) @enter_mode(build=True) def add_observer(self, data, entity): self._add_observer(data, entity) def _update_observe_data(self, observer, data, new_data): self._data_to_entities.get(data, set()).discard(observer) self._add_observer(new_data, observer) @staticmethod def _set_data(entity, data): entity._data.detach(entity) entity._data = data data.attach(entity) @staticmethod def _get_data(obj): return obj.data if isinstance(obj, Entity) else obj @enter_mode(build=True) def data_changed(self, old_data, new_data): notified = set() processed_data = set() old_to_new = {old_data: new_data} q = [old_data] while len(q) > 0: data = q.pop() # handle entities for entity in data.entities: self._set_data(entity, old_to_new[data]) notified.add(entity) observers = {ob for ob in self._data_to_entities.pop(data, set()) if ob not in notified} for ob in observers: new_data = self._get_data(ob.op.on_input_modify(old_to_new[data])) old_data = ob.data self._update_observe_data(ob, ob.data, new_data) old_to_new[old_data] = new_data if old_data not in processed_data: q.append(old_data) processed_data.add(old_data) notified.add(ob) if data.op.create_view: old_input_data = data.inputs[0] new_input_data = self._get_data(data.op.on_output_modify(old_to_new[data])) old_to_new[old_input_data] = new_input_data if old_input_data not in processed_data: q.append(old_input_data) processed_data.add(old_input_data) entity_view_handler = EntityDataModificationHandler() class OutputType(enum.Enum): object = 1 tensor = 2 dataframe = 3 series = 4 index = 5 scalar = 6 categorical = 7 dataframe_groupby = 8 series_groupby = 9 @classmethod def serialize_list(cls, output_types): return [ot.value for ot in output_types] if output_types is not None else None @classmethod def deserialize_list(cls, output_types): return [cls(ot) for ot in output_types] if output_types is not None else None _OUTPUT_TYPE_TO_CHUNK_TYPES = {OutputType.object: OBJECT_CHUNK_TYPE} _OUTPUT_TYPE_TO_TILEABLE_TYPES = {OutputType.object: OBJECT_TYPE} _OUTPUT_TYPE_TO_FETCH_CLS = {} def register_output_types(output_type, tileable_types, chunk_types): _OUTPUT_TYPE_TO_TILEABLE_TYPES[output_type] = tileable_types _OUTPUT_TYPE_TO_CHUNK_TYPES[output_type] = chunk_types def register_fetch_class(output_type, fetch_cls, fetch_shuffle_cls): _OUTPUT_TYPE_TO_FETCH_CLS[output_type] = (fetch_cls, fetch_shuffle_cls) def get_tileable_types(output_type): return _OUTPUT_TYPE_TO_TILEABLE_TYPES[output_type] def get_chunk_types(output_type): return _OUTPUT_TYPE_TO_CHUNK_TYPES[output_type] def get_fetch_class(output_type): return _OUTPUT_TYPE_TO_FETCH_CLS[output_type] @functools.lru_cache(100) def _get_output_type_by_cls(cls): for tp in OutputType.__members__.values(): try: tileable_types = _OUTPUT_TYPE_TO_TILEABLE_TYPES[tp] chunk_types = _OUTPUT_TYPE_TO_CHUNK_TYPES[tp] if issubclass(cls, (tileable_types, chunk_types)): return tp except KeyError: # pragma: no cover continue raise TypeError('Output can only be tensor, dataframe or series') def get_output_types(*objs, unknown_as=None): output_types = [] for obj in objs: if obj is None: continue elif isinstance(obj, (FuseChunk, FuseChunkData)): obj = obj.chunk try: output_types.append(_get_output_type_by_cls(type(obj))) except TypeError: if unknown_as is not None: output_types.append(unknown_as) else: # pragma: no cover raise return output_types