# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from typing import List, Dict, Optional, Set, Type, TypeVar
from .... import oscar as mo
from ....lib.aio import alru_cache
from ....resource import Resource
from ....typing import BandType
from ...core import NodeRole
from ..core import (
watch_method,
NodeStatus,
WorkerSlotInfo,
QuotaInfo,
DiskInfo,
StorageInfo,
)
from .core import AbstractClusterAPI
APIType = TypeVar("APIType", bound="ClusterAPI")
logger = logging.getLogger(__name__)
[docs]class ClusterAPI(AbstractClusterAPI):
[docs] def __init__(self, address: str):
self._address = address
self._locator_ref = None
self._uploader_ref = None
async def _init(self):
from ..locator import SupervisorLocatorActor
from ..uploader import NodeInfoUploaderActor
self._locator_ref = await mo.actor_ref(
SupervisorLocatorActor.default_uid(), address=self._address
)
self._uploader_ref = await mo.actor_ref(
NodeInfoUploaderActor.default_uid(), address=self._address
)
@classmethod
@alru_cache(cache_exceptions=False)
async def create(cls: Type[APIType], address: str) -> APIType:
api_obj = cls(address)
await api_obj._init()
return api_obj
@alru_cache(cache_exceptions=False)
async def _get_node_info_ref(self):
from ..supervisor.node_info import NodeInfoCollectorActor
[node_info_ref] = await self.get_supervisor_refs(
[NodeInfoCollectorActor.default_uid()]
)
return node_info_ref
async def get_supervisors(self, filter_ready: bool = True) -> List[str]:
return await self._locator_ref.get_supervisors(filter_ready=filter_ready)
@watch_method
async def watch_supervisors(self, version: Optional[int] = None):
return await self._locator_ref.watch_supervisors(version=version)
async def get_supervisors_by_keys(self, keys: List[str]) -> List[str]:
"""
Get supervisor address hosting the specified key
Parameters
----------
keys
key for a supervisor address
Returns
-------
out
addresses of the supervisor
"""
get_supervisor = self._locator_ref.get_supervisor
return await get_supervisor.batch(*(get_supervisor.delay(k) for k in keys))
@watch_method
async def watch_supervisors_by_keys(
self, keys: List[str], version: Optional[int] = None
):
return await self._locator_ref.watch_supervisors_by_keys(keys, version=version)
async def get_supervisor_refs(self, uids: List[str]) -> List[mo.ActorRef]:
"""
Get actor references hosting the specified actor uid
Parameters
----------
uids
uids for a supervisor address
watch
if True, will watch changes of supervisor changes
Returns
-------
out : List[mo.ActorRef]
references of the actors
"""
addrs = await self.get_supervisors_by_keys(uids)
if any(addr is None for addr in addrs):
none_uid = next(uid for addr, uid in zip(addrs, uids) if addr is None)
raise mo.ActorNotExist(f"Actor {none_uid} not exist as no supervisors")
return await asyncio.gather(
*[mo.actor_ref(uid, address=addr) for addr, uid in zip(addrs, uids)]
)
async def watch_supervisor_refs(self, uids: List[str]):
async for addrs in self.watch_supervisors_by_keys(uids):
yield await asyncio.gather(
*[mo.actor_ref(uid, address=addr) for addr, uid in zip(addrs, uids)]
)
@watch_method
async def watch_nodes(
self,
role: NodeRole,
env: bool = False,
resource: bool = False,
detail: bool = False,
version: Optional[int] = None,
statuses: Set[NodeStatus] = None,
exclude_statuses: Set[NodeStatus] = None,
) -> List[Dict[str, Dict]]:
statuses = self._calc_statuses(statuses, exclude_statuses)
node_info_ref = await self._get_node_info_ref()
return await node_info_ref.watch_nodes(
role,
env=env,
resource=resource,
detail=detail,
statuses=statuses,
version=version,
)
async def get_nodes_info(
self,
nodes: List[str] = None,
role: NodeRole = None,
env: bool = False,
resource: bool = False,
detail: bool = False,
statuses: Set[NodeStatus] = None,
exclude_statuses: Set[NodeStatus] = None,
) -> Dict[str, Dict]:
statuses = self._calc_statuses(statuses, exclude_statuses)
node_info_ref = await self._get_node_info_ref()
return await node_info_ref.get_nodes_info(
nodes=nodes,
role=role,
env=env,
resource=resource,
detail=detail,
statuses=statuses,
)
async def set_node_status(self, node: str, role: NodeRole, status: NodeStatus):
"""
Set status of node
Parameters
----------
node : str
address of node
role: NodeRole
role of node
status : NodeStatus
status of node
"""
node_info_ref = await self._get_node_info_ref()
await node_info_ref.update_node_info(node, role, status=status)
async def get_all_bands(
self,
role: NodeRole = None,
statuses: Set[NodeStatus] = None,
exclude_statuses: Set[NodeStatus] = None,
) -> Dict[BandType, Resource]:
statuses = self._calc_statuses(statuses, exclude_statuses)
node_info_ref = await self._get_node_info_ref()
return await node_info_ref.get_all_bands(role, statuses=statuses)
@watch_method
async def watch_all_bands(
self,
role: NodeRole = None,
version: Optional[int] = None,
statuses: Set[NodeStatus] = None,
exclude_statuses: Set[NodeStatus] = None,
):
statuses = self._calc_statuses(statuses, exclude_statuses)
node_info_ref = await self._get_node_info_ref()
return await node_info_ref.watch_all_bands(
role, statuses=statuses, version=version
)
async def get_mars_versions(self) -> List[str]:
node_info_ref = await self._get_node_info_ref()
return await node_info_ref.get_mars_versions()
async def get_bands(self) -> Dict:
"""
Get bands that can be used for computation on current node.
Returns
-------
band_to_resource : dict
Band to resource.
"""
return await self._uploader_ref.get_bands()
async def mark_node_ready(self):
"""
Mark current node ready for work loads
"""
await self._uploader_ref.mark_node_ready()
async def wait_node_ready(self):
"""
Wait current node to be ready
"""
await self._uploader_ref.wait_node_ready()
async def wait_all_supervisors_ready(self):
"""
Wait till all expected supervisors are ready
"""
await self._locator_ref.wait_all_supervisors_ready()
async def set_band_slot_infos(
self, band_name: str, slot_infos: List[WorkerSlotInfo]
):
await self._uploader_ref.set_band_slot_infos.tell(band_name, slot_infos)
async def set_band_quota_info(self, band_name: str, quota_info: QuotaInfo):
await self._uploader_ref.set_band_quota_info.tell(band_name, quota_info)
async def set_node_disk_info(self, disk_info: List[DiskInfo]):
await self._uploader_ref.set_node_disk_info(disk_info)
@mo.extensible
async def set_band_storage_info(self, band_name: str, storage_info: StorageInfo):
await self._uploader_ref.set_band_storage_info(band_name, storage_info)
async def request_worker(
self, worker_cpu: int = None, worker_mem: int = None, timeout: int = None
) -> str:
node_allocator_ref = await self._get_node_allocator_ref()
address = await node_allocator_ref.request_worker(
worker_cpu, worker_mem, timeout
)
return address
async def release_worker(self, address: str):
node_allocator_ref = await self._get_node_allocator_ref()
await node_allocator_ref.release_worker(address)
node_info_ref = await self._get_node_info_ref()
await node_info_ref.update_node_info(
address, NodeRole.WORKER, status=NodeStatus.STOPPED
)
async def reconstruct_worker(self, address: str):
node_allocator_ref = await self._get_node_allocator_ref()
await node_allocator_ref.reconstruct_worker(address)
@alru_cache(cache_exceptions=False)
async def _get_node_allocator_ref(self):
from ..supervisor.node_allocator import NodeAllocatorActor
[node_allocator_ref] = await self.get_supervisor_refs(
[NodeAllocatorActor.default_uid()]
)
return node_allocator_ref
async def _get_process_info_manager_ref(self, address: str = None):
from ..procinfo import ProcessInfoManagerActor
return await mo.actor_ref(
ProcessInfoManagerActor.default_uid(), address=address or self._address
)
async def get_node_pool_configs(self, address: str = None) -> List[Dict]:
ref = await self._get_process_info_manager_ref(address)
return await ref.get_pool_configs()
async def get_node_thread_stacks(
self, address: str = None
) -> List[Dict[int, List[str]]]:
ref = await self._get_process_info_manager_ref(address)
return await ref.get_thread_stacks()
class MockClusterAPI(ClusterAPI):
@classmethod
async def create(cls: Type[APIType], address: str, **kw) -> APIType:
from ..procinfo import ProcessInfoManagerActor
from ..supervisor.locator import SupervisorPeerLocatorActor
from ..supervisor.node_allocator import NodeAllocatorActor
from ..supervisor.node_info import NodeInfoCollectorActor
from ..uploader import NodeInfoUploaderActor
create_actor_coros = [
mo.create_actor(
SupervisorPeerLocatorActor,
"fixed",
address,
uid=SupervisorPeerLocatorActor.default_uid(),
address=address,
),
mo.create_actor(
NodeInfoCollectorActor,
uid=NodeInfoCollectorActor.default_uid(),
address=address,
),
mo.create_actor(
NodeAllocatorActor,
"fixed",
address,
uid=NodeAllocatorActor.default_uid(),
address=address,
),
mo.create_actor(
NodeInfoUploaderActor,
NodeRole.WORKER,
interval=kw.get("upload_interval"),
band_to_resource=kw.get("band_to_resource"),
use_gpu=kw.get("use_gpu", False),
uid=NodeInfoUploaderActor.default_uid(),
address=address,
),
mo.create_actor(
ProcessInfoManagerActor,
uid=ProcessInfoManagerActor.default_uid(),
address=address,
),
]
dones, _ = await asyncio.wait(
[asyncio.ensure_future(coro) for coro in create_actor_coros]
)
for task in dones:
try:
task.result()
except mo.ActorAlreadyExist: # pragma: no cover
pass
api = await super().create(address=address)
await api.mark_node_ready()
return api
@classmethod
async def cleanup(cls, address: str):
from ..supervisor.locator import SupervisorPeerLocatorActor
from ..uploader import NodeInfoUploaderActor
from ..supervisor.node_info import NodeInfoCollectorActor
await asyncio.gather(
mo.destroy_actor(
mo.create_actor_ref(
uid=SupervisorPeerLocatorActor.default_uid(), address=address
)
),
mo.destroy_actor(
mo.create_actor_ref(
uid=NodeInfoCollectorActor.default_uid(), address=address
)
),
mo.destroy_actor(
mo.create_actor_ref(
uid=NodeInfoUploaderActor.default_uid(), address=address
)
),
)