Source code for grunnur._array_metadata
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from .dtypes import _normalize_type
if TYPE_CHECKING: # pragma: no cover
import numpy
from numpy.typing import DTypeLike, NDArray
from ._array import Array
[docs]
class AsArrayMetadata(ABC):
"""An abstract class for any object allowing conversion to :py:class:`ArrayMetadata`."""
[docs]
@abstractmethod
def as_array_metadata(self) -> ArrayMetadata:
"""Returns array metadata representing this object."""
...
[docs]
class ArrayMetadata(AsArrayMetadata):
"""
A helper object for array-like classes that handles shape/strides/buffer size checks
without actual data attached to it.
"""
shape: tuple[int, ...]
"""Array shape."""
dtype: numpy.dtype[Any]
"""Array item data type."""
strides: tuple[int, ...]
"""Array strides."""
buffer_size: int
"""The size of the buffer this array resides in."""
span: int
"""The minimum size of the buffer that fits all the elements described by this metadata."""
min_offset: int
"""The minimum offset of an array element described by this metadata."""
first_element_offset: int
"""The offset of the first element (that is, the one with the all indices equal to 0)."""
is_contiguous: bool
"""If ``True``, means that array's data forms a continuous chunk of memory."""
@classmethod
def from_arraylike(cls, array_like: AsArrayMetadata | NDArray[Any]) -> ArrayMetadata:
if isinstance(array_like, AsArrayMetadata):
return array_like.as_array_metadata()
return cls(shape=array_like.shape, dtype=array_like.dtype, strides=array_like.strides)
def __init__(
self,
shape: Iterable[int] | int,
dtype: DTypeLike,
*,
strides: Iterable[int] | None = None,
first_element_offset: int | None = None,
buffer_size: int | None = None,
):
shape = tuple(shape) if isinstance(shape, Iterable) else (shape,)
if len(shape) == 0:
raise ValueError("Array shape cannot be an empty sequence")
dtype = _normalize_type(dtype)
default_strides = _get_strides(shape, dtype.itemsize)
strides = default_strides if strides is None else tuple(strides)
self.shape = shape
self.dtype = dtype
self.strides = strides
# Note that these are minimum and maximum offsets
# when the first element offset is 0.
min_offset, max_offset = _get_range(shape, dtype.itemsize, strides)
self.span = max_offset - min_offset
if first_element_offset is None:
first_element_offset = -min_offset
elif first_element_offset < -min_offset:
raise ValueError(f"First element offset is smaller than the minimum {-min_offset}")
self.first_element_offset = first_element_offset
self.min_offset = first_element_offset + min_offset
min_buffer_size = self.first_element_offset + max_offset
if buffer_size is None:
buffer_size = min_buffer_size
elif buffer_size < min_buffer_size:
raise ValueError(f"Buffer size is smaller than the minimum {min_buffer_size}")
self.buffer_size = buffer_size
# Technically, an array with non-default (e.g., overlapping) strides
# can be contioguous, but that's too hard to determine.
self.is_contiguous = strides == default_strides
self._default_strides = strides == default_strides
[docs]
def with_(self, dtype: DTypeLike | None = None) -> ArrayMetadata:
"""Replaces a property of the metadata and returns a new metadata object."""
return ArrayMetadata(
shape=self.shape,
dtype=self.dtype if dtype is None else dtype,
strides=self.strides,
first_element_offset=self.first_element_offset,
buffer_size=self.buffer_size,
)
def _basis(self) -> tuple[numpy.dtype[Any], tuple[int, ...], tuple[int, ...], int, int]:
return (
self.dtype,
self.shape,
self.strides,
self.first_element_offset,
self.buffer_size,
)
def __eq__(self, other: object) -> bool:
return isinstance(other, ArrayMetadata) and self._basis() == other._basis()
def __hash__(self) -> int:
return hash((type(self), self._basis()))
[docs]
def get_sub_region(self, origin: int, size: int) -> ArrayMetadata:
"""
Returns the same metadata shape-wise, but for the given subregion
of the original buffer.
"""
# The size errors will be checked by ArrayMetadata constructor
return ArrayMetadata(
shape=self.shape,
dtype=self.dtype,
strides=self.strides,
first_element_offset=self.first_element_offset - origin,
buffer_size=size,
)
[docs]
def __getitem__(self, slices: slice | tuple[slice, ...]) -> ArrayMetadata:
"""
Returns the view of this metadata with the given ranges,
with the offsets and buffer size corresponding to the original buffer.
"""
if isinstance(slices, slice):
slices = (slices,)
if len(slices) < len(self.shape):
slices += (slice(None),) * (len(self.shape) - len(slices))
offset, new_shape, new_strides = _get_view(self.shape, self.strides, slices)
return ArrayMetadata(
shape=new_shape,
dtype=self.dtype,
strides=new_strides,
first_element_offset=self.first_element_offset + offset,
buffer_size=self.buffer_size,
)
def __repr__(self) -> str:
args = [f"dtype={self.dtype}", f"shape={self.shape}"]
if not self._default_strides:
args.append(f"strides={self.strides}")
if self.first_element_offset != 0:
args.append(f"first_element_offset={self.first_element_offset}")
if self.buffer_size != self.min_offset + self.span:
args.append(f"buffer_size={self.buffer_size}")
args_str = ", ".join(args)
return f"ArrayMetadata({args_str})"
def _get_strides(shape: tuple[int, ...], itemsize: int) -> tuple[int, ...]:
# Constructs strides for a contiguous array of shape ``shape`` and item size ``itemsize``.
strides = []
stride = itemsize
for length in reversed(shape):
strides.append(stride)
stride *= length
return tuple(reversed(strides))
def _normalize_slice(length: int, stride: int, slice_: slice) -> tuple[int, int, int]:
"""
Given a slice over an array of length ``length`` with the stride ``stride`` between elements,
return a tuple ``(offset, last, stride)`` where ``offset`` is the offset of the first
element of the resulting view, ``last`` is the index of the last element of the view (0-based),
and ``stride`` is the new stride between elements.
"""
start, stop, step = slice_.indices(length)
offset = start * stride
total_elems = abs(stop - start)
length = (total_elems - 1) // abs(step) + 1
new_stride = stride * step
return offset, length, new_stride
def _get_view(
shape: tuple[int, ...], strides: tuple[int, ...], slices: tuple[slice, ...]
) -> tuple[int, tuple[int, ...], tuple[int, ...]]:
"""
Given an array shape and strides, and a sequence of slices defining a view,
returns a tuple of three elements: the offset of the first element of the view,
the view shape and the view strides.
"""
if len(strides) != len(shape):
raise ValueError("Shape and strides must have the same length")
if len(slices) != len(shape):
raise ValueError("Shape and slices must have the same length")
offsets, lengths, strides = zip(
*[
_normalize_slice(length, stride, slice_)
for length, stride, slice_ in zip(shape, strides, slices, strict=True)
],
strict=True,
)
return sum(offsets), tuple(lengths), tuple(strides)
def _get_range(shape: tuple[int, ...], itemsize: int, strides: tuple[int, ...]) -> tuple[int, int]:
"""
Given an array shape, item size (in bytes), and a sequence of strides,
returns a pair ``(min_offset, max_offset)``,
where ``min_offset`` is the minimum byte offset of an array element,
and ``max_offset`` is the maximum byte offset of an array element plus itemsize.
"""
if len(strides) != len(shape):
raise ValueError("Shape and strides must have the same length")
# Now the address of an element (i1, i2, ...) of the resulting view is
# addr = i1 * stride1 + i2 * stride2 + ...,
# where 0 <= i_k <= length_k - 1
# We want to find the minimum and the maximum value of addr,
# keeping in mind that strides may be negative.
# Since it is a linear function of each index, the extrema will be located
# at the ends of intervals, so we can find minima and maxima for each term separately.
# Since we separated the offsets already, for each dimension the address
# of the first element is 0. We calculate the address of the last byte in each dimension.
last_addrs = [(length - 1) * stride for length, stride in zip(shape, strides, strict=True)]
# Sort the pairs (0, last_addr)
pairs = [(0, last_addr) if last_addr > 0 else (last_addr, 0) for last_addr in last_addrs]
minima, maxima = zip(*pairs, strict=True)
min_offset = sum(minima)
max_offset = sum(maxima) + itemsize
return min_offset, max_offset