"""
:py:class:`~grunnur.Module` factories
which are used to compensate for the lack of complex number operations in OpenCL,
and the lack of C++ synthax which would allow one to write them.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from warnings import warn
import numpy
from . import dtypes
from ._modules import Module
from ._template import Template
if TYPE_CHECKING: # pragma: no cover
from numpy.typing import DTypeLike
TEMPLATE = Template.from_associated_file(__file__)
def _check_information_loss(out_dtype: DTypeLike, expected_dtype: DTypeLike) -> None:
if dtypes.is_complex(expected_dtype) and not dtypes.is_complex(out_dtype):
warn(
"Imaginary part ignored during the downcast from "
+ str(expected_dtype)
+ " to "
+ str(out_dtype),
numpy.exceptions.ComplexWarning,
stacklevel=2,
)
def _derive_out_dtype(
*in_dtypes: DTypeLike, out_dtype: DTypeLike | None = None
) -> tuple[list[numpy.dtype[Any]], numpy.dtype[Any]]:
in_dtypes_normalized = [numpy.dtype(dtype) for dtype in in_dtypes]
expected_dtype = dtypes.result_type(*in_dtypes_normalized)
if out_dtype is None:
result = expected_dtype
else:
_check_information_loss(out_dtype, expected_dtype)
result = numpy.dtype(out_dtype)
return in_dtypes_normalized, result
[docs]
def cast(in_dtype: DTypeLike, out_dtype: DTypeLike) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that casts values of ``in_dtype`` to ``out_dtype``.
"""
in_dtype = numpy.dtype(in_dtype)
out_dtype = numpy.dtype(out_dtype)
upcast_to_complex = not dtypes.is_complex(in_dtype) and dtypes.is_complex(out_dtype)
same_space = dtypes.is_complex(out_dtype) == dtypes.is_complex(in_dtype)
if not upcast_to_complex and not same_space:
raise ValueError(f"cast from {in_dtype} to {out_dtype} is not supported")
return Module(
TEMPLATE.get_def("cast"),
render_globals=dict(
dtypes=dtypes,
out_dtype=out_dtype,
in_dtype=in_dtype,
upcast_to_complex=upcast_to_complex,
same_space=same_space,
),
)
[docs]
def add(*in_dtypes: DTypeLike, out_dtype: DTypeLike | None = None) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of
``len(in_dtypes)`` arguments that adds values of types ``in_dtypes``.
If ``out_dtype`` is given, it will be set as a return type for this function.
This is necessary since on some platforms complex numbers are based on 2-vectors,
and therefore the ``+`` operator for a complex and a real number
works in an unexpected way (returning ``(a.x + b, a.y + b)`` instead of ``(a.x + b, a.y)``).
"""
in_dtypes_normalized, out_dtype = _derive_out_dtype(*in_dtypes, out_dtype=out_dtype)
return Module(
TEMPLATE.get_def("add_or_mul"),
render_globals=dict(
dtypes=dtypes, op="add", out_dtype=out_dtype, in_dtypes=in_dtypes_normalized
),
)
[docs]
def mul(*in_dtypes: DTypeLike, out_dtype: DTypeLike | None = None) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of
``len(in_dtypes)`` arguments that multiplies values of types ``in_dtypes``.
If ``out_dtype`` is given, it will be set as a return type for this function.
"""
in_dtypes_normalized, out_dtype = _derive_out_dtype(*in_dtypes, out_dtype=out_dtype)
return Module(
TEMPLATE.get_def("add_or_mul"),
render_globals=dict(
dtypes=dtypes, op="mul", out_dtype=out_dtype, in_dtypes=in_dtypes_normalized
),
)
[docs]
def div(
dividend_dtype: DTypeLike,
divisor_dtype: DTypeLike,
out_dtype: DTypeLike | None = None,
) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of two arguments
that divides a value of type ``dividend_dtype`` by a value of type ``divisor_dtype``.
If ``out_dtype`` is given, it will be set as a return type for this function.
"""
_in_dtypes, out_dtype = _derive_out_dtype(dividend_dtype, divisor_dtype, out_dtype=out_dtype)
return Module(
TEMPLATE.get_def("div"),
render_globals=dict(
dtypes=dtypes,
out_dtype=out_dtype,
dividend_dtype=dividend_dtype,
divisor_dtype=divisor_dtype,
),
)
[docs]
def conj(dtype: DTypeLike) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that conjugates the value of type ``dtype``
(if it is not a complex data type, the value will not be modified).
"""
return Module(
TEMPLATE.get_def("conj"), render_globals=dict(dtypes=dtypes, dtype=numpy.dtype(dtype))
)
[docs]
def polar_unit(dtype: DTypeLike) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that returns a complex number ``exp(i * theta) == (cos(theta), sin(theta))``
for a value ``theta`` of type ``dtype`` (must be a real data type).
"""
dtype = numpy.dtype(dtype)
if not dtypes.is_real(dtype):
raise ValueError("polar_unit() can only be applied to real dtypes")
return Module(TEMPLATE.get_def("polar_unit"), render_globals=dict(dtypes=dtypes, dtype=dtype))
[docs]
def norm(dtype: DTypeLike) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that returns the 2-norm of the value of type ``dtype``
(product by the complex conjugate if the value is complex, square otherwise).
"""
return Module(
TEMPLATE.get_def("norm"), render_globals=dict(dtypes=dtypes, dtype=numpy.dtype(dtype))
)
[docs]
def exp(dtype: DTypeLike) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that exponentiates the value of type ``dtype``
(must be a real or a complex data type).
"""
dtype = numpy.dtype(dtype)
# Supporting this will require an explicit output type specification.
if dtypes.is_integer(dtype):
raise ValueError(f"exp() of {dtype} is not supported")
polar_unit_ = None if dtypes.is_real(dtype) else polar_unit(dtypes.real_for(dtype))
return Module(
TEMPLATE.get_def("exp"),
render_globals=dict(dtypes=dtypes, dtype=dtype, polar_unit_=polar_unit_),
)
[docs]
def pow( # noqa: A001
base_dtype: DTypeLike,
exponent_dtype: DTypeLike | None = None,
out_dtype: DTypeLike | None = None,
) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of two arguments
that raises the first argument of type ``base_dtype``
to the power of the second argument of type ``exponent_dtype``
(an integer or real data type).
If ``exponent_dtype`` or ``out_dtype`` are not given, they default to ``base_dtype``.
If ``base_dtype`` is not the same as ``out_dtype``,
the input is cast to ``out_dtype`` *before* exponentiation.
If ``exponent_dtype`` is real, but both ``base_dtype`` and ``out_dtype`` are integer,
a ``ValueError`` is raised.
"""
base_dtype = numpy.dtype(base_dtype)
exponent_dtype = base_dtype if exponent_dtype is None else numpy.dtype(exponent_dtype)
out_dtype = base_dtype if out_dtype is None else numpy.dtype(out_dtype)
if dtypes.is_complex(exponent_dtype):
raise ValueError("pow() with a complex exponent is not supported")
if dtypes.is_real(exponent_dtype):
if dtypes.is_complex(out_dtype):
exponent_dtype = dtypes.real_for(out_dtype)
elif dtypes.is_real(out_dtype):
exponent_dtype = out_dtype
else:
raise ValueError("pow(integer, float): integer is not supported")
kwds: dict[str, Any] = dict(
dtypes=dtypes,
base_dtype=base_dtype,
exponent_dtype=exponent_dtype,
out_dtype=out_dtype,
div_=None,
mul_=None,
cast_=None,
polar_=None,
)
if out_dtype != base_dtype:
kwds["cast_"] = cast(base_dtype, out_dtype)
if dtypes.is_integer(exponent_dtype) and not dtypes.is_real(out_dtype):
kwds["mul_"] = mul(out_dtype, out_dtype)
kwds["div_"] = div(out_dtype, out_dtype)
if dtypes.is_complex(out_dtype):
kwds["polar_"] = polar(dtypes.real_for(out_dtype))
return Module(TEMPLATE.get_def("pow"), render_globals=kwds)
[docs]
def polar(dtype: DTypeLike) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of two arguments
that returns the complex-valued ``rho * exp(i * theta)``
for values ``rho, theta`` of type ``dtype`` (must be a real data type).
"""
dtype = numpy.dtype(dtype)
if not dtypes.is_real(dtype):
raise ValueError("polar() of " + str(dtype) + " is not supported")
return Module(
TEMPLATE.get_def("polar"),
render_globals=dict(dtypes=dtypes, dtype=dtype, polar_unit_=polar_unit(dtype)),
)