add read me
This commit is contained in:
24
venv/lib/python3.12/site-packages/sympy/codegen/__init__.py
Normal file
24
venv/lib/python3.12/site-packages/sympy/codegen/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
""" The ``sympy.codegen`` module contains classes and functions for building
|
||||
abstract syntax trees of algorithms. These trees may then be printed by the
|
||||
code-printers in ``sympy.printing``.
|
||||
|
||||
There are several submodules available:
|
||||
- ``sympy.codegen.ast``: AST nodes useful across multiple languages.
|
||||
- ``sympy.codegen.cnodes``: AST nodes useful for the C family of languages.
|
||||
- ``sympy.codegen.fnodes``: AST nodes useful for Fortran.
|
||||
- ``sympy.codegen.cfunctions``: functions specific to C (C99 math functions)
|
||||
- ``sympy.codegen.ffunctions``: functions specific to Fortran (e.g. ``kind``).
|
||||
|
||||
|
||||
|
||||
"""
|
||||
from .ast import (
|
||||
Assignment, aug_assign, CodeBlock, For, Attribute, Variable, Declaration,
|
||||
While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'Assignment', 'aug_assign', 'CodeBlock', 'For', 'Attribute', 'Variable',
|
||||
'Declaration', 'While', 'Scope', 'Print', 'FunctionPrototype',
|
||||
'FunctionDefinition', 'FunctionCall',
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,18 @@
|
||||
"""This module provides containers for python objects that are valid
|
||||
printing targets but are not a subclass of SymPy's Printable.
|
||||
"""
|
||||
|
||||
|
||||
from sympy.core.containers import Tuple
|
||||
|
||||
|
||||
class List(Tuple):
|
||||
"""Represents a (frozen) (Python) list (for code printing purposes)."""
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, list):
|
||||
return self == List(*other)
|
||||
else:
|
||||
return self.args == other
|
||||
|
||||
def __hash__(self):
|
||||
return super().__hash__()
|
||||
180
venv/lib/python3.12/site-packages/sympy/codegen/algorithms.py
Normal file
180
venv/lib/python3.12/site-packages/sympy/codegen/algorithms.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.numbers import oo
|
||||
from sympy.core.relational import (Gt, Lt)
|
||||
from sympy.core.symbol import (Dummy, Symbol)
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.miscellaneous import Min, Max
|
||||
from sympy.logic.boolalg import And
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, AddAugmentedAssignment, break_, CodeBlock, Declaration, FunctionDefinition,
|
||||
Print, Return, Scope, While, Variable, Pointer, real
|
||||
)
|
||||
from sympy.codegen.cfunctions import isnan
|
||||
|
||||
""" This module collects functions for constructing ASTs representing algorithms. """
|
||||
|
||||
def newtons_method(expr, wrt, atol=1e-12, delta=None, *, rtol=4e-16, debug=False,
|
||||
itermax=None, counter=None, delta_fn=lambda e, x: -e/e.diff(x),
|
||||
cse=False, handle_nan=None,
|
||||
bounds=None):
|
||||
""" Generates an AST for Newton-Raphson method (a root-finding algorithm).
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's
|
||||
method of root-finding.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : expression
|
||||
wrt : Symbol
|
||||
With respect to, i.e. what is the variable.
|
||||
atol : number or expression
|
||||
Absolute tolerance (stopping criterion)
|
||||
rtol : number or expression
|
||||
Relative tolerance (stopping criterion)
|
||||
delta : Symbol
|
||||
Will be a ``Dummy`` if ``None``.
|
||||
debug : bool
|
||||
Whether to print convergence information during iterations
|
||||
itermax : number or expr
|
||||
Maximum number of iterations.
|
||||
counter : Symbol
|
||||
Will be a ``Dummy`` if ``None``.
|
||||
delta_fn: Callable[[Expr, Symbol], Expr]
|
||||
computes the step, default is newtons method. For e.g. Halley's method
|
||||
use delta_fn=lambda e, x: -2*e*e.diff(x)/(2*e.diff(x)**2 - e*e.diff(x, 2))
|
||||
cse: bool
|
||||
Perform common sub-expression elimination on delta expression
|
||||
handle_nan: Token
|
||||
How to handle occurrence of not-a-number (NaN).
|
||||
bounds: Optional[tuple[Expr, Expr]]
|
||||
Perform optimization within bounds
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, cos
|
||||
>>> from sympy.codegen.ast import Assignment
|
||||
>>> from sympy.codegen.algorithms import newtons_method
|
||||
>>> x, dx, atol = symbols('x dx atol')
|
||||
>>> expr = cos(x) - x**3
|
||||
>>> algo = newtons_method(expr, x, atol=atol, delta=dx)
|
||||
>>> algo.has(Assignment(dx, -expr/expr.diff(x)))
|
||||
True
|
||||
|
||||
References
|
||||
==========
|
||||
|
||||
.. [1] https://en.wikipedia.org/wiki/Newton%27s_method
|
||||
|
||||
"""
|
||||
|
||||
if delta is None:
|
||||
delta = Dummy()
|
||||
Wrapper = Scope
|
||||
name_d = 'delta'
|
||||
else:
|
||||
Wrapper = lambda x: x
|
||||
name_d = delta.name
|
||||
|
||||
delta_expr = delta_fn(expr, wrt)
|
||||
if cse:
|
||||
from sympy.simplify.cse_main import cse
|
||||
cses, (red,) = cse([delta_expr.factor()])
|
||||
whl_bdy = [Assignment(dum, sub_e) for dum, sub_e in cses]
|
||||
whl_bdy += [Assignment(delta, red)]
|
||||
else:
|
||||
whl_bdy = [Assignment(delta, delta_expr)]
|
||||
if handle_nan is not None:
|
||||
whl_bdy += [While(isnan(delta), CodeBlock(handle_nan, break_))]
|
||||
whl_bdy += [AddAugmentedAssignment(wrt, delta)]
|
||||
if bounds is not None:
|
||||
whl_bdy += [Assignment(wrt, Min(Max(wrt, bounds[0]), bounds[1]))]
|
||||
if debug:
|
||||
prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d))
|
||||
whl_bdy += [prnt]
|
||||
req = Gt(Abs(delta), atol + rtol*Abs(wrt))
|
||||
declars = [Declaration(Variable(delta, type=real, value=oo))]
|
||||
if itermax is not None:
|
||||
counter = counter or Dummy(integer=True)
|
||||
v_counter = Variable.deduced(counter, 0)
|
||||
declars.append(Declaration(v_counter))
|
||||
whl_bdy.append(AddAugmentedAssignment(counter, 1))
|
||||
req = And(req, Lt(counter, itermax))
|
||||
whl = While(req, CodeBlock(*whl_bdy))
|
||||
blck = declars
|
||||
if debug:
|
||||
blck.append(Print([wrt], r"{}=%12.5g\n".format(wrt.name)))
|
||||
blck += [whl]
|
||||
return Wrapper(CodeBlock(*blck))
|
||||
|
||||
|
||||
def _symbol_of(arg):
|
||||
if isinstance(arg, Declaration):
|
||||
arg = arg.variable.symbol
|
||||
elif isinstance(arg, Variable):
|
||||
arg = arg.symbol
|
||||
return arg
|
||||
|
||||
|
||||
def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs):
|
||||
""" Generates an AST for a function implementing the Newton-Raphson method.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : expression
|
||||
wrt : Symbol
|
||||
With respect to, i.e. what is the variable
|
||||
params : iterable of symbols
|
||||
Symbols appearing in expr that are taken as constants during the iterations
|
||||
(these will be accepted as parameters to the generated function).
|
||||
func_name : str
|
||||
Name of the generated function.
|
||||
attrs : Tuple
|
||||
Attribute instances passed as ``attrs`` to ``FunctionDefinition``.
|
||||
\\*\\*kwargs :
|
||||
Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, cos
|
||||
>>> from sympy.codegen.algorithms import newtons_method_function
|
||||
>>> from sympy.codegen.pyutils import render_as_module
|
||||
>>> x = symbols('x')
|
||||
>>> expr = cos(x) - x**3
|
||||
>>> func = newtons_method_function(expr, x)
|
||||
>>> py_mod = render_as_module(func) # source code as string
|
||||
>>> namespace = {}
|
||||
>>> exec(py_mod, namespace, namespace)
|
||||
>>> res = eval('newton(0.5)', namespace)
|
||||
>>> abs(res - 0.865474033102) < 1e-12
|
||||
True
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
sympy.codegen.algorithms.newtons_method
|
||||
|
||||
"""
|
||||
if params is None:
|
||||
params = (wrt,)
|
||||
pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name)
|
||||
for p in params if isinstance(p, Pointer)}
|
||||
if delta is None:
|
||||
delta = Symbol('d_' + wrt.name)
|
||||
if expr.has(delta):
|
||||
delta = None # will use Dummy
|
||||
algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs)
|
||||
if isinstance(algo, Scope):
|
||||
algo = algo.body
|
||||
not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params})
|
||||
if not_in_params:
|
||||
raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params)))
|
||||
declars = tuple(Variable(p, real) for p in params)
|
||||
body = CodeBlock(algo, Return(wrt))
|
||||
return FunctionDefinition(real, func_name, declars, body, attrs=attrs)
|
||||
@@ -0,0 +1,187 @@
|
||||
import math
|
||||
from sympy.sets.sets import Interval
|
||||
from sympy.calculus.singularities import is_increasing, is_decreasing
|
||||
from sympy.codegen.rewriting import Optimization
|
||||
from sympy.core.function import UndefinedFunction
|
||||
|
||||
"""
|
||||
This module collects classes useful for approximate rewriting of expressions.
|
||||
This can be beneficial when generating numeric code for which performance is
|
||||
of greater importance than precision (e.g. for preconditioners used in iterative
|
||||
methods).
|
||||
"""
|
||||
|
||||
class SumApprox(Optimization):
|
||||
"""
|
||||
Approximates sum by neglecting small terms.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
If terms are expressions which can be determined to be monotonic, then
|
||||
bounds for those expressions are added.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
bounds : dict
|
||||
Mapping expressions to length 2 tuple of bounds (low, high).
|
||||
reltol : number
|
||||
Threshold for when to ignore a term. Taken relative to the largest
|
||||
lower bound among bounds.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import exp
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> from sympy.codegen.rewriting import optimize
|
||||
>>> from sympy.codegen.approximations import SumApprox
|
||||
>>> bounds = {x: (-1, 1), y: (1000, 2000), z: (-10, 3)}
|
||||
>>> sum_approx3 = SumApprox(bounds, reltol=1e-3)
|
||||
>>> sum_approx2 = SumApprox(bounds, reltol=1e-2)
|
||||
>>> sum_approx1 = SumApprox(bounds, reltol=1e-1)
|
||||
>>> expr = 3*(x + y + exp(z))
|
||||
>>> optimize(expr, [sum_approx3])
|
||||
3*(x + y + exp(z))
|
||||
>>> optimize(expr, [sum_approx2])
|
||||
3*y + 3*exp(z)
|
||||
>>> optimize(expr, [sum_approx1])
|
||||
3*y
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, bounds, reltol, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.bounds = bounds
|
||||
self.reltol = reltol
|
||||
|
||||
def __call__(self, expr):
|
||||
return expr.factor().replace(self.query, lambda arg: self.value(arg))
|
||||
|
||||
def query(self, expr):
|
||||
return expr.is_Add
|
||||
|
||||
def value(self, add):
|
||||
for term in add.args:
|
||||
if term.is_number or term in self.bounds or len(term.free_symbols) != 1:
|
||||
continue
|
||||
fs, = term.free_symbols
|
||||
if fs not in self.bounds:
|
||||
continue
|
||||
intrvl = Interval(*self.bounds[fs])
|
||||
if is_increasing(term, intrvl, fs):
|
||||
self.bounds[term] = (
|
||||
term.subs({fs: self.bounds[fs][0]}),
|
||||
term.subs({fs: self.bounds[fs][1]})
|
||||
)
|
||||
elif is_decreasing(term, intrvl, fs):
|
||||
self.bounds[term] = (
|
||||
term.subs({fs: self.bounds[fs][1]}),
|
||||
term.subs({fs: self.bounds[fs][0]})
|
||||
)
|
||||
else:
|
||||
return add
|
||||
|
||||
if all(term.is_number or term in self.bounds for term in add.args):
|
||||
bounds = [(term, term) if term.is_number else self.bounds[term] for term in add.args]
|
||||
largest_abs_guarantee = 0
|
||||
for lo, hi in bounds:
|
||||
if lo <= 0 <= hi:
|
||||
continue
|
||||
largest_abs_guarantee = max(largest_abs_guarantee,
|
||||
min(abs(lo), abs(hi)))
|
||||
new_terms = []
|
||||
for term, (lo, hi) in zip(add.args, bounds):
|
||||
if max(abs(lo), abs(hi)) >= largest_abs_guarantee*self.reltol:
|
||||
new_terms.append(term)
|
||||
return add.func(*new_terms)
|
||||
else:
|
||||
return add
|
||||
|
||||
|
||||
class SeriesApprox(Optimization):
|
||||
""" Approximates functions by expanding them as a series.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
bounds : dict
|
||||
Mapping expressions to length 2 tuple of bounds (low, high).
|
||||
reltol : number
|
||||
Threshold for when to ignore a term. Taken relative to the largest
|
||||
lower bound among bounds.
|
||||
max_order : int
|
||||
Largest order to include in series expansion
|
||||
n_point_checks : int (even)
|
||||
The validity of an expansion (with respect to reltol) is checked at
|
||||
discrete points (linearly spaced over the bounds of the variable). The
|
||||
number of points used in this numerical check is given by this number.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import sin, pi
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy.codegen.rewriting import optimize
|
||||
>>> from sympy.codegen.approximations import SeriesApprox
|
||||
>>> bounds = {x: (-.1, .1), y: (pi-1, pi+1)}
|
||||
>>> series_approx2 = SeriesApprox(bounds, reltol=1e-2)
|
||||
>>> series_approx3 = SeriesApprox(bounds, reltol=1e-3)
|
||||
>>> series_approx8 = SeriesApprox(bounds, reltol=1e-8)
|
||||
>>> expr = sin(x)*sin(y)
|
||||
>>> optimize(expr, [series_approx2])
|
||||
x*(-y + (y - pi)**3/6 + pi)
|
||||
>>> optimize(expr, [series_approx3])
|
||||
(-x**3/6 + x)*sin(y)
|
||||
>>> optimize(expr, [series_approx8])
|
||||
sin(x)*sin(y)
|
||||
|
||||
"""
|
||||
def __init__(self, bounds, reltol, max_order=4, n_point_checks=4, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.bounds = bounds
|
||||
self.reltol = reltol
|
||||
self.max_order = max_order
|
||||
if n_point_checks % 2 == 1:
|
||||
raise ValueError("Checking the solution at expansion point is not helpful")
|
||||
self.n_point_checks = n_point_checks
|
||||
self._prec = math.ceil(-math.log10(self.reltol))
|
||||
|
||||
def __call__(self, expr):
|
||||
return expr.factor().replace(self.query, lambda arg: self.value(arg))
|
||||
|
||||
def query(self, expr):
|
||||
return (expr.is_Function and not isinstance(expr, UndefinedFunction)
|
||||
and len(expr.args) == 1)
|
||||
|
||||
def value(self, fexpr):
|
||||
free_symbols = fexpr.free_symbols
|
||||
if len(free_symbols) != 1:
|
||||
return fexpr
|
||||
symb, = free_symbols
|
||||
if symb not in self.bounds:
|
||||
return fexpr
|
||||
lo, hi = self.bounds[symb]
|
||||
x0 = (lo + hi)/2
|
||||
cheapest = None
|
||||
for n in range(self.max_order+1, 0, -1):
|
||||
fseri = fexpr.series(symb, x0=x0, n=n).removeO()
|
||||
n_ok = True
|
||||
for idx in range(self.n_point_checks):
|
||||
x = lo + idx*(hi - lo)/(self.n_point_checks - 1)
|
||||
val = fseri.xreplace({symb: x})
|
||||
ref = fexpr.xreplace({symb: x})
|
||||
if abs((1 - val/ref).evalf(self._prec)) > self.reltol:
|
||||
n_ok = False
|
||||
break
|
||||
|
||||
if n_ok:
|
||||
cheapest = fseri
|
||||
else:
|
||||
break
|
||||
|
||||
if cheapest is None:
|
||||
return fexpr
|
||||
else:
|
||||
return cheapest
|
||||
1906
venv/lib/python3.12/site-packages/sympy/codegen/ast.py
Normal file
1906
venv/lib/python3.12/site-packages/sympy/codegen/ast.py
Normal file
File diff suppressed because it is too large
Load Diff
558
venv/lib/python3.12/site-packages/sympy/codegen/cfunctions.py
Normal file
558
venv/lib/python3.12/site-packages/sympy/codegen/cfunctions.py
Normal file
@@ -0,0 +1,558 @@
|
||||
"""
|
||||
This module contains SymPy functions mathcin corresponding to special math functions in the
|
||||
C standard library (since C99, also available in C++11).
|
||||
|
||||
The functions defined in this module allows the user to express functions such as ``expm1``
|
||||
as a SymPy function for symbolic manipulation.
|
||||
|
||||
"""
|
||||
from sympy.core.function import ArgumentIndexError, Function
|
||||
from sympy.core.numbers import Rational
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.miscellaneous import sqrt
|
||||
from sympy.logic.boolalg import BooleanFunction, true, false
|
||||
|
||||
def _expm1(x):
|
||||
return exp(x) - S.One
|
||||
|
||||
|
||||
class expm1(Function):
|
||||
"""
|
||||
Represents the exponential function minus one.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``expm1(x)`` over ``exp(x) - 1``
|
||||
is that the latter is prone to cancellation under finite precision
|
||||
arithmetic when x is close to zero.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import expm1
|
||||
>>> '%.0e' % expm1(1e-99).evalf()
|
||||
'1e-99'
|
||||
>>> from math import exp
|
||||
>>> exp(1e-99) - 1
|
||||
0.0
|
||||
>>> expm1(x).diff(x)
|
||||
exp(x)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
log1p
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return exp(*self.args)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _expm1(*self.args)
|
||||
|
||||
def _eval_rewrite_as_exp(self, arg, **kwargs):
|
||||
return exp(arg) - S.One
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_exp
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
exp_arg = exp.eval(arg)
|
||||
if exp_arg is not None:
|
||||
return exp_arg - S.One
|
||||
|
||||
def _eval_is_real(self):
|
||||
return self.args[0].is_real
|
||||
|
||||
def _eval_is_finite(self):
|
||||
return self.args[0].is_finite
|
||||
|
||||
|
||||
def _log1p(x):
|
||||
return log(x + S.One)
|
||||
|
||||
|
||||
class log1p(Function):
|
||||
"""
|
||||
Represents the natural logarithm of a number plus one.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``log1p(x)`` over ``log(x + 1)``
|
||||
is that the latter is prone to cancellation under finite precision
|
||||
arithmetic when x is close to zero.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import log1p
|
||||
>>> from sympy import expand_log
|
||||
>>> '%.0e' % expand_log(log1p(1e-99)).evalf()
|
||||
'1e-99'
|
||||
>>> from math import log
|
||||
>>> log(1 + 1e-99)
|
||||
0.0
|
||||
>>> log1p(x).diff(x)
|
||||
1/(x + 1)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
expm1
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return S.One/(self.args[0] + S.One)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _log1p(*self.args)
|
||||
|
||||
def _eval_rewrite_as_log(self, arg, **kwargs):
|
||||
return _log1p(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_log
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_Rational:
|
||||
return log(arg + S.One)
|
||||
elif not arg.is_Float: # not safe to add 1 to Float
|
||||
return log.eval(arg + S.One)
|
||||
elif arg.is_number:
|
||||
return log(Rational(arg) + S.One)
|
||||
|
||||
def _eval_is_real(self):
|
||||
return (self.args[0] + S.One).is_nonnegative
|
||||
|
||||
def _eval_is_finite(self):
|
||||
if (self.args[0] + S.One).is_zero:
|
||||
return False
|
||||
return self.args[0].is_finite
|
||||
|
||||
def _eval_is_positive(self):
|
||||
return self.args[0].is_positive
|
||||
|
||||
def _eval_is_zero(self):
|
||||
return self.args[0].is_zero
|
||||
|
||||
def _eval_is_nonnegative(self):
|
||||
return self.args[0].is_nonnegative
|
||||
|
||||
_Two = S(2)
|
||||
|
||||
def _exp2(x):
|
||||
return Pow(_Two, x)
|
||||
|
||||
class exp2(Function):
|
||||
"""
|
||||
Represents the exponential function with base two.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``exp2(x)`` over ``2**x``
|
||||
is that the latter is not as efficient under finite precision
|
||||
arithmetic.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import exp2
|
||||
>>> exp2(2).evalf() == 4.0
|
||||
True
|
||||
>>> exp2(x).diff(x)
|
||||
log(2)*exp2(x)
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
log2
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return self*log(_Two)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _exp2(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _exp2(*self.args)
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_number:
|
||||
return _exp2(arg)
|
||||
|
||||
|
||||
def _log2(x):
|
||||
return log(x)/log(_Two)
|
||||
|
||||
|
||||
class log2(Function):
|
||||
"""
|
||||
Represents the logarithm function with base two.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``log2(x)`` over ``log(x)/log(2)``
|
||||
is that the latter is not as efficient under finite precision
|
||||
arithmetic.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import log2
|
||||
>>> log2(4).evalf() == 2.0
|
||||
True
|
||||
>>> log2(x).diff(x)
|
||||
1/(x*log(2))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
exp2
|
||||
log10
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return S.One/(log(_Two)*self.args[0])
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_number:
|
||||
result = log.eval(arg, base=_Two)
|
||||
if result.is_Atom:
|
||||
return result
|
||||
elif arg.is_Pow and arg.base == _Two:
|
||||
return arg.exp
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(log).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _log2(*self.args)
|
||||
|
||||
def _eval_rewrite_as_log(self, arg, **kwargs):
|
||||
return _log2(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_log
|
||||
|
||||
|
||||
def _fma(x, y, z):
|
||||
return x*y + z
|
||||
|
||||
|
||||
class fma(Function):
|
||||
"""
|
||||
Represents "fused multiply add".
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The benefit of using ``fma(x, y, z)`` over ``x*y + z``
|
||||
is that, under finite precision arithmetic, the former is
|
||||
supported by special instructions on some CPUs.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x, y, z
|
||||
>>> from sympy.codegen.cfunctions import fma
|
||||
>>> fma(x, y, z).diff(x)
|
||||
y
|
||||
|
||||
"""
|
||||
nargs = 3
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex in (1, 2):
|
||||
return self.args[2 - argindex]
|
||||
elif argindex == 3:
|
||||
return S.One
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _fma(*self.args)
|
||||
|
||||
def _eval_rewrite_as_tractable(self, arg, limitvar=None, **kwargs):
|
||||
return _fma(arg)
|
||||
|
||||
|
||||
_Ten = S(10)
|
||||
|
||||
|
||||
def _log10(x):
|
||||
return log(x)/log(_Ten)
|
||||
|
||||
|
||||
class log10(Function):
|
||||
"""
|
||||
Represents the logarithm function with base ten.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import log10
|
||||
>>> log10(100).evalf() == 2.0
|
||||
True
|
||||
>>> log10(x).diff(x)
|
||||
1/(x*log(10))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
log2
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return S.One/(log(_Ten)*self.args[0])
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_number:
|
||||
result = log.eval(arg, base=_Ten)
|
||||
if result.is_Atom:
|
||||
return result
|
||||
elif arg.is_Pow and arg.base == _Ten:
|
||||
return arg.exp
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _log10(*self.args)
|
||||
|
||||
def _eval_rewrite_as_log(self, arg, **kwargs):
|
||||
return _log10(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_log
|
||||
|
||||
|
||||
def _Sqrt(x):
|
||||
return Pow(x, S.Half)
|
||||
|
||||
|
||||
class Sqrt(Function): # 'sqrt' already defined in sympy.functions.elementary.miscellaneous
|
||||
"""
|
||||
Represents the square root function.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The reason why one would use ``Sqrt(x)`` over ``sqrt(x)``
|
||||
is that the latter is internally represented as ``Pow(x, S.Half)`` which
|
||||
may not be what one wants when doing code-generation.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import Sqrt
|
||||
>>> Sqrt(x)
|
||||
Sqrt(x)
|
||||
>>> Sqrt(x).diff(x)
|
||||
1/(2*sqrt(x))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
Cbrt
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return Pow(self.args[0], Rational(-1, 2))/_Two
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _Sqrt(*self.args)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _Sqrt(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
|
||||
def _Cbrt(x):
|
||||
return Pow(x, Rational(1, 3))
|
||||
|
||||
|
||||
class Cbrt(Function): # 'cbrt' already defined in sympy.functions.elementary.miscellaneous
|
||||
"""
|
||||
Represents the cube root function.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The reason why one would use ``Cbrt(x)`` over ``cbrt(x)``
|
||||
is that the latter is internally represented as ``Pow(x, Rational(1, 3))`` which
|
||||
may not be what one wants when doing code-generation.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cfunctions import Cbrt
|
||||
>>> Cbrt(x)
|
||||
Cbrt(x)
|
||||
>>> Cbrt(x).diff(x)
|
||||
1/(3*x**(2/3))
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
Sqrt
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return Pow(self.args[0], Rational(-_Two/3))/3
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _Cbrt(*self.args)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _Cbrt(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
|
||||
def _hypot(x, y):
|
||||
return sqrt(Pow(x, 2) + Pow(y, 2))
|
||||
|
||||
|
||||
class hypot(Function):
|
||||
"""
|
||||
Represents the hypotenuse function.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The hypotenuse function is provided by e.g. the math library
|
||||
in the C99 standard, hence one may want to represent the function
|
||||
symbolically when doing code-generation.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x, y
|
||||
>>> from sympy.codegen.cfunctions import hypot
|
||||
>>> hypot(3, 4).evalf() == 5.0
|
||||
True
|
||||
>>> hypot(x, y)
|
||||
hypot(x, y)
|
||||
>>> hypot(x, y).diff(x)
|
||||
x/hypot(x, y)
|
||||
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex in (1, 2):
|
||||
return 2*self.args[argindex-1]/(_Two*self.func(*self.args))
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
|
||||
def _eval_expand_func(self, **hints):
|
||||
return _hypot(*self.args)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, arg, **kwargs):
|
||||
return _hypot(arg)
|
||||
|
||||
_eval_rewrite_as_tractable = _eval_rewrite_as_Pow
|
||||
|
||||
|
||||
class isnan(BooleanFunction):
|
||||
nargs = 1
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg is S.NaN:
|
||||
return true
|
||||
elif arg.is_number:
|
||||
return false
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class isinf(BooleanFunction):
|
||||
nargs = 1
|
||||
|
||||
@classmethod
|
||||
def eval(cls, arg):
|
||||
if arg.is_infinite:
|
||||
return true
|
||||
elif arg.is_finite:
|
||||
return false
|
||||
else:
|
||||
return None
|
||||
156
venv/lib/python3.12/site-packages/sympy/codegen/cnodes.py
Normal file
156
venv/lib/python3.12/site-packages/sympy/codegen/cnodes.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
AST nodes specific to the C family of languages
|
||||
"""
|
||||
|
||||
from sympy.codegen.ast import (
|
||||
Attribute, Declaration, Node, String, Token, Type, none,
|
||||
FunctionCall, CodeBlock
|
||||
)
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.sympify import sympify
|
||||
|
||||
void = Type('void')
|
||||
|
||||
restrict = Attribute('restrict') # guarantees no pointer aliasing
|
||||
volatile = Attribute('volatile')
|
||||
static = Attribute('static')
|
||||
|
||||
|
||||
def alignof(arg):
|
||||
""" Generate of FunctionCall instance for calling 'alignof' """
|
||||
return FunctionCall('alignof', [String(arg) if isinstance(arg, str) else arg])
|
||||
|
||||
|
||||
def sizeof(arg):
|
||||
""" Generate of FunctionCall instance for calling 'sizeof'
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.ast import real
|
||||
>>> from sympy.codegen.cnodes import sizeof
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(sizeof(real))
|
||||
'sizeof(double)'
|
||||
"""
|
||||
return FunctionCall('sizeof', [String(arg) if isinstance(arg, str) else arg])
|
||||
|
||||
|
||||
class CommaOperator(Basic):
|
||||
""" Represents the comma operator in C """
|
||||
def __new__(cls, *args):
|
||||
return Basic.__new__(cls, *[sympify(arg) for arg in args])
|
||||
|
||||
|
||||
class Label(Node):
|
||||
""" Label for use with e.g. goto statement.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import ccode, Symbol
|
||||
>>> from sympy.codegen.cnodes import Label, PreIncrement
|
||||
>>> print(ccode(Label('foo')))
|
||||
foo:
|
||||
>>> print(ccode(Label('bar', [PreIncrement(Symbol('a'))])))
|
||||
bar:
|
||||
++(a);
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'body')
|
||||
defaults = {'body': none}
|
||||
_construct_name = String
|
||||
|
||||
@classmethod
|
||||
def _construct_body(cls, itr):
|
||||
if isinstance(itr, CodeBlock):
|
||||
return itr
|
||||
else:
|
||||
return CodeBlock(*itr)
|
||||
|
||||
|
||||
class goto(Token):
|
||||
""" Represents goto in C """
|
||||
__slots__ = _fields = ('label',)
|
||||
_construct_label = Label
|
||||
|
||||
|
||||
class PreDecrement(Basic):
|
||||
""" Represents the pre-decrement operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PreDecrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PreDecrement(x))
|
||||
'--(x)'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class PostDecrement(Basic):
|
||||
""" Represents the post-decrement operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PostDecrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PostDecrement(x))
|
||||
'(x)--'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class PreIncrement(Basic):
|
||||
""" Represents the pre-increment operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PreIncrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PreIncrement(x))
|
||||
'++(x)'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class PostIncrement(Basic):
|
||||
""" Represents the post-increment operator
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.abc import x
|
||||
>>> from sympy.codegen.cnodes import PostIncrement
|
||||
>>> from sympy import ccode
|
||||
>>> ccode(PostIncrement(x))
|
||||
'(x)++'
|
||||
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
|
||||
class struct(Node):
|
||||
""" Represents a struct in C """
|
||||
__slots__ = _fields = ('name', 'declarations')
|
||||
defaults = {'name': none}
|
||||
_construct_name = String
|
||||
|
||||
@classmethod
|
||||
def _construct_declarations(cls, args):
|
||||
return Tuple(*[Declaration(arg) for arg in args])
|
||||
|
||||
|
||||
class union(struct):
|
||||
""" Represents a union in C """
|
||||
__slots__ = ()
|
||||
@@ -0,0 +1,8 @@
|
||||
from sympy.printing.c import C99CodePrinter
|
||||
|
||||
def render_as_source_file(content, Printer=C99CodePrinter, settings=None):
|
||||
""" Renders a C source file (with required #include statements) """
|
||||
printer = Printer(settings or {})
|
||||
code_str = printer.doprint(content)
|
||||
includes = '\n'.join(['#include <%s>' % h for h in printer.headers])
|
||||
return includes + '\n\n' + code_str
|
||||
14
venv/lib/python3.12/site-packages/sympy/codegen/cxxnodes.py
Normal file
14
venv/lib/python3.12/site-packages/sympy/codegen/cxxnodes.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""
|
||||
AST nodes specific to C++.
|
||||
"""
|
||||
|
||||
from sympy.codegen.ast import Attribute, String, Token, Type, none
|
||||
|
||||
class using(Token):
|
||||
""" Represents a 'using' statement in C++ """
|
||||
__slots__ = _fields = ('type', 'alias')
|
||||
defaults = {'alias': none}
|
||||
_construct_type = Type
|
||||
_construct_alias = String
|
||||
|
||||
constexpr = Attribute('constexpr')
|
||||
658
venv/lib/python3.12/site-packages/sympy/codegen/fnodes.py
Normal file
658
venv/lib/python3.12/site-packages/sympy/codegen/fnodes.py
Normal file
@@ -0,0 +1,658 @@
|
||||
"""
|
||||
AST nodes specific to Fortran.
|
||||
|
||||
The functions defined in this module allows the user to express functions such as ``dsign``
|
||||
as a SymPy function for symbolic manipulation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from sympy.codegen.ast import (
|
||||
Attribute, CodeBlock, FunctionCall, Node, none, String,
|
||||
Token, _mk_Tuple, Variable
|
||||
)
|
||||
from sympy.core.basic import Basic
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.expr import Expr
|
||||
from sympy.core.function import Function
|
||||
from sympy.core.numbers import Float, Integer
|
||||
from sympy.core.symbol import Str
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.logic import true, false
|
||||
from sympy.utilities.iterables import iterable
|
||||
|
||||
|
||||
|
||||
pure = Attribute('pure')
|
||||
elemental = Attribute('elemental') # (all elemental procedures are also pure)
|
||||
|
||||
intent_in = Attribute('intent_in')
|
||||
intent_out = Attribute('intent_out')
|
||||
intent_inout = Attribute('intent_inout')
|
||||
|
||||
allocatable = Attribute('allocatable')
|
||||
|
||||
class Program(Token):
|
||||
""" Represents a 'program' block in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.ast import Print
|
||||
>>> from sympy.codegen.fnodes import Program
|
||||
>>> prog = Program('myprogram', [Print([42])])
|
||||
>>> from sympy import fcode
|
||||
>>> print(fcode(prog, source_format='free'))
|
||||
program myprogram
|
||||
print *, 42
|
||||
end program
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'body')
|
||||
_construct_name = String
|
||||
_construct_body = staticmethod(lambda body: CodeBlock(*body))
|
||||
|
||||
|
||||
class use_rename(Token):
|
||||
""" Represents a renaming in a use statement in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import use_rename, use
|
||||
>>> from sympy import fcode
|
||||
>>> ren = use_rename("thingy", "convolution2d")
|
||||
>>> print(fcode(ren, source_format='free'))
|
||||
thingy => convolution2d
|
||||
>>> full = use('signallib', only=['snr', ren])
|
||||
>>> print(fcode(full, source_format='free'))
|
||||
use signallib, only: snr, thingy => convolution2d
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('local', 'original')
|
||||
_construct_local = String
|
||||
_construct_original = String
|
||||
|
||||
def _name(arg):
|
||||
if hasattr(arg, 'name'):
|
||||
return arg.name
|
||||
else:
|
||||
return String(arg)
|
||||
|
||||
class use(Token):
|
||||
""" Represents a use statement in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import use
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(use('signallib'), source_format='free')
|
||||
'use signallib'
|
||||
>>> fcode(use('signallib', [('metric', 'snr')]), source_format='free')
|
||||
'use signallib, metric => snr'
|
||||
>>> fcode(use('signallib', only=['snr', 'convolution2d']), source_format='free')
|
||||
'use signallib, only: snr, convolution2d'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('namespace', 'rename', 'only')
|
||||
defaults = {'rename': none, 'only': none}
|
||||
_construct_namespace = staticmethod(_name)
|
||||
_construct_rename = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else use_rename(*arg) for arg in args]))
|
||||
_construct_only = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else _name(arg) for arg in args]))
|
||||
|
||||
|
||||
class Module(Token):
|
||||
""" Represents a module in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import Module
|
||||
>>> from sympy import fcode
|
||||
>>> print(fcode(Module('signallib', ['implicit none'], []), source_format='free'))
|
||||
module signallib
|
||||
implicit none
|
||||
<BLANKLINE>
|
||||
contains
|
||||
<BLANKLINE>
|
||||
<BLANKLINE>
|
||||
end module
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'declarations', 'definitions')
|
||||
defaults = {'declarations': Tuple()}
|
||||
_construct_name = String
|
||||
|
||||
@classmethod
|
||||
def _construct_declarations(cls, args):
|
||||
args = [Str(arg) if isinstance(arg, str) else arg for arg in args]
|
||||
return CodeBlock(*args)
|
||||
|
||||
_construct_definitions = staticmethod(lambda arg: CodeBlock(*arg))
|
||||
|
||||
|
||||
class Subroutine(Node):
|
||||
""" Represents a subroutine in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, symbols
|
||||
>>> from sympy.codegen.ast import Print
|
||||
>>> from sympy.codegen.fnodes import Subroutine
|
||||
>>> x, y = symbols('x y', real=True)
|
||||
>>> sub = Subroutine('mysub', [x, y], [Print([x**2 + y**2, x*y])])
|
||||
>>> print(fcode(sub, source_format='free', standard=2003))
|
||||
subroutine mysub(x, y)
|
||||
real*8 :: x
|
||||
real*8 :: y
|
||||
print *, x**2 + y**2, x*y
|
||||
end subroutine
|
||||
|
||||
"""
|
||||
__slots__ = ('name', 'parameters', 'body')
|
||||
_fields = __slots__ + Node._fields
|
||||
_construct_name = String
|
||||
_construct_parameters = staticmethod(lambda params: Tuple(*map(Variable.deduced, params)))
|
||||
|
||||
@classmethod
|
||||
def _construct_body(cls, itr):
|
||||
if isinstance(itr, CodeBlock):
|
||||
return itr
|
||||
else:
|
||||
return CodeBlock(*itr)
|
||||
|
||||
class SubroutineCall(Token):
|
||||
""" Represents a call to a subroutine in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import SubroutineCall
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(SubroutineCall('mysub', 'x y'.split()))
|
||||
' call mysub(x, y)'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('name', 'subroutine_args')
|
||||
_construct_name = staticmethod(_name)
|
||||
_construct_subroutine_args = staticmethod(_mk_Tuple)
|
||||
|
||||
|
||||
class Do(Token):
|
||||
""" Represents a Do loop in in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, symbols
|
||||
>>> from sympy.codegen.ast import aug_assign, Print
|
||||
>>> from sympy.codegen.fnodes import Do
|
||||
>>> i, n = symbols('i n', integer=True)
|
||||
>>> r = symbols('r', real=True)
|
||||
>>> body = [aug_assign(r, '+', 1/i), Print([i, r])]
|
||||
>>> do1 = Do(body, i, 1, n)
|
||||
>>> print(fcode(do1, source_format='free'))
|
||||
do i = 1, n
|
||||
r = r + 1d0/i
|
||||
print *, i, r
|
||||
end do
|
||||
>>> do2 = Do(body, i, 1, n, 2)
|
||||
>>> print(fcode(do2, source_format='free'))
|
||||
do i = 1, n, 2
|
||||
r = r + 1d0/i
|
||||
print *, i, r
|
||||
end do
|
||||
|
||||
"""
|
||||
|
||||
__slots__ = _fields = ('body', 'counter', 'first', 'last', 'step', 'concurrent')
|
||||
defaults = {'step': Integer(1), 'concurrent': false}
|
||||
_construct_body = staticmethod(lambda body: CodeBlock(*body))
|
||||
_construct_counter = staticmethod(sympify)
|
||||
_construct_first = staticmethod(sympify)
|
||||
_construct_last = staticmethod(sympify)
|
||||
_construct_step = staticmethod(sympify)
|
||||
_construct_concurrent = staticmethod(lambda arg: true if arg else false)
|
||||
|
||||
|
||||
class ArrayConstructor(Token):
|
||||
""" Represents an array constructor.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import ArrayConstructor
|
||||
>>> ac = ArrayConstructor([1, 2, 3])
|
||||
>>> fcode(ac, standard=95, source_format='free')
|
||||
'(/1, 2, 3/)'
|
||||
>>> fcode(ac, standard=2003, source_format='free')
|
||||
'[1, 2, 3]'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('elements',)
|
||||
_construct_elements = staticmethod(_mk_Tuple)
|
||||
|
||||
|
||||
class ImpliedDoLoop(Token):
|
||||
""" Represents an implied do loop in Fortran.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol, fcode
|
||||
>>> from sympy.codegen.fnodes import ImpliedDoLoop, ArrayConstructor
|
||||
>>> i = Symbol('i', integer=True)
|
||||
>>> idl = ImpliedDoLoop(i**3, i, -3, 3, 2) # -27, -1, 1, 27
|
||||
>>> ac = ArrayConstructor([-28, idl, 28]) # -28, -27, -1, 1, 27, 28
|
||||
>>> fcode(ac, standard=2003, source_format='free')
|
||||
'[-28, (i**3, i = -3, 3, 2), 28]'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('expr', 'counter', 'first', 'last', 'step')
|
||||
defaults = {'step': Integer(1)}
|
||||
_construct_expr = staticmethod(sympify)
|
||||
_construct_counter = staticmethod(sympify)
|
||||
_construct_first = staticmethod(sympify)
|
||||
_construct_last = staticmethod(sympify)
|
||||
_construct_step = staticmethod(sympify)
|
||||
|
||||
|
||||
class Extent(Basic):
|
||||
""" Represents a dimension extent.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import Extent
|
||||
>>> e = Extent(-3, 3) # -3, -2, -1, 0, 1, 2, 3
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(e, source_format='free')
|
||||
'-3:3'
|
||||
>>> from sympy.codegen.ast import Variable, real
|
||||
>>> from sympy.codegen.fnodes import dimension, intent_out
|
||||
>>> dim = dimension(e, e)
|
||||
>>> arr = Variable('x', real, attrs=[dim, intent_out])
|
||||
>>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
|
||||
'real*8, dimension(-3:3, -3:3), intent(out) :: x'
|
||||
|
||||
"""
|
||||
def __new__(cls, *args):
|
||||
if len(args) == 2:
|
||||
low, high = args
|
||||
return Basic.__new__(cls, sympify(low), sympify(high))
|
||||
elif len(args) == 0 or (len(args) == 1 and args[0] in (':', None)):
|
||||
return Basic.__new__(cls) # assumed shape
|
||||
else:
|
||||
raise ValueError("Expected 0 or 2 args (or one argument == None or ':')")
|
||||
|
||||
def _sympystr(self, printer):
|
||||
if len(self.args) == 0:
|
||||
return ':'
|
||||
return ":".join(str(arg) for arg in self.args)
|
||||
|
||||
assumed_extent = Extent() # or Extent(':'), Extent(None)
|
||||
|
||||
|
||||
def dimension(*args):
|
||||
""" Creates a 'dimension' Attribute with (up to 7) extents.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import dimension, intent_in
|
||||
>>> dim = dimension('2', ':') # 2 rows, runtime determined number of columns
|
||||
>>> from sympy.codegen.ast import Variable, integer
|
||||
>>> arr = Variable('a', integer, attrs=[dim, intent_in])
|
||||
>>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
|
||||
'integer*4, dimension(2, :), intent(in) :: a'
|
||||
|
||||
"""
|
||||
if len(args) > 7:
|
||||
raise ValueError("Fortran only supports up to 7 dimensional arrays")
|
||||
parameters = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Extent):
|
||||
parameters.append(arg)
|
||||
elif isinstance(arg, str):
|
||||
if arg == ':':
|
||||
parameters.append(Extent())
|
||||
else:
|
||||
parameters.append(String(arg))
|
||||
elif iterable(arg):
|
||||
parameters.append(Extent(*arg))
|
||||
else:
|
||||
parameters.append(sympify(arg))
|
||||
if len(args) == 0:
|
||||
raise ValueError("Need at least one dimension")
|
||||
return Attribute('dimension', parameters)
|
||||
|
||||
|
||||
assumed_size = dimension('*')
|
||||
|
||||
def array(symbol, dim, intent=None, *, attrs=(), value=None, type=None):
|
||||
""" Convenience function for creating a Variable instance for a Fortran array.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
symbol : symbol
|
||||
dim : Attribute or iterable
|
||||
If dim is an ``Attribute`` it need to have the name 'dimension'. If it is
|
||||
not an ``Attribute``, then it is passed to :func:`dimension` as ``*dim``
|
||||
intent : str
|
||||
One of: 'in', 'out', 'inout' or None
|
||||
\\*\\*kwargs:
|
||||
Keyword arguments for ``Variable`` ('type' & 'value')
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.ast import integer, real
|
||||
>>> from sympy.codegen.fnodes import array
|
||||
>>> arr = array('a', '*', 'in', type=integer)
|
||||
>>> print(fcode(arr.as_Declaration(), source_format='free', standard=2003))
|
||||
integer*4, dimension(*), intent(in) :: a
|
||||
>>> x = array('x', [3, ':', ':'], intent='out', type=real)
|
||||
>>> print(fcode(x.as_Declaration(value=1), source_format='free', standard=2003))
|
||||
real*8, dimension(3, :, :), intent(out) :: x = 1
|
||||
|
||||
"""
|
||||
if isinstance(dim, Attribute):
|
||||
if str(dim.name) != 'dimension':
|
||||
raise ValueError("Got an unexpected Attribute argument as dim: %s" % str(dim))
|
||||
else:
|
||||
dim = dimension(*dim)
|
||||
|
||||
attrs = list(attrs) + [dim]
|
||||
if intent is not None:
|
||||
if intent not in (intent_in, intent_out, intent_inout):
|
||||
intent = {'in': intent_in, 'out': intent_out, 'inout': intent_inout}[intent]
|
||||
attrs.append(intent)
|
||||
if type is None:
|
||||
return Variable.deduced(symbol, value=value, attrs=attrs)
|
||||
else:
|
||||
return Variable(symbol, type, value=value, attrs=attrs)
|
||||
|
||||
def _printable(arg):
|
||||
return String(arg) if isinstance(arg, str) else sympify(arg)
|
||||
|
||||
|
||||
def allocated(array):
|
||||
""" Creates an AST node for a function call to Fortran's "allocated(...)"
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import allocated
|
||||
>>> alloc = allocated('x')
|
||||
>>> fcode(alloc, source_format='free')
|
||||
'allocated(x)'
|
||||
|
||||
"""
|
||||
return FunctionCall('allocated', [_printable(array)])
|
||||
|
||||
|
||||
def lbound(array, dim=None, kind=None):
|
||||
""" Creates an AST node for a function call to Fortran's "lbound(...)"
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
array : Symbol or String
|
||||
dim : expr
|
||||
kind : expr
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import lbound
|
||||
>>> lb = lbound('arr', dim=2)
|
||||
>>> fcode(lb, source_format='free')
|
||||
'lbound(arr, 2)'
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'lbound',
|
||||
[_printable(array)] +
|
||||
([_printable(dim)] if dim else []) +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def ubound(array, dim=None, kind=None):
|
||||
return FunctionCall(
|
||||
'ubound',
|
||||
[_printable(array)] +
|
||||
([_printable(dim)] if dim else []) +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def shape(source, kind=None):
|
||||
""" Creates an AST node for a function call to Fortran's "shape(...)"
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
source : Symbol or String
|
||||
kind : expr
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode
|
||||
>>> from sympy.codegen.fnodes import shape
|
||||
>>> shp = shape('x')
|
||||
>>> fcode(shp, source_format='free')
|
||||
'shape(x)'
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'shape',
|
||||
[_printable(source)] +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def size(array, dim=None, kind=None):
|
||||
""" Creates an AST node for a function call to Fortran's "size(...)"
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, Symbol
|
||||
>>> from sympy.codegen.ast import FunctionDefinition, real, Return
|
||||
>>> from sympy.codegen.fnodes import array, sum_, size
|
||||
>>> a = Symbol('a', real=True)
|
||||
>>> body = [Return((sum_(a**2)/size(a))**.5)]
|
||||
>>> arr = array(a, dim=[':'], intent='in')
|
||||
>>> fd = FunctionDefinition(real, 'rms', [arr], body)
|
||||
>>> print(fcode(fd, source_format='free', standard=2003))
|
||||
real*8 function rms(a)
|
||||
real*8, dimension(:), intent(in) :: a
|
||||
rms = sqrt(sum(a**2)*1d0/size(a))
|
||||
end function
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'size',
|
||||
[_printable(array)] +
|
||||
([_printable(dim)] if dim else []) +
|
||||
([_printable(kind)] if kind else [])
|
||||
)
|
||||
|
||||
|
||||
def reshape(source, shape, pad=None, order=None):
|
||||
""" Creates an AST node for a function call to Fortran's "reshape(...)"
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
source : Symbol or String
|
||||
shape : ArrayExpr
|
||||
|
||||
"""
|
||||
return FunctionCall(
|
||||
'reshape',
|
||||
[_printable(source), _printable(shape)] +
|
||||
([_printable(pad)] if pad else []) +
|
||||
([_printable(order)] if pad else [])
|
||||
)
|
||||
|
||||
|
||||
def bind_C(name=None):
|
||||
""" Creates an Attribute ``bind_C`` with a name.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
name : str
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import fcode, Symbol
|
||||
>>> from sympy.codegen.ast import FunctionDefinition, real, Return
|
||||
>>> from sympy.codegen.fnodes import array, sum_, bind_C
|
||||
>>> a = Symbol('a', real=True)
|
||||
>>> s = Symbol('s', integer=True)
|
||||
>>> arr = array(a, dim=[s], intent='in')
|
||||
>>> body = [Return((sum_(a**2)/s)**.5)]
|
||||
>>> fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
|
||||
>>> print(fcode(fd, source_format='free', standard=2003))
|
||||
real*8 function rms(a, s) bind(C, name="rms")
|
||||
real*8, dimension(s), intent(in) :: a
|
||||
integer*4 :: s
|
||||
rms = sqrt(sum(a**2)/s)
|
||||
end function
|
||||
|
||||
"""
|
||||
return Attribute('bind_C', [String(name)] if name else [])
|
||||
|
||||
class GoTo(Token):
|
||||
""" Represents a goto statement in Fortran
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import GoTo
|
||||
>>> go = GoTo([10, 20, 30], 'i')
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(go, source_format='free')
|
||||
'go to (10, 20, 30), i'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('labels', 'expr')
|
||||
defaults = {'expr': none}
|
||||
_construct_labels = staticmethod(_mk_Tuple)
|
||||
_construct_expr = staticmethod(sympify)
|
||||
|
||||
|
||||
class FortranReturn(Token):
|
||||
""" AST node explicitly mapped to a fortran "return".
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Because a return statement in fortran is different from C, and
|
||||
in order to aid reuse of our codegen ASTs the ordinary
|
||||
``.codegen.ast.Return`` is interpreted as assignment to
|
||||
the result variable of the function. If one for some reason needs
|
||||
to generate a fortran RETURN statement, this node should be used.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy.codegen.fnodes import FortranReturn
|
||||
>>> from sympy import fcode
|
||||
>>> fcode(FortranReturn('x'))
|
||||
' return x'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('return_value',)
|
||||
defaults = {'return_value': none}
|
||||
_construct_return_value = staticmethod(sympify)
|
||||
|
||||
|
||||
class FFunction(Function):
|
||||
_required_standard = 77
|
||||
|
||||
def _fcode(self, printer):
|
||||
name = self.__class__.__name__
|
||||
if printer._settings['standard'] < self._required_standard:
|
||||
raise NotImplementedError("%s requires Fortran %d or newer" %
|
||||
(name, self._required_standard))
|
||||
return '{}({})'.format(name, ', '.join(map(printer._print, self.args)))
|
||||
|
||||
|
||||
class F95Function(FFunction):
|
||||
_required_standard = 95
|
||||
|
||||
|
||||
class isign(FFunction):
|
||||
""" Fortran sign intrinsic for integer arguments. """
|
||||
nargs = 2
|
||||
|
||||
|
||||
class dsign(FFunction):
|
||||
""" Fortran sign intrinsic for double precision arguments. """
|
||||
nargs = 2
|
||||
|
||||
|
||||
class cmplx(FFunction):
|
||||
""" Fortran complex conversion function. """
|
||||
nargs = 2 # may be extended to (2, 3) at a later point
|
||||
|
||||
|
||||
class kind(FFunction):
|
||||
""" Fortran kind function. """
|
||||
nargs = 1
|
||||
|
||||
|
||||
class merge(F95Function):
|
||||
""" Fortran merge function """
|
||||
nargs = 3
|
||||
|
||||
|
||||
class _literal(Float):
|
||||
_token: str
|
||||
_decimals: int
|
||||
|
||||
def _fcode(self, printer, *args, **kwargs):
|
||||
mantissa, sgnd_ex = ('%.{}e'.format(self._decimals) % self).split('e')
|
||||
mantissa = mantissa.strip('0').rstrip('.')
|
||||
ex_sgn, ex_num = sgnd_ex[0], sgnd_ex[1:].lstrip('0')
|
||||
ex_sgn = '' if ex_sgn == '+' else ex_sgn
|
||||
return (mantissa or '0') + self._token + ex_sgn + (ex_num or '0')
|
||||
|
||||
|
||||
class literal_sp(_literal):
|
||||
""" Fortran single precision real literal """
|
||||
_token = 'e'
|
||||
_decimals = 9
|
||||
|
||||
|
||||
class literal_dp(_literal):
|
||||
""" Fortran double precision real literal """
|
||||
_token = 'd'
|
||||
_decimals = 17
|
||||
|
||||
|
||||
class sum_(Token, Expr):
|
||||
__slots__ = _fields = ('array', 'dim', 'mask')
|
||||
defaults = {'dim': none, 'mask': none}
|
||||
_construct_array = staticmethod(sympify)
|
||||
_construct_dim = staticmethod(sympify)
|
||||
|
||||
|
||||
class product_(Token, Expr):
|
||||
__slots__ = _fields = ('array', 'dim', 'mask')
|
||||
defaults = {'dim': none, 'mask': none}
|
||||
_construct_array = staticmethod(sympify)
|
||||
_construct_dim = staticmethod(sympify)
|
||||
40
venv/lib/python3.12/site-packages/sympy/codegen/futils.py
Normal file
40
venv/lib/python3.12/site-packages/sympy/codegen/futils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from itertools import chain
|
||||
from sympy.codegen.fnodes import Module
|
||||
from sympy.core.symbol import Dummy
|
||||
from sympy.printing.fortran import FCodePrinter
|
||||
|
||||
""" This module collects utilities for rendering Fortran code. """
|
||||
|
||||
|
||||
def render_as_module(definitions, name, declarations=(), printer_settings=None):
|
||||
""" Creates a ``Module`` instance and renders it as a string.
|
||||
|
||||
This generates Fortran source code for a module with the correct ``use`` statements.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
definitions : iterable
|
||||
Passed to :class:`sympy.codegen.fnodes.Module`.
|
||||
name : str
|
||||
Passed to :class:`sympy.codegen.fnodes.Module`.
|
||||
declarations : iterable
|
||||
Passed to :class:`sympy.codegen.fnodes.Module`. It will be extended with
|
||||
use statements, 'implicit none' and public list generated from ``definitions``.
|
||||
printer_settings : dict
|
||||
Passed to ``FCodePrinter`` (default: ``{'standard': 2003, 'source_format': 'free'}``).
|
||||
|
||||
"""
|
||||
printer_settings = printer_settings or {'standard': 2003, 'source_format': 'free'}
|
||||
printer = FCodePrinter(printer_settings)
|
||||
dummy = Dummy()
|
||||
if isinstance(definitions, Module):
|
||||
raise ValueError("This function expects to construct a module on its own.")
|
||||
mod = Module(name, chain(declarations, [dummy]), definitions)
|
||||
fstr = printer.doprint(mod)
|
||||
module_use_str = ' %s\n' % ' \n'.join(['use %s, only: %s' % (k, ', '.join(v)) for
|
||||
k, v in printer.module_uses.items()])
|
||||
module_use_str += ' implicit none\n'
|
||||
module_use_str += ' private\n'
|
||||
module_use_str += ' public %s\n' % ', '.join([str(node.name) for node in definitions if getattr(node, 'name', None)])
|
||||
return fstr.replace(printer.doprint(dummy), module_use_str)
|
||||
@@ -0,0 +1,71 @@
|
||||
"""
|
||||
Additional AST nodes for operations on matrices. The nodes in this module
|
||||
are meant to represent optimization of matrix expressions within codegen's
|
||||
target languages that cannot be represented by SymPy expressions.
|
||||
|
||||
As an example, we can use :meth:`sympy.codegen.rewriting.optimize` and the
|
||||
``matin_opt`` optimization provided in :mod:`sympy.codegen.rewriting` to
|
||||
transform matrix multiplication under certain assumptions:
|
||||
|
||||
>>> from sympy import symbols, MatrixSymbol
|
||||
>>> n = symbols('n', integer=True)
|
||||
>>> A = MatrixSymbol('A', n, n)
|
||||
>>> x = MatrixSymbol('x', n, 1)
|
||||
>>> expr = A**(-1) * x
|
||||
>>> from sympy import assuming, Q
|
||||
>>> from sympy.codegen.rewriting import matinv_opt, optimize
|
||||
>>> with assuming(Q.fullrank(A)):
|
||||
... optimize(expr, [matinv_opt])
|
||||
MatrixSolve(A, vector=x)
|
||||
"""
|
||||
|
||||
from .ast import Token
|
||||
from sympy.matrices import MatrixExpr
|
||||
from sympy.core.sympify import sympify
|
||||
|
||||
|
||||
class MatrixSolve(Token, MatrixExpr):
|
||||
"""Represents an operation to solve a linear matrix equation.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
matrix : MatrixSymbol
|
||||
|
||||
Matrix representing the coefficients of variables in the linear
|
||||
equation. This matrix must be square and full-rank (i.e. all columns must
|
||||
be linearly independent) for the solving operation to be valid.
|
||||
|
||||
vector : MatrixSymbol
|
||||
|
||||
One-column matrix representing the solutions to the equations
|
||||
represented in ``matrix``.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, MatrixSymbol
|
||||
>>> from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
>>> n = symbols('n', integer=True)
|
||||
>>> A = MatrixSymbol('A', n, n)
|
||||
>>> x = MatrixSymbol('x', n, 1)
|
||||
>>> from sympy.printing.numpy import NumPyPrinter
|
||||
>>> NumPyPrinter().doprint(MatrixSolve(A, x))
|
||||
'numpy.linalg.solve(A, x)'
|
||||
>>> from sympy import octave_code
|
||||
>>> octave_code(MatrixSolve(A, x))
|
||||
'A \\\\ x'
|
||||
|
||||
"""
|
||||
__slots__ = _fields = ('matrix', 'vector')
|
||||
|
||||
_construct_matrix = staticmethod(sympify)
|
||||
_construct_vector = staticmethod(sympify)
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self.vector.shape
|
||||
|
||||
def _eval_derivative(self, x):
|
||||
A, b = self.matrix, self.vector
|
||||
return MatrixSolve(A, b.diff(x) - A.diff(x) * MatrixSolve(A, b))
|
||||
177
venv/lib/python3.12/site-packages/sympy/codegen/numpy_nodes.py
Normal file
177
venv/lib/python3.12/site-packages/sympy/codegen/numpy_nodes.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from sympy.core.function import Add, ArgumentIndexError, Function
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.sorting import default_sort_key
|
||||
from sympy.core.sympify import sympify
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.miscellaneous import Max, Min
|
||||
from .ast import Token, none
|
||||
|
||||
|
||||
def _logaddexp(x1, x2, *, evaluate=True):
|
||||
return log(Add(exp(x1, evaluate=evaluate), exp(x2, evaluate=evaluate), evaluate=evaluate))
|
||||
|
||||
|
||||
_two = S.One*2
|
||||
_ln2 = log(_two)
|
||||
|
||||
|
||||
def _lb(x, *, evaluate=True):
|
||||
return log(x, evaluate=evaluate)/_ln2
|
||||
|
||||
|
||||
def _exp2(x, *, evaluate=True):
|
||||
return Pow(_two, x, evaluate=evaluate)
|
||||
|
||||
|
||||
def _logaddexp2(x1, x2, *, evaluate=True):
|
||||
return _lb(Add(_exp2(x1, evaluate=evaluate),
|
||||
_exp2(x2, evaluate=evaluate), evaluate=evaluate))
|
||||
|
||||
|
||||
class logaddexp(Function):
|
||||
""" Logarithm of the sum of exponentiations of the inputs.
|
||||
|
||||
Helper class for use with e.g. numpy.logaddexp
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def __new__(cls, *args):
|
||||
return Function.__new__(cls, *sorted(args, key=default_sort_key))
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
wrt, other = self.args
|
||||
elif argindex == 2:
|
||||
other, wrt = self.args
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
return S.One/(S.One + exp(other-wrt))
|
||||
|
||||
def _eval_rewrite_as_log(self, x1, x2, **kwargs):
|
||||
return _logaddexp(x1, x2)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(log).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, *args, **kwargs):
|
||||
a, b = (x.simplify(**kwargs) for x in self.args)
|
||||
candidate = _logaddexp(a, b)
|
||||
if candidate != _logaddexp(a, b, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return logaddexp(a, b)
|
||||
|
||||
|
||||
class logaddexp2(Function):
|
||||
""" Logarithm of the sum of exponentiations of the inputs in base-2.
|
||||
|
||||
Helper class for use with e.g. numpy.logaddexp2
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def __new__(cls, *args):
|
||||
return Function.__new__(cls, *sorted(args, key=default_sort_key))
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
wrt, other = self.args
|
||||
elif argindex == 2:
|
||||
other, wrt = self.args
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
return S.One/(S.One + _exp2(other-wrt))
|
||||
|
||||
def _eval_rewrite_as_log(self, x1, x2, **kwargs):
|
||||
return _logaddexp2(x1, x2)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(log).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, *args, **kwargs):
|
||||
a, b = (x.simplify(**kwargs).factor() for x in self.args)
|
||||
candidate = _logaddexp2(a, b)
|
||||
if candidate != _logaddexp2(a, b, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return logaddexp2(a, b)
|
||||
|
||||
|
||||
class amin(Token):
|
||||
""" Minimum value along an axis.
|
||||
|
||||
Helper class for use with e.g. numpy.amin
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.amin.html
|
||||
"""
|
||||
__slots__ = _fields = ('array', 'axis')
|
||||
defaults = {'axis': none}
|
||||
_construct_axis = staticmethod(sympify)
|
||||
|
||||
|
||||
class amax(Token):
|
||||
""" Maximum value along an axis.
|
||||
|
||||
Helper class for use with e.g. numpy.amax
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.amax.html
|
||||
"""
|
||||
__slots__ = _fields = ('array', 'axis')
|
||||
defaults = {'axis': none}
|
||||
_construct_axis = staticmethod(sympify)
|
||||
|
||||
|
||||
class maximum(Function):
|
||||
""" Element-wise maximum of array elements.
|
||||
|
||||
Helper class for use with e.g. numpy.maximum
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.maximum.html
|
||||
"""
|
||||
|
||||
def _eval_rewrite_as_Max(self, *args):
|
||||
return Max(*self.args)
|
||||
|
||||
|
||||
class minimum(Function):
|
||||
""" Element-wise minimum of array elements.
|
||||
|
||||
Helper class for use with e.g. numpy.minimum
|
||||
|
||||
|
||||
See Also
|
||||
========
|
||||
|
||||
https://numpy.org/doc/stable/reference/generated/numpy.minimum.html
|
||||
"""
|
||||
|
||||
def _eval_rewrite_as_Min(self, *args):
|
||||
return Min(*self.args)
|
||||
11
venv/lib/python3.12/site-packages/sympy/codegen/pynodes.py
Normal file
11
venv/lib/python3.12/site-packages/sympy/codegen/pynodes.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from .abstract_nodes import List as AbstractList
|
||||
from .ast import Token
|
||||
|
||||
|
||||
class List(AbstractList):
|
||||
pass
|
||||
|
||||
|
||||
class NumExprEvaluate(Token):
|
||||
"""represents a call to :class:`numexpr`s :func:`evaluate`"""
|
||||
__slots__ = _fields = ('expr',)
|
||||
24
venv/lib/python3.12/site-packages/sympy/codegen/pyutils.py
Normal file
24
venv/lib/python3.12/site-packages/sympy/codegen/pyutils.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from sympy.printing.pycode import PythonCodePrinter
|
||||
|
||||
""" This module collects utilities for rendering Python code. """
|
||||
|
||||
|
||||
def render_as_module(content, standard='python3'):
|
||||
"""Renders Python code as a module (with the required imports).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
standard :
|
||||
See the parameter ``standard`` in
|
||||
:meth:`sympy.printing.pycode.pycode`
|
||||
"""
|
||||
|
||||
printer = PythonCodePrinter({'standard':standard})
|
||||
pystr = printer.doprint(content)
|
||||
if printer._settings['fully_qualified_modules']:
|
||||
module_imports_str = '\n'.join('import %s' % k for k in printer.module_imports)
|
||||
else:
|
||||
module_imports_str = '\n'.join(['from %s import %s' % (k, ', '.join(v)) for
|
||||
k, v in printer.module_imports.items()])
|
||||
return module_imports_str + '\n\n' + pystr
|
||||
357
venv/lib/python3.12/site-packages/sympy/codegen/rewriting.py
Normal file
357
venv/lib/python3.12/site-packages/sympy/codegen/rewriting.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Classes and functions useful for rewriting expressions for optimized code
|
||||
generation. Some languages (or standards thereof), e.g. C99, offer specialized
|
||||
math functions for better performance and/or precision.
|
||||
|
||||
Using the ``optimize`` function in this module, together with a collection of
|
||||
rules (represented as instances of ``Optimization``), one can rewrite the
|
||||
expressions for this purpose::
|
||||
|
||||
>>> from sympy import Symbol, exp, log
|
||||
>>> from sympy.codegen.rewriting import optimize, optims_c99
|
||||
>>> x = Symbol('x')
|
||||
>>> optimize(3*exp(2*x) - 3, optims_c99)
|
||||
3*expm1(2*x)
|
||||
>>> optimize(exp(2*x) - 1 - exp(-33), optims_c99)
|
||||
expm1(2*x) - exp(-33)
|
||||
>>> optimize(log(3*x + 3), optims_c99)
|
||||
log1p(x) + log(3)
|
||||
>>> optimize(log(2*x + 3), optims_c99)
|
||||
log(2*x + 3)
|
||||
|
||||
The ``optims_c99`` imported above is tuple containing the following instances
|
||||
(which may be imported from ``sympy.codegen.rewriting``):
|
||||
|
||||
- ``expm1_opt``
|
||||
- ``log1p_opt``
|
||||
- ``exp2_opt``
|
||||
- ``log2_opt``
|
||||
- ``log2const_opt``
|
||||
|
||||
|
||||
"""
|
||||
from sympy.core.function import expand_log
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Wild
|
||||
from sympy.functions.elementary.complexes import sign
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import (Max, Min)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
|
||||
from sympy.assumptions import Q, ask
|
||||
from sympy.codegen.cfunctions import log1p, log2, exp2, expm1
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.core.expr import UnevaluatedExpr
|
||||
from sympy.core.power import Pow
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
from sympy.core.mul import Mul
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.utilities.iterables import sift
|
||||
|
||||
|
||||
class Optimization:
|
||||
""" Abstract base class for rewriting optimization.
|
||||
|
||||
Subclasses should implement ``__call__`` taking an expression
|
||||
as argument.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
cost_function : callable returning number
|
||||
priority : number
|
||||
|
||||
"""
|
||||
def __init__(self, cost_function=None, priority=1):
|
||||
self.cost_function = cost_function
|
||||
self.priority=priority
|
||||
|
||||
def cheapest(self, *args):
|
||||
return min(args, key=self.cost_function)
|
||||
|
||||
|
||||
class ReplaceOptim(Optimization):
|
||||
""" Rewriting optimization calling replace on expressions.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The instance can be used as a function on expressions for which
|
||||
it will apply the ``replace`` method (see
|
||||
:meth:`sympy.core.basic.Basic.replace`).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
query :
|
||||
First argument passed to replace.
|
||||
value :
|
||||
Second argument passed to replace.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol
|
||||
>>> from sympy.codegen.rewriting import ReplaceOptim
|
||||
>>> from sympy.codegen.cfunctions import exp2
|
||||
>>> x = Symbol('x')
|
||||
>>> exp2_opt = ReplaceOptim(lambda p: p.is_Pow and p.base == 2,
|
||||
... lambda p: exp2(p.exp))
|
||||
>>> exp2_opt(2**x)
|
||||
exp2(x)
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, query, value, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.query = query
|
||||
self.value = value
|
||||
|
||||
def __call__(self, expr):
|
||||
return expr.replace(self.query, self.value)
|
||||
|
||||
|
||||
def optimize(expr, optimizations):
|
||||
""" Apply optimizations to an expression.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
expr : expression
|
||||
optimizations : iterable of ``Optimization`` instances
|
||||
The optimizations will be sorted with respect to ``priority`` (highest first).
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import log, Symbol
|
||||
>>> from sympy.codegen.rewriting import optims_c99, optimize
|
||||
>>> x = Symbol('x')
|
||||
>>> optimize(log(x+3)/log(2) + log(x**2 + 1), optims_c99)
|
||||
log1p(x**2) + log2(x + 3)
|
||||
|
||||
"""
|
||||
|
||||
for optim in sorted(optimizations, key=lambda opt: opt.priority, reverse=True):
|
||||
new_expr = optim(expr)
|
||||
if optim.cost_function is None:
|
||||
expr = new_expr
|
||||
else:
|
||||
expr = optim.cheapest(expr, new_expr)
|
||||
return expr
|
||||
|
||||
|
||||
exp2_opt = ReplaceOptim(
|
||||
lambda p: p.is_Pow and p.base == 2,
|
||||
lambda p: exp2(p.exp)
|
||||
)
|
||||
|
||||
|
||||
_d = Wild('d', properties=[lambda x: x.is_Dummy])
|
||||
_u = Wild('u', properties=[lambda x: not x.is_number and not x.is_Add])
|
||||
_v = Wild('v')
|
||||
_w = Wild('w')
|
||||
_n = Wild('n', properties=[lambda x: x.is_number])
|
||||
|
||||
sinc_opt1 = ReplaceOptim(
|
||||
sin(_w)/_w, sinc(_w)
|
||||
)
|
||||
sinc_opt2 = ReplaceOptim(
|
||||
sin(_n*_w)/_w, _n*sinc(_n*_w)
|
||||
)
|
||||
sinc_opts = (sinc_opt1, sinc_opt2)
|
||||
|
||||
log2_opt = ReplaceOptim(_v*log(_w)/log(2), _v*log2(_w), cost_function=lambda expr: expr.count(
|
||||
lambda e: ( # division & eval of transcendentals are expensive floating point operations...
|
||||
e.is_Pow and e.exp.is_negative # division
|
||||
or (isinstance(e, (log, log2)) and not e.args[0].is_number)) # transcendental
|
||||
)
|
||||
)
|
||||
|
||||
log2const_opt = ReplaceOptim(log(2)*log2(_w), log(_w))
|
||||
|
||||
logsumexp_2terms_opt = ReplaceOptim(
|
||||
lambda l: (isinstance(l, log)
|
||||
and l.args[0].is_Add
|
||||
and len(l.args[0].args) == 2
|
||||
and all(isinstance(t, exp) for t in l.args[0].args)),
|
||||
lambda l: (
|
||||
Max(*[e.args[0] for e in l.args[0].args]) +
|
||||
log1p(exp(Min(*[e.args[0] for e in l.args[0].args])))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class FuncMinusOneOptim(ReplaceOptim):
|
||||
"""Specialization of ReplaceOptim for functions evaluating "f(x) - 1".
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
Numerical functions which go toward one as x go toward zero is often best
|
||||
implemented by a dedicated function in order to avoid catastrophic
|
||||
cancellation. One such example is ``expm1(x)`` in the C standard library
|
||||
which evaluates ``exp(x) - 1``. Such functions preserves many more
|
||||
significant digits when its argument is much smaller than one, compared
|
||||
to subtracting one afterwards.
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
func :
|
||||
The function which is subtracted by one.
|
||||
func_m_1 :
|
||||
The specialized function evaluating ``func(x) - 1``.
|
||||
opportunistic : bool
|
||||
When ``True``, apply the transformation as long as the magnitude of the
|
||||
remaining number terms decreases. When ``False``, only apply the
|
||||
transformation if it completely eliminates the number term.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import symbols, exp
|
||||
>>> from sympy.codegen.rewriting import FuncMinusOneOptim
|
||||
>>> from sympy.codegen.cfunctions import expm1
|
||||
>>> x, y = symbols('x y')
|
||||
>>> expm1_opt = FuncMinusOneOptim(exp, expm1)
|
||||
>>> expm1_opt(exp(x) + 2*exp(5*y) - 3)
|
||||
expm1(x) + 2*expm1(5*y)
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, func, func_m_1, opportunistic=True):
|
||||
weight = 10 # <-- this is an arbitrary number (heuristic)
|
||||
super().__init__(lambda e: e.is_Add, self.replace_in_Add,
|
||||
cost_function=lambda expr: expr.count_ops() - weight*expr.count(func_m_1))
|
||||
self.func = func
|
||||
self.func_m_1 = func_m_1
|
||||
self.opportunistic = opportunistic
|
||||
|
||||
def _group_Add_terms(self, add):
|
||||
numbers, non_num = sift(add.args, lambda arg: arg.is_number, binary=True)
|
||||
numsum = sum(numbers)
|
||||
terms_with_func, other = sift(non_num, lambda arg: arg.has(self.func), binary=True)
|
||||
return numsum, terms_with_func, other
|
||||
|
||||
def replace_in_Add(self, e):
|
||||
""" passed as second argument to Basic.replace(...) """
|
||||
numsum, terms_with_func, other_non_num_terms = self._group_Add_terms(e)
|
||||
if numsum == 0:
|
||||
return e
|
||||
substituted, untouched = [], []
|
||||
for with_func in terms_with_func:
|
||||
if with_func.is_Mul:
|
||||
func, coeff = sift(with_func.args, lambda arg: arg.func == self.func, binary=True)
|
||||
if len(func) == 1 and len(coeff) == 1:
|
||||
func, coeff = func[0], coeff[0]
|
||||
else:
|
||||
coeff = None
|
||||
elif with_func.func == self.func:
|
||||
func, coeff = with_func, S.One
|
||||
else:
|
||||
coeff = None
|
||||
|
||||
if coeff is not None and coeff.is_number and sign(coeff) == -sign(numsum):
|
||||
if self.opportunistic:
|
||||
do_substitute = abs(coeff+numsum) < abs(numsum)
|
||||
else:
|
||||
do_substitute = coeff+numsum == 0
|
||||
|
||||
if do_substitute: # advantageous substitution
|
||||
numsum += coeff
|
||||
substituted.append(coeff*self.func_m_1(*func.args))
|
||||
continue
|
||||
untouched.append(with_func)
|
||||
|
||||
return e.func(numsum, *substituted, *untouched, *other_non_num_terms)
|
||||
|
||||
def __call__(self, expr):
|
||||
alt1 = super().__call__(expr)
|
||||
alt2 = super().__call__(expr.factor())
|
||||
return self.cheapest(alt1, alt2)
|
||||
|
||||
|
||||
expm1_opt = FuncMinusOneOptim(exp, expm1)
|
||||
cosm1_opt = FuncMinusOneOptim(cos, cosm1)
|
||||
powm1_opt = FuncMinusOneOptim(Pow, powm1)
|
||||
|
||||
log1p_opt = ReplaceOptim(
|
||||
lambda e: isinstance(e, log),
|
||||
lambda l: expand_log(l.replace(
|
||||
log, lambda arg: log(arg.factor())
|
||||
)).replace(log(_u+1), log1p(_u))
|
||||
)
|
||||
|
||||
def create_expand_pow_optimization(limit, *, base_req=lambda b: b.is_symbol):
|
||||
""" Creates an instance of :class:`ReplaceOptim` for expanding ``Pow``.
|
||||
|
||||
Explanation
|
||||
===========
|
||||
|
||||
The requirements for expansions are that the base needs to be a symbol
|
||||
and the exponent needs to be an Integer (and be less than or equal to
|
||||
``limit``).
|
||||
|
||||
Parameters
|
||||
==========
|
||||
|
||||
limit : int
|
||||
The highest power which is expanded into multiplication.
|
||||
base_req : function returning bool
|
||||
Requirement on base for expansion to happen, default is to return
|
||||
the ``is_symbol`` attribute of the base.
|
||||
|
||||
Examples
|
||||
========
|
||||
|
||||
>>> from sympy import Symbol, sin
|
||||
>>> from sympy.codegen.rewriting import create_expand_pow_optimization
|
||||
>>> x = Symbol('x')
|
||||
>>> expand_opt = create_expand_pow_optimization(3)
|
||||
>>> expand_opt(x**5 + x**3)
|
||||
x**5 + x*x*x
|
||||
>>> expand_opt(x**5 + x**3 + sin(x)**3)
|
||||
x**5 + sin(x)**3 + x*x*x
|
||||
>>> opt2 = create_expand_pow_optimization(3, base_req=lambda b: not b.is_Function)
|
||||
>>> opt2((x+1)**2 + sin(x)**2)
|
||||
sin(x)**2 + (x + 1)*(x + 1)
|
||||
|
||||
"""
|
||||
return ReplaceOptim(
|
||||
lambda e: e.is_Pow and base_req(e.base) and e.exp.is_Integer and abs(e.exp) <= limit,
|
||||
lambda p: (
|
||||
UnevaluatedExpr(Mul(*([p.base]*+p.exp), evaluate=False)) if p.exp > 0 else
|
||||
1/UnevaluatedExpr(Mul(*([p.base]*-p.exp), evaluate=False))
|
||||
))
|
||||
|
||||
# Optimization procedures for turning A**(-1) * x into MatrixSolve(A, x)
|
||||
def _matinv_predicate(expr):
|
||||
# TODO: We should be able to support more than 2 elements
|
||||
if expr.is_MatMul and len(expr.args) == 2:
|
||||
left, right = expr.args
|
||||
if left.is_Inverse and right.shape[1] == 1:
|
||||
inv_arg = left.arg
|
||||
if isinstance(inv_arg, MatrixSymbol):
|
||||
return bool(ask(Q.fullrank(left.arg)))
|
||||
|
||||
return False
|
||||
|
||||
def _matinv_transform(expr):
|
||||
left, right = expr.args
|
||||
inv_arg = left.arg
|
||||
return MatrixSolve(inv_arg, right)
|
||||
|
||||
|
||||
matinv_opt = ReplaceOptim(_matinv_predicate, _matinv_transform)
|
||||
|
||||
|
||||
logaddexp_opt = ReplaceOptim(log(exp(_v)+exp(_w)), logaddexp(_v, _w))
|
||||
logaddexp2_opt = ReplaceOptim(log(Pow(2, _v)+Pow(2, _w)), logaddexp2(_v, _w)*log(2))
|
||||
|
||||
# Collections of optimizations:
|
||||
optims_c99 = (expm1_opt, log1p_opt, exp2_opt, log2_opt, log2const_opt)
|
||||
|
||||
optims_numpy = optims_c99 + (logaddexp_opt, logaddexp2_opt,) + sinc_opts
|
||||
|
||||
optims_scipy = (cosm1_opt, powm1_opt)
|
||||
@@ -0,0 +1,79 @@
|
||||
from sympy.core.function import Add, ArgumentIndexError, Function
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.functions.elementary.exponential import log
|
||||
from sympy.functions.elementary.trigonometric import cos, sin
|
||||
|
||||
|
||||
def _cosm1(x, *, evaluate=True):
|
||||
return Add(cos(x, evaluate=evaluate), -S.One, evaluate=evaluate)
|
||||
|
||||
|
||||
class cosm1(Function):
|
||||
""" Minus one plus cosine of x, i.e. cos(x) - 1. For use when x is close to zero.
|
||||
|
||||
Helper class for use with e.g. scipy.special.cosm1
|
||||
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.cosm1.html
|
||||
"""
|
||||
nargs = 1
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return -sin(*self.args)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_rewrite_as_cos(self, x, **kwargs):
|
||||
return _cosm1(x)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(cos).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, **kwargs):
|
||||
x, = self.args
|
||||
candidate = _cosm1(x.simplify(**kwargs))
|
||||
if candidate != _cosm1(x, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return cosm1(x)
|
||||
|
||||
|
||||
def _powm1(x, y, *, evaluate=True):
|
||||
return Add(Pow(x, y, evaluate=evaluate), -S.One, evaluate=evaluate)
|
||||
|
||||
|
||||
class powm1(Function):
|
||||
""" Minus one plus x to the power of y, i.e. x**y - 1. For use when x is close to one or y is close to zero.
|
||||
|
||||
Helper class for use with e.g. scipy.special.powm1
|
||||
See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.special.powm1.html
|
||||
"""
|
||||
nargs = 2
|
||||
|
||||
def fdiff(self, argindex=1):
|
||||
"""
|
||||
Returns the first derivative of this function.
|
||||
"""
|
||||
if argindex == 1:
|
||||
return Pow(self.args[0], self.args[1])*self.args[1]/self.args[0]
|
||||
elif argindex == 2:
|
||||
return log(self.args[0])*Pow(*self.args)
|
||||
else:
|
||||
raise ArgumentIndexError(self, argindex)
|
||||
|
||||
def _eval_rewrite_as_Pow(self, x, y, **kwargs):
|
||||
return _powm1(x, y)
|
||||
|
||||
def _eval_evalf(self, *args, **kwargs):
|
||||
return self.rewrite(Pow).evalf(*args, **kwargs)
|
||||
|
||||
def _eval_simplify(self, **kwargs):
|
||||
x, y = self.args
|
||||
candidate = _powm1(x.simplify(**kwargs), y.simplify(**kwargs))
|
||||
if candidate != _powm1(x, y, evaluate=False):
|
||||
return candidate
|
||||
else:
|
||||
return powm1(x, y)
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.abstract_nodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
hash(m)
|
||||
@@ -0,0 +1,180 @@
|
||||
import tempfile
|
||||
from sympy import log, Min, Max, sqrt
|
||||
from sympy.core.numbers import Float
|
||||
from sympy.core.symbol import Symbol, symbols
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString
|
||||
from sympy.codegen.algorithms import newtons_method, newtons_method_function
|
||||
from sympy.codegen.cfunctions import expm1
|
||||
from sympy.codegen.fnodes import bind_C
|
||||
from sympy.codegen.futils import render_as_module as f_module
|
||||
from sympy.codegen.pyutils import render_as_module as py_module
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, raises, skip_under_pyodide
|
||||
|
||||
cython = import_module('cython')
|
||||
wurlitzer = import_module('wurlitzer')
|
||||
|
||||
def test_newtons_method():
|
||||
x, dx, atol = symbols('x dx atol')
|
||||
expr = cos(x) - x**3
|
||||
algo = newtons_method(expr, x, atol, dx)
|
||||
assert algo.has(Assignment(dx, -expr/expr.diff(x)))
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__ccode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double)\n"
|
||||
"def py_newton(x):\n"
|
||||
" return newton(x)\n"))
|
||||
], build_dir=folder, compile_kwargs=compile_kw)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__fcode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_fortran():
|
||||
skip("No Fortran compiler found.")
|
||||
|
||||
f_mod = f_module([func], 'mod_newton')
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.f90', f_mod),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double*)\n"
|
||||
"def py_newton(double x):\n"
|
||||
" return newton(&x)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
def test_newtons_method_function__pycode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
py_mod = py_module(func)
|
||||
namespace = {}
|
||||
exec(py_mod, namespace, namespace)
|
||||
res = eval('newton(0.5)', namespace)
|
||||
assert abs(res - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_newtons_method_function__ccode_parameters():
|
||||
args = x, A, k, p = symbols('x A k p')
|
||||
expr = A*cos(k*x) - p*x**3
|
||||
raises(ValueError, lambda: newtons_method_function(expr, x))
|
||||
use_wurlitzer = wurlitzer
|
||||
|
||||
func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
|
||||
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton_par.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double, double, double, double)\n"
|
||||
"def py_newton(x, A=1, k=1, p=1):\n"
|
||||
" return newton(x, A, k, p)\n"))
|
||||
], compile_kwargs=compile_kw, build_dir=folder)
|
||||
|
||||
if use_wurlitzer:
|
||||
with wurlitzer.pipes() as (out, err):
|
||||
result = mod.py_newton(0.5)
|
||||
else:
|
||||
result = mod.py_newton(0.5)
|
||||
|
||||
assert abs(result - 0.865474033102) < 1e-12
|
||||
|
||||
if not use_wurlitzer:
|
||||
skip("C-level output only tested when package 'wurlitzer' is available.")
|
||||
|
||||
out, err = out.read(), err.read()
|
||||
assert err == ''
|
||||
assert out == """\
|
||||
x= 0.5
|
||||
x= 1.1121 d_x= 0.61214
|
||||
x= 0.90967 d_x= -0.20247
|
||||
x= 0.86726 d_x= -0.042409
|
||||
x= 0.86548 d_x= -0.0017867
|
||||
x= 0.86547 d_x= -3.1022e-06
|
||||
x= 0.86547 d_x= -9.3421e-12
|
||||
x= 0.86547 d_x= 3.6902e-17
|
||||
""" # try to run tests with LC_ALL=C if this assertion fails
|
||||
|
||||
|
||||
def test_newtons_method_function__rtol_cse_nan():
|
||||
a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True)
|
||||
i = Symbol('i', integer=True, nonnegative=True)
|
||||
N_ari = N_tot - N_geo - 1
|
||||
delta_ari = (c-b)/N_ari
|
||||
ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo))
|
||||
eqb_log = ln_delta_geo - log(delta_ari)
|
||||
|
||||
def _clamp(low, expr, high):
|
||||
return Min(Max(low, expr), high)
|
||||
|
||||
meth_kw = {
|
||||
'clamped_newton': {'delta_fn': lambda e, x: _clamp(
|
||||
(sqrt(a*x)-x)*0.99,
|
||||
-e/e.diff(x),
|
||||
(sqrt(c*x)-x)*0.99
|
||||
)},
|
||||
'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))},
|
||||
'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))},
|
||||
}
|
||||
args = eqb_log, b
|
||||
for use_cse in [False, True]:
|
||||
kwargs = {
|
||||
'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse,
|
||||
'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c),
|
||||
'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN.")))
|
||||
}
|
||||
func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()}
|
||||
py_mod = {k: py_module(v) for k, v in func.items()}
|
||||
namespace = {}
|
||||
root_find_b = {}
|
||||
for k, v in py_mod.items():
|
||||
ns = namespace[k] = {}
|
||||
exec(v, ns, ns)
|
||||
root_find_b[k] = ns[f'{k}_b']
|
||||
ref = Float('13.2261515064168768938151923226496')
|
||||
reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16}
|
||||
guess = 4.0
|
||||
for meth, func in root_find_b.items():
|
||||
result = func(guess, 1e-2, 1e2, 50, 100)
|
||||
req = ref*reftol[meth]
|
||||
if use_cse:
|
||||
req *= 2
|
||||
assert abs(result - ref) < req
|
||||
@@ -0,0 +1,58 @@
|
||||
# This file contains tests that exercise multiple AST nodes
|
||||
|
||||
import tempfile
|
||||
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, skip_under_pyodide
|
||||
from sympy.codegen.ast import (
|
||||
FunctionDefinition, FunctionPrototype, Variable, Pointer, real, Assignment,
|
||||
integer, CodeBlock, While
|
||||
)
|
||||
from sympy.codegen.cnodes import void, PreIncrement
|
||||
from sympy.codegen.cutils import render_as_source_file
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
def _mk_func1():
|
||||
declars = n, inp, out = Variable('n', integer), Pointer('inp', real), Pointer('out', real)
|
||||
i = Variable('i', integer)
|
||||
whl = While(i<n, [Assignment(out[i], inp[i]), PreIncrement(i)])
|
||||
body = CodeBlock(i.as_Declaration(value=0), whl)
|
||||
return FunctionDefinition(void, 'our_test_function', declars, body)
|
||||
|
||||
|
||||
def _render_compile_import(funcdef, build_dir):
|
||||
code_str = render_as_source_file(funcdef, settings={"contract": False})
|
||||
declar = ccode(FunctionPrototype.from_FunctionDefinition(funcdef))
|
||||
return compile_link_import_strings([
|
||||
('our_test_func.c', code_str),
|
||||
('_our_test_func.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern {declar}\n"
|
||||
"def _{fname}({typ}[:] inp, {typ}[:] out):\n"
|
||||
" {fname}(inp.size, &inp[0], &out[0])").format(
|
||||
declar=declar, fname=funcdef.name, typ='double'
|
||||
))
|
||||
], build_dir=build_dir)
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_copying_function():
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
|
||||
info = None
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = _render_compile_import(_mk_func1(), build_dir=folder)
|
||||
inp = np.arange(10.0)
|
||||
out = np.empty_like(inp)
|
||||
mod._our_test_function(inp, out)
|
||||
assert np.allclose(inp, out)
|
||||
@@ -0,0 +1,53 @@
|
||||
import math
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.codegen.rewriting import optimize
|
||||
from sympy.codegen.approximations import SumApprox, SeriesApprox
|
||||
|
||||
|
||||
def test_SumApprox_trivial():
|
||||
x = symbols('x')
|
||||
expr1 = 1 + x
|
||||
sum_approx = SumApprox(bounds={x: (-1e-20, 1e-20)}, reltol=1e-16)
|
||||
apx1 = optimize(expr1, [sum_approx])
|
||||
assert apx1 - 1 == 0
|
||||
|
||||
|
||||
def test_SumApprox_monotone_terms():
|
||||
x, y, z = symbols('x y z')
|
||||
expr1 = exp(z)*(x**2 + y**2 + 1)
|
||||
bnds1 = {x: (0, 1e-3), y: (100, 1000)}
|
||||
sum_approx_m2 = SumApprox(bounds=bnds1, reltol=1e-2)
|
||||
sum_approx_m5 = SumApprox(bounds=bnds1, reltol=1e-5)
|
||||
sum_approx_m11 = SumApprox(bounds=bnds1, reltol=1e-11)
|
||||
assert (optimize(expr1, [sum_approx_m2])/exp(z) - (y**2)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m5])/exp(z) - (y**2 + 1)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m11])/exp(z) - (y**2 + 1 + x**2)).simplify() == 0
|
||||
|
||||
|
||||
def test_SeriesApprox_trivial():
|
||||
x, z = symbols('x z')
|
||||
for factor in [1, exp(z)]:
|
||||
x = symbols('x')
|
||||
expr1 = exp(x)*factor
|
||||
bnds1 = {x: (-1, 1)}
|
||||
series_approx_50 = SeriesApprox(bounds=bnds1, reltol=0.50)
|
||||
series_approx_10 = SeriesApprox(bounds=bnds1, reltol=0.10)
|
||||
series_approx_05 = SeriesApprox(bounds=bnds1, reltol=0.05)
|
||||
c = (bnds1[x][1] + bnds1[x][0])/2 # 0.0
|
||||
f0 = math.exp(c) # 1.0
|
||||
|
||||
ref_50 = f0 + x + x**2/2
|
||||
ref_10 = f0 + x + x**2/2 + x**3/6
|
||||
ref_05 = f0 + x + x**2/2 + x**3/6 + x**4/24
|
||||
|
||||
res_50 = optimize(expr1, [series_approx_50])
|
||||
res_10 = optimize(expr1, [series_approx_10])
|
||||
res_05 = optimize(expr1, [series_approx_05])
|
||||
|
||||
assert (res_50/factor - ref_50).simplify() == 0
|
||||
assert (res_10/factor - ref_10).simplify() == 0
|
||||
assert (res_05/factor - ref_05).simplify() == 0
|
||||
|
||||
max_ord3 = SeriesApprox(bounds=bnds1, reltol=0.05, max_order=3)
|
||||
assert optimize(expr1, [max_ord3]) == expr1
|
||||
@@ -0,0 +1,661 @@
|
||||
import math
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.numbers import nan, oo, Float, Integer
|
||||
from sympy.core.relational import Lt
|
||||
from sympy.core.symbol import symbols, Symbol
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.sets.fancysets import Range
|
||||
from sympy.tensor.indexed import Idx, IndexedBase
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Attribute, aug_assign, CodeBlock, For, Type, Variable, Pointer, Declaration,
|
||||
AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
|
||||
DivAugmentedAssignment, ModAugmentedAssignment, value_const, pointer_const,
|
||||
integer, real, complex_, int8, uint8, float16 as f16, float32 as f32,
|
||||
float64 as f64, float80 as f80, float128 as f128, complex64 as c64, complex128 as c128,
|
||||
While, Scope, String, Print, QuotedString, FunctionPrototype, FunctionDefinition, Return,
|
||||
FunctionCall, untyped, IntBaseType, intc, Node, none, NoneToken, Token, Comment
|
||||
)
|
||||
|
||||
x, y, z, t, x0, x1, x2, a, b = symbols("x, y, z, t, x0, x1, x2, a, b")
|
||||
n = symbols("n", integer=True)
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
mat = Matrix([1, 2, 3])
|
||||
B = IndexedBase('B')
|
||||
i = Idx("i", n)
|
||||
A22 = MatrixSymbol('A22',2,2)
|
||||
B22 = MatrixSymbol('B22',2,2)
|
||||
|
||||
|
||||
def test_Assignment():
|
||||
# Here we just do things to show they don't error
|
||||
Assignment(x, y)
|
||||
Assignment(x, 0)
|
||||
Assignment(A, mat)
|
||||
Assignment(A[1,0], 0)
|
||||
Assignment(A[1,0], x)
|
||||
Assignment(B[i], x)
|
||||
Assignment(B[i], 0)
|
||||
a = Assignment(x, y)
|
||||
assert a.func(*a.args) == a
|
||||
assert a.op == ':='
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: Assignment(B[i], A))
|
||||
raises(ValueError, lambda: Assignment(B[i], mat))
|
||||
raises(ValueError, lambda: Assignment(x, mat))
|
||||
raises(ValueError, lambda: Assignment(x, A))
|
||||
raises(ValueError, lambda: Assignment(A[1,0], mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: Assignment(A, x))
|
||||
raises(ValueError, lambda: Assignment(A, 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: Assignment(mat, A))
|
||||
raises(TypeError, lambda: Assignment(0, x))
|
||||
raises(TypeError, lambda: Assignment(x*x, 1))
|
||||
raises(TypeError, lambda: Assignment(A + A, mat))
|
||||
raises(TypeError, lambda: Assignment(B, 0))
|
||||
|
||||
|
||||
def test_AugAssign():
|
||||
# Here we just do things to show they don't error
|
||||
aug_assign(x, '+', y)
|
||||
aug_assign(x, '+', 0)
|
||||
aug_assign(A, '+', mat)
|
||||
aug_assign(A[1, 0], '+', 0)
|
||||
aug_assign(A[1, 0], '+', x)
|
||||
aug_assign(B[i], '+', x)
|
||||
aug_assign(B[i], '+', 0)
|
||||
|
||||
# Check creation via aug_assign vs constructor
|
||||
for binop, cls in [
|
||||
('+', AddAugmentedAssignment),
|
||||
('-', SubAugmentedAssignment),
|
||||
('*', MulAugmentedAssignment),
|
||||
('/', DivAugmentedAssignment),
|
||||
('%', ModAugmentedAssignment),
|
||||
]:
|
||||
a = aug_assign(x, binop, y)
|
||||
b = cls(x, y)
|
||||
assert a.func(*a.args) == a == b
|
||||
assert a.binop == binop
|
||||
assert a.op == binop + '='
|
||||
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', A))
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', A))
|
||||
raises(ValueError, lambda: aug_assign(A[1, 0], '+', mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: aug_assign(A, '+', x))
|
||||
raises(ValueError, lambda: aug_assign(A, '+', 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: aug_assign(mat, '+', A))
|
||||
raises(TypeError, lambda: aug_assign(0, '+', x))
|
||||
raises(TypeError, lambda: aug_assign(x * x, '+', 1))
|
||||
raises(TypeError, lambda: aug_assign(A + A, '+', mat))
|
||||
raises(TypeError, lambda: aug_assign(B, '+', 0))
|
||||
|
||||
|
||||
def test_Assignment_printing():
|
||||
assignment_classes = [
|
||||
Assignment,
|
||||
AddAugmentedAssignment,
|
||||
SubAugmentedAssignment,
|
||||
MulAugmentedAssignment,
|
||||
DivAugmentedAssignment,
|
||||
ModAugmentedAssignment,
|
||||
]
|
||||
pairs = [
|
||||
(x, 2 * y + 2),
|
||||
(B[i], x),
|
||||
(A22, B22),
|
||||
(A[0, 0], x),
|
||||
]
|
||||
|
||||
for cls in assignment_classes:
|
||||
for lhs, rhs in pairs:
|
||||
a = cls(lhs, rhs)
|
||||
assert repr(a) == '%s(%s, %s)' % (cls.__name__, repr(lhs), repr(rhs))
|
||||
|
||||
|
||||
def test_CodeBlock():
|
||||
c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
|
||||
assert c.func(*c.args) == c
|
||||
|
||||
assert c.left_hand_sides == Tuple(x, y)
|
||||
assert c.right_hand_sides == Tuple(1, x + 1)
|
||||
|
||||
def test_CodeBlock_topological_sort():
|
||||
assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
ordered_assignments = [
|
||||
# Note that the unrelated z=1 and y=2 are kept in that order
|
||||
Assignment(z, 1),
|
||||
Assignment(y, 2),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
c1 = CodeBlock.topological_sort(assignments)
|
||||
assert c1 == CodeBlock(*ordered_assignments)
|
||||
|
||||
# Cycle
|
||||
invalid_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(y, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
raises(ValueError, lambda: CodeBlock.topological_sort(invalid_assignments))
|
||||
|
||||
# Free symbols
|
||||
free_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
]
|
||||
|
||||
free_assignments_ordered = [
|
||||
Assignment(z, a * b),
|
||||
Assignment(y, b + 3),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
|
||||
c2 = CodeBlock.topological_sort(free_assignments)
|
||||
assert c2 == CodeBlock(*free_assignments_ordered)
|
||||
|
||||
def test_CodeBlock_free_symbols():
|
||||
c1 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
)
|
||||
assert c1.free_symbols == set()
|
||||
|
||||
c2 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
)
|
||||
assert c2.free_symbols == {a, b}
|
||||
|
||||
def test_CodeBlock_cse():
|
||||
c1 = CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x, sin(y)),
|
||||
Assignment(z, sin(y)),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
assert c1.cse() == CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(x, x0),
|
||||
Assignment(z, x0),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
|
||||
# Multiple assignments to same symbol not supported
|
||||
raises(NotImplementedError, lambda: CodeBlock(
|
||||
Assignment(x, 1),
|
||||
Assignment(y, 1), Assignment(y, 2)
|
||||
).cse())
|
||||
|
||||
# Check auto-generated symbols do not collide with existing ones
|
||||
c2 = CodeBlock(
|
||||
Assignment(x0, sin(y) + 1),
|
||||
Assignment(x1, 2 * sin(y)),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
assert c2.cse() == CodeBlock(
|
||||
Assignment(x2, sin(y)),
|
||||
Assignment(x0, x2 + 1),
|
||||
Assignment(x1, 2 * x2),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
|
||||
|
||||
def test_CodeBlock_cse__issue_14118():
|
||||
# see https://github.com/sympy/sympy/issues/14118
|
||||
c = CodeBlock(
|
||||
Assignment(A22, Matrix([[x, sin(y)],[3, 4]])),
|
||||
Assignment(B22, Matrix([[sin(y), 2*sin(y)], [sin(y)**2, 7]]))
|
||||
)
|
||||
assert c.cse() == CodeBlock(
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(A22, Matrix([[x, x0],[3, 4]])),
|
||||
Assignment(B22, Matrix([[x0, 2*x0], [x0**2, 7]]))
|
||||
)
|
||||
|
||||
def test_For():
|
||||
f = For(n, Range(0, 3), (Assignment(A[n, 0], x + n), aug_assign(x, '+', y)))
|
||||
f = For(n, (1, 2, 3, 4, 5), (Assignment(A[n, 0], x + n),))
|
||||
assert f.func(*f.args) == f
|
||||
raises(TypeError, lambda: For(n, x, (x + y,)))
|
||||
|
||||
|
||||
def test_none():
|
||||
assert none.is_Atom
|
||||
assert none == none
|
||||
class Foo(Token):
|
||||
pass
|
||||
foo = Foo()
|
||||
assert foo != none
|
||||
assert none == None
|
||||
assert none == NoneToken()
|
||||
assert none.func(*none.args) == none
|
||||
|
||||
|
||||
def test_String():
|
||||
st = String('foobar')
|
||||
assert st.is_Atom
|
||||
assert st == String('foobar')
|
||||
assert st.text == 'foobar'
|
||||
assert st.func(**st.kwargs()) == st
|
||||
assert st.func(*st.args) == st
|
||||
|
||||
|
||||
class Signifier(String):
|
||||
pass
|
||||
|
||||
si = Signifier('foobar')
|
||||
assert si != st
|
||||
assert si.text == st.text
|
||||
s = String('foo')
|
||||
assert str(s) == 'foo'
|
||||
assert repr(s) == "String('foo')"
|
||||
|
||||
def test_Comment():
|
||||
c = Comment('foobar')
|
||||
assert c.text == 'foobar'
|
||||
assert str(c) == 'foobar'
|
||||
|
||||
def test_Node():
|
||||
n = Node()
|
||||
assert n == Node()
|
||||
assert n.func(*n.args) == n
|
||||
|
||||
|
||||
def test_Type():
|
||||
t = Type('MyType')
|
||||
assert len(t.args) == 1
|
||||
assert t.name == String('MyType')
|
||||
assert str(t) == 'MyType'
|
||||
assert repr(t) == "Type(String('MyType'))"
|
||||
assert Type(t) == t
|
||||
assert t.func(*t.args) == t
|
||||
t1 = Type('t1')
|
||||
t2 = Type('t2')
|
||||
assert t1 != t2
|
||||
assert t1 == t1 and t2 == t2
|
||||
t1b = Type('t1')
|
||||
assert t1 == t1b
|
||||
assert t2 != t1b
|
||||
|
||||
|
||||
def test_Type__from_expr():
|
||||
assert Type.from_expr(i) == integer
|
||||
u = symbols('u', real=True)
|
||||
assert Type.from_expr(u) == real
|
||||
assert Type.from_expr(n) == integer
|
||||
assert Type.from_expr(3) == integer
|
||||
assert Type.from_expr(3.0) == real
|
||||
assert Type.from_expr(3+1j) == complex_
|
||||
raises(ValueError, lambda: Type.from_expr(sum))
|
||||
|
||||
|
||||
def test_Type__cast_check__integers():
|
||||
# Rounding
|
||||
raises(ValueError, lambda: integer.cast_check(3.5))
|
||||
assert integer.cast_check('3') == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000000')) == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000001')) == 3 # unintuitive maybe?
|
||||
|
||||
# Range
|
||||
assert int8.cast_check(127.0) == 127
|
||||
raises(ValueError, lambda: int8.cast_check(128))
|
||||
assert int8.cast_check(-128) == -128
|
||||
raises(ValueError, lambda: int8.cast_check(-129))
|
||||
|
||||
assert uint8.cast_check(0) == 0
|
||||
assert uint8.cast_check(128) == 128
|
||||
raises(ValueError, lambda: uint8.cast_check(256.0))
|
||||
raises(ValueError, lambda: uint8.cast_check(-1))
|
||||
|
||||
def test_Attribute():
|
||||
noexcept = Attribute('noexcept')
|
||||
assert noexcept == Attribute('noexcept')
|
||||
alignas16 = Attribute('alignas', [16])
|
||||
alignas32 = Attribute('alignas', [32])
|
||||
assert alignas16 != alignas32
|
||||
assert alignas16.func(*alignas16.args) == alignas16
|
||||
|
||||
|
||||
def test_Variable():
|
||||
v = Variable(x, type=real)
|
||||
assert v == Variable(v)
|
||||
assert v == Variable('x', type=real)
|
||||
assert v.symbol == x
|
||||
assert v.type == real
|
||||
assert value_const not in v.attrs
|
||||
assert v.func(*v.args) == v
|
||||
assert str(v) == 'Variable(x, type=real)'
|
||||
|
||||
w = Variable(y, f32, attrs={value_const})
|
||||
assert w.symbol == y
|
||||
assert w.type == f32
|
||||
assert value_const in w.attrs
|
||||
assert w.func(*w.args) == w
|
||||
|
||||
v_n = Variable(n, type=Type.from_expr(n))
|
||||
assert v_n.type == integer
|
||||
assert v_n.func(*v_n.args) == v_n
|
||||
v_i = Variable(i, type=Type.from_expr(n))
|
||||
assert v_i.type == integer
|
||||
assert v_i != v_n
|
||||
|
||||
a_i = Variable.deduced(i)
|
||||
assert a_i.type == integer
|
||||
assert Variable.deduced(Symbol('x', real=True)).type == real
|
||||
assert a_i.func(*a_i.args) == a_i
|
||||
|
||||
v_n2 = Variable.deduced(n, value=3.5, cast_check=False)
|
||||
assert v_n2.func(*v_n2.args) == v_n2
|
||||
assert abs(v_n2.value - 3.5) < 1e-15
|
||||
raises(ValueError, lambda: Variable.deduced(n, value=3.5, cast_check=True))
|
||||
|
||||
v_n3 = Variable.deduced(n)
|
||||
assert v_n3.type == integer
|
||||
assert str(v_n3) == 'Variable(n, type=integer)'
|
||||
assert Variable.deduced(z, value=3).type == integer
|
||||
assert Variable.deduced(z, value=3.0).type == real
|
||||
assert Variable.deduced(z, value=3.0+1j).type == complex_
|
||||
|
||||
|
||||
def test_Pointer():
|
||||
p = Pointer(x)
|
||||
assert p.symbol == x
|
||||
assert p.type == untyped
|
||||
assert value_const not in p.attrs
|
||||
assert pointer_const not in p.attrs
|
||||
assert p.func(*p.args) == p
|
||||
|
||||
u = symbols('u', real=True)
|
||||
pu = Pointer(u, type=Type.from_expr(u), attrs={value_const, pointer_const})
|
||||
assert pu.symbol is u
|
||||
assert pu.type == real
|
||||
assert value_const in pu.attrs
|
||||
assert pointer_const in pu.attrs
|
||||
assert pu.func(*pu.args) == pu
|
||||
|
||||
i = symbols('i', integer=True)
|
||||
deref = pu[i]
|
||||
assert deref.indices == (i,)
|
||||
|
||||
|
||||
def test_Declaration():
|
||||
u = symbols('u', real=True)
|
||||
vu = Variable(u, type=Type.from_expr(u))
|
||||
assert Declaration(vu).variable.type == real
|
||||
vn = Variable(n, type=Type.from_expr(n))
|
||||
assert Declaration(vn).variable.type == integer
|
||||
|
||||
# PR 19107, does not allow comparison between expressions and Basic
|
||||
# lt = StrictLessThan(vu, vn)
|
||||
# assert isinstance(lt, StrictLessThan)
|
||||
|
||||
vuc = Variable(u, Type.from_expr(u), value=3.0, attrs={value_const})
|
||||
assert value_const in vuc.attrs
|
||||
assert pointer_const not in vuc.attrs
|
||||
decl = Declaration(vuc)
|
||||
assert decl.variable == vuc
|
||||
assert isinstance(decl.variable.value, Float)
|
||||
assert decl.variable.value == 3.0
|
||||
assert decl.func(*decl.args) == decl
|
||||
assert vuc.as_Declaration() == decl
|
||||
assert vuc.as_Declaration(value=None, attrs=None) == Declaration(vu)
|
||||
|
||||
vy = Variable(y, type=integer, value=3)
|
||||
decl2 = Declaration(vy)
|
||||
assert decl2.variable == vy
|
||||
assert decl2.variable.value == Integer(3)
|
||||
|
||||
vi = Variable(i, type=Type.from_expr(i), value=3.0)
|
||||
decl3 = Declaration(vi)
|
||||
assert decl3.variable.type == integer
|
||||
assert decl3.variable.value == 3.0
|
||||
|
||||
raises(ValueError, lambda: Declaration(vi, 42))
|
||||
|
||||
|
||||
def test_IntBaseType():
|
||||
assert intc.name == String('intc')
|
||||
assert intc.args == (intc.name,)
|
||||
assert str(IntBaseType('a').name) == 'a'
|
||||
|
||||
|
||||
def test_FloatType():
|
||||
assert f16.dig == 3
|
||||
assert f32.dig == 6
|
||||
assert f64.dig == 15
|
||||
assert f80.dig == 18
|
||||
assert f128.dig == 33
|
||||
|
||||
assert f16.decimal_dig == 5
|
||||
assert f32.decimal_dig == 9
|
||||
assert f64.decimal_dig == 17
|
||||
assert f80.decimal_dig == 21
|
||||
assert f128.decimal_dig == 36
|
||||
|
||||
assert f16.max_exponent == 16
|
||||
assert f32.max_exponent == 128
|
||||
assert f64.max_exponent == 1024
|
||||
assert f80.max_exponent == 16384
|
||||
assert f128.max_exponent == 16384
|
||||
|
||||
assert f16.min_exponent == -13
|
||||
assert f32.min_exponent == -125
|
||||
assert f64.min_exponent == -1021
|
||||
assert f80.min_exponent == -16381
|
||||
assert f128.min_exponent == -16381
|
||||
|
||||
assert abs(f16.eps / Float('0.00097656', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.eps / Float('1.1920929e-07', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.eps / Float('2.2204460492503131e-16', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.eps / Float('1.08420217248550443401e-19', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.eps / Float(' 1.92592994438723585305597794258492732e-34', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert abs(f16.max / Float('65504', precision=16) - 1) < .1*10**-f16.dig
|
||||
assert abs(f32.max / Float('3.40282347e+38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.max / Float('1.79769313486231571e+308', precision=64) - 1) < 0.1*10**-f64.dig # cf. np.finfo(np.float64).max
|
||||
assert abs(f80.max / Float('1.18973149535723176502e+4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.max / Float('1.18973149535723176508575932662800702e+4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
# cf. np.finfo(np.float32).tiny
|
||||
assert abs(f16.tiny / Float('6.1035e-05', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.tiny / Float('1.17549435e-38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.tiny / Float('2.22507385850720138e-308', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.tiny / Float('3.36210314311209350626e-4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.tiny / Float('3.3621031431120935062626778173217526e-4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert f64.cast_check(0.5) == Float(0.5, 17)
|
||||
assert abs(f64.cast_check(3.7) - 3.7) < 3e-17
|
||||
assert isinstance(f64.cast_check(3), (Float, float))
|
||||
|
||||
assert f64.cast_nocheck(oo) == float('inf')
|
||||
assert f64.cast_nocheck(-oo) == float('-inf')
|
||||
assert f64.cast_nocheck(float(oo)) == float('inf')
|
||||
assert f64.cast_nocheck(float(-oo)) == float('-inf')
|
||||
assert math.isnan(f64.cast_nocheck(nan))
|
||||
|
||||
assert f32 != f64
|
||||
assert f64 == f64.func(*f64.args)
|
||||
|
||||
|
||||
def test_Type__cast_check__floating_point():
|
||||
raises(ValueError, lambda: f32.cast_check(123.45678949))
|
||||
raises(ValueError, lambda: f32.cast_check(12.345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(1.2345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(.12345678949))
|
||||
assert abs(123.456789049 - f32.cast_check(123.456789049) - 4.9e-8) < 1e-8
|
||||
assert abs(0.12345678904 - f32.cast_check(0.12345678904) - 4e-11) < 1e-11
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') # 21 decimals
|
||||
assert abs(dcm21 - f64.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
|
||||
f80.cast_check(Float('0.12345678901234567890103', precision=88))
|
||||
raises(ValueError, lambda: f80.cast_check(Float('0.12345678901234567890149', precision=88)))
|
||||
|
||||
v10 = 12345.67894
|
||||
raises(ValueError, lambda: f32.cast_check(v10))
|
||||
assert abs(Float(str(v10), precision=64+8) - f64.cast_check(v10)) < v10*1e-16
|
||||
|
||||
assert abs(f32.cast_check(2147483647) - 2147483650) < 1
|
||||
|
||||
|
||||
def test_Type__cast_check__complex_floating_point():
|
||||
val9_11 = 123.456789049 + 0.123456789049j
|
||||
raises(ValueError, lambda: c64.cast_check(.12345678949 + .12345678949j))
|
||||
assert abs(val9_11 - c64.cast_check(val9_11) - 4.9e-8) < 1e-8
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') + 1e-20j # 21 decimals
|
||||
assert abs(dcm21 - c128.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
v19 = Float('0.1234567890123456749') + 1j*Float('0.1234567890123456749')
|
||||
raises(ValueError, lambda: c128.cast_check(v19))
|
||||
|
||||
|
||||
def test_While():
|
||||
xpp = AddAugmentedAssignment(x, 1)
|
||||
whl1 = While(x < 2, [xpp])
|
||||
assert whl1.condition.args[0] == x
|
||||
assert whl1.condition.args[1] == 2
|
||||
assert whl1.condition == Lt(x, 2, evaluate=False)
|
||||
assert whl1.body.args == (xpp,)
|
||||
assert whl1.func(*whl1.args) == whl1
|
||||
|
||||
cblk = CodeBlock(AddAugmentedAssignment(x, 1))
|
||||
whl2 = While(x < 2, cblk)
|
||||
assert whl1 == whl2
|
||||
assert whl1 != While(x < 3, [xpp])
|
||||
|
||||
|
||||
def test_Scope():
|
||||
assign = Assignment(x, y)
|
||||
incr = AddAugmentedAssignment(x, 1)
|
||||
scp = Scope([assign, incr])
|
||||
cblk = CodeBlock(assign, incr)
|
||||
assert scp.body == cblk
|
||||
assert scp == Scope(cblk)
|
||||
assert scp != Scope([incr, assign])
|
||||
assert scp.func(*scp.args) == scp
|
||||
|
||||
|
||||
def test_Print():
|
||||
fmt = "%d %.3f"
|
||||
ps = Print([n, x], fmt)
|
||||
assert str(ps.format_string) == fmt
|
||||
assert ps.print_args == Tuple(n, x)
|
||||
assert ps.args == (Tuple(n, x), QuotedString(fmt), none)
|
||||
assert ps == Print((n, x), fmt)
|
||||
assert ps != Print([x, n], fmt)
|
||||
assert ps.func(*ps.args) == ps
|
||||
|
||||
ps2 = Print([n, x])
|
||||
assert ps2 == Print([n, x])
|
||||
assert ps2 != ps
|
||||
assert ps2.format_string == None
|
||||
|
||||
|
||||
def test_FunctionPrototype_and_FunctionDefinition():
|
||||
vx = Variable(x, type=real)
|
||||
vn = Variable(n, type=integer)
|
||||
fp1 = FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1.return_type == real
|
||||
assert fp1.name == String('power')
|
||||
assert fp1.parameters == Tuple(vx, vn)
|
||||
assert fp1 == FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1 != FunctionPrototype(real, 'power', [vn, vx])
|
||||
assert fp1.func(*fp1.args) == fp1
|
||||
|
||||
|
||||
body = [Assignment(x, x**n), Return(x)]
|
||||
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1.return_type == real
|
||||
assert str(fd1.name) == 'power'
|
||||
assert fd1.parameters == Tuple(vx, vn)
|
||||
assert fd1.body == CodeBlock(*body)
|
||||
assert fd1 == FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1 != FunctionDefinition(real, 'power', [vx, vn], body[::-1])
|
||||
assert fd1.func(*fd1.args) == fd1
|
||||
|
||||
fp2 = FunctionPrototype.from_FunctionDefinition(fd1)
|
||||
assert fp2 == fp1
|
||||
|
||||
fd2 = FunctionDefinition.from_FunctionPrototype(fp1, body)
|
||||
assert fd2 == fd1
|
||||
|
||||
|
||||
def test_Return():
|
||||
rs = Return(x)
|
||||
assert rs.args == (x,)
|
||||
assert rs == Return(x)
|
||||
assert rs != Return(y)
|
||||
assert rs.func(*rs.args) == rs
|
||||
|
||||
|
||||
def test_FunctionCall():
|
||||
fc = FunctionCall('power', (x, 3))
|
||||
assert fc.function_args[0] == x
|
||||
assert fc.function_args[1] == 3
|
||||
assert len(fc.function_args) == 2
|
||||
assert isinstance(fc.function_args[1], Integer)
|
||||
assert fc == FunctionCall('power', (x, 3))
|
||||
assert fc != FunctionCall('power', (3, x))
|
||||
assert fc != FunctionCall('Power', (x, 3))
|
||||
assert fc.func(*fc.args) == fc
|
||||
|
||||
fc2 = FunctionCall('fma', [2, 3, 4])
|
||||
assert len(fc2.function_args) == 3
|
||||
assert fc2.function_args[0] == 2
|
||||
assert fc2.function_args[1] == 3
|
||||
assert fc2.function_args[2] == 4
|
||||
assert str(fc2) in ( # not sure if QuotedString is a better default...
|
||||
'FunctionCall(fma, function_args=(2, 3, 4))',
|
||||
'FunctionCall("fma", function_args=(2, 3, 4))',
|
||||
)
|
||||
|
||||
def test_ast_replace():
|
||||
x = Variable('x', real)
|
||||
y = Variable('y', real)
|
||||
n = Variable('n', integer)
|
||||
|
||||
pwer = FunctionDefinition(real, 'pwer', [x, n], [pow(x.symbol, n.symbol)])
|
||||
pname = pwer.name
|
||||
pcall = FunctionCall('pwer', [y, 3])
|
||||
|
||||
tree1 = CodeBlock(pwer, pcall)
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
for a, b in zip(tree1, [pwer, pcall]):
|
||||
assert a == b
|
||||
|
||||
tree2 = tree1.replace(pname, String('power'))
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
assert str(tree2.args[0].name) == 'power'
|
||||
assert str(tree2.args[1].name) == 'power'
|
||||
@@ -0,0 +1,186 @@
|
||||
from sympy.core.numbers import (Rational, pi)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.codegen.cfunctions import (
|
||||
expm1, log1p, exp2, log2, fma, log10, Sqrt, Cbrt, hypot, isnan, isinf
|
||||
)
|
||||
from sympy.core.function import expand_log
|
||||
|
||||
|
||||
def test_expm1():
|
||||
# Eval
|
||||
assert expm1(0) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert expm1(x).expand(func=True) - exp(x) == -1
|
||||
assert expm1(x).rewrite('tractable') - exp(x) == -1
|
||||
assert expm1(x).rewrite('exp') - exp(x) == -1
|
||||
|
||||
# Precision
|
||||
assert not ((exp(1e-10).evalf() - 1) - 1e-10 - 5e-21) < 1e-22 # for comparison
|
||||
assert abs(expm1(1e-10).evalf() - 1e-10 - 5e-21) < 1e-22
|
||||
|
||||
# Properties
|
||||
assert expm1(x).is_real
|
||||
assert expm1(x).is_finite
|
||||
|
||||
# Diff
|
||||
assert expm1(42*x).diff(x) - 42*exp(42*x) == 0
|
||||
assert expm1(42*x).diff(x) - expm1(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_log1p():
|
||||
# Eval
|
||||
assert log1p(0) == 0
|
||||
d = S(10)
|
||||
assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert log1p(x).expand(func=True) - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('tractable') - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('log') - log(x + 1) == 0
|
||||
|
||||
# Precision
|
||||
assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison
|
||||
assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100
|
||||
|
||||
# Properties
|
||||
assert log1p(-2**Rational(-1, 2)).is_real
|
||||
|
||||
assert not log1p(-1).is_finite
|
||||
assert log1p(pi).is_finite
|
||||
|
||||
assert not log1p(x).is_positive
|
||||
assert log1p(Symbol('y', positive=True)).is_positive
|
||||
|
||||
assert not log1p(x).is_zero
|
||||
assert log1p(Symbol('z', zero=True)).is_zero
|
||||
|
||||
assert not log1p(x).is_nonnegative
|
||||
assert log1p(Symbol('o', nonnegative=True)).is_nonnegative
|
||||
|
||||
# Diff
|
||||
assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0
|
||||
assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_exp2():
|
||||
# Eval
|
||||
assert exp2(2) == 4
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand
|
||||
assert exp2(x).expand(func=True) - 2**x == 0
|
||||
|
||||
# Diff
|
||||
assert exp2(42*x).diff(x) - 42*exp2(42*x)*log(2) == 0
|
||||
assert exp2(42*x).diff(x) - exp2(42*x).diff(x) == 0
|
||||
|
||||
|
||||
def test_log2():
|
||||
# Eval
|
||||
assert log2(8) == 3
|
||||
assert log2(pi) != log(pi)/log(2) # log2 should *save* (CPU) instructions
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
assert log2(x) != log(x)/log(2)
|
||||
assert log2(2**x) == x
|
||||
|
||||
# Expand
|
||||
assert log2(x).expand(func=True) - log(x)/log(2) == 0
|
||||
|
||||
# Diff
|
||||
assert log2(42*x).diff() - 1/(log(2)*x) == 0
|
||||
assert log2(42*x).diff() - log2(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_fma():
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
# Expand
|
||||
assert fma(x, y, z).expand(func=True) - x*y - z == 0
|
||||
|
||||
expr = fma(17*x, 42*y, 101*z)
|
||||
|
||||
# Diff
|
||||
assert expr.diff(x) - expr.expand(func=True).diff(x) == 0
|
||||
assert expr.diff(y) - expr.expand(func=True).diff(y) == 0
|
||||
assert expr.diff(z) - expr.expand(func=True).diff(z) == 0
|
||||
|
||||
assert expr.diff(x) - 17*42*y == 0
|
||||
assert expr.diff(y) - 17*42*x == 0
|
||||
assert expr.diff(z) - 101 == 0
|
||||
|
||||
|
||||
def test_log10():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert log10(x).expand(func=True) - log(x)/log(10) == 0
|
||||
|
||||
# Diff
|
||||
assert log10(42*x).diff(x) - 1/(log(10)*x) == 0
|
||||
assert log10(42*x).diff(x) - log10(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Cbrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Cbrt(x).expand(func=True) - x**Rational(1, 3) == 0
|
||||
|
||||
# Diff
|
||||
assert Cbrt(42*x).diff(x) - 42*(42*x)**(Rational(1, 3) - 1)/3 == 0
|
||||
assert Cbrt(42*x).diff(x) - Cbrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Sqrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Sqrt(x).expand(func=True) - x**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert Sqrt(42*x).diff(x) - 42*(42*x)**(S.Half - 1)/2 == 0
|
||||
assert Sqrt(42*x).diff(x) - Sqrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_hypot():
|
||||
x, y = symbols('x y')
|
||||
|
||||
# Expand
|
||||
assert hypot(x, y).expand(func=True) - (x**2 + y**2)**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(x) == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(y) == 0
|
||||
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - 2*17*17*x*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - 2*42*42*y*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
|
||||
|
||||
def test_isnan_isinf():
|
||||
x = Symbol('x')
|
||||
|
||||
# isinf
|
||||
assert isinf(+S.Infinity) == True
|
||||
assert isinf(-S.Infinity) == True
|
||||
assert isinf(S.Pi) == False
|
||||
isinfx = isinf(x)
|
||||
assert isinfx not in (False, True)
|
||||
assert isinfx.func is isinf
|
||||
assert isinfx.args == (x,)
|
||||
|
||||
# isnan
|
||||
assert isnan(S.NaN) == True
|
||||
assert isnan(S.Pi) == False
|
||||
isnanx = isnan(x)
|
||||
assert isnanx not in (False, True)
|
||||
assert isnanx.func is isnan
|
||||
assert isnanx.args == (x,)
|
||||
@@ -0,0 +1,112 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
|
||||
from sympy.codegen.cnodes import (
|
||||
alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
|
||||
sizeof, union, struct
|
||||
)
|
||||
|
||||
x, y = symbols('x y')
|
||||
|
||||
|
||||
def test_alignof():
|
||||
ax = alignof(x)
|
||||
assert ccode(ax) == 'alignof(x)'
|
||||
assert ax.func(*ax.args) == ax
|
||||
|
||||
|
||||
def test_CommaOperator():
|
||||
expr = CommaOperator(PreIncrement(x), 2*x)
|
||||
assert ccode(expr) == '(++(x), 2*x)'
|
||||
assert expr.func(*expr.args) == expr
|
||||
|
||||
|
||||
def test_goto_Label():
|
||||
s = 'early_exit'
|
||||
g = goto(s)
|
||||
assert g.func(*g.args) == g
|
||||
assert g != goto('foobar')
|
||||
assert ccode(g) == 'goto early_exit'
|
||||
|
||||
l1 = Label(s)
|
||||
assert ccode(l1) == 'early_exit:'
|
||||
assert l1 == Label('early_exit')
|
||||
assert l1 != Label('foobar')
|
||||
|
||||
body = [PreIncrement(x)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"++(x);")
|
||||
|
||||
body = [PreIncrement(x), PreDecrement(y)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"{\n ++(x);\n --(y);\n}")
|
||||
|
||||
|
||||
def test_PreDecrement():
|
||||
p = PreDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '--(x)'
|
||||
|
||||
|
||||
def test_PostDecrement():
|
||||
p = PostDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)--'
|
||||
|
||||
|
||||
def test_PreIncrement():
|
||||
p = PreIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '++(x)'
|
||||
|
||||
|
||||
def test_PostIncrement():
|
||||
p = PostIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)++'
|
||||
|
||||
|
||||
def test_sizeof():
|
||||
typename = 'unsigned int'
|
||||
sz = sizeof(typename)
|
||||
assert ccode(sz) == 'sizeof(%s)' % typename
|
||||
assert sz.func(*sz.args) == sz
|
||||
assert not sz.is_Atom
|
||||
assert sz.atoms() == {String('unsigned int'), String('sizeof')}
|
||||
|
||||
|
||||
def test_struct():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=float64)
|
||||
s = struct('vec2', [vx, vy])
|
||||
assert s.func(*s.args) == s
|
||||
assert s == struct('vec2', (vx, vy))
|
||||
assert s != struct('vec2', (vy, vx))
|
||||
assert str(s.name) == 'vec2'
|
||||
assert len(s.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in s.declarations)
|
||||
assert ccode(s) == (
|
||||
"struct vec2 {\n"
|
||||
" double x;\n"
|
||||
" double y;\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_union():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=int64)
|
||||
u = union('dualuse', [vx, vy])
|
||||
assert u.func(*u.args) == u
|
||||
assert u == union('dualuse', (vx, vy))
|
||||
assert str(u.name) == 'dualuse'
|
||||
assert len(u.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in u.declarations)
|
||||
assert ccode(u) == (
|
||||
"union dualuse {\n"
|
||||
" double x;\n"
|
||||
" int64_t y;\n"
|
||||
"}")
|
||||
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.codegen.ast import Type
|
||||
from sympy.codegen.cxxnodes import using
|
||||
from sympy.printing.codeprinter import cxxcode
|
||||
|
||||
x = Symbol('x')
|
||||
|
||||
def test_using():
|
||||
v = Type('std::vector')
|
||||
u1 = using(v)
|
||||
assert cxxcode(u1) == 'using std::vector'
|
||||
|
||||
u2 = using(v, 'vec')
|
||||
assert cxxcode(u2) == 'using vec = std::vector'
|
||||
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
import tempfile
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Print, Declaration, FunctionDefinition, Return, real,
|
||||
FunctionCall, Variable, Element, integer
|
||||
)
|
||||
from sympy.codegen.fnodes import (
|
||||
allocatable, ArrayConstructor, isign, dsign, cmplx, kind, literal_dp,
|
||||
Program, Module, use, Subroutine, dimension, assumed_extent, ImpliedDoLoop,
|
||||
intent_out, size, Do, SubroutineCall, sum_, array, bind_C
|
||||
)
|
||||
from sympy.codegen.futils import render_as_module
|
||||
from sympy.core.expr import unchanged
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import fcode
|
||||
from sympy.utilities._compilation import has_fortran, compile_run_strings, compile_link_import_strings
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, XFAIL
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
|
||||
def test_size():
|
||||
x = Symbol('x', real=True)
|
||||
sx = size(x)
|
||||
assert fcode(sx, source_format='free') == 'size(x)'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_size_assumed_shape():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
a = Symbol('a', real=True)
|
||||
body = [Return((sum_(a**2)/size(a))**.5)]
|
||||
arr = array(a, dim=[':'], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr], body)
|
||||
render_as_module([fd], 'mod_rms')
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('rms.f90', render_as_module([fd], 'mod_rms')),
|
||||
('main.f90', (
|
||||
'program myprog\n'
|
||||
'use mod_rms, only: rms\n'
|
||||
'real*8, dimension(4), parameter :: x = [4, 2, 2, 2]\n'
|
||||
'print "(f7.5)", dsqrt(7d0) - rms(x)\n'
|
||||
'end program\n'
|
||||
))
|
||||
], clean=True)
|
||||
assert '0.00000' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_ImpliedDoLoop():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
a, i = symbols('a i', integer=True)
|
||||
idl = ImpliedDoLoop(i**3, i, -3, 3, 2)
|
||||
ac = ArrayConstructor([-28, idl, 28])
|
||||
a = array(a, dim=[':'], attrs=[allocatable])
|
||||
prog = Program('idlprog', [
|
||||
a.as_Declaration(),
|
||||
Assignment(a, ac),
|
||||
Print([a])
|
||||
])
|
||||
fsrc = fcode(prog, standard=2003, source_format='free')
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fsrc)], clean=True)
|
||||
for numstr in '-28 -27 -1 1 27 28'.split():
|
||||
assert numstr in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Program():
|
||||
x = Symbol('x', real=True)
|
||||
vx = Variable.deduced(x, 42)
|
||||
decl = Declaration(vx)
|
||||
prnt = Print([x, x+1])
|
||||
prog = Program('foo', [decl, prnt])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fcode(prog, standard=90))], clean=True)
|
||||
assert '42' in stdout
|
||||
assert '43' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Module():
|
||||
x = Symbol('x', real=True)
|
||||
v_x = Variable.deduced(x)
|
||||
sq = FunctionDefinition(real, 'sqr', [v_x], [Return(x**2)])
|
||||
mod_sq = Module('mod_sq', [], [sq])
|
||||
sq_call = FunctionCall('sqr', [42.])
|
||||
prg_sq = Program('foobar', [
|
||||
use('mod_sq', only=['sqr']),
|
||||
Print(['"Square of 42 = "', sq_call])
|
||||
])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('mod_sq.f90', fcode(mod_sq, standard=90)),
|
||||
('main.f90', fcode(prg_sq, standard=90))
|
||||
], clean=True)
|
||||
assert '42' in stdout
|
||||
assert str(42**2) in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_Subroutine():
|
||||
# Code to generate the subroutine in the example from
|
||||
# http://www.fortran90.org/src/best-practices.html#arrays
|
||||
r = Symbol('r', real=True)
|
||||
i = Symbol('i', integer=True)
|
||||
v_r = Variable.deduced(r, attrs=(dimension(assumed_extent), intent_out))
|
||||
v_i = Variable.deduced(i)
|
||||
v_n = Variable('n', integer)
|
||||
do_loop = Do([
|
||||
Assignment(Element(r, [i]), literal_dp(1)/i**2)
|
||||
], i, 1, v_n)
|
||||
sub = Subroutine("f", [v_r], [
|
||||
Declaration(v_n),
|
||||
Declaration(v_i),
|
||||
Assignment(v_n, size(r)),
|
||||
do_loop
|
||||
])
|
||||
x = Symbol('x', real=True)
|
||||
v_x3 = Variable.deduced(x, attrs=[dimension(3)])
|
||||
mod = Module('mymod', definitions=[sub])
|
||||
prog = Program('foo', [
|
||||
use(mod, only=[sub]),
|
||||
Declaration(v_x3),
|
||||
SubroutineCall(sub, [v_x3]),
|
||||
Print([sum_(v_x3), v_x3])
|
||||
])
|
||||
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('a.f90', fcode(mod, standard=90)),
|
||||
('b.f90', fcode(prog, standard=90))
|
||||
], clean=True)
|
||||
ref = [1.0/i**2 for i in range(1, 4)]
|
||||
assert str(sum(ref))[:-3] in stdout
|
||||
for _ in ref:
|
||||
assert str(_)[:-3] in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
def test_isign():
|
||||
x = Symbol('x', integer=True)
|
||||
assert unchanged(isign, 1, x)
|
||||
assert fcode(isign(1, x), standard=95, source_format='free') == 'isign(1, x)'
|
||||
|
||||
|
||||
def test_dsign():
|
||||
x = Symbol('x')
|
||||
assert unchanged(dsign, 1, x)
|
||||
assert fcode(dsign(literal_dp(1), x), standard=95, source_format='free') == 'dsign(1d0, x)'
|
||||
|
||||
|
||||
def test_cmplx():
|
||||
x = Symbol('x')
|
||||
assert unchanged(cmplx, 1, x)
|
||||
|
||||
|
||||
def test_kind():
|
||||
x = Symbol('x')
|
||||
assert unchanged(kind, x)
|
||||
|
||||
|
||||
def test_literal_dp():
|
||||
assert fcode(literal_dp(0), source_format='free') == '0d0'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_bind_C():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
if not np:
|
||||
skip("NumPy not found.")
|
||||
|
||||
a = Symbol('a', real=True)
|
||||
s = Symbol('s', integer=True)
|
||||
body = [Return((sum_(a**2)/s)**.5)]
|
||||
arr = array(a, dim=[s], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
|
||||
f_mod = render_as_module([fd], 'mod_rms')
|
||||
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('rms.f90', f_mod),
|
||||
('_rms.pyx', (
|
||||
"#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double rms(double*, int*)\n"
|
||||
"def py_rms(double[::1] x):\n"
|
||||
" cdef int s = x.size\n"
|
||||
" return rms(&x[0], &s)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_rms(np.array([2., 4., 2., 2.])) - 7**0.5) < 1e-14
|
||||
@@ -0,0 +1,50 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.core.function import Function
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.dense import zeros
|
||||
from sympy.simplify.simplify import simplify
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.printing.numpy import NumPyPrinter
|
||||
from sympy.testing.pytest import skip
|
||||
from sympy.external import import_module
|
||||
|
||||
|
||||
def test_matrix_solve_issue_24862():
|
||||
A = Matrix(3, 3, symbols('a:9'))
|
||||
b = Matrix(3, 1, symbols('b:3'))
|
||||
hash(MatrixSolve(A, b))
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_exact():
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
x_lu = A.LUsolve(b)
|
||||
dxdq_lu = A.LUsolve(b.diff(q) - A.diff(q) * A.LUsolve(b))
|
||||
assert simplify(x_lu.diff(q) - dxdq_lu) == zeros(2, 1)
|
||||
# dxdq_ms is the MatrixSolve equivalent of dxdq_lu
|
||||
dxdq_ms = MatrixSolve(A, b.diff(q) - A.diff(q) * MatrixSolve(A, b))
|
||||
assert MatrixSolve(A, b).diff(q) == dxdq_ms
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_numpy():
|
||||
np = import_module('numpy')
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
dx_lu = A.LUsolve(b).diff(q)
|
||||
subs = {a11.diff(q): 0.2, a12.diff(q): 0.3, a21.diff(q): 0.1,
|
||||
a22.diff(q): 0.5, b1.diff(q): 0.4, b2.diff(q): 0.9,
|
||||
a11: 1.3, a12: 0.5, a21: 1.2, a22: 4, b1: 6.2, b2: 3.5}
|
||||
p, p_vals = zip(*subs.items())
|
||||
dx_sm = MatrixSolve(A, b).diff(q)
|
||||
np.testing.assert_allclose(
|
||||
lambdify(p, dx_sm, printer=NumPyPrinter)(*p_vals),
|
||||
lambdify(p, dx_lu, printer=NumPyPrinter)(*p_vals))
|
||||
@@ -0,0 +1,69 @@
|
||||
from itertools import product
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import Max, Min
|
||||
from sympy.printing.repr import srepr
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, minimum, maximum, amax, amin
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
def test_logaddexp():
|
||||
lae_xy = logaddexp(x, y)
|
||||
ref_xy = log(exp(x) + exp(y))
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).simplify() == 0
|
||||
|
||||
one_third_e = 1*exp(1)/3
|
||||
two_thirds_e = 2*exp(1)/3
|
||||
logThirdE = log(one_third_e)
|
||||
logTwoThirdsE = log(two_thirds_e)
|
||||
lae_sum_to_e = logaddexp(logThirdE, logTwoThirdsE)
|
||||
assert lae_sum_to_e.rewrite(log) == 1
|
||||
assert lae_sum_to_e.simplify() == 1
|
||||
was = logaddexp(2, 3)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with 2, 3
|
||||
|
||||
|
||||
def test_logaddexp2():
|
||||
lae2_xy = logaddexp2(x, y)
|
||||
ref2_xy = log(2**x + 2**y)/log(2)
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae2_xy.diff(wrt, deriv_order) -
|
||||
ref2_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).cancel() == 0
|
||||
|
||||
def lb(x):
|
||||
return log(x)/log(2)
|
||||
|
||||
two_thirds = S.One*2/3
|
||||
four_thirds = 2*two_thirds
|
||||
lbTwoThirds = lb(two_thirds)
|
||||
lbFourThirds = lb(four_thirds)
|
||||
lae2_sum_to_2 = logaddexp2(lbTwoThirds, lbFourThirds)
|
||||
assert lae2_sum_to_2.rewrite(log) == 1
|
||||
assert lae2_sum_to_2.simplify() == 1
|
||||
was = logaddexp2(x, y)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with x, y
|
||||
|
||||
|
||||
def test_minimum_maximum():
|
||||
for MM, mm in zip([Min, Max], [minimum, maximum]):
|
||||
ref = MM(x, y, z)
|
||||
m = mm(x, y, z)
|
||||
assert m != ref
|
||||
assert m.rewrite(MM) == ref
|
||||
|
||||
|
||||
def test_amin_amax():
|
||||
for am in [amin, amax]:
|
||||
assert am(x).array == x
|
||||
assert am(x).axis == None
|
||||
assert am(x, axis=3).axis == 3
|
||||
with raises(ValueError):
|
||||
am(x, y, z)
|
||||
@@ -0,0 +1,13 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.pynodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
@@ -0,0 +1,7 @@
|
||||
from sympy.codegen.ast import Print
|
||||
from sympy.codegen.pyutils import render_as_module
|
||||
|
||||
def test_standard():
|
||||
ast = Print('x y'.split(), r"coordinate: %12.5g %12.5g\n")
|
||||
assert render_as_module(ast, standard='python3') == \
|
||||
'\n\nprint("coordinate: %12.5g %12.5g\\n" % (x, y), end="")'
|
||||
@@ -0,0 +1,479 @@
|
||||
import tempfile
|
||||
from sympy.core.numbers import pi, Rational
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.assumptions import assuming, Q
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
from sympy.codegen.rewriting import (
|
||||
optimize, cosm1_opt, log2_opt, exp2_opt, expm1_opt, log1p_opt, powm1_opt, optims_c99,
|
||||
create_expand_pow_optimization, matinv_opt, logaddexp_opt, logaddexp2_opt,
|
||||
optims_numpy, optims_scipy, sinc_opts, FuncMinusOneOptim
|
||||
)
|
||||
from sympy.testing.pytest import XFAIL, skip
|
||||
from sympy.utilities import lambdify
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
|
||||
cython = import_module('cython')
|
||||
numpy = import_module('numpy')
|
||||
scipy = import_module('scipy')
|
||||
|
||||
|
||||
def test_log2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 7*log(3*x + 5)/(log(2))
|
||||
opt1 = optimize(expr1, [log2_opt])
|
||||
assert opt1 == 7*log2(3*x + 5)
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = 3*log(5*x + 7)/(13*log(2))
|
||||
opt2 = optimize(expr2, [log2_opt])
|
||||
assert opt2 == 3*log2(5*x + 7)/13
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2)
|
||||
opt3 = optimize(expr3, [log2_opt])
|
||||
assert opt3 == log2(x)
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x)/log(2) + log(x+1)
|
||||
opt4 = optimize(expr4, [log2_opt])
|
||||
assert opt4 == log2(x) + log(2)*log2(x+1)
|
||||
assert opt4.rewrite(log) == expr4
|
||||
|
||||
expr5 = log(17)
|
||||
opt5 = optimize(expr5, [log2_opt])
|
||||
assert opt5 == expr5
|
||||
|
||||
expr6 = log(x + 3)/log(2)
|
||||
opt6 = optimize(expr6, [log2_opt])
|
||||
assert str(opt6) == 'log2(x + 3)'
|
||||
assert opt6.rewrite(log) == expr6
|
||||
|
||||
|
||||
def test_exp2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 1 + 2**x
|
||||
opt1 = optimize(expr1, [exp2_opt])
|
||||
assert opt1 == 1 + exp2(x)
|
||||
assert opt1.rewrite(Pow) == expr1
|
||||
|
||||
expr2 = 1 + 3**x
|
||||
assert expr2 == optimize(expr2, [exp2_opt])
|
||||
|
||||
|
||||
def test_expm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = exp(x) - 1
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert expm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(exp) == expr1
|
||||
|
||||
expr2 = 3*exp(x) - 3
|
||||
opt2 = optimize(expr2, [expm1_opt])
|
||||
assert 3*expm1(x) == opt2
|
||||
assert opt2.rewrite(exp) == expr2
|
||||
|
||||
expr3 = 3*exp(x) - 5
|
||||
opt3 = optimize(expr3, [expm1_opt])
|
||||
assert 3*expm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(exp) == expr3
|
||||
expm1_opt_non_opportunistic = FuncMinusOneOptim(exp, expm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [expm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [expm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [expm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*exp(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [expm1_opt])
|
||||
assert 3*expm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(exp) == expr4
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, [expm1_opt])
|
||||
assert 3*expm1(2*x) == opt5
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = (2*exp(x) + 1)/(exp(x) + 1) + 1
|
||||
opt6 = optimize(expr6, [expm1_opt])
|
||||
assert opt6.count_ops() <= expr6.count_ops()
|
||||
|
||||
def ev(e):
|
||||
return e.subs(x, 3).evalf()
|
||||
assert abs(ev(expr6) - ev(opt6)) < 1e-15
|
||||
|
||||
y = Symbol('y')
|
||||
expr7 = (2*exp(x) - 1)/(1 - exp(y)) - 1/(1-exp(y))
|
||||
opt7 = optimize(expr7, [expm1_opt])
|
||||
assert -2*expm1(x)/expm1(y) == opt7
|
||||
assert (opt7.rewrite(exp) - expr7).factor() == 0
|
||||
|
||||
expr8 = (1+exp(x))**2 - 4
|
||||
opt8 = optimize(expr8, [expm1_opt])
|
||||
tgt8a = (exp(x) + 3)*expm1(x)
|
||||
tgt8b = 2*expm1(x) + expm1(2*x)
|
||||
# Both tgt8a & tgt8b seem to give full precision (~16 digits for double)
|
||||
# for x=1e-7 (compare with expr8 which only achieves ~8 significant digits).
|
||||
# If we can show that either tgt8a or tgt8b is preferable, we can
|
||||
# change this test to ensure the preferable version is returned.
|
||||
assert (tgt8a - tgt8b).rewrite(exp).factor() == 0
|
||||
assert opt8 in (tgt8a, tgt8b)
|
||||
assert (opt8.rewrite(exp) - expr8).factor() == 0
|
||||
|
||||
expr9 = sin(expr8)
|
||||
opt9 = optimize(expr9, [expm1_opt])
|
||||
tgt9a = sin(tgt8a)
|
||||
tgt9b = sin(tgt8b)
|
||||
assert opt9 in (tgt9a, tgt9b)
|
||||
assert (opt9.rewrite(exp) - expr9.rewrite(exp)).factor().is_zero
|
||||
|
||||
|
||||
def test_expm1_two_exp_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = exp(x) + exp(y) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert opt1 == expm1(x) + expm1(y)
|
||||
|
||||
|
||||
def test_cosm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert cosm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(cos) == expr1
|
||||
|
||||
expr2 = 3*cos(x) - 3
|
||||
opt2 = optimize(expr2, [cosm1_opt])
|
||||
assert 3*cosm1(x) == opt2
|
||||
assert opt2.rewrite(cos) == expr2
|
||||
|
||||
expr3 = 3*cos(x) - 5
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert 3*cosm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(cos) == expr3
|
||||
cosm1_opt_non_opportunistic = FuncMinusOneOptim(cos, cosm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [cosm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [cosm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [cosm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*cos(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [cosm1_opt])
|
||||
assert 3*cosm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(cos) == expr4
|
||||
|
||||
expr5 = 3*cos(2*x) - 3
|
||||
opt5 = optimize(expr5, [cosm1_opt])
|
||||
assert 3*cosm1(2*x) == opt5
|
||||
assert opt5.rewrite(cos) == expr5
|
||||
|
||||
expr6 = 2 - 2*cos(x)
|
||||
opt6 = optimize(expr6, [cosm1_opt])
|
||||
assert -2*cosm1(x) == opt6
|
||||
assert opt6.rewrite(cos) == expr6
|
||||
|
||||
|
||||
def test_cosm1_two_cos_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = cos(x) + cos(y) - 2
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == cosm1(x) + cosm1(y)
|
||||
|
||||
|
||||
def test_expm1_cosm1_mixed():
|
||||
x = Symbol('x')
|
||||
expr1 = exp(x) + cos(x) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt, cosm1_opt])
|
||||
assert opt1 == cosm1(x) + expm1(x)
|
||||
|
||||
|
||||
def _check_num_lambdify(expr, opt, val_subs, approx_ref, lambdify_kw=None, poorness=1e10):
|
||||
""" poorness=1e10 signifies that `expr` loses precision of at least ten decimal digits. """
|
||||
num_ref = expr.subs(val_subs).evalf()
|
||||
eps = numpy.finfo(numpy.float64).eps
|
||||
assert abs(num_ref - approx_ref) < approx_ref*eps
|
||||
f1 = lambdify(list(val_subs.keys()), opt, **(lambdify_kw or {}))
|
||||
args_float = tuple(map(float, val_subs.values()))
|
||||
num_err1 = abs(f1(*args_float) - approx_ref)
|
||||
assert num_err1 < abs(num_ref*eps)
|
||||
f2 = lambdify(list(val_subs.keys()), expr, **(lambdify_kw or {}))
|
||||
num_err2 = abs(f2(*args_float) - approx_ref)
|
||||
assert num_err2 > abs(num_ref*eps*poorness) # this only ensures that the *test* works as intended
|
||||
|
||||
|
||||
def test_cosm1_apart():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 1/cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == -cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr1, opt1, {x: S(10)**-30}, 5e-61, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr2 = 2/cos(x) - 2
|
||||
opt2 = optimize(expr2, optims_scipy)
|
||||
assert opt2 == -2*cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr2, opt2, {x: S(10)**-30}, 1e-60, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr3 = pi/cos(3*x) - pi
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert opt3 == -pi*cosm1(3*x)/cos(3*x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr3, opt3, {x: S(10)**-30/3}, float(5e-61*pi), lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
|
||||
def test_powm1():
|
||||
args = x, y = map(Symbol, "xy")
|
||||
|
||||
expr1 = x**y - 1
|
||||
opt1 = optimize(expr1, [powm1_opt])
|
||||
assert opt1 == powm1(x, y)
|
||||
for arg in args:
|
||||
assert expr1.diff(arg) == opt1.diff(arg)
|
||||
if scipy and tuple(map(int, scipy.version.version.split('.')[:3])) >= (1, 10, 0):
|
||||
subs1_a = {x: Rational(*(1.0+1e-13).as_integer_ratio()), y: pi}
|
||||
ref1_f64_a = 3.139081648208105e-13
|
||||
_check_num_lambdify(expr1, opt1, subs1_a, ref1_f64_a, lambdify_kw={"modules": 'scipy'}, poorness=10**11)
|
||||
|
||||
subs1_b = {x: pi, y: Rational(*(1e-10).as_integer_ratio())}
|
||||
ref1_f64_b = 1.1447298859149205e-10
|
||||
_check_num_lambdify(expr1, opt1, subs1_b, ref1_f64_b, lambdify_kw={"modules": 'scipy'}, poorness=10**9)
|
||||
|
||||
|
||||
def test_log1p_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = log(x + 1)
|
||||
opt1 = optimize(expr1, [log1p_opt])
|
||||
assert log1p(x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = log(3*x + 3)
|
||||
opt2 = optimize(expr2, [log1p_opt])
|
||||
assert log1p(x) + log(3) == opt2
|
||||
assert (opt2.rewrite(log) - expr2).simplify() == 0
|
||||
|
||||
expr3 = log(2*x + 1)
|
||||
opt3 = optimize(expr3, [log1p_opt])
|
||||
assert log1p(2*x) - opt3 == 0
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x+3)
|
||||
opt4 = optimize(expr4, [log1p_opt])
|
||||
assert str(opt4) == 'log(x + 3)'
|
||||
|
||||
|
||||
def test_optims_c99():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
|
||||
opt1 = optimize(expr1, optims_c99).simplify()
|
||||
assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
|
||||
assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1
|
||||
|
||||
expr2 = log(x)/log(2) + log(x + 1)
|
||||
opt2 = optimize(expr2, optims_c99)
|
||||
assert opt2 == log2(x) + log1p(x)
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2) + log(17*x + 17)
|
||||
opt3 = optimize(expr3, optims_c99)
|
||||
delta3 = opt3 - (log2(x) + log(17) + log1p(x))
|
||||
assert delta3 == 0
|
||||
assert (opt3.rewrite(log) - expr3).simplify() == 0
|
||||
|
||||
expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
|
||||
opt4 = optimize(expr4, optims_c99).simplify()
|
||||
delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
|
||||
assert delta4 == 0
|
||||
assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, optims_c99)
|
||||
delta5 = opt5 - 3*expm1(2*x)
|
||||
assert delta5 == 0
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = exp(2*x) - 3
|
||||
opt6 = optimize(expr6, optims_c99)
|
||||
assert opt6 in (expm1(2*x) - 2, expr6) # expm1(2*x) - 2 is not better or worse
|
||||
|
||||
expr7 = log(3*x + 3)
|
||||
opt7 = optimize(expr7, optims_c99)
|
||||
delta7 = opt7 - (log(3) + log1p(x))
|
||||
assert delta7 == 0
|
||||
assert (opt7.rewrite(log) - expr7).simplify() == 0
|
||||
|
||||
expr8 = log(2*x + 3)
|
||||
opt8 = optimize(expr8, optims_c99)
|
||||
assert opt8 == expr8
|
||||
|
||||
|
||||
def test_create_expand_pow_optimization():
|
||||
cc = lambda x: ccode(
|
||||
optimize(x, [create_expand_pow_optimization(4)]))
|
||||
x = Symbol('x')
|
||||
assert cc(x**4) == 'x*x*x*x'
|
||||
assert cc(x**4 + x**2) == 'x*x + x*x*x*x'
|
||||
assert cc(x**5 + x**4) == 'pow(x, 5) + x*x*x*x'
|
||||
assert cc(sin(x)**4) == 'pow(sin(x), 4)'
|
||||
# gh issue 15335
|
||||
assert cc(x**(-4)) == '1.0/(x*x*x*x)'
|
||||
assert cc(x**(-5)) == 'pow(x, -5)'
|
||||
assert cc(-x**4) == '-(x*x*x*x)'
|
||||
assert cc(x**4 - x**2) == '-(x*x) + x*x*x*x'
|
||||
i = Symbol('i', integer=True)
|
||||
assert cc(x**i - x**2) == 'pow(x, i) - (x*x)'
|
||||
y = Symbol('y', real=True)
|
||||
assert cc(Abs(exp(y**4))) == "exp(y*y*y*y)"
|
||||
|
||||
# gh issue 20753
|
||||
cc2 = lambda x: ccode(optimize(x, [create_expand_pow_optimization(
|
||||
4, base_req=lambda b: b.is_Function)]))
|
||||
assert cc2(x**3 + sin(x)**3) == "pow(x, 3) + sin(x)*sin(x)*sin(x)"
|
||||
|
||||
|
||||
def test_matsolve():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
x = MatrixSymbol('x', n, 1)
|
||||
|
||||
with assuming(Q.fullrank(A)):
|
||||
assert optimize(A**(-1) * x, [matinv_opt]) == MatrixSolve(A, x)
|
||||
assert optimize(A**(-1) * x + x, [matinv_opt]) == MatrixSolve(A, x) + x
|
||||
|
||||
|
||||
def test_logaddexp_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(exp(x) + exp(y))
|
||||
opt1 = optimize(expr1, [logaddexp_opt])
|
||||
assert logaddexp(x, y) - opt1 == 0
|
||||
assert logaddexp(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_logaddexp2_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(2**x + 2**y)/log(2)
|
||||
opt1 = optimize(expr1, [logaddexp2_opt])
|
||||
assert logaddexp2(x, y) - opt1 == 0
|
||||
assert logaddexp2(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_sinc_opts():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, sinc_opts) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(x)/x : sinc(x),
|
||||
sin(2*x)/(2*x) : sinc(2*x),
|
||||
sin(3*x)/x : 3*sinc(3*x),
|
||||
x*sin(x) : x*sin(x)
|
||||
})
|
||||
|
||||
y = Symbol('y')
|
||||
check({
|
||||
sin(x*y)/(x*y) : sinc(x*y),
|
||||
y*sin(x/y)/x : sinc(x/y),
|
||||
sin(sin(x))/sin(x) : sinc(sin(x)),
|
||||
sin(3*sin(x))/sin(x) : 3*sinc(3*sin(x)),
|
||||
sin(x)/y : sin(x)/y
|
||||
})
|
||||
|
||||
|
||||
def test_optims_numpy():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x),
|
||||
log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3)
|
||||
})
|
||||
|
||||
|
||||
@XFAIL # room for improvement, ideally this test case should pass.
|
||||
def test_optims_numpy_TODO():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
check({
|
||||
log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y),
|
||||
exp(x*sin(y)/y) - 1: expm1(x*sinc(y))
|
||||
})
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_compiled_ccode_with_rewriting():
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
x = Symbol('x')
|
||||
about_two = 2**(58/S(117))*3**(97/S(117))*5**(4/S(39))*7**(92/S(117))/S(30)*pi
|
||||
# about_two: 1.999999999999581826
|
||||
unchanged = 2*exp(x) - about_two
|
||||
xval = S(10)**-11
|
||||
ref = unchanged.subs(x, xval).n(19) # 2.0418173913673213e-11
|
||||
|
||||
rewritten = optimize(2*exp(x) - about_two, [expm1_opt])
|
||||
|
||||
# Unfortunately, we need to call ``.n()`` on our expressions before we hand them
|
||||
# to ``ccode``, and we need to request a large number of significant digits.
|
||||
# In this test, results converged for double precision when the following number
|
||||
# of significant digits were chosen:
|
||||
NUMBER_OF_DIGITS = 25 # TODO: this should ideally be automatically handled.
|
||||
|
||||
func_c = '''
|
||||
#include <math.h>
|
||||
|
||||
double func_unchanged(double x) {
|
||||
return %(unchanged)s;
|
||||
}
|
||||
double func_rewritten(double x) {
|
||||
return %(rewritten)s;
|
||||
}
|
||||
''' % {"unchanged": ccode(unchanged.n(NUMBER_OF_DIGITS)),
|
||||
"rewritten": ccode(rewritten.n(NUMBER_OF_DIGITS))}
|
||||
|
||||
func_pyx = '''
|
||||
#cython: language_level=3
|
||||
cdef extern double func_unchanged(double)
|
||||
cdef extern double func_rewritten(double)
|
||||
def py_unchanged(x):
|
||||
return func_unchanged(x)
|
||||
def py_rewritten(x):
|
||||
return func_rewritten(x)
|
||||
'''
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings(
|
||||
[('func.c', func_c), ('_func.pyx', func_pyx)],
|
||||
build_dir=folder, compile_kwargs={"std": 'c99'}
|
||||
)
|
||||
err_rewritten = abs(mod.py_rewritten(1e-11) - ref)
|
||||
err_unchanged = abs(mod.py_unchanged(1e-11) - ref)
|
||||
assert 1e-27 < err_rewritten < 1e-25 # highly accurate.
|
||||
assert 1e-19 < err_unchanged < 1e-16 # quite poor.
|
||||
|
||||
# Tolerances used above were determined as follows:
|
||||
# >>> no_opt = unchanged.subs(x, xval.evalf()).evalf()
|
||||
# >>> with_opt = rewritten.n(25).subs(x, 1e-11).evalf()
|
||||
# >>> with_opt - ref, no_opt - ref
|
||||
# (1.1536301877952077e-26, 1.6547074214222335e-18)
|
||||
@@ -0,0 +1,44 @@
|
||||
from itertools import product
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.core.numbers import pi
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
|
||||
def test_cosm1():
|
||||
cm1_xy = cosm1(x*y)
|
||||
ref_xy = cos(x*y) - 1
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
cm1_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(cos).simplify() == 0
|
||||
|
||||
expr_minus2 = cosm1(pi)
|
||||
assert expr_minus2.rewrite(cos) == -2
|
||||
assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14
|
||||
assert cosm1(pi/2).simplify() == -1
|
||||
assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0
|
||||
|
||||
|
||||
def test_powm1():
|
||||
cases = {
|
||||
powm1(x, y): x**y - 1,
|
||||
powm1(x*y, z): (x*y)**z - 1,
|
||||
powm1(x, y*z): x**(y*z)-1,
|
||||
powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1
|
||||
}
|
||||
for pm1_e, ref_e in cases.items():
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
der = pm1_e.diff(wrt, deriv_order)
|
||||
ref = ref_e.diff(wrt, deriv_order)
|
||||
delta = (der - ref).rewrite(Pow)
|
||||
assert delta.simplify() == 0
|
||||
|
||||
eulers_constant_m1 = powm1(x, 1/log(x))
|
||||
assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1
|
||||
assert eulers_constant_m1.simplify() == exp(1) - 1
|
||||
Reference in New Issue
Block a user