Source code for

# Copyright 1999-2021 Alibaba Group Holding Ltd.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
from typing import Any, List, Tuple, Type, TypeVar, Union

from .... import oscar as mo
from ....lib.aio import alru_cache
from import StorageLevel, StorageFileObject
from ...cluster import StorageInfo
from ..core import (
from ..handler import StorageHandlerActor
from .core import AbstractStorageAPI

_is_windows = sys.platform.lower().startswith("win")
APIType = TypeVar("APIType", bound="StorageAPI")

[docs]class StorageAPI(AbstractStorageAPI): _storage_handler_ref: mo.ActorRefType[StorageHandlerActor] _data_manager_ref: mo.ActorRefType[DataManagerActor]
[docs] def __init__(self, address: str, session_id: str, band_name: str): self._address = address self._session_id = session_id self._band_name = band_name
async def _init(self): self._storage_handler_ref = await mo.actor_ref( self._address, StorageHandlerActor.gen_uid(self._band_name) ) self._data_manager_ref = await mo.actor_ref( self._address, DataManagerActor.default_uid() ) @classmethod @alru_cache(cache_exceptions=False) async def create( cls: Type[APIType], session_id: str, address: str, band_name: str = "numa-0", **kwargs, ) -> APIType: """ Create storage API. Parameters ---------- session_id: str session id address: str worker address band_name: str name of band, default as 'numa-0' Returns ------- storage_api Storage api. """ if kwargs: # pragma: no cover raise TypeError(f'Got unexpected arguments: {",".join(kwargs)}') api = StorageAPI(address, session_id, band_name) await api._init() return api async def is_seekable(self, storage_level: StorageLevel = None) -> bool: """ If storage backend is seekable. """ return await self._storage_handler_ref.is_seekable(storage_level) @mo.extensible async def get( self, data_key: str, conditions: List = None, error: str = "raise" ) -> Any: return await self._storage_handler_ref.get( self._session_id, data_key, conditions, error ) @get.batch async def batch_get(self, args_list, kwargs_list): gets = [] for args, kwargs in zip(args_list, kwargs_list): gets.append( self._storage_handler_ref.get.delay(self._session_id, *args, **kwargs) ) return await self._storage_handler_ref.get.batch(*gets) @mo.extensible async def put( self, data_key: str, obj: object, level: StorageLevel = None ) -> DataInfo: return await self._storage_handler_ref.put( self._session_id, data_key, obj, level ) @put.batch async def batch_put(self, args_list, kwargs_list): puts = [] for args, kwargs in zip(args_list, kwargs_list): puts.append( self._storage_handler_ref.put.delay(self._session_id, *args, **kwargs) ) return await self._storage_handler_ref.put.batch(*puts) @mo.extensible async def get_infos(self, data_key: str) -> List[DataInfo]: """ Get data information items for specific data key Parameters ---------- data_key Returns ------- out List of information for specified key """ return await self._data_manager_ref.get_data_infos( self._session_id, data_key, self._band_name ) @mo.extensible async def delete(self, data_key: str, error: str = "raise"): """ Delete object. Parameters ---------- data_key: str object key to delete error: str raise or ignore """ await self._storage_handler_ref.delete(self._session_id, data_key, error=error) @delete.batch async def batch_delete(self, args_list, kwargs_list): deletes = [] for args, kwargs in zip(args_list, kwargs_list): deletes.append( self._storage_handler_ref.delete.delay( self._session_id, *args, **kwargs ) ) return await self._storage_handler_ref.delete.batch(*deletes) @mo.extensible async def fetch( self, data_key: Union[str, Tuple], level: StorageLevel = None, band_name: str = None, remote_address: str = None, error: str = "raise", ): """ Fetch object from remote worker or load object from disk. Parameters ---------- data_key: str or tuple data key(tuple when is shuffle key) to fetch to current worker with specific level. level: StorageLevel the storage level to put into, MEMORY as default band_name: BandType put data on specific band remote_address: remote address that stores the data error: str raise or ignore """ await self._storage_handler_ref.fetch_batch( self._session_id, [data_key], level, band_name, remote_address, error ) @fetch.batch async def batch_fetch(self, args_list, kwargs_list): extracted_args = [] data_keys = [] for args, kwargs in zip(args_list, kwargs_list): data_key, level, band_name, dest_address, error = self.fetch.bind( *args, **kwargs ) if extracted_args: assert extracted_args == (level, band_name, dest_address, error) extracted_args = (level, band_name, dest_address, error) data_keys.append(data_key) await self._storage_handler_ref.fetch_batch( self._session_id, data_keys, *extracted_args ) @mo.extensible async def unpin(self, data_key: str, error: str = "raise"): """ Unpin the data, allow storage to release the data. Parameters ---------- data_key: str data key to unpin error: str raise or ignore """ await self._storage_handler_ref.unpin(self._session_id, data_key, error) @unpin.batch async def batch_unpin(self, args_list, kwargs_list): unpins = [] for args, kwargs in zip(args_list, kwargs_list): data_key, error = self.unpin.bind(*args, **kwargs) unpins.append( self._storage_handler_ref.unpin.delay(self._session_id, data_key, error) ) return await self._storage_handler_ref.unpin.batch(*unpins) async def open_reader(self, data_key: str) -> StorageFileObject: """ Return a file-like object for reading. Parameters ---------- data_key: str data key Returns ------- return a file-like object. """ return await self._storage_handler_ref.open_reader(self._session_id, data_key) async def open_writer( self, data_key: Union[Tuple, str], size: int, level: StorageLevel = None ) -> WrappedStorageFileObject: """ Return a file-like object for writing data. Parameters ---------- data_key: str or tuple data key size: int the total size of data level: StorageLevel the storage level to write Returns ------- return a file-like object. """ return await self._storage_handler_ref.open_writer( self._session_id, data_key, size, level ) async def list(self, level: StorageLevel) -> List: """ List all stored data_keys in storage. Parameters ---------- level: StorageLevel the storage level to list all objects Returns ------- list of data keys """ return await self._storage_handler_ref.list(level=level) async def get_storage_level_info(self, level: StorageLevel) -> StorageInfo: """ Get storage level's info. Parameters ---------- level : StorageLevel Storage level. Returns ------- storage_level_info : StorageInfo """ return await self._storage_handler_ref.get_storage_level_info(level) async def get_storage_info(self, level: StorageLevel) -> dict: """ Get the customized storage backend info of requested storage backend. Parameters ---------- level: StorageLevel the storage level to fetch the backend info. Returns ------- info : dict Customized storage backend info dict. """ return await self._storage_handler_ref.get_storage_backend_info(level)
class MockStorageAPI(StorageAPI): @classmethod async def create( cls: Type[APIType], session_id: str, address: str, **kwargs ) -> APIType: from ..core import StorageManagerActor storage_configs = kwargs.get("storage_configs") if not storage_configs: storage_configs = {"shared_memory": {}} storage_handler_cls = kwargs.pop("storage_handler_cls", StorageHandlerActor) await mo.create_actor( StorageManagerActor, storage_configs, storage_handler_cls=storage_handler_cls, uid=StorageManagerActor.default_uid(), address=address, ) return await super().create(address=address, session_id=session_id) @classmethod async def cleanup(cls: Type[APIType], address: str): await mo.destroy_actor( await mo.actor_ref(address, StorageManagerActor.default_uid()) )