Source code for grunnur.array_metadata

from typing import Any, Protocol, Tuple, Optional, Sequence, Union, runtime_checkable

import numpy
from numpy.typing import DTypeLike

from .dtypes import _normalize_type


[docs]@runtime_checkable class ArrayMetadataLike(Protocol): """ A protocol for an object providing array metadata. :py:class:`numpy.ndarray` or :py:class:`Array` follow this protocol. """ @property def shape(self) -> Tuple[int, ...]: """ Array shape. """ @property def dtype(self) -> "numpy.dtype[Any]": """ The type of an array element. """
class ArrayMetadata: """ A helper object for array-like classes that handles shape/strides/buffer size checks without actual data attached to it. """ shape: Tuple[int, ...] """Array shape.""" dtype: "numpy.dtype[Any]" """Array item data type.""" strides: Tuple[int, ...] """Array strides.""" is_contiguous: bool """If ``True``, means that array's data forms a continuous chunk of memory.""" @classmethod def from_arraylike(cls, array: ArrayMetadataLike) -> "ArrayMetadata": return cls(array.shape, array.dtype, strides=getattr(array, "strides", None)) def __init__( self, shape: Sequence[int], dtype: DTypeLike, strides: Optional[Sequence[int]] = None, first_element_offset: int = 0, buffer_size: Optional[int] = None, ): shape = tuple(shape) dtype = _normalize_type(dtype) default_strides = get_strides(shape, dtype.itemsize) if strides is None: strides = default_strides self.contiguous = True else: strides = tuple(strides) # Technically, an array with non-default (e.g., overlapping) strides # can be contioguous, but that's too hard to determine. self.contiguous = strides == default_strides min_offset, max_offset = get_range(shape, dtype.itemsize, strides) if buffer_size is None: buffer_size = first_element_offset + max_offset full_min_offset = first_element_offset + min_offset if full_min_offset < 0 or full_min_offset + dtype.itemsize > buffer_size: raise ValueError( f"The minimum offset for given strides ({full_min_offset}) " f"is outside the given buffer range ({buffer_size})" ) full_max_offset = first_element_offset + max_offset if full_max_offset > buffer_size: raise ValueError( f"The maximum offset for given strides ({full_max_offset}) " f"is outside the given buffer range ({buffer_size})" ) self.shape = shape self.dtype = dtype self.strides = strides self.first_element_offset = first_element_offset self._full_min_offset = full_min_offset self._full_max_offset = full_max_offset self.buffer_size = buffer_size def minimal_subregion(self) -> Tuple[int, int, "ArrayMetadata"]: """ Returns the metadata for the minimal subregion that fits all the data in this view, along with the subgregion offset in the current buffer and the required subregion length. """ subregion_origin = self._full_min_offset subregion_size = self._full_max_offset - self._full_min_offset new_metadata = ArrayMetadata( self.shape, self.dtype, strides=self.strides, first_element_offset=self.first_element_offset - self._full_min_offset, buffer_size=subregion_size, ) return subregion_origin, subregion_size, new_metadata def __getitem__(self, slices: Union[slice, Tuple[slice, ...]]) -> "ArrayMetadata": if isinstance(slices, slice): slices = (slices,) if len(slices) < len(self.shape): slices += (slice(None),) * (len(self.shape) - len(slices)) new_fe_offset, new_shape, new_strides = get_view(self.shape, self.strides, slices) return ArrayMetadata( new_shape, self.dtype, strides=new_strides, first_element_offset=new_fe_offset ) def get_strides(shape: Sequence[int], itemsize: int) -> Tuple[int, ...]: # Constructs strides for a contiguous array of shape ``shape`` and item size ``itemsize``. strides = [] stride = itemsize for length in reversed(shape): strides.append(stride) stride *= length return tuple(reversed(strides)) def normalize_slice(length: int, stride: int, slice_: slice) -> Tuple[int, int, int]: """ Given a slice over an array of length ``length`` with the stride ``stride`` between elements, return a tuple ``(offset, last, stride)`` where ``offset`` is the offset of the first element of the resulting view, ``last`` is the index of the last element of the view (0-based), and ``stride`` is the new stride between elements. """ start, stop, step = slice_.indices(length) offset = start * stride total_elems = abs(stop - start) length = (total_elems - 1) // abs(step) + 1 new_stride = stride * step return offset, length, new_stride def get_view( shape: Sequence[int], strides: Sequence[int], slices: Sequence[slice] ) -> Tuple[int, Tuple[int, ...], Tuple[int, ...]]: """ Given an array shape and strides, and a sequence of slices defining a view, returns a tuple of three elements: the offset of the first element of the view, the view shape and the view strides. """ assert len(slices) == len(shape) assert len(strides) == len(shape) offsets, lengths, strides = zip( *[ normalize_slice(length, stride, slice_) for length, stride, slice_ in zip(shape, strides, slices) ] ) return sum(offsets), tuple(lengths), tuple(strides) def get_range(shape: Sequence[int], itemsize: int, strides: Sequence[int]) -> Tuple[int, int]: """ Given an array shape, item size (in bytes), and a sequence of strides, returns a pair ``(min_offset, max_offset)``, where ``min_offset`` is the minimum byte offset of an array element, and ``max_offset`` is the maximum byte offset of an array element plus itemsize. """ assert len(strides) == len(shape) # Now the address of an element (i1, i2, ...) of the resulting view is # addr = i1 * stride1 + i2 * stride2 + ..., # where 0 <= i_k <= length_k - 1 # We want to find the minimum and the maximum value of addr, # keeping in mind that strides may be negative. # Since it is a linear function of each index, the extrema will be located # at the ends of intervals, so we can find minima and maxima for each term separately. # Since we separated the offsets already, for each dimension the address # of the first element is 0. We calculate the address of the last byte in each dimension. last_addrs = [(length - 1) * stride for length, stride in zip(shape, strides)] # Sort the pairs (0, last_addr) pairs = [(0, last_addr) if last_addr > 0 else (last_addr, 0) for last_addr in last_addrs] minima, maxima = zip(*pairs) min_offset = sum(minima) max_offset = sum(maxima) + itemsize return min_offset, max_offset