Source code for grunnur._context

from __future__ import annotations

from collections.abc import Iterable, Iterator, Sequence
from typing import TYPE_CHECKING, Any, overload

from ._api import API, cuda_api_id
from ._device import Device, DeviceFilter
from ._device_discovery import select_devices
from ._platform import Platform, PlatformFilter
from ._utils import all_same, normalize_object_sequence

if TYPE_CHECKING:  # pragma: no cover
    from ._adapter_base import ContextAdapter, DeviceAdapter


[docs] class BoundDevice(Device): """A :py:class:`~grunnur.Device` object in a :py:class:`~grunnur.Context`.""" context: Context """The context this device belongs to.""" def __init__(self, context: Context, device_adapter: DeviceAdapter): super().__init__(device_adapter) self.context = context # A proper hashing would require `Context` to be hashable too, # but `BoundDevice` objects are only ever used in small collections # and with all the device indices being different. # If somehow there's a hash collision, it will be taken care of by ``__eq__``. self._hash = hash(device_adapter) def as_unbound(self) -> Device: """ :meta private: Returns the unbound :py:class:`Device` object. """ return Device(self._device_adapter) def __eq__(self, other: object) -> bool: return ( isinstance(other, BoundDevice) and self.context == other.context and super().__eq__(other) ) def __hash__(self) -> int: return self._hash def __str__(self) -> str: return super().__str__() + " in " + str(self.context)
[docs] class BoundMultiDevice(Sequence[BoundDevice]): """A sequence of bound devices belonging to the same context.""" context: Context """The context these devices belong to."""
[docs] @classmethod def from_bound_devices(cls, devices: Sequence[BoundDevice]) -> BoundMultiDevice: """ Creates this object from a sequence of bound devices (note that a ``BoundMultiDevice`` object itself can serve as such a sequence). """ if not all_same(device.context for device in devices): raise ValueError("All devices in a multi-device must belong to the same context") if len(set(devices)) != len(devices): raise ValueError("All devices in a multi-device must be distinct") return cls(devices[0].context, [device._device_adapter for device in devices]) # noqa: SLF001
def __init__(self, context: Context, device_adapters: Sequence[DeviceAdapter]): self.context = context self._devices = [BoundDevice(context, device_adapter) for device_adapter in device_adapters] self._devices_as_set = set(self._devices) def __eq__(self, other: object) -> bool: return ( isinstance(other, BoundMultiDevice) and self.context == other.context and self._devices == other._devices ) def issubset(self, devices: BoundMultiDevice) -> bool: return self._devices_as_set.issubset(devices._devices_as_set) def __iter__(self) -> Iterator[BoundDevice]: return iter(self._devices) @overload def __getitem__(self, idx: int) -> BoundDevice: ... @overload def __getitem__(self, idx: slice | Iterable[int]) -> BoundMultiDevice: ...
[docs] def __getitem__(self, idx: int | slice | Iterable[int]) -> BoundDevice | BoundMultiDevice: """ Given a single index, returns a single :py:class:`BoundDevice`. Given a sequence of indices, returns a :py:class:`BoundMultiDevice` object containing respective devices. The indices correspond to the list of devices used to create this context. """ if isinstance(idx, Iterable): return BoundMultiDevice.from_bound_devices([self._devices[i] for i in idx]) if isinstance(idx, slice): return BoundMultiDevice.from_bound_devices(self._devices[idx]) return self._devices[idx]
def __len__(self) -> int: return len(self._devices)
[docs] class Context: """GPGPU context.""" platform: Platform """The platform this context is based on.""" api: API """The API this context is based on."""
[docs] @classmethod def from_devices(cls, devices: Sequence[Device]) -> Context: """ Creates a context from a device or an iterable of devices. :param devices: one or several devices to use. """ devices = normalize_object_sequence(devices, Device) platforms = [device.platform for device in devices] if not all_same(platforms): raise ValueError("All devices must belong to the same platform") platform = platforms[0] device_adapters = [device._device_adapter for device in devices] # noqa: SLF001 api_adapter = platform.api._api_adapter # noqa: SLF001 context_adapter = api_adapter.make_context_adapter_from_device_adapters(device_adapters) return cls(context_adapter)
[docs] @classmethod def from_backend_devices(cls, backend_devices: Sequence[Any]) -> Context: """Creates a context from a single or several backend device objects.""" devices = [Device.from_backend_device(backend_device) for backend_device in backend_devices] return cls.from_devices(devices)
[docs] @classmethod def from_backend_contexts( cls, backend_contexts: Sequence[Any], *, take_ownership: bool = False ) -> Context: """ Creates a context from a single or several backend device contexts. If ``take_ownership`` is ``True``, this object will be responsible for the lifetime of backend context objects (only important for the CUDA backend). """ for api in API.all_available(): if api._api_adapter.isa_backend_context(backend_contexts[0]): # noqa: SLF001 context_adapter = api._api_adapter.make_context_adapter_from_backend_contexts( # noqa: SLF001 backend_contexts, take_ownership=take_ownership ) return cls(context_adapter) raise TypeError( f"{type(backend_contexts[0])} objects were not recognized as contexts by any API" )
[docs] @classmethod def from_criteria( cls, api: API, *, interactive: bool = False, devices_num: int | None = 1, device_filter: DeviceFilter | None = None, platform_filter: PlatformFilter | None = None, ) -> Context: """ Finds devices matching the given criteria and creates a :py:class:`Context` object out of them. :param interactive: passed to :py:func:`select_devices`. :param devices_num: passed to :py:func:`select_devices` as ``quantity``. :param device_filters: passed to :py:func:`select_devices`. """ devices = select_devices( api, interactive=interactive, quantity=devices_num, device_filter=device_filter, platform_filter=platform_filter, ) return cls.from_devices(devices)
def __init__(self, context_adapter: ContextAdapter): self._context_adapter = context_adapter self._device_adapters = context_adapter.device_adapters self.platform = Platform(next(iter(self._device_adapters.values())).platform_adapter) self.api = self.platform.api @property def devices(self) -> BoundMultiDevice: """ Returns the :py:class:`~grunnur._context.BoundMultiDevice` encompassing all the devices in this context. """ # Need to create it on-demand to avoid a circular reference. device_adapters = [ self._device_adapters[device_idx] for device_idx in self._context_adapter.device_order ] return BoundMultiDevice(self, device_adapters) @property def device(self) -> BoundDevice: if len(self._device_adapters) > 1: raise RuntimeError("The `device` shortcut only works for single-device contexts") return self.devices[0]
[docs] def deactivate(self) -> None: """ For CUDA API: deactivates this context, popping all the CUDA context objects from the stack. Other APIs: no effect. Only call it if you need to manage CUDA contexts manually, and created this object with `take_ownership = False`. If `take_ownership = True` contexts will be deactivated automatically in the destructor. """ self._context_adapter.deactivate()