Source code for grunnur.vsize

from __future__ import annotations

from collections import Counter
import itertools
from math import floor, ceil, sqrt
from typing import (
    NamedTuple,
    Mapping,
    List,
    Dict,
    Iterable,
    Optional,
    Tuple,
    Iterator,
    Sequence,
    Callable,
    Any,
)

from .template import Template
from .modules import Module, RenderableModule, Snippet
from .utils import min_blocks, prod


TEMPLATE = Template.from_associated_file(__file__)


def factorize(num: int) -> List[int]:
    step: Callable[[int], int] = lambda x: 1 + (x << 2) - ((x >> 1) << 1)
    maxq = int(floor(sqrt(num)))
    d = 1
    q = 2 if num % 2 == 0 else 3
    while q <= maxq and num % q != 0:
        q = step(d)
        d += 1
    return [q] + factorize(num // q) if q <= maxq else [num]


class PrimeFactors:
    """
    Contains a natural number's decomposition into prime factors.
    """

    def __init__(self, factors: Mapping[int, int]):
        self.factors = factors

    @classmethod
    def decompose(cls, num: int) -> "PrimeFactors":
        factors_list = factorize(num)
        factors = Counter(factors_list)
        return cls(dict(factors))

    def get_value(self) -> int:
        res = 1
        for pwr, exp in self.factors.items():
            res *= pwr**exp
        return res

    def get_arrays(self) -> Tuple[List[int], List[int]]:
        bases = list(self.factors.keys())
        exponents = [self.factors[base] for base in bases]
        return bases, exponents

    def div_by(self, other: "PrimeFactors") -> "PrimeFactors":
        # assumes that `self` is a multiple of `other`
        factors = dict(self.factors)
        for o_pwr, o_exp in other.factors.items():
            factors[o_pwr] -= o_exp
            assert factors[o_pwr] >= 0  # sanity check
            if factors[o_pwr] == 0:
                del factors[o_pwr]
        return PrimeFactors(factors)

    def __eq__(self, other: Any) -> bool:
        return isinstance(other, PrimeFactors) and self.factors == other.factors


def _get_decompositions(num_factors: PrimeFactors, parts: int) -> Iterator[List[int]]:
    """
    Helper recursive function for ``get_decompositions()``.
    Iterates over all possible decompositions of ``num_factors`` into ``parts`` factors.
    """
    if parts == 1:
        yield [num_factors.get_value()]
        return

    bases, exponents = num_factors.get_arrays()
    for sub_exps in itertools.product(*[range(exp, -1, -1) for exp in exponents]):
        part_factors = PrimeFactors(
            dict(((pwr, sub_exp) for pwr, sub_exp in zip(bases, sub_exps) if sub_exp > 0))
        )
        part = part_factors.get_value()
        remainder = num_factors.div_by(part_factors)
        for decomp in _get_decompositions(remainder, parts - 1):
            yield [part] + decomp


def get_decompositions(num: int, parts: int) -> Iterator[List[int]]:
    """
    Iterates over all possible decompositions of ``num`` into ``parts`` factors.
    """
    num_factors = PrimeFactors.decompose(num)
    return _get_decompositions(num_factors, parts)


def find_local_size_decomposition(
    global_size: Sequence[int], flat_local_size: int, threshold: float = 0.05
) -> List[int]:
    """
    Returns a tuple of the same size as ``global_size``,
    with the product equal to ``flat_local_size``,
    and minimal difference between ``product(global_size)``
    and ``product(min_blocks(gs, ls) * ls for gs, ls in zip(global_size, local_size))``
    (i.e. tries to minimize the amount of empty threads).
    """
    flat_global_size = prod(global_size)
    if flat_local_size >= flat_global_size:
        return list(global_size)

    threads_num = prod(global_size)

    best_ratio: Optional[float] = None
    best_local_size: Optional[List[int]] = None

    for local_size in get_decompositions(flat_local_size, len(global_size)):
        bounding_global_size = [ls * min_blocks(gs, ls) for gs, ls in zip(global_size, local_size)]
        empty_threads = prod(bounding_global_size) - threads_num
        ratio = empty_threads / threads_num

        # Stopping iteration early, because there may be a lot of elements to iterate over,
        # and we do not need the perfect solution.
        if ratio < threshold:
            return local_size

        if best_ratio is None or ratio < best_ratio:
            best_ratio = ratio
            best_local_size = local_size

    # This looks like the above loop can finish without setting `best_local_size`,
    # but providing flat_local_size <= product(global_size),
    # there is at least one decomposition (flat_local_size, 1, 1, ...).

    assert best_local_size is not None  # sanity check to catch a possible bug early

    return best_local_size


def _group_dimensions(
    vdim: int, virtual_shape: Sequence[int], adim: int, available_shape: Sequence[int]
) -> Tuple[List[List[int]], List[List[int]]]:
    """
    ``vdim`` and ``adim`` are used for the absolute addressing of dimensions during recursive calls.
    """
    if len(virtual_shape) == 0:
        return [], []

    vdim_group = 1  # number of currently grouped virtual dimensions
    adim_group = 1  # number of currently grouped available dimensions

    while True:
        # If we have more elements in the virtual group than there is in the available group,
        # extend the available group by one dimension.
        if prod(virtual_shape[:vdim_group]) > prod(available_shape[:adim_group]):
            adim_group += 1
            continue

        # If the remaining available dimensions cannot accommodate the remaining virtual dimensions,
        # we try to fit one more virtual dimension in the virtual group.
        if prod(virtual_shape[vdim_group:]) > prod(available_shape[adim_group:]):
            vdim_group += 1
            continue

        # If we are here, it means that:
        # 1) the current available group can accommodate the current virtual group;
        # 2) the remaining available dimensions can accommodate the remaining virtual dimensions.
        # This means we can make a recursive call now.

        # Attach any following trivial virtual dimensions (of size 1) to this group
        # This will help to avoid unassigned trivial dimensions with no real dimensions left.
        while vdim_group < len(virtual_shape) and virtual_shape[vdim_group] == 1:
            vdim_group += 1

        v_res = list(range(vdim, vdim + vdim_group))
        a_res = list(range(adim, adim + adim_group))

        v_remainder, a_remainder = _group_dimensions(
            vdim + vdim_group,
            virtual_shape[vdim_group:],
            adim + adim_group,
            available_shape[adim_group:],
        )
        return [v_res] + v_remainder, [a_res] + a_remainder


def group_dimensions(
    virtual_shape: Sequence[int], available_shape: Sequence[int]
) -> Tuple[List[List[int]], List[List[int]]]:
    """
    Determines which available dimensions the virtual dimensions can be embedded into.
    Prefers using the maximum number of available dimensions, since in that case
    less divisions will be needed to calculate virtual indices based on the real ones.

    Returns two lists, one of tuples with indices of grouped virtual dimensions, the other
    one of tuples with indices of corresponding group of available dimensions,
    such that for any group of virtual dimensions, the total number of elements they cover
    does not exceed the number of elements covered by the
    corresponding group of available dimensions.

    Dimensions are grouped in order, so tuples in both lists, if concatenated,
    give `(0 ... len(virtual_shape)-1)` and `(0 ... n)`, where `n < len(available_shape`.
    """
    assert prod(virtual_shape) <= prod(available_shape)
    return _group_dimensions(0, virtual_shape, 0, available_shape)


def find_bounding_shape(virtual_size: int, available_shape: Sequence[int]) -> List[int]:
    """
    Finds a tuple of the same length as ``available_shape``, with every element
    not greater than the corresponding element of ``available_shape``,
    and product not lower than ``virtual_size`` (trying to minimize that difference).
    """

    # TODO: in most cases it is possible to find such a tuple that `prod(result) == virtual_size`,
    # but the current algorithm does not gurantee it. Finding such a tuple would
    # eliminate some empty threads.

    assert virtual_size <= prod(available_shape)

    free_size = virtual_size
    free_dims = set(range(len(available_shape)))
    bounding_shape = [0] * len(available_shape)

    # The loop terminates, since `virtual_size` is guaranteed to fit into `available_shape`
    # (worst case scenario, the result will be just `available_shape` itself).
    while True:
        guess = ceil(free_size ** (1 / len(free_dims)))
        fixed_size = 1
        for fdim in list(free_dims):
            if guess > available_shape[fdim]:
                fixed_size *= available_shape[fdim]
                free_dims.remove(fdim)
                bounding_shape[fdim] = available_shape[fdim]
            else:
                bounding_shape[fdim] = guess

        if fixed_size == 1:
            break

        free_size = min_blocks(free_size, fixed_size)

    return bounding_shape


class ShapeGroups:
    def __init__(self, virtual_shape: Sequence[int], available_shape: Sequence[int]):
        # A mapping from a dimension in the virtual shape to a tuple of dimensions
        # in the real shape it uses (and possibly shares with other virtual dimensions).
        self.real_dims: Dict[int, List[int]] = {}

        # A mapping from a dimension in the virtual shape to a tuple of strides
        # used to get a flat index in the group of real dimensions it uses.
        self.real_strides: Dict[int, List[int]] = {}

        # A mapping from a dimension in the virtual shape to the stride that is used to extract it
        # from the flat index obtained from the corresponding group of real dimensions.
        self.virtual_strides: Dict[int, int] = {}

        # A mapping from a dimension in the virtual shape to the major dimension
        # (the one with the largest stride) in the group of virtual dimensions it belongs to
        # (the group includes all virtual dimensions using a certain subset of real dimensions).
        self.major_vdims: Dict[int, int] = {}

        # The actual shape used to enqueue the kernel.
        self.bounding_shape: List[int] = []

        # A list of tuples `(threshold, stride_info)` used for skipping unused threads.
        # `stride_info` is a list of 2-tuples `(real_dim, stride)` used to construct
        # a flat index from several real dimensions, and then compare it with the threshold.
        self.skip_thresholds: List[Tuple[int, List[Tuple[int, int]]]] = []

        v_groups, a_groups = group_dimensions(virtual_shape, available_shape)

        for v_group, a_group in zip(v_groups, a_groups):
            virtual_subshape = virtual_shape[v_group[0] : v_group[-1] + 1]
            virtual_subsize = prod(virtual_subshape)

            bounding_subshape = find_bounding_shape(
                virtual_subsize, available_shape[a_group[0] : a_group[-1] + 1]
            )

            self.bounding_shape += bounding_subshape

            if virtual_subsize < prod(bounding_subshape):
                strides = [(adim, prod(bounding_subshape[:i])) for i, adim in enumerate(a_group)]
                self.skip_thresholds.append((virtual_subsize, strides))

            for vdim in v_group:
                self.real_dims[vdim] = a_group
                self.real_strides[vdim] = [
                    prod(self.bounding_shape[a_group[0] : adim]) for adim in a_group
                ]
                self.virtual_strides[vdim] = prod(virtual_shape[v_group[0] : vdim])

                # The major virtual dimension (the one that does not require
                # modulus operation when extracting its index from the flat index)
                # is the last non-trivial one (not of size 1).
                # Modulus will not be optimized away by the compiler,
                # but we know that all threads outside of the virtual group will be
                # filtered out by VIRTUAL_SKIP_THREADS.
                for major_vdim in range(len(v_group) - 1, -1, -1):
                    if virtual_shape[v_group[major_vdim]] > 1:
                        break

                self.major_vdims[vdim] = v_group[major_vdim]


[docs]class VsizeModules(NamedTuple): """ A collection of modules passed to :py:class:`grunnur.StaticKernel`. Should be used instead of regular group/thread id functions. """ local_id: Module """ Provides the function ``VSIZE_T ${local_id}(int dim)`` returning the local id of the current thread. """ local_size: Module """ Provides the function ``VSIZE_T ${local_size}(int dim)`` returning the size of the current group. """ group_id: Module """ Provides the function ``VSIZE_T ${group_id}(int dim)`` returning the group id of the current thread. """ num_groups: Module """ Provides the function ``VSIZE_T ${num_groups}(int dim)`` returning the number of groups in dimension ``dim``. """ global_id: Module """ Provides the function ``VSIZE_T ${global_id}(int dim)`` returning the global id of the current thread. """ global_size: Module """ Provides the function ``VSIZE_T ${global_size}(int dim)`` returning the global size along dimension ``dim``.""" global_flat_id: Module """ Provides the function ``VSIZE_T ${global_flat_id}()`` returning the global id of the current thread with all dimensions flattened. """ global_flat_size: Module """ Provides the function ``VSIZE_T ${global_flat_size}()``. returning the global size of with all dimensions flattened. """ skip: Module """ Provides the function ``bool ${skip}()`` that should be used at the start of a static kernel function to see if the current thread/work item is inside the padding area and needs to be skipped. Usually one would write ``if (${skip}()) return;``. """ def __process_modules__( self, process: Callable[[Module], RenderableModule] ) -> "VsizeRenderableModules": return VsizeRenderableModules( local_id=process(self.local_id), local_size=process(self.local_size), group_id=process(self.group_id), num_groups=process(self.num_groups), global_id=process(self.global_id), global_size=process(self.global_size), global_flat_id=process(self.global_flat_id), global_flat_size=process(self.global_flat_size), skip=process(self.skip), ) @classmethod def from_shape_data( cls, virtual_global_size: Sequence[int], virtual_local_size: Sequence[int], bounding_global_size: Sequence[int], virtual_grid_size: Sequence[int], local_groups: ShapeGroups, grid_groups: ShapeGroups, ) -> "VsizeModules": local_id_mod = Module( TEMPLATE.get_def("local_id"), render_globals=dict(virtual_local_size=virtual_local_size, local_groups=local_groups), ) local_size_mod = Module( TEMPLATE.get_def("local_size"), render_globals=dict(virtual_local_size=virtual_local_size), ) group_id_mod = Module( TEMPLATE.get_def("group_id"), render_globals=dict(virtual_grid_size=virtual_grid_size, grid_groups=grid_groups), ) num_groups_mod = Module( TEMPLATE.get_def("num_groups"), render_globals=dict(virtual_grid_size=virtual_grid_size) ) global_id_mod = Module( TEMPLATE.get_def("global_id"), render_globals=dict( local_id_mod=local_id_mod, group_id_mod=group_id_mod, local_size_mod=local_size_mod ), ) global_size_mod = Module( TEMPLATE.get_def("global_size"), render_globals=dict(virtual_global_size=virtual_global_size), ) global_flat_id_mod = Module( TEMPLATE.get_def("global_flat_id"), render_globals=dict( virtual_global_size=virtual_global_size, global_id_mod=global_id_mod, prod=prod ), ) global_flat_size_mod = Module( TEMPLATE.get_def("global_flat_size"), render_globals=dict( global_size_mod=global_size_mod, virtual_global_size=virtual_global_size ), ) skip_local_threads_mod = Module( TEMPLATE.get_def("skip_local_threads"), render_globals=dict(local_groups=local_groups) ) skip_groups_mod = Module( TEMPLATE.get_def("skip_groups"), render_globals=dict(grid_groups=grid_groups) ) skip_global_threads_mod = Module( TEMPLATE.get_def("skip_global_threads"), render_globals=dict( virtual_global_size=virtual_global_size, bounding_global_size=bounding_global_size, global_id_mod=global_id_mod, ), ) skip = Module( TEMPLATE.get_def("skip"), render_globals=dict( skip_local_threads_mod=skip_local_threads_mod, skip_groups_mod=skip_groups_mod, skip_global_threads_mod=skip_global_threads_mod, ), ) return cls( local_id=local_id_mod, local_size=local_size_mod, group_id=group_id_mod, num_groups=num_groups_mod, global_id=global_id_mod, global_size=global_size_mod, global_flat_id=global_flat_id_mod, global_flat_size=global_flat_size_mod, skip=skip, )
class VsizeRenderableModules(NamedTuple): local_id: RenderableModule local_size: RenderableModule group_id: RenderableModule num_groups: RenderableModule global_id: RenderableModule global_size: RenderableModule global_flat_id: RenderableModule global_flat_size: RenderableModule skip: RenderableModule class VirtualSizeError(Exception): """ Raised when a virtual size cannot be found due to device limitations. """ pass class VirtualSizes: real_local_size: Tuple[int, ...] real_global_size: Tuple[int, ...] vsize_modules: VsizeModules def __init__( self, max_total_local_size: int, max_local_sizes: Sequence[int], max_num_groups: Sequence[int], local_size_multiple: int, virtual_global_size: Sequence[int], virtual_local_size: Optional[Sequence[int]] = None, ): if virtual_local_size is not None: if len(virtual_local_size) != len(virtual_global_size): raise ValueError( "Global size and local size must have the same number of dimensions" ) # Since the device uses column-major ordering of sizes, while we get # row-major ordered shapes, we invert our shapes # to facilitate internal handling. virtual_global_size = list(reversed(virtual_global_size)) if virtual_local_size is not None: virtual_local_size = list(reversed(virtual_local_size)) # In device parameters `max_total_local_size` is >= any of `max_local_sizes`, # but it can be overridden to get a kernel that uses less resources. max_local_sizes = [min(max_total_local_size, mls) for mls in max_local_sizes] assert max_total_local_size <= prod(max_local_sizes) # sanity check if virtual_local_size is None: # FIXME: we can obtain better results by taking occupancy into account here, # but for now we will assume that the more threads, the better. flat_global_size = prod(virtual_global_size) if flat_global_size < max_total_local_size: flat_local_size = flat_global_size else: # A sanity check - it would be very strange if a device had a local size multiple # so big you can't actually launch that many threads. assert max_total_local_size >= local_size_multiple flat_local_size = local_size_multiple * ( max_total_local_size // local_size_multiple ) # product(virtual_local_size) == flat_local_size <= max_total_local_size # Note: it's ok if local size elements are greater # than the corresponding global size elements as long as it minimizes the total # number of skipped threads. virtual_local_size = find_local_size_decomposition(virtual_global_size, flat_local_size) else: if prod(virtual_local_size) > max_total_local_size: raise VirtualSizeError( f"Requested local size is greater than the maximum {max_total_local_size}" ) # Global and local sizes supported by CUDA or OpenCL restricted number of dimensions, # which may have limited size, so we need to pack our multidimensional sizes. virtual_grid_size = [ min_blocks(gs, ls) for gs, ls in zip(virtual_global_size, virtual_local_size) ] bounding_global_size = [grs * ls for grs, ls in zip(virtual_grid_size, virtual_local_size)] if prod(virtual_grid_size) > prod(max_num_groups): # Report the bounding size in reversed form so that it matches the provided # virtual global size. raise VirtualSizeError( f"Bounding global size {list(reversed(bounding_global_size))} is too large" ) local_groups = ShapeGroups(virtual_local_size, max_local_sizes) grid_groups = ShapeGroups(virtual_grid_size, max_num_groups) # These can be different lenghts because of expansion into multiple dimensions # find_bounding_shape() does. real_local_size = local_groups.bounding_shape real_grid_size = grid_groups.bounding_shape diff = len(real_local_size) - len(real_grid_size) real_local_size = real_local_size + [1] * (-diff) real_grid_size = real_grid_size + [1] * diff # This function will be used to translate between internal column-major vdims # and user-supplied row-major vdims. vsize_modules = VsizeModules.from_shape_data( virtual_local_size=virtual_local_size, virtual_global_size=virtual_global_size, bounding_global_size=bounding_global_size, virtual_grid_size=virtual_grid_size, local_groups=local_groups, grid_groups=grid_groups, ) # For testing purposes # (Note that these will have column-major order, same as real_global/local_size) self._virtual_local_size = virtual_local_size self._virtual_global_size = virtual_global_size self._bounding_global_size = bounding_global_size self._virtual_grid_size = virtual_grid_size self.real_local_size = tuple(real_local_size) self.real_global_size = tuple(gs * ls for gs, ls in zip(real_grid_size, real_local_size)) self.vsize_modules = vsize_modules