Source code for mars.services.cluster.api.oscar

# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import logging
from typing import List, Dict, Optional, Set, Type, TypeVar

from .... import oscar as mo
from ....lib.aio import alru_cache
from ....resource import Resource
from ....typing import BandType
from ...core import NodeRole
from ..core import (
    watch_method,
    NodeStatus,
    WorkerSlotInfo,
    QuotaInfo,
    DiskInfo,
    StorageInfo,
)
from .core import AbstractClusterAPI

APIType = TypeVar("APIType", bound="ClusterAPI")
logger = logging.getLogger(__name__)


[docs]class ClusterAPI(AbstractClusterAPI):
[docs] def __init__(self, address: str): self._address = address self._locator_ref = None self._uploader_ref = None
async def _init(self): from ..locator import SupervisorLocatorActor from ..uploader import NodeInfoUploaderActor self._locator_ref = await mo.actor_ref( SupervisorLocatorActor.default_uid(), address=self._address ) self._uploader_ref = await mo.actor_ref( NodeInfoUploaderActor.default_uid(), address=self._address ) @classmethod @alru_cache(cache_exceptions=False) async def create(cls: Type[APIType], address: str) -> APIType: api_obj = cls(address) await api_obj._init() return api_obj @alru_cache(cache_exceptions=False) async def _get_node_info_ref(self): from ..supervisor.node_info import NodeInfoCollectorActor [node_info_ref] = await self.get_supervisor_refs( [NodeInfoCollectorActor.default_uid()] ) return node_info_ref async def get_supervisors(self, filter_ready: bool = True) -> List[str]: return await self._locator_ref.get_supervisors(filter_ready=filter_ready) @watch_method async def watch_supervisors(self, version: Optional[int] = None): return await self._locator_ref.watch_supervisors(version=version) async def get_supervisors_by_keys(self, keys: List[str]) -> List[str]: """ Get supervisor address hosting the specified key Parameters ---------- keys key for a supervisor address Returns ------- out addresses of the supervisor """ get_supervisor = self._locator_ref.get_supervisor return await get_supervisor.batch(*(get_supervisor.delay(k) for k in keys)) @watch_method async def watch_supervisors_by_keys( self, keys: List[str], version: Optional[int] = None ): return await self._locator_ref.watch_supervisors_by_keys(keys, version=version) async def get_supervisor_refs(self, uids: List[str]) -> List[mo.ActorRef]: """ Get actor references hosting the specified actor uid Parameters ---------- uids uids for a supervisor address watch if True, will watch changes of supervisor changes Returns ------- out : List[mo.ActorRef] references of the actors """ addrs = await self.get_supervisors_by_keys(uids) if any(addr is None for addr in addrs): none_uid = next(uid for addr, uid in zip(addrs, uids) if addr is None) raise mo.ActorNotExist(f"Actor {none_uid} not exist as no supervisors") return await asyncio.gather( *[mo.actor_ref(uid, address=addr) for addr, uid in zip(addrs, uids)] ) async def watch_supervisor_refs(self, uids: List[str]): async for addrs in self.watch_supervisors_by_keys(uids): yield await asyncio.gather( *[mo.actor_ref(uid, address=addr) for addr, uid in zip(addrs, uids)] ) @watch_method async def watch_nodes( self, role: NodeRole, env: bool = False, resource: bool = False, detail: bool = False, version: Optional[int] = None, statuses: Set[NodeStatus] = None, exclude_statuses: Set[NodeStatus] = None, ) -> List[Dict[str, Dict]]: statuses = self._calc_statuses(statuses, exclude_statuses) node_info_ref = await self._get_node_info_ref() return await node_info_ref.watch_nodes( role, env=env, resource=resource, detail=detail, statuses=statuses, version=version, ) async def get_nodes_info( self, nodes: List[str] = None, role: NodeRole = None, env: bool = False, resource: bool = False, detail: bool = False, statuses: Set[NodeStatus] = None, exclude_statuses: Set[NodeStatus] = None, ) -> Dict[str, Dict]: statuses = self._calc_statuses(statuses, exclude_statuses) node_info_ref = await self._get_node_info_ref() return await node_info_ref.get_nodes_info( nodes=nodes, role=role, env=env, resource=resource, detail=detail, statuses=statuses, ) async def set_node_status(self, node: str, role: NodeRole, status: NodeStatus): """ Set status of node Parameters ---------- node : str address of node role: NodeRole role of node status : NodeStatus status of node """ node_info_ref = await self._get_node_info_ref() await node_info_ref.update_node_info(node, role, status=status) async def get_all_bands( self, role: NodeRole = None, statuses: Set[NodeStatus] = None, exclude_statuses: Set[NodeStatus] = None, ) -> Dict[BandType, Resource]: statuses = self._calc_statuses(statuses, exclude_statuses) node_info_ref = await self._get_node_info_ref() return await node_info_ref.get_all_bands(role, statuses=statuses) @watch_method async def watch_all_bands( self, role: NodeRole = None, version: Optional[int] = None, statuses: Set[NodeStatus] = None, exclude_statuses: Set[NodeStatus] = None, ): statuses = self._calc_statuses(statuses, exclude_statuses) node_info_ref = await self._get_node_info_ref() return await node_info_ref.watch_all_bands( role, statuses=statuses, version=version ) async def get_mars_versions(self) -> List[str]: node_info_ref = await self._get_node_info_ref() return await node_info_ref.get_mars_versions() async def get_bands(self) -> Dict: """ Get bands that can be used for computation on current node. Returns ------- band_to_resource : dict Band to resource. """ return await self._uploader_ref.get_bands() async def mark_node_ready(self): """ Mark current node ready for work loads """ await self._uploader_ref.mark_node_ready() async def wait_node_ready(self): """ Wait current node to be ready """ await self._uploader_ref.wait_node_ready() async def wait_all_supervisors_ready(self): """ Wait till all expected supervisors are ready """ await self._locator_ref.wait_all_supervisors_ready() async def set_band_slot_infos( self, band_name: str, slot_infos: List[WorkerSlotInfo] ): await self._uploader_ref.set_band_slot_infos.tell(band_name, slot_infos) async def set_band_quota_info(self, band_name: str, quota_info: QuotaInfo): await self._uploader_ref.set_band_quota_info.tell(band_name, quota_info) async def set_node_disk_info(self, disk_info: List[DiskInfo]): await self._uploader_ref.set_node_disk_info(disk_info) @mo.extensible async def set_band_storage_info(self, band_name: str, storage_info: StorageInfo): await self._uploader_ref.set_band_storage_info(band_name, storage_info) async def request_worker( self, worker_cpu: int = None, worker_mem: int = None, timeout: int = None ) -> str: node_allocator_ref = await self._get_node_allocator_ref() address = await node_allocator_ref.request_worker( worker_cpu, worker_mem, timeout ) return address async def release_worker(self, address: str): node_allocator_ref = await self._get_node_allocator_ref() await node_allocator_ref.release_worker(address) node_info_ref = await self._get_node_info_ref() await node_info_ref.update_node_info( address, NodeRole.WORKER, status=NodeStatus.STOPPED ) async def reconstruct_worker(self, address: str): node_allocator_ref = await self._get_node_allocator_ref() await node_allocator_ref.reconstruct_worker(address) @alru_cache(cache_exceptions=False) async def _get_node_allocator_ref(self): from ..supervisor.node_allocator import NodeAllocatorActor [node_allocator_ref] = await self.get_supervisor_refs( [NodeAllocatorActor.default_uid()] ) return node_allocator_ref async def _get_process_info_manager_ref(self, address: str = None): from ..procinfo import ProcessInfoManagerActor return await mo.actor_ref( ProcessInfoManagerActor.default_uid(), address=address or self._address ) async def get_node_pool_configs(self, address: str = None) -> List[Dict]: ref = await self._get_process_info_manager_ref(address) return await ref.get_pool_configs() async def get_node_thread_stacks( self, address: str = None ) -> List[Dict[int, List[str]]]: ref = await self._get_process_info_manager_ref(address) return await ref.get_thread_stacks()
class MockClusterAPI(ClusterAPI): @classmethod async def create(cls: Type[APIType], address: str, **kw) -> APIType: from ..procinfo import ProcessInfoManagerActor from ..supervisor.locator import SupervisorPeerLocatorActor from ..supervisor.node_allocator import NodeAllocatorActor from ..supervisor.node_info import NodeInfoCollectorActor from ..uploader import NodeInfoUploaderActor create_actor_coros = [ mo.create_actor( SupervisorPeerLocatorActor, "fixed", address, uid=SupervisorPeerLocatorActor.default_uid(), address=address, ), mo.create_actor( NodeInfoCollectorActor, uid=NodeInfoCollectorActor.default_uid(), address=address, ), mo.create_actor( NodeAllocatorActor, "fixed", address, uid=NodeAllocatorActor.default_uid(), address=address, ), mo.create_actor( NodeInfoUploaderActor, NodeRole.WORKER, interval=kw.get("upload_interval"), band_to_resource=kw.get("band_to_resource"), use_gpu=kw.get("use_gpu", False), uid=NodeInfoUploaderActor.default_uid(), address=address, ), mo.create_actor( ProcessInfoManagerActor, uid=ProcessInfoManagerActor.default_uid(), address=address, ), ] dones, _ = await asyncio.wait( [asyncio.ensure_future(coro) for coro in create_actor_coros] ) for task in dones: try: task.result() except mo.ActorAlreadyExist: # pragma: no cover pass api = await super().create(address=address) await api.mark_node_ready() return api @classmethod async def cleanup(cls, address: str): from ..supervisor.locator import SupervisorPeerLocatorActor from ..uploader import NodeInfoUploaderActor from ..supervisor.node_info import NodeInfoCollectorActor await asyncio.gather( mo.destroy_actor( mo.create_actor_ref( uid=SupervisorPeerLocatorActor.default_uid(), address=address ) ), mo.destroy_actor( mo.create_actor_ref( uid=NodeInfoCollectorActor.default_uid(), address=address ) ), mo.destroy_actor( mo.create_actor_ref( uid=NodeInfoUploaderActor.default_uid(), address=address ) ), )