Source code for grunnur.context

from collections.abc import Sequence
from typing import Any, Optional, Union, Iterable, Iterator, Tuple, Sequence, overload

from .adapter_base import DeviceAdapter, ContextAdapter, DeviceAdapter
from .api import API, cuda_api_id
from .device import Device, DeviceFilter
from .device_discovery import select_devices
from .platform import Platform, PlatformFilter
from .utils import normalize_object_sequence, all_same


[docs]class BoundDevice(Device): """ A :py:class:`~grunnur.Device` object in a :py:class:`~grunnur.Context`. """ context: "Context" """The context this device belongs to.""" def __init__(self, context: "Context", device_adapter: DeviceAdapter): super().__init__(device_adapter) self.context = context # A proper hashing would require `Context` to be hashable too, # but `BoundDevice` objects are only ever used in small collections # and with all the device indices being different. # If somehow there's a hash collision, it will be taken care of by ``__eq__``. self._hash = hash(device_adapter) def as_unbound(self) -> "Device": """ :meta private: Returns the unbound :py:class:`Device` object. """ return Device(self._device_adapter) def __eq__(self, other: Any) -> bool: return type(self) == type(other) and self.context == other.context and super().__eq__(other) def __hash__(self) -> int: return self._hash def __str__(self) -> str: return super().__str__() + " in " + str(self.context)
[docs]class BoundMultiDevice(Sequence[BoundDevice]): """ A sequence of bound devices belonging to the same context. """ context: "Context" """The context these devices belong to."""
[docs] @classmethod def from_bound_devices(cls, devices: Sequence[BoundDevice]) -> "BoundMultiDevice": """ Creates this object from a sequence of bound devices (note that a ``BoundMultiDevice`` object itself can serve as such a sequence). """ if not all_same(device.context for device in devices): raise ValueError("All devices in a multi-device must belong to the same context") if len(set(devices)) != len(devices): raise ValueError("All devices in a multi-device must be distinct") return cls(devices[0].context, [device._device_adapter for device in devices])
def __init__(self, context: "Context", device_adapters: Sequence[DeviceAdapter]): self.context = context self._devices = [BoundDevice(context, device_adapter) for device_adapter in device_adapters] self._devices_as_set = set(self._devices) def __eq__(self, other: Any) -> bool: return ( type(self) == type(other) and self.context == other.context and self._devices == other._devices ) def issubset(self, devices: "BoundMultiDevice") -> bool: return self._devices_as_set.issubset(devices._devices_as_set) def __iter__(self) -> Iterator[BoundDevice]: return iter(self._devices) @overload def __getitem__(self, idx: Union[int]) -> BoundDevice: ... @overload def __getitem__(self, idx: Union[slice, Iterable[int]]) -> "BoundMultiDevice": ...
[docs] def __getitem__( self, idx: Union[int, slice, Iterable[int]] ) -> Union[BoundDevice, "BoundMultiDevice"]: """ Given a single index, returns a single :py:class:`BoundDevice`. Given a sequence of indices, returns a :py:class:`BoundMultiDevice` object containing respective devices. The indices correspond to the list of devices used to create this context. """ if isinstance(idx, Iterable): return BoundMultiDevice.from_bound_devices([self._devices[i] for i in idx]) elif isinstance(idx, slice): return BoundMultiDevice.from_bound_devices(self._devices[idx]) else: return self._devices[idx]
def __len__(self) -> int: return len(self._devices)
[docs]class Context: """ GPGPU context. """ platform: Platform """The platform this context is based on.""" api: API """The API this context is based on."""
[docs] @classmethod def from_devices(cls, devices: Sequence["Device"]) -> "Context": """ Creates a context from a device or an iterable of devices. :param devices: one or several devices to use. """ devices = normalize_object_sequence(devices, Device) platforms = [device.platform for device in devices] if not all_same(platforms): raise ValueError("All devices must belong to the same platform") platform = platforms[0] device_adapters = [device._device_adapter for device in devices] api_adapter = platform.api._api_adapter context_adapter = api_adapter.make_context_adapter_from_device_adapters(device_adapters) return cls(context_adapter)
[docs] @classmethod def from_backend_devices(cls, backend_devices: Sequence[Any]) -> "Context": """ Creates a context from a single or several backend device objects. """ devices = [Device.from_backend_device(backend_device) for backend_device in backend_devices] return cls.from_devices(devices)
[docs] @classmethod def from_backend_contexts( cls, backend_contexts: Sequence[Any], take_ownership: bool = False ) -> "Context": """ Creates a context from a single or several backend device contexts. If ``take_ownership`` is ``True``, this object will be responsible for the lifetime of backend context objects (only important for the CUDA backend). """ for api in API.all_available(): if api._api_adapter.isa_backend_context(backend_contexts[0]): context_adapter = api._api_adapter.make_context_adapter_from_backend_contexts( backend_contexts, take_ownership=take_ownership ) return cls(context_adapter) raise TypeError( f"{type(backend_contexts[0])} objects were not recognized as contexts by any API" )
[docs] @classmethod def from_criteria( cls, api: "API", interactive: bool = False, devices_num: Optional[int] = 1, device_filter: Optional[DeviceFilter] = None, platform_filter: Optional[PlatformFilter] = None, ) -> "Context": """ Finds devices matching the given criteria and creates a :py:class:`Context` object out of them. :param interactive: passed to :py:func:`select_devices`. :param devices_num: passed to :py:func:`select_devices` as ``quantity``. :param device_filters: passed to :py:func:`select_devices`. """ devices = select_devices( api, interactive=interactive, quantity=devices_num, device_filter=device_filter, platform_filter=platform_filter, ) return cls.from_devices(devices)
def __init__(self, context_adapter: ContextAdapter): self._context_adapter = context_adapter self._device_adapters = context_adapter.device_adapters self.platform = Platform(next(iter(self._device_adapters.values())).platform_adapter) self.api = self.platform.api @property def devices(self) -> "BoundMultiDevice": """ Returns the :py:class:`~grunnur.context.BoundMultiDevice` encompassing all the devices in this context. """ # Need to create it on-demand to avoid a circular reference. device_adapters = [ self._device_adapters[device_idx] for device_idx in self._context_adapter.device_order ] return BoundMultiDevice(self, device_adapters) @property def device(self) -> "BoundDevice": if len(self._device_adapters) > 1: raise RuntimeError("The `device` shortcut only works for single-device contexts") return self.devices[0]
[docs] def deactivate(self) -> None: """ For CUDA API: deactivates this context, popping all the CUDA context objects from the stack. Other APIs: no effect. Only call it if you need to manage CUDA contexts manually, and created this object with `take_ownership = False`. If `take_ownership = True` contexts will be deactivated automatically in the destructor. """ self._context_adapter.deactivate()