Source code for mars.services.lifecycle.api.oscar
# Copyright 1999-2021 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from .... import oscar as mo
from ....lib.aio import alru_cache
from ..supervisor.tracker import LifecycleTrackerActor
from .core import AbstractLifecycleAPI
[docs]class LifecycleAPI(AbstractLifecycleAPI):
[docs] def __init__(
self,
session_id: str,
lifecycle_tracker_ref: mo.ActorRefType[LifecycleTrackerActor],
):
self._session_id = session_id
self._lifecycle_tracker_ref = lifecycle_tracker_ref
@classmethod
@alru_cache(cache_exceptions=False)
async def create(cls, session_id: str, address: str) -> "LifecycleAPI":
"""
Create Lifecycle API.
Parameters
----------
session_id : str
Session ID.
address : str
Supervisor address.
Returns
-------
lifecycle_api
Lifecycle API.
"""
lifecycle_tracker_ref = await mo.actor_ref(
address, LifecycleTrackerActor.gen_uid(session_id)
)
return LifecycleAPI(session_id, lifecycle_tracker_ref)
@mo.extensible
async def track(self, tileable_key: str, chunk_keys: List[str]):
"""
Track tileable.
Parameters
----------
tileable_key : str
Tileable key.
chunk_keys : list
List of chunk keys.
"""
return await self._lifecycle_tracker_ref.track(tileable_key, chunk_keys)
@track.batch
async def batch_track(self, args_list, kwargs_list):
tracks = []
for args, kwargs in zip(args_list, kwargs_list):
tracks.append(self._lifecycle_tracker_ref.track.delay(*args, **kwargs))
return await self._lifecycle_tracker_ref.track.batch(*tracks)
async def incref_tileables(
self, tileable_keys: List[str], counts: List[int] = None
):
"""
Incref tileables.
Parameters
----------
tileable_keys : list
List of tileable keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.incref_tileables(
tileable_keys, counts=counts
)
async def decref_tileables(
self, tileable_keys: List[str], counts: List[int] = None
):
"""
Decref tileables.
Parameters
----------
tileable_keys : list
List of tileable keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.decref_tileables(tileable_keys)
async def get_tileable_ref_counts(self, tileable_keys: List[str]) -> List[int]:
"""
Get ref counts of tileables.
Parameters
----------
tileable_keys : list
List of tileable keys.
Returns
-------
ref_counts : list
List of ref counts.
"""
return await self._lifecycle_tracker_ref.get_tileable_ref_counts(tileable_keys)
async def incref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
"""
Incref chunks.
Parameters
----------
chunk_keys : list
List of chunk keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.incref_chunks(
chunk_keys, counts=counts
)
async def decref_chunks(self, chunk_keys: List[str], counts: List[int] = None):
"""
Decref chunks
Parameters
----------
chunk_keys : list
List of chunk keys.
counts: list
List of ref count.
"""
return await self._lifecycle_tracker_ref.decref_chunks(
chunk_keys, counts=counts
)
async def get_chunk_ref_counts(self, chunk_keys: List[str]) -> List[int]:
"""
Get ref counts of chunks.
Parameters
----------
chunk_keys : list
List of chunk keys.
Returns
-------
ref_counts : list
List of ref counts.
"""
return await self._lifecycle_tracker_ref.get_chunk_ref_counts(chunk_keys)
async def get_all_chunk_ref_counts(self) -> Dict[str, int]:
"""
Get all chunk keys' ref counts.
Returns
-------
key_to_ref_counts: dict
"""
return await self._lifecycle_tracker_ref.get_all_chunk_ref_counts()
class MockLifecycleAPI(LifecycleAPI):
@classmethod
async def create(cls, session_id: str, address: str) -> "LifecycleAPI":
from ..supervisor.service import LifecycleSupervisorService
service = LifecycleSupervisorService({}, address)
await service.create_session(session_id)
return await super().create(session_id=session_id, address=address)