"""
This module contains :py:class:`~grunnur.Module` factories
which are used to compensate for the lack of complex number operations in OpenCL,
and the lack of C++ synthax which would allow one to write them.
"""
from typing import Optional, Any
from warnings import warn
import numpy
from .template import Template
from . import dtypes
from .modules import Module
TEMPLATE = Template.from_associated_file(__file__)
def check_information_loss(
out_dtype: "numpy.dtype[Any]", expected_dtype: "numpy.dtype[Any]"
) -> None:
if dtypes.is_complex(expected_dtype) and not dtypes.is_complex(out_dtype):
warn(
"Imaginary part ignored during the downcast from "
+ str(expected_dtype)
+ " to "
+ str(out_dtype),
numpy.ComplexWarning,
)
def derive_out_dtype(
*in_dtypes: "numpy.dtype[Any]", out_dtype: Optional["numpy.dtype[Any]"] = None
) -> "numpy.dtype[Any]":
expected_dtype = dtypes.result_type(*in_dtypes)
if out_dtype is None:
out_dtype = expected_dtype
else:
check_information_loss(out_dtype, expected_dtype)
return out_dtype
[docs]def cast(in_dtype: "numpy.dtype[Any]", out_dtype: "numpy.dtype[Any]") -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that casts values of ``in_dtype`` to ``out_dtype``.
"""
return Module(
TEMPLATE.get_def("cast"),
render_globals=dict(dtypes=dtypes, out_dtype=out_dtype, in_dtype=in_dtype),
)
[docs]def add(*in_dtypes: "numpy.dtype[Any]", out_dtype: Optional["numpy.dtype[Any]"] = None) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of
``len(in_dtypes)`` arguments that adds values of types ``in_dtypes``.
If ``out_dtype`` is given, it will be set as a return type for this function.
This is necessary since on some platforms complex numbers are based on 2-vectors,
and therefore the ``+`` operator for a complex and a real number
works in an unexpected way (returning ``(a.x + b, a.y + b)`` instead of ``(a.x + b, a.y)``).
"""
out_dtype = derive_out_dtype(*in_dtypes, out_dtype=out_dtype)
return Module(
TEMPLATE.get_def("add_or_mul"),
render_globals=dict(dtypes=dtypes, op="add", out_dtype=out_dtype, in_dtypes=in_dtypes),
)
[docs]def mul(*in_dtypes: "numpy.dtype[Any]", out_dtype: Optional["numpy.dtype[Any]"] = None) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of
``len(in_dtypes)`` arguments that multiplies values of types ``in_dtypes``.
If ``out_dtype`` is given, it will be set as a return type for this function.
"""
out_dtype = derive_out_dtype(*in_dtypes, out_dtype=out_dtype)
return Module(
TEMPLATE.get_def("add_or_mul"),
render_globals=dict(dtypes=dtypes, op="mul", out_dtype=out_dtype, in_dtypes=in_dtypes),
)
[docs]def div(
dividend_dtype: "numpy.dtype[Any]",
divisor_dtype: "numpy.dtype[Any]",
out_dtype: Optional["numpy.dtype[Any]"] = None,
) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of two arguments
that divides a value of type ``dividend_dtype`` by a value of type ``divisor_dtype``.
If ``out_dtype`` is given, it will be set as a return type for this function.
"""
out_dtype = derive_out_dtype(dividend_dtype, divisor_dtype, out_dtype=out_dtype)
return Module(
TEMPLATE.get_def("div"),
render_globals=dict(
dtypes=dtypes,
out_dtype=out_dtype,
dividend_dtype=dividend_dtype,
divisor_dtype=divisor_dtype,
),
)
[docs]def conj(dtype: "numpy.dtype[Any]") -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that conjugates the value of type ``dtype``
(if it is not a complex data type, the value will not be modified).
"""
return Module(TEMPLATE.get_def("conj"), render_globals=dict(dtypes=dtypes, dtype=dtype))
[docs]def polar_unit(dtype: "numpy.dtype[Any]") -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that returns a complex number ``exp(i * theta) == (cos(theta), sin(theta))``
for a value ``theta`` of type ``dtype`` (must be a real data type).
"""
if not dtypes.is_real(dtype):
raise ValueError("polar_unit() can only be applied to real dtypes")
return Module(TEMPLATE.get_def("polar_unit"), render_globals=dict(dtypes=dtypes, dtype=dtype))
[docs]def norm(dtype: "numpy.dtype[Any]") -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that returns the 2-norm of the value of type ``dtype``
(product by the complex conjugate if the value is complex, square otherwise).
"""
return Module(TEMPLATE.get_def("norm"), render_globals=dict(dtypes=dtypes, dtype=dtype))
[docs]def exp(dtype: "numpy.dtype[Any]") -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of one argument
that exponentiates the value of type ``dtype``
(must be a real or a complex data type).
"""
# Supporting this will require an explicit output type specification.
if dtypes.is_integer(dtype):
raise ValueError(f"exp() of {dtype} is not supported")
if dtypes.is_real(dtype):
polar_unit_ = None
else:
polar_unit_ = polar_unit(dtypes.real_for(dtype))
return Module(
TEMPLATE.get_def("exp"),
render_globals=dict(dtypes=dtypes, dtype=dtype, polar_unit_=polar_unit_),
)
[docs]def pow(
base_dtype: "numpy.dtype[Any]",
exponent_dtype: Optional["numpy.dtype[Any]"] = None,
out_dtype: Optional["numpy.dtype[Any]"] = None,
) -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of two arguments
that raises the first argument of type ``base_dtype``
to the power of the second argument of type ``exponent_dtype``
(an integer or real data type).
If ``exponent_dtype`` or ``out_dtype`` are not given, they default to ``base_dtype``.
If ``base_dtype`` is not the same as ``out_dtype``,
the input is cast to ``out_dtype`` *before* exponentiation.
If ``exponent_dtype`` is real, but both ``base_dtype`` and ``out_dtype`` are integer,
a ``ValueError`` is raised.
"""
if exponent_dtype is None:
exponent_dtype = base_dtype
if out_dtype is None:
out_dtype = base_dtype
if dtypes.is_complex(exponent_dtype):
raise ValueError("pow() with a complex exponent is not supported")
if dtypes.is_real(exponent_dtype):
if dtypes.is_complex(out_dtype):
exponent_dtype = dtypes.real_for(out_dtype)
elif dtypes.is_real(out_dtype):
exponent_dtype = out_dtype
else:
raise ValueError("pow(integer, float): integer is not supported")
kwds = dict(
dtypes=dtypes,
base_dtype=base_dtype,
exponent_dtype=exponent_dtype,
out_dtype=out_dtype,
div_=None,
mul_=None,
cast_=None,
polar_=None,
)
if out_dtype != base_dtype:
kwds["cast_"] = cast(base_dtype, out_dtype)
if dtypes.is_integer(exponent_dtype) and not dtypes.is_real(out_dtype):
kwds["mul_"] = mul(out_dtype, out_dtype)
kwds["div_"] = div(out_dtype, out_dtype)
if dtypes.is_complex(out_dtype):
kwds["polar_"] = polar(dtypes.real_for(out_dtype))
return Module(TEMPLATE.get_def("pow"), render_globals=kwds)
[docs]def polar(dtype: "numpy.dtype[Any]") -> Module:
"""
Returns a :py:class:`~grunnur.Module` with a function of two arguments
that returns the complex-valued ``rho * exp(i * theta)``
for values ``rho, theta`` of type ``dtype`` (must be a real data type).
"""
if not dtypes.is_real(dtype):
raise ValueError("polar() of " + str(dtype) + " is not supported")
return Module(
TEMPLATE.get_def("polar"),
render_globals=dict(dtypes=dtypes, dtype=dtype, polar_unit_=polar_unit(dtype)),
)