add read me
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.abstract_nodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
hash(m)
|
||||
@@ -0,0 +1,180 @@
|
||||
import tempfile
|
||||
from sympy import log, Min, Max, sqrt
|
||||
from sympy.core.numbers import Float
|
||||
from sympy.core.symbol import Symbol, symbols
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString
|
||||
from sympy.codegen.algorithms import newtons_method, newtons_method_function
|
||||
from sympy.codegen.cfunctions import expm1
|
||||
from sympy.codegen.fnodes import bind_C
|
||||
from sympy.codegen.futils import render_as_module as f_module
|
||||
from sympy.codegen.pyutils import render_as_module as py_module
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, raises, skip_under_pyodide
|
||||
|
||||
cython = import_module('cython')
|
||||
wurlitzer = import_module('wurlitzer')
|
||||
|
||||
def test_newtons_method():
|
||||
x, dx, atol = symbols('x dx atol')
|
||||
expr = cos(x) - x**3
|
||||
algo = newtons_method(expr, x, atol, dx)
|
||||
assert algo.has(Assignment(dx, -expr/expr.diff(x)))
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__ccode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double)\n"
|
||||
"def py_newton(x):\n"
|
||||
" return newton(x)\n"))
|
||||
], build_dir=folder, compile_kwargs=compile_kw)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_newtons_method_function__fcode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')])
|
||||
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_fortran():
|
||||
skip("No Fortran compiler found.")
|
||||
|
||||
f_mod = f_module([func], 'mod_newton')
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton.f90', f_mod),
|
||||
('_newton.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double*)\n"
|
||||
"def py_newton(double x):\n"
|
||||
" return newton(&x)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
def test_newtons_method_function__pycode():
|
||||
x = Symbol('x', real=True)
|
||||
expr = cos(x) - x**3
|
||||
func = newtons_method_function(expr, x)
|
||||
py_mod = py_module(func)
|
||||
namespace = {}
|
||||
exec(py_mod, namespace, namespace)
|
||||
res = eval('newton(0.5)', namespace)
|
||||
assert abs(res - 0.865474033102) < 1e-12
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_newtons_method_function__ccode_parameters():
|
||||
args = x, A, k, p = symbols('x A k p')
|
||||
expr = A*cos(k*x) - p*x**3
|
||||
raises(ValueError, lambda: newtons_method_function(expr, x))
|
||||
use_wurlitzer = wurlitzer
|
||||
|
||||
func = newtons_method_function(expr, x, args, debug=use_wurlitzer)
|
||||
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
|
||||
compile_kw = {"std": 'c99'}
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('newton_par.c', ('#include <math.h>\n'
|
||||
'#include <stdio.h>\n') + ccode(func)),
|
||||
('_newton_par.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double newton(double, double, double, double)\n"
|
||||
"def py_newton(x, A=1, k=1, p=1):\n"
|
||||
" return newton(x, A, k, p)\n"))
|
||||
], compile_kwargs=compile_kw, build_dir=folder)
|
||||
|
||||
if use_wurlitzer:
|
||||
with wurlitzer.pipes() as (out, err):
|
||||
result = mod.py_newton(0.5)
|
||||
else:
|
||||
result = mod.py_newton(0.5)
|
||||
|
||||
assert abs(result - 0.865474033102) < 1e-12
|
||||
|
||||
if not use_wurlitzer:
|
||||
skip("C-level output only tested when package 'wurlitzer' is available.")
|
||||
|
||||
out, err = out.read(), err.read()
|
||||
assert err == ''
|
||||
assert out == """\
|
||||
x= 0.5
|
||||
x= 1.1121 d_x= 0.61214
|
||||
x= 0.90967 d_x= -0.20247
|
||||
x= 0.86726 d_x= -0.042409
|
||||
x= 0.86548 d_x= -0.0017867
|
||||
x= 0.86547 d_x= -3.1022e-06
|
||||
x= 0.86547 d_x= -9.3421e-12
|
||||
x= 0.86547 d_x= 3.6902e-17
|
||||
""" # try to run tests with LC_ALL=C if this assertion fails
|
||||
|
||||
|
||||
def test_newtons_method_function__rtol_cse_nan():
|
||||
a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True)
|
||||
i = Symbol('i', integer=True, nonnegative=True)
|
||||
N_ari = N_tot - N_geo - 1
|
||||
delta_ari = (c-b)/N_ari
|
||||
ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo))
|
||||
eqb_log = ln_delta_geo - log(delta_ari)
|
||||
|
||||
def _clamp(low, expr, high):
|
||||
return Min(Max(low, expr), high)
|
||||
|
||||
meth_kw = {
|
||||
'clamped_newton': {'delta_fn': lambda e, x: _clamp(
|
||||
(sqrt(a*x)-x)*0.99,
|
||||
-e/e.diff(x),
|
||||
(sqrt(c*x)-x)*0.99
|
||||
)},
|
||||
'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))},
|
||||
'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))},
|
||||
}
|
||||
args = eqb_log, b
|
||||
for use_cse in [False, True]:
|
||||
kwargs = {
|
||||
'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse,
|
||||
'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c),
|
||||
'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN.")))
|
||||
}
|
||||
func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()}
|
||||
py_mod = {k: py_module(v) for k, v in func.items()}
|
||||
namespace = {}
|
||||
root_find_b = {}
|
||||
for k, v in py_mod.items():
|
||||
ns = namespace[k] = {}
|
||||
exec(v, ns, ns)
|
||||
root_find_b[k] = ns[f'{k}_b']
|
||||
ref = Float('13.2261515064168768938151923226496')
|
||||
reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16}
|
||||
guess = 4.0
|
||||
for meth, func in root_find_b.items():
|
||||
result = func(guess, 1e-2, 1e2, 50, 100)
|
||||
req = ref*reftol[meth]
|
||||
if use_cse:
|
||||
req *= 2
|
||||
assert abs(result - ref) < req
|
||||
@@ -0,0 +1,58 @@
|
||||
# This file contains tests that exercise multiple AST nodes
|
||||
|
||||
import tempfile
|
||||
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, skip_under_pyodide
|
||||
from sympy.codegen.ast import (
|
||||
FunctionDefinition, FunctionPrototype, Variable, Pointer, real, Assignment,
|
||||
integer, CodeBlock, While
|
||||
)
|
||||
from sympy.codegen.cnodes import void, PreIncrement
|
||||
from sympy.codegen.cutils import render_as_source_file
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
def _mk_func1():
|
||||
declars = n, inp, out = Variable('n', integer), Pointer('inp', real), Pointer('out', real)
|
||||
i = Variable('i', integer)
|
||||
whl = While(i<n, [Assignment(out[i], inp[i]), PreIncrement(i)])
|
||||
body = CodeBlock(i.as_Declaration(value=0), whl)
|
||||
return FunctionDefinition(void, 'our_test_function', declars, body)
|
||||
|
||||
|
||||
def _render_compile_import(funcdef, build_dir):
|
||||
code_str = render_as_source_file(funcdef, settings={"contract": False})
|
||||
declar = ccode(FunctionPrototype.from_FunctionDefinition(funcdef))
|
||||
return compile_link_import_strings([
|
||||
('our_test_func.c', code_str),
|
||||
('_our_test_func.pyx', ("#cython: language_level={}\n".format("3") +
|
||||
"cdef extern {declar}\n"
|
||||
"def _{fname}({typ}[:] inp, {typ}[:] out):\n"
|
||||
" {fname}(inp.size, &inp[0], &out[0])").format(
|
||||
declar=declar, fname=funcdef.name, typ='double'
|
||||
))
|
||||
], build_dir=build_dir)
|
||||
|
||||
|
||||
@may_xfail
|
||||
@skip_under_pyodide("Emscripten does not support process spawning")
|
||||
def test_copying_function():
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
|
||||
info = None
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = _render_compile_import(_mk_func1(), build_dir=folder)
|
||||
inp = np.arange(10.0)
|
||||
out = np.empty_like(inp)
|
||||
mod._our_test_function(inp, out)
|
||||
assert np.allclose(inp, out)
|
||||
@@ -0,0 +1,53 @@
|
||||
import math
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp
|
||||
from sympy.codegen.rewriting import optimize
|
||||
from sympy.codegen.approximations import SumApprox, SeriesApprox
|
||||
|
||||
|
||||
def test_SumApprox_trivial():
|
||||
x = symbols('x')
|
||||
expr1 = 1 + x
|
||||
sum_approx = SumApprox(bounds={x: (-1e-20, 1e-20)}, reltol=1e-16)
|
||||
apx1 = optimize(expr1, [sum_approx])
|
||||
assert apx1 - 1 == 0
|
||||
|
||||
|
||||
def test_SumApprox_monotone_terms():
|
||||
x, y, z = symbols('x y z')
|
||||
expr1 = exp(z)*(x**2 + y**2 + 1)
|
||||
bnds1 = {x: (0, 1e-3), y: (100, 1000)}
|
||||
sum_approx_m2 = SumApprox(bounds=bnds1, reltol=1e-2)
|
||||
sum_approx_m5 = SumApprox(bounds=bnds1, reltol=1e-5)
|
||||
sum_approx_m11 = SumApprox(bounds=bnds1, reltol=1e-11)
|
||||
assert (optimize(expr1, [sum_approx_m2])/exp(z) - (y**2)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m5])/exp(z) - (y**2 + 1)).simplify() == 0
|
||||
assert (optimize(expr1, [sum_approx_m11])/exp(z) - (y**2 + 1 + x**2)).simplify() == 0
|
||||
|
||||
|
||||
def test_SeriesApprox_trivial():
|
||||
x, z = symbols('x z')
|
||||
for factor in [1, exp(z)]:
|
||||
x = symbols('x')
|
||||
expr1 = exp(x)*factor
|
||||
bnds1 = {x: (-1, 1)}
|
||||
series_approx_50 = SeriesApprox(bounds=bnds1, reltol=0.50)
|
||||
series_approx_10 = SeriesApprox(bounds=bnds1, reltol=0.10)
|
||||
series_approx_05 = SeriesApprox(bounds=bnds1, reltol=0.05)
|
||||
c = (bnds1[x][1] + bnds1[x][0])/2 # 0.0
|
||||
f0 = math.exp(c) # 1.0
|
||||
|
||||
ref_50 = f0 + x + x**2/2
|
||||
ref_10 = f0 + x + x**2/2 + x**3/6
|
||||
ref_05 = f0 + x + x**2/2 + x**3/6 + x**4/24
|
||||
|
||||
res_50 = optimize(expr1, [series_approx_50])
|
||||
res_10 = optimize(expr1, [series_approx_10])
|
||||
res_05 = optimize(expr1, [series_approx_05])
|
||||
|
||||
assert (res_50/factor - ref_50).simplify() == 0
|
||||
assert (res_10/factor - ref_10).simplify() == 0
|
||||
assert (res_05/factor - ref_05).simplify() == 0
|
||||
|
||||
max_ord3 = SeriesApprox(bounds=bnds1, reltol=0.05, max_order=3)
|
||||
assert optimize(expr1, [max_ord3]) == expr1
|
||||
@@ -0,0 +1,661 @@
|
||||
import math
|
||||
from sympy.core.containers import Tuple
|
||||
from sympy.core.numbers import nan, oo, Float, Integer
|
||||
from sympy.core.relational import Lt
|
||||
from sympy.core.symbol import symbols, Symbol
|
||||
from sympy.functions.elementary.trigonometric import sin
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.sets.fancysets import Range
|
||||
from sympy.tensor.indexed import Idx, IndexedBase
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Attribute, aug_assign, CodeBlock, For, Type, Variable, Pointer, Declaration,
|
||||
AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
|
||||
DivAugmentedAssignment, ModAugmentedAssignment, value_const, pointer_const,
|
||||
integer, real, complex_, int8, uint8, float16 as f16, float32 as f32,
|
||||
float64 as f64, float80 as f80, float128 as f128, complex64 as c64, complex128 as c128,
|
||||
While, Scope, String, Print, QuotedString, FunctionPrototype, FunctionDefinition, Return,
|
||||
FunctionCall, untyped, IntBaseType, intc, Node, none, NoneToken, Token, Comment
|
||||
)
|
||||
|
||||
x, y, z, t, x0, x1, x2, a, b = symbols("x, y, z, t, x0, x1, x2, a, b")
|
||||
n = symbols("n", integer=True)
|
||||
A = MatrixSymbol('A', 3, 1)
|
||||
mat = Matrix([1, 2, 3])
|
||||
B = IndexedBase('B')
|
||||
i = Idx("i", n)
|
||||
A22 = MatrixSymbol('A22',2,2)
|
||||
B22 = MatrixSymbol('B22',2,2)
|
||||
|
||||
|
||||
def test_Assignment():
|
||||
# Here we just do things to show they don't error
|
||||
Assignment(x, y)
|
||||
Assignment(x, 0)
|
||||
Assignment(A, mat)
|
||||
Assignment(A[1,0], 0)
|
||||
Assignment(A[1,0], x)
|
||||
Assignment(B[i], x)
|
||||
Assignment(B[i], 0)
|
||||
a = Assignment(x, y)
|
||||
assert a.func(*a.args) == a
|
||||
assert a.op == ':='
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: Assignment(B[i], A))
|
||||
raises(ValueError, lambda: Assignment(B[i], mat))
|
||||
raises(ValueError, lambda: Assignment(x, mat))
|
||||
raises(ValueError, lambda: Assignment(x, A))
|
||||
raises(ValueError, lambda: Assignment(A[1,0], mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: Assignment(A, x))
|
||||
raises(ValueError, lambda: Assignment(A, 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: Assignment(mat, A))
|
||||
raises(TypeError, lambda: Assignment(0, x))
|
||||
raises(TypeError, lambda: Assignment(x*x, 1))
|
||||
raises(TypeError, lambda: Assignment(A + A, mat))
|
||||
raises(TypeError, lambda: Assignment(B, 0))
|
||||
|
||||
|
||||
def test_AugAssign():
|
||||
# Here we just do things to show they don't error
|
||||
aug_assign(x, '+', y)
|
||||
aug_assign(x, '+', 0)
|
||||
aug_assign(A, '+', mat)
|
||||
aug_assign(A[1, 0], '+', 0)
|
||||
aug_assign(A[1, 0], '+', x)
|
||||
aug_assign(B[i], '+', x)
|
||||
aug_assign(B[i], '+', 0)
|
||||
|
||||
# Check creation via aug_assign vs constructor
|
||||
for binop, cls in [
|
||||
('+', AddAugmentedAssignment),
|
||||
('-', SubAugmentedAssignment),
|
||||
('*', MulAugmentedAssignment),
|
||||
('/', DivAugmentedAssignment),
|
||||
('%', ModAugmentedAssignment),
|
||||
]:
|
||||
a = aug_assign(x, binop, y)
|
||||
b = cls(x, y)
|
||||
assert a.func(*a.args) == a == b
|
||||
assert a.binop == binop
|
||||
assert a.op == binop + '='
|
||||
|
||||
# Here we test things to show that they error
|
||||
# Matrix to scalar
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', A))
|
||||
raises(ValueError, lambda: aug_assign(B[i], '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', mat))
|
||||
raises(ValueError, lambda: aug_assign(x, '+', A))
|
||||
raises(ValueError, lambda: aug_assign(A[1, 0], '+', mat))
|
||||
# Scalar to matrix
|
||||
raises(ValueError, lambda: aug_assign(A, '+', x))
|
||||
raises(ValueError, lambda: aug_assign(A, '+', 0))
|
||||
# Non-atomic lhs
|
||||
raises(TypeError, lambda: aug_assign(mat, '+', A))
|
||||
raises(TypeError, lambda: aug_assign(0, '+', x))
|
||||
raises(TypeError, lambda: aug_assign(x * x, '+', 1))
|
||||
raises(TypeError, lambda: aug_assign(A + A, '+', mat))
|
||||
raises(TypeError, lambda: aug_assign(B, '+', 0))
|
||||
|
||||
|
||||
def test_Assignment_printing():
|
||||
assignment_classes = [
|
||||
Assignment,
|
||||
AddAugmentedAssignment,
|
||||
SubAugmentedAssignment,
|
||||
MulAugmentedAssignment,
|
||||
DivAugmentedAssignment,
|
||||
ModAugmentedAssignment,
|
||||
]
|
||||
pairs = [
|
||||
(x, 2 * y + 2),
|
||||
(B[i], x),
|
||||
(A22, B22),
|
||||
(A[0, 0], x),
|
||||
]
|
||||
|
||||
for cls in assignment_classes:
|
||||
for lhs, rhs in pairs:
|
||||
a = cls(lhs, rhs)
|
||||
assert repr(a) == '%s(%s, %s)' % (cls.__name__, repr(lhs), repr(rhs))
|
||||
|
||||
|
||||
def test_CodeBlock():
|
||||
c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
|
||||
assert c.func(*c.args) == c
|
||||
|
||||
assert c.left_hand_sides == Tuple(x, y)
|
||||
assert c.right_hand_sides == Tuple(1, x + 1)
|
||||
|
||||
def test_CodeBlock_topological_sort():
|
||||
assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
ordered_assignments = [
|
||||
# Note that the unrelated z=1 and y=2 are kept in that order
|
||||
Assignment(z, 1),
|
||||
Assignment(y, 2),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
c1 = CodeBlock.topological_sort(assignments)
|
||||
assert c1 == CodeBlock(*ordered_assignments)
|
||||
|
||||
# Cycle
|
||||
invalid_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(y, x),
|
||||
Assignment(y, 2),
|
||||
]
|
||||
|
||||
raises(ValueError, lambda: CodeBlock.topological_sort(invalid_assignments))
|
||||
|
||||
# Free symbols
|
||||
free_assignments = [
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
]
|
||||
|
||||
free_assignments_ordered = [
|
||||
Assignment(z, a * b),
|
||||
Assignment(y, b + 3),
|
||||
Assignment(x, y + z),
|
||||
Assignment(t, x),
|
||||
]
|
||||
|
||||
c2 = CodeBlock.topological_sort(free_assignments)
|
||||
assert c2 == CodeBlock(*free_assignments_ordered)
|
||||
|
||||
def test_CodeBlock_free_symbols():
|
||||
c1 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, 1),
|
||||
Assignment(t, x),
|
||||
Assignment(y, 2),
|
||||
)
|
||||
assert c1.free_symbols == set()
|
||||
|
||||
c2 = CodeBlock(
|
||||
Assignment(x, y + z),
|
||||
Assignment(z, a * b),
|
||||
Assignment(t, x),
|
||||
Assignment(y, b + 3),
|
||||
)
|
||||
assert c2.free_symbols == {a, b}
|
||||
|
||||
def test_CodeBlock_cse():
|
||||
c1 = CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x, sin(y)),
|
||||
Assignment(z, sin(y)),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
assert c1.cse() == CodeBlock(
|
||||
Assignment(y, 1),
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(x, x0),
|
||||
Assignment(z, x0),
|
||||
Assignment(t, x*z),
|
||||
)
|
||||
|
||||
# Multiple assignments to same symbol not supported
|
||||
raises(NotImplementedError, lambda: CodeBlock(
|
||||
Assignment(x, 1),
|
||||
Assignment(y, 1), Assignment(y, 2)
|
||||
).cse())
|
||||
|
||||
# Check auto-generated symbols do not collide with existing ones
|
||||
c2 = CodeBlock(
|
||||
Assignment(x0, sin(y) + 1),
|
||||
Assignment(x1, 2 * sin(y)),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
assert c2.cse() == CodeBlock(
|
||||
Assignment(x2, sin(y)),
|
||||
Assignment(x0, x2 + 1),
|
||||
Assignment(x1, 2 * x2),
|
||||
Assignment(z, x * y),
|
||||
)
|
||||
|
||||
|
||||
def test_CodeBlock_cse__issue_14118():
|
||||
# see https://github.com/sympy/sympy/issues/14118
|
||||
c = CodeBlock(
|
||||
Assignment(A22, Matrix([[x, sin(y)],[3, 4]])),
|
||||
Assignment(B22, Matrix([[sin(y), 2*sin(y)], [sin(y)**2, 7]]))
|
||||
)
|
||||
assert c.cse() == CodeBlock(
|
||||
Assignment(x0, sin(y)),
|
||||
Assignment(A22, Matrix([[x, x0],[3, 4]])),
|
||||
Assignment(B22, Matrix([[x0, 2*x0], [x0**2, 7]]))
|
||||
)
|
||||
|
||||
def test_For():
|
||||
f = For(n, Range(0, 3), (Assignment(A[n, 0], x + n), aug_assign(x, '+', y)))
|
||||
f = For(n, (1, 2, 3, 4, 5), (Assignment(A[n, 0], x + n),))
|
||||
assert f.func(*f.args) == f
|
||||
raises(TypeError, lambda: For(n, x, (x + y,)))
|
||||
|
||||
|
||||
def test_none():
|
||||
assert none.is_Atom
|
||||
assert none == none
|
||||
class Foo(Token):
|
||||
pass
|
||||
foo = Foo()
|
||||
assert foo != none
|
||||
assert none == None
|
||||
assert none == NoneToken()
|
||||
assert none.func(*none.args) == none
|
||||
|
||||
|
||||
def test_String():
|
||||
st = String('foobar')
|
||||
assert st.is_Atom
|
||||
assert st == String('foobar')
|
||||
assert st.text == 'foobar'
|
||||
assert st.func(**st.kwargs()) == st
|
||||
assert st.func(*st.args) == st
|
||||
|
||||
|
||||
class Signifier(String):
|
||||
pass
|
||||
|
||||
si = Signifier('foobar')
|
||||
assert si != st
|
||||
assert si.text == st.text
|
||||
s = String('foo')
|
||||
assert str(s) == 'foo'
|
||||
assert repr(s) == "String('foo')"
|
||||
|
||||
def test_Comment():
|
||||
c = Comment('foobar')
|
||||
assert c.text == 'foobar'
|
||||
assert str(c) == 'foobar'
|
||||
|
||||
def test_Node():
|
||||
n = Node()
|
||||
assert n == Node()
|
||||
assert n.func(*n.args) == n
|
||||
|
||||
|
||||
def test_Type():
|
||||
t = Type('MyType')
|
||||
assert len(t.args) == 1
|
||||
assert t.name == String('MyType')
|
||||
assert str(t) == 'MyType'
|
||||
assert repr(t) == "Type(String('MyType'))"
|
||||
assert Type(t) == t
|
||||
assert t.func(*t.args) == t
|
||||
t1 = Type('t1')
|
||||
t2 = Type('t2')
|
||||
assert t1 != t2
|
||||
assert t1 == t1 and t2 == t2
|
||||
t1b = Type('t1')
|
||||
assert t1 == t1b
|
||||
assert t2 != t1b
|
||||
|
||||
|
||||
def test_Type__from_expr():
|
||||
assert Type.from_expr(i) == integer
|
||||
u = symbols('u', real=True)
|
||||
assert Type.from_expr(u) == real
|
||||
assert Type.from_expr(n) == integer
|
||||
assert Type.from_expr(3) == integer
|
||||
assert Type.from_expr(3.0) == real
|
||||
assert Type.from_expr(3+1j) == complex_
|
||||
raises(ValueError, lambda: Type.from_expr(sum))
|
||||
|
||||
|
||||
def test_Type__cast_check__integers():
|
||||
# Rounding
|
||||
raises(ValueError, lambda: integer.cast_check(3.5))
|
||||
assert integer.cast_check('3') == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000000')) == 3
|
||||
assert integer.cast_check(Float('3.0000000000000000001')) == 3 # unintuitive maybe?
|
||||
|
||||
# Range
|
||||
assert int8.cast_check(127.0) == 127
|
||||
raises(ValueError, lambda: int8.cast_check(128))
|
||||
assert int8.cast_check(-128) == -128
|
||||
raises(ValueError, lambda: int8.cast_check(-129))
|
||||
|
||||
assert uint8.cast_check(0) == 0
|
||||
assert uint8.cast_check(128) == 128
|
||||
raises(ValueError, lambda: uint8.cast_check(256.0))
|
||||
raises(ValueError, lambda: uint8.cast_check(-1))
|
||||
|
||||
def test_Attribute():
|
||||
noexcept = Attribute('noexcept')
|
||||
assert noexcept == Attribute('noexcept')
|
||||
alignas16 = Attribute('alignas', [16])
|
||||
alignas32 = Attribute('alignas', [32])
|
||||
assert alignas16 != alignas32
|
||||
assert alignas16.func(*alignas16.args) == alignas16
|
||||
|
||||
|
||||
def test_Variable():
|
||||
v = Variable(x, type=real)
|
||||
assert v == Variable(v)
|
||||
assert v == Variable('x', type=real)
|
||||
assert v.symbol == x
|
||||
assert v.type == real
|
||||
assert value_const not in v.attrs
|
||||
assert v.func(*v.args) == v
|
||||
assert str(v) == 'Variable(x, type=real)'
|
||||
|
||||
w = Variable(y, f32, attrs={value_const})
|
||||
assert w.symbol == y
|
||||
assert w.type == f32
|
||||
assert value_const in w.attrs
|
||||
assert w.func(*w.args) == w
|
||||
|
||||
v_n = Variable(n, type=Type.from_expr(n))
|
||||
assert v_n.type == integer
|
||||
assert v_n.func(*v_n.args) == v_n
|
||||
v_i = Variable(i, type=Type.from_expr(n))
|
||||
assert v_i.type == integer
|
||||
assert v_i != v_n
|
||||
|
||||
a_i = Variable.deduced(i)
|
||||
assert a_i.type == integer
|
||||
assert Variable.deduced(Symbol('x', real=True)).type == real
|
||||
assert a_i.func(*a_i.args) == a_i
|
||||
|
||||
v_n2 = Variable.deduced(n, value=3.5, cast_check=False)
|
||||
assert v_n2.func(*v_n2.args) == v_n2
|
||||
assert abs(v_n2.value - 3.5) < 1e-15
|
||||
raises(ValueError, lambda: Variable.deduced(n, value=3.5, cast_check=True))
|
||||
|
||||
v_n3 = Variable.deduced(n)
|
||||
assert v_n3.type == integer
|
||||
assert str(v_n3) == 'Variable(n, type=integer)'
|
||||
assert Variable.deduced(z, value=3).type == integer
|
||||
assert Variable.deduced(z, value=3.0).type == real
|
||||
assert Variable.deduced(z, value=3.0+1j).type == complex_
|
||||
|
||||
|
||||
def test_Pointer():
|
||||
p = Pointer(x)
|
||||
assert p.symbol == x
|
||||
assert p.type == untyped
|
||||
assert value_const not in p.attrs
|
||||
assert pointer_const not in p.attrs
|
||||
assert p.func(*p.args) == p
|
||||
|
||||
u = symbols('u', real=True)
|
||||
pu = Pointer(u, type=Type.from_expr(u), attrs={value_const, pointer_const})
|
||||
assert pu.symbol is u
|
||||
assert pu.type == real
|
||||
assert value_const in pu.attrs
|
||||
assert pointer_const in pu.attrs
|
||||
assert pu.func(*pu.args) == pu
|
||||
|
||||
i = symbols('i', integer=True)
|
||||
deref = pu[i]
|
||||
assert deref.indices == (i,)
|
||||
|
||||
|
||||
def test_Declaration():
|
||||
u = symbols('u', real=True)
|
||||
vu = Variable(u, type=Type.from_expr(u))
|
||||
assert Declaration(vu).variable.type == real
|
||||
vn = Variable(n, type=Type.from_expr(n))
|
||||
assert Declaration(vn).variable.type == integer
|
||||
|
||||
# PR 19107, does not allow comparison between expressions and Basic
|
||||
# lt = StrictLessThan(vu, vn)
|
||||
# assert isinstance(lt, StrictLessThan)
|
||||
|
||||
vuc = Variable(u, Type.from_expr(u), value=3.0, attrs={value_const})
|
||||
assert value_const in vuc.attrs
|
||||
assert pointer_const not in vuc.attrs
|
||||
decl = Declaration(vuc)
|
||||
assert decl.variable == vuc
|
||||
assert isinstance(decl.variable.value, Float)
|
||||
assert decl.variable.value == 3.0
|
||||
assert decl.func(*decl.args) == decl
|
||||
assert vuc.as_Declaration() == decl
|
||||
assert vuc.as_Declaration(value=None, attrs=None) == Declaration(vu)
|
||||
|
||||
vy = Variable(y, type=integer, value=3)
|
||||
decl2 = Declaration(vy)
|
||||
assert decl2.variable == vy
|
||||
assert decl2.variable.value == Integer(3)
|
||||
|
||||
vi = Variable(i, type=Type.from_expr(i), value=3.0)
|
||||
decl3 = Declaration(vi)
|
||||
assert decl3.variable.type == integer
|
||||
assert decl3.variable.value == 3.0
|
||||
|
||||
raises(ValueError, lambda: Declaration(vi, 42))
|
||||
|
||||
|
||||
def test_IntBaseType():
|
||||
assert intc.name == String('intc')
|
||||
assert intc.args == (intc.name,)
|
||||
assert str(IntBaseType('a').name) == 'a'
|
||||
|
||||
|
||||
def test_FloatType():
|
||||
assert f16.dig == 3
|
||||
assert f32.dig == 6
|
||||
assert f64.dig == 15
|
||||
assert f80.dig == 18
|
||||
assert f128.dig == 33
|
||||
|
||||
assert f16.decimal_dig == 5
|
||||
assert f32.decimal_dig == 9
|
||||
assert f64.decimal_dig == 17
|
||||
assert f80.decimal_dig == 21
|
||||
assert f128.decimal_dig == 36
|
||||
|
||||
assert f16.max_exponent == 16
|
||||
assert f32.max_exponent == 128
|
||||
assert f64.max_exponent == 1024
|
||||
assert f80.max_exponent == 16384
|
||||
assert f128.max_exponent == 16384
|
||||
|
||||
assert f16.min_exponent == -13
|
||||
assert f32.min_exponent == -125
|
||||
assert f64.min_exponent == -1021
|
||||
assert f80.min_exponent == -16381
|
||||
assert f128.min_exponent == -16381
|
||||
|
||||
assert abs(f16.eps / Float('0.00097656', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.eps / Float('1.1920929e-07', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.eps / Float('2.2204460492503131e-16', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.eps / Float('1.08420217248550443401e-19', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.eps / Float(' 1.92592994438723585305597794258492732e-34', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert abs(f16.max / Float('65504', precision=16) - 1) < .1*10**-f16.dig
|
||||
assert abs(f32.max / Float('3.40282347e+38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.max / Float('1.79769313486231571e+308', precision=64) - 1) < 0.1*10**-f64.dig # cf. np.finfo(np.float64).max
|
||||
assert abs(f80.max / Float('1.18973149535723176502e+4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.max / Float('1.18973149535723176508575932662800702e+4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
# cf. np.finfo(np.float32).tiny
|
||||
assert abs(f16.tiny / Float('6.1035e-05', precision=16) - 1) < 0.1*10**-f16.dig
|
||||
assert abs(f32.tiny / Float('1.17549435e-38', precision=32) - 1) < 0.1*10**-f32.dig
|
||||
assert abs(f64.tiny / Float('2.22507385850720138e-308', precision=64) - 1) < 0.1*10**-f64.dig
|
||||
assert abs(f80.tiny / Float('3.36210314311209350626e-4932', precision=80) - 1) < 0.1*10**-f80.dig
|
||||
assert abs(f128.tiny / Float('3.3621031431120935062626778173217526e-4932', precision=128) - 1) < 0.1*10**-f128.dig
|
||||
|
||||
assert f64.cast_check(0.5) == Float(0.5, 17)
|
||||
assert abs(f64.cast_check(3.7) - 3.7) < 3e-17
|
||||
assert isinstance(f64.cast_check(3), (Float, float))
|
||||
|
||||
assert f64.cast_nocheck(oo) == float('inf')
|
||||
assert f64.cast_nocheck(-oo) == float('-inf')
|
||||
assert f64.cast_nocheck(float(oo)) == float('inf')
|
||||
assert f64.cast_nocheck(float(-oo)) == float('-inf')
|
||||
assert math.isnan(f64.cast_nocheck(nan))
|
||||
|
||||
assert f32 != f64
|
||||
assert f64 == f64.func(*f64.args)
|
||||
|
||||
|
||||
def test_Type__cast_check__floating_point():
|
||||
raises(ValueError, lambda: f32.cast_check(123.45678949))
|
||||
raises(ValueError, lambda: f32.cast_check(12.345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(1.2345678949))
|
||||
raises(ValueError, lambda: f32.cast_check(.12345678949))
|
||||
assert abs(123.456789049 - f32.cast_check(123.456789049) - 4.9e-8) < 1e-8
|
||||
assert abs(0.12345678904 - f32.cast_check(0.12345678904) - 4e-11) < 1e-11
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') # 21 decimals
|
||||
assert abs(dcm21 - f64.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
|
||||
f80.cast_check(Float('0.12345678901234567890103', precision=88))
|
||||
raises(ValueError, lambda: f80.cast_check(Float('0.12345678901234567890149', precision=88)))
|
||||
|
||||
v10 = 12345.67894
|
||||
raises(ValueError, lambda: f32.cast_check(v10))
|
||||
assert abs(Float(str(v10), precision=64+8) - f64.cast_check(v10)) < v10*1e-16
|
||||
|
||||
assert abs(f32.cast_check(2147483647) - 2147483650) < 1
|
||||
|
||||
|
||||
def test_Type__cast_check__complex_floating_point():
|
||||
val9_11 = 123.456789049 + 0.123456789049j
|
||||
raises(ValueError, lambda: c64.cast_check(.12345678949 + .12345678949j))
|
||||
assert abs(val9_11 - c64.cast_check(val9_11) - 4.9e-8) < 1e-8
|
||||
|
||||
dcm21 = Float('0.123456789012345670499') + 1e-20j # 21 decimals
|
||||
assert abs(dcm21 - c128.cast_check(dcm21) - 4.99e-19) < 1e-19
|
||||
v19 = Float('0.1234567890123456749') + 1j*Float('0.1234567890123456749')
|
||||
raises(ValueError, lambda: c128.cast_check(v19))
|
||||
|
||||
|
||||
def test_While():
|
||||
xpp = AddAugmentedAssignment(x, 1)
|
||||
whl1 = While(x < 2, [xpp])
|
||||
assert whl1.condition.args[0] == x
|
||||
assert whl1.condition.args[1] == 2
|
||||
assert whl1.condition == Lt(x, 2, evaluate=False)
|
||||
assert whl1.body.args == (xpp,)
|
||||
assert whl1.func(*whl1.args) == whl1
|
||||
|
||||
cblk = CodeBlock(AddAugmentedAssignment(x, 1))
|
||||
whl2 = While(x < 2, cblk)
|
||||
assert whl1 == whl2
|
||||
assert whl1 != While(x < 3, [xpp])
|
||||
|
||||
|
||||
def test_Scope():
|
||||
assign = Assignment(x, y)
|
||||
incr = AddAugmentedAssignment(x, 1)
|
||||
scp = Scope([assign, incr])
|
||||
cblk = CodeBlock(assign, incr)
|
||||
assert scp.body == cblk
|
||||
assert scp == Scope(cblk)
|
||||
assert scp != Scope([incr, assign])
|
||||
assert scp.func(*scp.args) == scp
|
||||
|
||||
|
||||
def test_Print():
|
||||
fmt = "%d %.3f"
|
||||
ps = Print([n, x], fmt)
|
||||
assert str(ps.format_string) == fmt
|
||||
assert ps.print_args == Tuple(n, x)
|
||||
assert ps.args == (Tuple(n, x), QuotedString(fmt), none)
|
||||
assert ps == Print((n, x), fmt)
|
||||
assert ps != Print([x, n], fmt)
|
||||
assert ps.func(*ps.args) == ps
|
||||
|
||||
ps2 = Print([n, x])
|
||||
assert ps2 == Print([n, x])
|
||||
assert ps2 != ps
|
||||
assert ps2.format_string == None
|
||||
|
||||
|
||||
def test_FunctionPrototype_and_FunctionDefinition():
|
||||
vx = Variable(x, type=real)
|
||||
vn = Variable(n, type=integer)
|
||||
fp1 = FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1.return_type == real
|
||||
assert fp1.name == String('power')
|
||||
assert fp1.parameters == Tuple(vx, vn)
|
||||
assert fp1 == FunctionPrototype(real, 'power', [vx, vn])
|
||||
assert fp1 != FunctionPrototype(real, 'power', [vn, vx])
|
||||
assert fp1.func(*fp1.args) == fp1
|
||||
|
||||
|
||||
body = [Assignment(x, x**n), Return(x)]
|
||||
fd1 = FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1.return_type == real
|
||||
assert str(fd1.name) == 'power'
|
||||
assert fd1.parameters == Tuple(vx, vn)
|
||||
assert fd1.body == CodeBlock(*body)
|
||||
assert fd1 == FunctionDefinition(real, 'power', [vx, vn], body)
|
||||
assert fd1 != FunctionDefinition(real, 'power', [vx, vn], body[::-1])
|
||||
assert fd1.func(*fd1.args) == fd1
|
||||
|
||||
fp2 = FunctionPrototype.from_FunctionDefinition(fd1)
|
||||
assert fp2 == fp1
|
||||
|
||||
fd2 = FunctionDefinition.from_FunctionPrototype(fp1, body)
|
||||
assert fd2 == fd1
|
||||
|
||||
|
||||
def test_Return():
|
||||
rs = Return(x)
|
||||
assert rs.args == (x,)
|
||||
assert rs == Return(x)
|
||||
assert rs != Return(y)
|
||||
assert rs.func(*rs.args) == rs
|
||||
|
||||
|
||||
def test_FunctionCall():
|
||||
fc = FunctionCall('power', (x, 3))
|
||||
assert fc.function_args[0] == x
|
||||
assert fc.function_args[1] == 3
|
||||
assert len(fc.function_args) == 2
|
||||
assert isinstance(fc.function_args[1], Integer)
|
||||
assert fc == FunctionCall('power', (x, 3))
|
||||
assert fc != FunctionCall('power', (3, x))
|
||||
assert fc != FunctionCall('Power', (x, 3))
|
||||
assert fc.func(*fc.args) == fc
|
||||
|
||||
fc2 = FunctionCall('fma', [2, 3, 4])
|
||||
assert len(fc2.function_args) == 3
|
||||
assert fc2.function_args[0] == 2
|
||||
assert fc2.function_args[1] == 3
|
||||
assert fc2.function_args[2] == 4
|
||||
assert str(fc2) in ( # not sure if QuotedString is a better default...
|
||||
'FunctionCall(fma, function_args=(2, 3, 4))',
|
||||
'FunctionCall("fma", function_args=(2, 3, 4))',
|
||||
)
|
||||
|
||||
def test_ast_replace():
|
||||
x = Variable('x', real)
|
||||
y = Variable('y', real)
|
||||
n = Variable('n', integer)
|
||||
|
||||
pwer = FunctionDefinition(real, 'pwer', [x, n], [pow(x.symbol, n.symbol)])
|
||||
pname = pwer.name
|
||||
pcall = FunctionCall('pwer', [y, 3])
|
||||
|
||||
tree1 = CodeBlock(pwer, pcall)
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
for a, b in zip(tree1, [pwer, pcall]):
|
||||
assert a == b
|
||||
|
||||
tree2 = tree1.replace(pname, String('power'))
|
||||
assert str(tree1.args[0].name) == 'pwer'
|
||||
assert str(tree1.args[1].name) == 'pwer'
|
||||
assert str(tree2.args[0].name) == 'power'
|
||||
assert str(tree2.args[1].name) == 'power'
|
||||
@@ -0,0 +1,186 @@
|
||||
from sympy.core.numbers import (Rational, pi)
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.codegen.cfunctions import (
|
||||
expm1, log1p, exp2, log2, fma, log10, Sqrt, Cbrt, hypot, isnan, isinf
|
||||
)
|
||||
from sympy.core.function import expand_log
|
||||
|
||||
|
||||
def test_expm1():
|
||||
# Eval
|
||||
assert expm1(0) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert expm1(x).expand(func=True) - exp(x) == -1
|
||||
assert expm1(x).rewrite('tractable') - exp(x) == -1
|
||||
assert expm1(x).rewrite('exp') - exp(x) == -1
|
||||
|
||||
# Precision
|
||||
assert not ((exp(1e-10).evalf() - 1) - 1e-10 - 5e-21) < 1e-22 # for comparison
|
||||
assert abs(expm1(1e-10).evalf() - 1e-10 - 5e-21) < 1e-22
|
||||
|
||||
# Properties
|
||||
assert expm1(x).is_real
|
||||
assert expm1(x).is_finite
|
||||
|
||||
# Diff
|
||||
assert expm1(42*x).diff(x) - 42*exp(42*x) == 0
|
||||
assert expm1(42*x).diff(x) - expm1(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_log1p():
|
||||
# Eval
|
||||
assert log1p(0) == 0
|
||||
d = S(10)
|
||||
assert expand_log(log1p(d**-1000) - log(d**1000 + 1) + log(d**1000)) == 0
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand and rewrite
|
||||
assert log1p(x).expand(func=True) - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('tractable') - log(x + 1) == 0
|
||||
assert log1p(x).rewrite('log') - log(x + 1) == 0
|
||||
|
||||
# Precision
|
||||
assert not abs(log(1e-99 + 1).evalf() - 1e-99) < 1e-100 # for comparison
|
||||
assert abs(expand_log(log1p(1e-99)).evalf() - 1e-99) < 1e-100
|
||||
|
||||
# Properties
|
||||
assert log1p(-2**Rational(-1, 2)).is_real
|
||||
|
||||
assert not log1p(-1).is_finite
|
||||
assert log1p(pi).is_finite
|
||||
|
||||
assert not log1p(x).is_positive
|
||||
assert log1p(Symbol('y', positive=True)).is_positive
|
||||
|
||||
assert not log1p(x).is_zero
|
||||
assert log1p(Symbol('z', zero=True)).is_zero
|
||||
|
||||
assert not log1p(x).is_nonnegative
|
||||
assert log1p(Symbol('o', nonnegative=True)).is_nonnegative
|
||||
|
||||
# Diff
|
||||
assert log1p(42*x).diff(x) - 42/(42*x + 1) == 0
|
||||
assert log1p(42*x).diff(x) - log1p(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_exp2():
|
||||
# Eval
|
||||
assert exp2(2) == 4
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
|
||||
# Expand
|
||||
assert exp2(x).expand(func=True) - 2**x == 0
|
||||
|
||||
# Diff
|
||||
assert exp2(42*x).diff(x) - 42*exp2(42*x)*log(2) == 0
|
||||
assert exp2(42*x).diff(x) - exp2(42*x).diff(x) == 0
|
||||
|
||||
|
||||
def test_log2():
|
||||
# Eval
|
||||
assert log2(8) == 3
|
||||
assert log2(pi) != log(pi)/log(2) # log2 should *save* (CPU) instructions
|
||||
|
||||
x = Symbol('x', real=True)
|
||||
assert log2(x) != log(x)/log(2)
|
||||
assert log2(2**x) == x
|
||||
|
||||
# Expand
|
||||
assert log2(x).expand(func=True) - log(x)/log(2) == 0
|
||||
|
||||
# Diff
|
||||
assert log2(42*x).diff() - 1/(log(2)*x) == 0
|
||||
assert log2(42*x).diff() - log2(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_fma():
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
# Expand
|
||||
assert fma(x, y, z).expand(func=True) - x*y - z == 0
|
||||
|
||||
expr = fma(17*x, 42*y, 101*z)
|
||||
|
||||
# Diff
|
||||
assert expr.diff(x) - expr.expand(func=True).diff(x) == 0
|
||||
assert expr.diff(y) - expr.expand(func=True).diff(y) == 0
|
||||
assert expr.diff(z) - expr.expand(func=True).diff(z) == 0
|
||||
|
||||
assert expr.diff(x) - 17*42*y == 0
|
||||
assert expr.diff(y) - 17*42*x == 0
|
||||
assert expr.diff(z) - 101 == 0
|
||||
|
||||
|
||||
def test_log10():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert log10(x).expand(func=True) - log(x)/log(10) == 0
|
||||
|
||||
# Diff
|
||||
assert log10(42*x).diff(x) - 1/(log(10)*x) == 0
|
||||
assert log10(42*x).diff(x) - log10(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Cbrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Cbrt(x).expand(func=True) - x**Rational(1, 3) == 0
|
||||
|
||||
# Diff
|
||||
assert Cbrt(42*x).diff(x) - 42*(42*x)**(Rational(1, 3) - 1)/3 == 0
|
||||
assert Cbrt(42*x).diff(x) - Cbrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_Sqrt():
|
||||
x = Symbol('x')
|
||||
|
||||
# Expand
|
||||
assert Sqrt(x).expand(func=True) - x**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert Sqrt(42*x).diff(x) - 42*(42*x)**(S.Half - 1)/2 == 0
|
||||
assert Sqrt(42*x).diff(x) - Sqrt(42*x).expand(func=True).diff(x) == 0
|
||||
|
||||
|
||||
def test_hypot():
|
||||
x, y = symbols('x y')
|
||||
|
||||
# Expand
|
||||
assert hypot(x, y).expand(func=True) - (x**2 + y**2)**S.Half == 0
|
||||
|
||||
# Diff
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(x) == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - hypot(17*x, 42*y).expand(func=True).diff(y) == 0
|
||||
|
||||
assert hypot(17*x, 42*y).diff(x).expand(func=True) - 2*17*17*x*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
assert hypot(17*x, 42*y).diff(y).expand(func=True) - 2*42*42*y*((17*x)**2 + (42*y)**2)**Rational(-1, 2)/2 == 0
|
||||
|
||||
|
||||
def test_isnan_isinf():
|
||||
x = Symbol('x')
|
||||
|
||||
# isinf
|
||||
assert isinf(+S.Infinity) == True
|
||||
assert isinf(-S.Infinity) == True
|
||||
assert isinf(S.Pi) == False
|
||||
isinfx = isinf(x)
|
||||
assert isinfx not in (False, True)
|
||||
assert isinfx.func is isinf
|
||||
assert isinfx.args == (x,)
|
||||
|
||||
# isnan
|
||||
assert isnan(S.NaN) == True
|
||||
assert isnan(S.Pi) == False
|
||||
isnanx = isnan(x)
|
||||
assert isnanx not in (False, True)
|
||||
assert isnanx.func is isnan
|
||||
assert isnanx.args == (x,)
|
||||
@@ -0,0 +1,112 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.ast import Declaration, Variable, float64, int64, String, CodeBlock
|
||||
from sympy.codegen.cnodes import (
|
||||
alignof, CommaOperator, goto, Label, PreDecrement, PostDecrement, PreIncrement, PostIncrement,
|
||||
sizeof, union, struct
|
||||
)
|
||||
|
||||
x, y = symbols('x y')
|
||||
|
||||
|
||||
def test_alignof():
|
||||
ax = alignof(x)
|
||||
assert ccode(ax) == 'alignof(x)'
|
||||
assert ax.func(*ax.args) == ax
|
||||
|
||||
|
||||
def test_CommaOperator():
|
||||
expr = CommaOperator(PreIncrement(x), 2*x)
|
||||
assert ccode(expr) == '(++(x), 2*x)'
|
||||
assert expr.func(*expr.args) == expr
|
||||
|
||||
|
||||
def test_goto_Label():
|
||||
s = 'early_exit'
|
||||
g = goto(s)
|
||||
assert g.func(*g.args) == g
|
||||
assert g != goto('foobar')
|
||||
assert ccode(g) == 'goto early_exit'
|
||||
|
||||
l1 = Label(s)
|
||||
assert ccode(l1) == 'early_exit:'
|
||||
assert l1 == Label('early_exit')
|
||||
assert l1 != Label('foobar')
|
||||
|
||||
body = [PreIncrement(x)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"++(x);")
|
||||
|
||||
body = [PreIncrement(x), PreDecrement(y)]
|
||||
l2 = Label(s, body)
|
||||
assert l2.name == String("early_exit")
|
||||
assert l2.body == CodeBlock(PreIncrement(x), PreDecrement(y))
|
||||
assert ccode(l2) == ("early_exit:\n"
|
||||
"{\n ++(x);\n --(y);\n}")
|
||||
|
||||
|
||||
def test_PreDecrement():
|
||||
p = PreDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '--(x)'
|
||||
|
||||
|
||||
def test_PostDecrement():
|
||||
p = PostDecrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)--'
|
||||
|
||||
|
||||
def test_PreIncrement():
|
||||
p = PreIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '++(x)'
|
||||
|
||||
|
||||
def test_PostIncrement():
|
||||
p = PostIncrement(x)
|
||||
assert p.func(*p.args) == p
|
||||
assert ccode(p) == '(x)++'
|
||||
|
||||
|
||||
def test_sizeof():
|
||||
typename = 'unsigned int'
|
||||
sz = sizeof(typename)
|
||||
assert ccode(sz) == 'sizeof(%s)' % typename
|
||||
assert sz.func(*sz.args) == sz
|
||||
assert not sz.is_Atom
|
||||
assert sz.atoms() == {String('unsigned int'), String('sizeof')}
|
||||
|
||||
|
||||
def test_struct():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=float64)
|
||||
s = struct('vec2', [vx, vy])
|
||||
assert s.func(*s.args) == s
|
||||
assert s == struct('vec2', (vx, vy))
|
||||
assert s != struct('vec2', (vy, vx))
|
||||
assert str(s.name) == 'vec2'
|
||||
assert len(s.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in s.declarations)
|
||||
assert ccode(s) == (
|
||||
"struct vec2 {\n"
|
||||
" double x;\n"
|
||||
" double y;\n"
|
||||
"}")
|
||||
|
||||
|
||||
def test_union():
|
||||
vx, vy = Variable(x, type=float64), Variable(y, type=int64)
|
||||
u = union('dualuse', [vx, vy])
|
||||
assert u.func(*u.args) == u
|
||||
assert u == union('dualuse', (vx, vy))
|
||||
assert str(u.name) == 'dualuse'
|
||||
assert len(u.declarations) == 2
|
||||
assert all(isinstance(arg, Declaration) for arg in u.declarations)
|
||||
assert ccode(u) == (
|
||||
"union dualuse {\n"
|
||||
" double x;\n"
|
||||
" int64_t y;\n"
|
||||
"}")
|
||||
@@ -0,0 +1,14 @@
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.codegen.ast import Type
|
||||
from sympy.codegen.cxxnodes import using
|
||||
from sympy.printing.codeprinter import cxxcode
|
||||
|
||||
x = Symbol('x')
|
||||
|
||||
def test_using():
|
||||
v = Type('std::vector')
|
||||
u1 = using(v)
|
||||
assert cxxcode(u1) == 'using std::vector'
|
||||
|
||||
u2 = using(v, 'vec')
|
||||
assert cxxcode(u2) == 'using vec = std::vector'
|
||||
@@ -0,0 +1,213 @@
|
||||
import os
|
||||
import tempfile
|
||||
from sympy.core.symbol import (Symbol, symbols)
|
||||
from sympy.codegen.ast import (
|
||||
Assignment, Print, Declaration, FunctionDefinition, Return, real,
|
||||
FunctionCall, Variable, Element, integer
|
||||
)
|
||||
from sympy.codegen.fnodes import (
|
||||
allocatable, ArrayConstructor, isign, dsign, cmplx, kind, literal_dp,
|
||||
Program, Module, use, Subroutine, dimension, assumed_extent, ImpliedDoLoop,
|
||||
intent_out, size, Do, SubroutineCall, sum_, array, bind_C
|
||||
)
|
||||
from sympy.codegen.futils import render_as_module
|
||||
from sympy.core.expr import unchanged
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import fcode
|
||||
from sympy.utilities._compilation import has_fortran, compile_run_strings, compile_link_import_strings
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
from sympy.testing.pytest import skip, XFAIL
|
||||
|
||||
cython = import_module('cython')
|
||||
np = import_module('numpy')
|
||||
|
||||
|
||||
def test_size():
|
||||
x = Symbol('x', real=True)
|
||||
sx = size(x)
|
||||
assert fcode(sx, source_format='free') == 'size(x)'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_size_assumed_shape():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
a = Symbol('a', real=True)
|
||||
body = [Return((sum_(a**2)/size(a))**.5)]
|
||||
arr = array(a, dim=[':'], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr], body)
|
||||
render_as_module([fd], 'mod_rms')
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('rms.f90', render_as_module([fd], 'mod_rms')),
|
||||
('main.f90', (
|
||||
'program myprog\n'
|
||||
'use mod_rms, only: rms\n'
|
||||
'real*8, dimension(4), parameter :: x = [4, 2, 2, 2]\n'
|
||||
'print "(f7.5)", dsqrt(7d0) - rms(x)\n'
|
||||
'end program\n'
|
||||
))
|
||||
], clean=True)
|
||||
assert '0.00000' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_ImpliedDoLoop():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
a, i = symbols('a i', integer=True)
|
||||
idl = ImpliedDoLoop(i**3, i, -3, 3, 2)
|
||||
ac = ArrayConstructor([-28, idl, 28])
|
||||
a = array(a, dim=[':'], attrs=[allocatable])
|
||||
prog = Program('idlprog', [
|
||||
a.as_Declaration(),
|
||||
Assignment(a, ac),
|
||||
Print([a])
|
||||
])
|
||||
fsrc = fcode(prog, standard=2003, source_format='free')
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fsrc)], clean=True)
|
||||
for numstr in '-28 -27 -1 1 27 28'.split():
|
||||
assert numstr in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Program():
|
||||
x = Symbol('x', real=True)
|
||||
vx = Variable.deduced(x, 42)
|
||||
decl = Declaration(vx)
|
||||
prnt = Print([x, x+1])
|
||||
prog = Program('foo', [decl, prnt])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([('main.f90', fcode(prog, standard=90))], clean=True)
|
||||
assert '42' in stdout
|
||||
assert '43' in stdout
|
||||
assert stderr == ''
|
||||
assert info['exit_status'] == os.EX_OK
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_Module():
|
||||
x = Symbol('x', real=True)
|
||||
v_x = Variable.deduced(x)
|
||||
sq = FunctionDefinition(real, 'sqr', [v_x], [Return(x**2)])
|
||||
mod_sq = Module('mod_sq', [], [sq])
|
||||
sq_call = FunctionCall('sqr', [42.])
|
||||
prg_sq = Program('foobar', [
|
||||
use('mod_sq', only=['sqr']),
|
||||
Print(['"Square of 42 = "', sq_call])
|
||||
])
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('mod_sq.f90', fcode(mod_sq, standard=90)),
|
||||
('main.f90', fcode(prg_sq, standard=90))
|
||||
], clean=True)
|
||||
assert '42' in stdout
|
||||
assert str(42**2) in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
@XFAIL # https://github.com/sympy/sympy/issues/20265
|
||||
@may_xfail
|
||||
def test_Subroutine():
|
||||
# Code to generate the subroutine in the example from
|
||||
# http://www.fortran90.org/src/best-practices.html#arrays
|
||||
r = Symbol('r', real=True)
|
||||
i = Symbol('i', integer=True)
|
||||
v_r = Variable.deduced(r, attrs=(dimension(assumed_extent), intent_out))
|
||||
v_i = Variable.deduced(i)
|
||||
v_n = Variable('n', integer)
|
||||
do_loop = Do([
|
||||
Assignment(Element(r, [i]), literal_dp(1)/i**2)
|
||||
], i, 1, v_n)
|
||||
sub = Subroutine("f", [v_r], [
|
||||
Declaration(v_n),
|
||||
Declaration(v_i),
|
||||
Assignment(v_n, size(r)),
|
||||
do_loop
|
||||
])
|
||||
x = Symbol('x', real=True)
|
||||
v_x3 = Variable.deduced(x, attrs=[dimension(3)])
|
||||
mod = Module('mymod', definitions=[sub])
|
||||
prog = Program('foo', [
|
||||
use(mod, only=[sub]),
|
||||
Declaration(v_x3),
|
||||
SubroutineCall(sub, [v_x3]),
|
||||
Print([sum_(v_x3), v_x3])
|
||||
])
|
||||
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
|
||||
(stdout, stderr), info = compile_run_strings([
|
||||
('a.f90', fcode(mod, standard=90)),
|
||||
('b.f90', fcode(prog, standard=90))
|
||||
], clean=True)
|
||||
ref = [1.0/i**2 for i in range(1, 4)]
|
||||
assert str(sum(ref))[:-3] in stdout
|
||||
for _ in ref:
|
||||
assert str(_)[:-3] in stdout
|
||||
assert stderr == ''
|
||||
|
||||
|
||||
def test_isign():
|
||||
x = Symbol('x', integer=True)
|
||||
assert unchanged(isign, 1, x)
|
||||
assert fcode(isign(1, x), standard=95, source_format='free') == 'isign(1, x)'
|
||||
|
||||
|
||||
def test_dsign():
|
||||
x = Symbol('x')
|
||||
assert unchanged(dsign, 1, x)
|
||||
assert fcode(dsign(literal_dp(1), x), standard=95, source_format='free') == 'dsign(1d0, x)'
|
||||
|
||||
|
||||
def test_cmplx():
|
||||
x = Symbol('x')
|
||||
assert unchanged(cmplx, 1, x)
|
||||
|
||||
|
||||
def test_kind():
|
||||
x = Symbol('x')
|
||||
assert unchanged(kind, x)
|
||||
|
||||
|
||||
def test_literal_dp():
|
||||
assert fcode(literal_dp(0), source_format='free') == '0d0'
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_bind_C():
|
||||
if not has_fortran():
|
||||
skip("No fortran compiler found.")
|
||||
if not cython:
|
||||
skip("Cython not found.")
|
||||
if not np:
|
||||
skip("NumPy not found.")
|
||||
|
||||
a = Symbol('a', real=True)
|
||||
s = Symbol('s', integer=True)
|
||||
body = [Return((sum_(a**2)/s)**.5)]
|
||||
arr = array(a, dim=[s], intent='in')
|
||||
fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
|
||||
f_mod = render_as_module([fd], 'mod_rms')
|
||||
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings([
|
||||
('rms.f90', f_mod),
|
||||
('_rms.pyx', (
|
||||
"#cython: language_level={}\n".format("3") +
|
||||
"cdef extern double rms(double*, int*)\n"
|
||||
"def py_rms(double[::1] x):\n"
|
||||
" cdef int s = x.size\n"
|
||||
" return rms(&x[0], &s)\n"))
|
||||
], build_dir=folder)
|
||||
assert abs(mod.py_rms(np.array([2., 4., 2., 2.])) - 7**0.5) < 1e-14
|
||||
@@ -0,0 +1,50 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.core.function import Function
|
||||
from sympy.matrices.dense import Matrix
|
||||
from sympy.matrices.dense import zeros
|
||||
from sympy.simplify.simplify import simplify
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.utilities.lambdify import lambdify
|
||||
from sympy.printing.numpy import NumPyPrinter
|
||||
from sympy.testing.pytest import skip
|
||||
from sympy.external import import_module
|
||||
|
||||
|
||||
def test_matrix_solve_issue_24862():
|
||||
A = Matrix(3, 3, symbols('a:9'))
|
||||
b = Matrix(3, 1, symbols('b:3'))
|
||||
hash(MatrixSolve(A, b))
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_exact():
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
x_lu = A.LUsolve(b)
|
||||
dxdq_lu = A.LUsolve(b.diff(q) - A.diff(q) * A.LUsolve(b))
|
||||
assert simplify(x_lu.diff(q) - dxdq_lu) == zeros(2, 1)
|
||||
# dxdq_ms is the MatrixSolve equivalent of dxdq_lu
|
||||
dxdq_ms = MatrixSolve(A, b.diff(q) - A.diff(q) * MatrixSolve(A, b))
|
||||
assert MatrixSolve(A, b).diff(q) == dxdq_ms
|
||||
|
||||
|
||||
def test_matrix_solve_derivative_numpy():
|
||||
np = import_module('numpy')
|
||||
if not np:
|
||||
skip("numpy not installed.")
|
||||
q = symbols('q')
|
||||
a11, a12, a21, a22, b1, b2 = (
|
||||
f(q) for f in symbols('a11 a12 a21 a22 b1 b2', cls=Function))
|
||||
A = Matrix([[a11, a12], [a21, a22]])
|
||||
b = Matrix([b1, b2])
|
||||
dx_lu = A.LUsolve(b).diff(q)
|
||||
subs = {a11.diff(q): 0.2, a12.diff(q): 0.3, a21.diff(q): 0.1,
|
||||
a22.diff(q): 0.5, b1.diff(q): 0.4, b2.diff(q): 0.9,
|
||||
a11: 1.3, a12: 0.5, a21: 1.2, a22: 4, b1: 6.2, b2: 3.5}
|
||||
p, p_vals = zip(*subs.items())
|
||||
dx_sm = MatrixSolve(A, b).diff(q)
|
||||
np.testing.assert_allclose(
|
||||
lambdify(p, dx_sm, printer=NumPyPrinter)(*p_vals),
|
||||
lambdify(p, dx_lu, printer=NumPyPrinter)(*p_vals))
|
||||
@@ -0,0 +1,69 @@
|
||||
from itertools import product
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.miscellaneous import Max, Min
|
||||
from sympy.printing.repr import srepr
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2, minimum, maximum, amax, amin
|
||||
from sympy.testing.pytest import raises
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
def test_logaddexp():
|
||||
lae_xy = logaddexp(x, y)
|
||||
ref_xy = log(exp(x) + exp(y))
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).simplify() == 0
|
||||
|
||||
one_third_e = 1*exp(1)/3
|
||||
two_thirds_e = 2*exp(1)/3
|
||||
logThirdE = log(one_third_e)
|
||||
logTwoThirdsE = log(two_thirds_e)
|
||||
lae_sum_to_e = logaddexp(logThirdE, logTwoThirdsE)
|
||||
assert lae_sum_to_e.rewrite(log) == 1
|
||||
assert lae_sum_to_e.simplify() == 1
|
||||
was = logaddexp(2, 3)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with 2, 3
|
||||
|
||||
|
||||
def test_logaddexp2():
|
||||
lae2_xy = logaddexp2(x, y)
|
||||
ref2_xy = log(2**x + 2**y)/log(2)
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
lae2_xy.diff(wrt, deriv_order) -
|
||||
ref2_xy.diff(wrt, deriv_order)
|
||||
).rewrite(log).cancel() == 0
|
||||
|
||||
def lb(x):
|
||||
return log(x)/log(2)
|
||||
|
||||
two_thirds = S.One*2/3
|
||||
four_thirds = 2*two_thirds
|
||||
lbTwoThirds = lb(two_thirds)
|
||||
lbFourThirds = lb(four_thirds)
|
||||
lae2_sum_to_2 = logaddexp2(lbTwoThirds, lbFourThirds)
|
||||
assert lae2_sum_to_2.rewrite(log) == 1
|
||||
assert lae2_sum_to_2.simplify() == 1
|
||||
was = logaddexp2(x, y)
|
||||
assert srepr(was) == srepr(was.simplify()) # cannot simplify with x, y
|
||||
|
||||
|
||||
def test_minimum_maximum():
|
||||
for MM, mm in zip([Min, Max], [minimum, maximum]):
|
||||
ref = MM(x, y, z)
|
||||
m = mm(x, y, z)
|
||||
assert m != ref
|
||||
assert m.rewrite(MM) == ref
|
||||
|
||||
|
||||
def test_amin_amax():
|
||||
for am in [amin, amax]:
|
||||
assert am(x).array == x
|
||||
assert am(x).axis == None
|
||||
assert am(x, axis=3).axis == 3
|
||||
with raises(ValueError):
|
||||
am(x, y, z)
|
||||
@@ -0,0 +1,13 @@
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.codegen.pynodes import List
|
||||
|
||||
|
||||
def test_List():
|
||||
l = List(2, 3, 4)
|
||||
assert l == List(2, 3, 4)
|
||||
assert str(l) == "[2, 3, 4]"
|
||||
x, y, z = symbols('x y z')
|
||||
l = List(x**2,y**3,z**4)
|
||||
# contrary to python's built-in list, we can call e.g. "replace" on List.
|
||||
m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp)
|
||||
assert m == [x**2, y-3, z-4]
|
||||
@@ -0,0 +1,7 @@
|
||||
from sympy.codegen.ast import Print
|
||||
from sympy.codegen.pyutils import render_as_module
|
||||
|
||||
def test_standard():
|
||||
ast = Print('x y'.split(), r"coordinate: %12.5g %12.5g\n")
|
||||
assert render_as_module(ast, standard='python3') == \
|
||||
'\n\nprint("coordinate: %12.5g %12.5g\\n" % (x, y), end="")'
|
||||
@@ -0,0 +1,479 @@
|
||||
import tempfile
|
||||
from sympy.core.numbers import pi, Rational
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.singleton import S
|
||||
from sympy.core.symbol import Symbol
|
||||
from sympy.functions.elementary.complexes import Abs
|
||||
from sympy.functions.elementary.exponential import (exp, log)
|
||||
from sympy.functions.elementary.trigonometric import (cos, sin, sinc)
|
||||
from sympy.matrices.expressions.matexpr import MatrixSymbol
|
||||
from sympy.assumptions import assuming, Q
|
||||
from sympy.external import import_module
|
||||
from sympy.printing.codeprinter import ccode
|
||||
from sympy.codegen.matrix_nodes import MatrixSolve
|
||||
from sympy.codegen.cfunctions import log2, exp2, expm1, log1p
|
||||
from sympy.codegen.numpy_nodes import logaddexp, logaddexp2
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
from sympy.codegen.rewriting import (
|
||||
optimize, cosm1_opt, log2_opt, exp2_opt, expm1_opt, log1p_opt, powm1_opt, optims_c99,
|
||||
create_expand_pow_optimization, matinv_opt, logaddexp_opt, logaddexp2_opt,
|
||||
optims_numpy, optims_scipy, sinc_opts, FuncMinusOneOptim
|
||||
)
|
||||
from sympy.testing.pytest import XFAIL, skip
|
||||
from sympy.utilities import lambdify
|
||||
from sympy.utilities._compilation import compile_link_import_strings, has_c
|
||||
from sympy.utilities._compilation.util import may_xfail
|
||||
|
||||
cython = import_module('cython')
|
||||
numpy = import_module('numpy')
|
||||
scipy = import_module('scipy')
|
||||
|
||||
|
||||
def test_log2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 7*log(3*x + 5)/(log(2))
|
||||
opt1 = optimize(expr1, [log2_opt])
|
||||
assert opt1 == 7*log2(3*x + 5)
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = 3*log(5*x + 7)/(13*log(2))
|
||||
opt2 = optimize(expr2, [log2_opt])
|
||||
assert opt2 == 3*log2(5*x + 7)/13
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2)
|
||||
opt3 = optimize(expr3, [log2_opt])
|
||||
assert opt3 == log2(x)
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x)/log(2) + log(x+1)
|
||||
opt4 = optimize(expr4, [log2_opt])
|
||||
assert opt4 == log2(x) + log(2)*log2(x+1)
|
||||
assert opt4.rewrite(log) == expr4
|
||||
|
||||
expr5 = log(17)
|
||||
opt5 = optimize(expr5, [log2_opt])
|
||||
assert opt5 == expr5
|
||||
|
||||
expr6 = log(x + 3)/log(2)
|
||||
opt6 = optimize(expr6, [log2_opt])
|
||||
assert str(opt6) == 'log2(x + 3)'
|
||||
assert opt6.rewrite(log) == expr6
|
||||
|
||||
|
||||
def test_exp2_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = 1 + 2**x
|
||||
opt1 = optimize(expr1, [exp2_opt])
|
||||
assert opt1 == 1 + exp2(x)
|
||||
assert opt1.rewrite(Pow) == expr1
|
||||
|
||||
expr2 = 1 + 3**x
|
||||
assert expr2 == optimize(expr2, [exp2_opt])
|
||||
|
||||
|
||||
def test_expm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = exp(x) - 1
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert expm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(exp) == expr1
|
||||
|
||||
expr2 = 3*exp(x) - 3
|
||||
opt2 = optimize(expr2, [expm1_opt])
|
||||
assert 3*expm1(x) == opt2
|
||||
assert opt2.rewrite(exp) == expr2
|
||||
|
||||
expr3 = 3*exp(x) - 5
|
||||
opt3 = optimize(expr3, [expm1_opt])
|
||||
assert 3*expm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(exp) == expr3
|
||||
expm1_opt_non_opportunistic = FuncMinusOneOptim(exp, expm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [expm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [expm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [expm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*exp(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [expm1_opt])
|
||||
assert 3*expm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(exp) == expr4
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, [expm1_opt])
|
||||
assert 3*expm1(2*x) == opt5
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = (2*exp(x) + 1)/(exp(x) + 1) + 1
|
||||
opt6 = optimize(expr6, [expm1_opt])
|
||||
assert opt6.count_ops() <= expr6.count_ops()
|
||||
|
||||
def ev(e):
|
||||
return e.subs(x, 3).evalf()
|
||||
assert abs(ev(expr6) - ev(opt6)) < 1e-15
|
||||
|
||||
y = Symbol('y')
|
||||
expr7 = (2*exp(x) - 1)/(1 - exp(y)) - 1/(1-exp(y))
|
||||
opt7 = optimize(expr7, [expm1_opt])
|
||||
assert -2*expm1(x)/expm1(y) == opt7
|
||||
assert (opt7.rewrite(exp) - expr7).factor() == 0
|
||||
|
||||
expr8 = (1+exp(x))**2 - 4
|
||||
opt8 = optimize(expr8, [expm1_opt])
|
||||
tgt8a = (exp(x) + 3)*expm1(x)
|
||||
tgt8b = 2*expm1(x) + expm1(2*x)
|
||||
# Both tgt8a & tgt8b seem to give full precision (~16 digits for double)
|
||||
# for x=1e-7 (compare with expr8 which only achieves ~8 significant digits).
|
||||
# If we can show that either tgt8a or tgt8b is preferable, we can
|
||||
# change this test to ensure the preferable version is returned.
|
||||
assert (tgt8a - tgt8b).rewrite(exp).factor() == 0
|
||||
assert opt8 in (tgt8a, tgt8b)
|
||||
assert (opt8.rewrite(exp) - expr8).factor() == 0
|
||||
|
||||
expr9 = sin(expr8)
|
||||
opt9 = optimize(expr9, [expm1_opt])
|
||||
tgt9a = sin(tgt8a)
|
||||
tgt9b = sin(tgt8b)
|
||||
assert opt9 in (tgt9a, tgt9b)
|
||||
assert (opt9.rewrite(exp) - expr9.rewrite(exp)).factor().is_zero
|
||||
|
||||
|
||||
def test_expm1_two_exp_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = exp(x) + exp(y) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt])
|
||||
assert opt1 == expm1(x) + expm1(y)
|
||||
|
||||
|
||||
def test_cosm1_opt():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert cosm1(x) - opt1 == 0
|
||||
assert opt1.rewrite(cos) == expr1
|
||||
|
||||
expr2 = 3*cos(x) - 3
|
||||
opt2 = optimize(expr2, [cosm1_opt])
|
||||
assert 3*cosm1(x) == opt2
|
||||
assert opt2.rewrite(cos) == expr2
|
||||
|
||||
expr3 = 3*cos(x) - 5
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert 3*cosm1(x) - 2 == opt3
|
||||
assert opt3.rewrite(cos) == expr3
|
||||
cosm1_opt_non_opportunistic = FuncMinusOneOptim(cos, cosm1, opportunistic=False)
|
||||
assert expr3 == optimize(expr3, [cosm1_opt_non_opportunistic])
|
||||
assert opt1 == optimize(expr1, [cosm1_opt_non_opportunistic])
|
||||
assert opt2 == optimize(expr2, [cosm1_opt_non_opportunistic])
|
||||
|
||||
expr4 = 3*cos(x) + log(x) - 3
|
||||
opt4 = optimize(expr4, [cosm1_opt])
|
||||
assert 3*cosm1(x) + log(x) == opt4
|
||||
assert opt4.rewrite(cos) == expr4
|
||||
|
||||
expr5 = 3*cos(2*x) - 3
|
||||
opt5 = optimize(expr5, [cosm1_opt])
|
||||
assert 3*cosm1(2*x) == opt5
|
||||
assert opt5.rewrite(cos) == expr5
|
||||
|
||||
expr6 = 2 - 2*cos(x)
|
||||
opt6 = optimize(expr6, [cosm1_opt])
|
||||
assert -2*cosm1(x) == opt6
|
||||
assert opt6.rewrite(cos) == expr6
|
||||
|
||||
|
||||
def test_cosm1_two_cos_terms():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = cos(x) + cos(y) - 2
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == cosm1(x) + cosm1(y)
|
||||
|
||||
|
||||
def test_expm1_cosm1_mixed():
|
||||
x = Symbol('x')
|
||||
expr1 = exp(x) + cos(x) - 2
|
||||
opt1 = optimize(expr1, [expm1_opt, cosm1_opt])
|
||||
assert opt1 == cosm1(x) + expm1(x)
|
||||
|
||||
|
||||
def _check_num_lambdify(expr, opt, val_subs, approx_ref, lambdify_kw=None, poorness=1e10):
|
||||
""" poorness=1e10 signifies that `expr` loses precision of at least ten decimal digits. """
|
||||
num_ref = expr.subs(val_subs).evalf()
|
||||
eps = numpy.finfo(numpy.float64).eps
|
||||
assert abs(num_ref - approx_ref) < approx_ref*eps
|
||||
f1 = lambdify(list(val_subs.keys()), opt, **(lambdify_kw or {}))
|
||||
args_float = tuple(map(float, val_subs.values()))
|
||||
num_err1 = abs(f1(*args_float) - approx_ref)
|
||||
assert num_err1 < abs(num_ref*eps)
|
||||
f2 = lambdify(list(val_subs.keys()), expr, **(lambdify_kw or {}))
|
||||
num_err2 = abs(f2(*args_float) - approx_ref)
|
||||
assert num_err2 > abs(num_ref*eps*poorness) # this only ensures that the *test* works as intended
|
||||
|
||||
|
||||
def test_cosm1_apart():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 1/cos(x) - 1
|
||||
opt1 = optimize(expr1, [cosm1_opt])
|
||||
assert opt1 == -cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr1, opt1, {x: S(10)**-30}, 5e-61, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr2 = 2/cos(x) - 2
|
||||
opt2 = optimize(expr2, optims_scipy)
|
||||
assert opt2 == -2*cosm1(x)/cos(x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr2, opt2, {x: S(10)**-30}, 1e-60, lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
expr3 = pi/cos(3*x) - pi
|
||||
opt3 = optimize(expr3, [cosm1_opt])
|
||||
assert opt3 == -pi*cosm1(3*x)/cos(3*x)
|
||||
if scipy:
|
||||
_check_num_lambdify(expr3, opt3, {x: S(10)**-30/3}, float(5e-61*pi), lambdify_kw={"modules": 'scipy'})
|
||||
|
||||
|
||||
def test_powm1():
|
||||
args = x, y = map(Symbol, "xy")
|
||||
|
||||
expr1 = x**y - 1
|
||||
opt1 = optimize(expr1, [powm1_opt])
|
||||
assert opt1 == powm1(x, y)
|
||||
for arg in args:
|
||||
assert expr1.diff(arg) == opt1.diff(arg)
|
||||
if scipy and tuple(map(int, scipy.version.version.split('.')[:3])) >= (1, 10, 0):
|
||||
subs1_a = {x: Rational(*(1.0+1e-13).as_integer_ratio()), y: pi}
|
||||
ref1_f64_a = 3.139081648208105e-13
|
||||
_check_num_lambdify(expr1, opt1, subs1_a, ref1_f64_a, lambdify_kw={"modules": 'scipy'}, poorness=10**11)
|
||||
|
||||
subs1_b = {x: pi, y: Rational(*(1e-10).as_integer_ratio())}
|
||||
ref1_f64_b = 1.1447298859149205e-10
|
||||
_check_num_lambdify(expr1, opt1, subs1_b, ref1_f64_b, lambdify_kw={"modules": 'scipy'}, poorness=10**9)
|
||||
|
||||
|
||||
def test_log1p_opt():
|
||||
x = Symbol('x')
|
||||
expr1 = log(x + 1)
|
||||
opt1 = optimize(expr1, [log1p_opt])
|
||||
assert log1p(x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
expr2 = log(3*x + 3)
|
||||
opt2 = optimize(expr2, [log1p_opt])
|
||||
assert log1p(x) + log(3) == opt2
|
||||
assert (opt2.rewrite(log) - expr2).simplify() == 0
|
||||
|
||||
expr3 = log(2*x + 1)
|
||||
opt3 = optimize(expr3, [log1p_opt])
|
||||
assert log1p(2*x) - opt3 == 0
|
||||
assert opt3.rewrite(log) == expr3
|
||||
|
||||
expr4 = log(x+3)
|
||||
opt4 = optimize(expr4, [log1p_opt])
|
||||
assert str(opt4) == 'log(x + 3)'
|
||||
|
||||
|
||||
def test_optims_c99():
|
||||
x = Symbol('x')
|
||||
|
||||
expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1
|
||||
opt1 = optimize(expr1, optims_c99).simplify()
|
||||
assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x)
|
||||
assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1
|
||||
|
||||
expr2 = log(x)/log(2) + log(x + 1)
|
||||
opt2 = optimize(expr2, optims_c99)
|
||||
assert opt2 == log2(x) + log1p(x)
|
||||
assert opt2.rewrite(log) == expr2
|
||||
|
||||
expr3 = log(x)/log(2) + log(17*x + 17)
|
||||
opt3 = optimize(expr3, optims_c99)
|
||||
delta3 = opt3 - (log2(x) + log(17) + log1p(x))
|
||||
assert delta3 == 0
|
||||
assert (opt3.rewrite(log) - expr3).simplify() == 0
|
||||
|
||||
expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17)
|
||||
opt4 = optimize(expr4, optims_c99).simplify()
|
||||
delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x))
|
||||
assert delta4 == 0
|
||||
assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0
|
||||
|
||||
expr5 = 3*exp(2*x) - 3
|
||||
opt5 = optimize(expr5, optims_c99)
|
||||
delta5 = opt5 - 3*expm1(2*x)
|
||||
assert delta5 == 0
|
||||
assert opt5.rewrite(exp) == expr5
|
||||
|
||||
expr6 = exp(2*x) - 3
|
||||
opt6 = optimize(expr6, optims_c99)
|
||||
assert opt6 in (expm1(2*x) - 2, expr6) # expm1(2*x) - 2 is not better or worse
|
||||
|
||||
expr7 = log(3*x + 3)
|
||||
opt7 = optimize(expr7, optims_c99)
|
||||
delta7 = opt7 - (log(3) + log1p(x))
|
||||
assert delta7 == 0
|
||||
assert (opt7.rewrite(log) - expr7).simplify() == 0
|
||||
|
||||
expr8 = log(2*x + 3)
|
||||
opt8 = optimize(expr8, optims_c99)
|
||||
assert opt8 == expr8
|
||||
|
||||
|
||||
def test_create_expand_pow_optimization():
|
||||
cc = lambda x: ccode(
|
||||
optimize(x, [create_expand_pow_optimization(4)]))
|
||||
x = Symbol('x')
|
||||
assert cc(x**4) == 'x*x*x*x'
|
||||
assert cc(x**4 + x**2) == 'x*x + x*x*x*x'
|
||||
assert cc(x**5 + x**4) == 'pow(x, 5) + x*x*x*x'
|
||||
assert cc(sin(x)**4) == 'pow(sin(x), 4)'
|
||||
# gh issue 15335
|
||||
assert cc(x**(-4)) == '1.0/(x*x*x*x)'
|
||||
assert cc(x**(-5)) == 'pow(x, -5)'
|
||||
assert cc(-x**4) == '-(x*x*x*x)'
|
||||
assert cc(x**4 - x**2) == '-(x*x) + x*x*x*x'
|
||||
i = Symbol('i', integer=True)
|
||||
assert cc(x**i - x**2) == 'pow(x, i) - (x*x)'
|
||||
y = Symbol('y', real=True)
|
||||
assert cc(Abs(exp(y**4))) == "exp(y*y*y*y)"
|
||||
|
||||
# gh issue 20753
|
||||
cc2 = lambda x: ccode(optimize(x, [create_expand_pow_optimization(
|
||||
4, base_req=lambda b: b.is_Function)]))
|
||||
assert cc2(x**3 + sin(x)**3) == "pow(x, 3) + sin(x)*sin(x)*sin(x)"
|
||||
|
||||
|
||||
def test_matsolve():
|
||||
n = Symbol('n', integer=True)
|
||||
A = MatrixSymbol('A', n, n)
|
||||
x = MatrixSymbol('x', n, 1)
|
||||
|
||||
with assuming(Q.fullrank(A)):
|
||||
assert optimize(A**(-1) * x, [matinv_opt]) == MatrixSolve(A, x)
|
||||
assert optimize(A**(-1) * x + x, [matinv_opt]) == MatrixSolve(A, x) + x
|
||||
|
||||
|
||||
def test_logaddexp_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(exp(x) + exp(y))
|
||||
opt1 = optimize(expr1, [logaddexp_opt])
|
||||
assert logaddexp(x, y) - opt1 == 0
|
||||
assert logaddexp(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_logaddexp2_opt():
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
expr1 = log(2**x + 2**y)/log(2)
|
||||
opt1 = optimize(expr1, [logaddexp2_opt])
|
||||
assert logaddexp2(x, y) - opt1 == 0
|
||||
assert logaddexp2(y, x) - opt1 == 0
|
||||
assert opt1.rewrite(log) == expr1
|
||||
|
||||
|
||||
def test_sinc_opts():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, sinc_opts) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(x)/x : sinc(x),
|
||||
sin(2*x)/(2*x) : sinc(2*x),
|
||||
sin(3*x)/x : 3*sinc(3*x),
|
||||
x*sin(x) : x*sin(x)
|
||||
})
|
||||
|
||||
y = Symbol('y')
|
||||
check({
|
||||
sin(x*y)/(x*y) : sinc(x*y),
|
||||
y*sin(x/y)/x : sinc(x/y),
|
||||
sin(sin(x))/sin(x) : sinc(sin(x)),
|
||||
sin(3*sin(x))/sin(x) : 3*sinc(3*sin(x)),
|
||||
sin(x)/y : sin(x)/y
|
||||
})
|
||||
|
||||
|
||||
def test_optims_numpy():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x = Symbol('x')
|
||||
check({
|
||||
sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x),
|
||||
log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3)
|
||||
})
|
||||
|
||||
|
||||
@XFAIL # room for improvement, ideally this test case should pass.
|
||||
def test_optims_numpy_TODO():
|
||||
def check(d):
|
||||
for k, v in d.items():
|
||||
assert optimize(k, optims_numpy) == v
|
||||
|
||||
x, y = map(Symbol, 'x y'.split())
|
||||
check({
|
||||
log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y),
|
||||
exp(x*sin(y)/y) - 1: expm1(x*sinc(y))
|
||||
})
|
||||
|
||||
|
||||
@may_xfail
|
||||
def test_compiled_ccode_with_rewriting():
|
||||
if not cython:
|
||||
skip("cython not installed.")
|
||||
if not has_c():
|
||||
skip("No C compiler found.")
|
||||
|
||||
x = Symbol('x')
|
||||
about_two = 2**(58/S(117))*3**(97/S(117))*5**(4/S(39))*7**(92/S(117))/S(30)*pi
|
||||
# about_two: 1.999999999999581826
|
||||
unchanged = 2*exp(x) - about_two
|
||||
xval = S(10)**-11
|
||||
ref = unchanged.subs(x, xval).n(19) # 2.0418173913673213e-11
|
||||
|
||||
rewritten = optimize(2*exp(x) - about_two, [expm1_opt])
|
||||
|
||||
# Unfortunately, we need to call ``.n()`` on our expressions before we hand them
|
||||
# to ``ccode``, and we need to request a large number of significant digits.
|
||||
# In this test, results converged for double precision when the following number
|
||||
# of significant digits were chosen:
|
||||
NUMBER_OF_DIGITS = 25 # TODO: this should ideally be automatically handled.
|
||||
|
||||
func_c = '''
|
||||
#include <math.h>
|
||||
|
||||
double func_unchanged(double x) {
|
||||
return %(unchanged)s;
|
||||
}
|
||||
double func_rewritten(double x) {
|
||||
return %(rewritten)s;
|
||||
}
|
||||
''' % {"unchanged": ccode(unchanged.n(NUMBER_OF_DIGITS)),
|
||||
"rewritten": ccode(rewritten.n(NUMBER_OF_DIGITS))}
|
||||
|
||||
func_pyx = '''
|
||||
#cython: language_level=3
|
||||
cdef extern double func_unchanged(double)
|
||||
cdef extern double func_rewritten(double)
|
||||
def py_unchanged(x):
|
||||
return func_unchanged(x)
|
||||
def py_rewritten(x):
|
||||
return func_rewritten(x)
|
||||
'''
|
||||
with tempfile.TemporaryDirectory() as folder:
|
||||
mod, info = compile_link_import_strings(
|
||||
[('func.c', func_c), ('_func.pyx', func_pyx)],
|
||||
build_dir=folder, compile_kwargs={"std": 'c99'}
|
||||
)
|
||||
err_rewritten = abs(mod.py_rewritten(1e-11) - ref)
|
||||
err_unchanged = abs(mod.py_unchanged(1e-11) - ref)
|
||||
assert 1e-27 < err_rewritten < 1e-25 # highly accurate.
|
||||
assert 1e-19 < err_unchanged < 1e-16 # quite poor.
|
||||
|
||||
# Tolerances used above were determined as follows:
|
||||
# >>> no_opt = unchanged.subs(x, xval.evalf()).evalf()
|
||||
# >>> with_opt = rewritten.n(25).subs(x, 1e-11).evalf()
|
||||
# >>> with_opt - ref, no_opt - ref
|
||||
# (1.1536301877952077e-26, 1.6547074214222335e-18)
|
||||
@@ -0,0 +1,44 @@
|
||||
from itertools import product
|
||||
from sympy.core.power import Pow
|
||||
from sympy.core.symbol import symbols
|
||||
from sympy.functions.elementary.exponential import exp, log
|
||||
from sympy.functions.elementary.trigonometric import cos
|
||||
from sympy.core.numbers import pi
|
||||
from sympy.codegen.scipy_nodes import cosm1, powm1
|
||||
|
||||
x, y, z = symbols('x y z')
|
||||
|
||||
|
||||
def test_cosm1():
|
||||
cm1_xy = cosm1(x*y)
|
||||
ref_xy = cos(x*y) - 1
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
assert (
|
||||
cm1_xy.diff(wrt, deriv_order) -
|
||||
ref_xy.diff(wrt, deriv_order)
|
||||
).rewrite(cos).simplify() == 0
|
||||
|
||||
expr_minus2 = cosm1(pi)
|
||||
assert expr_minus2.rewrite(cos) == -2
|
||||
assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14
|
||||
assert cosm1(pi/2).simplify() == -1
|
||||
assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0
|
||||
|
||||
|
||||
def test_powm1():
|
||||
cases = {
|
||||
powm1(x, y): x**y - 1,
|
||||
powm1(x*y, z): (x*y)**z - 1,
|
||||
powm1(x, y*z): x**(y*z)-1,
|
||||
powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1
|
||||
}
|
||||
for pm1_e, ref_e in cases.items():
|
||||
for wrt, deriv_order in product([x, y, z], range(3)):
|
||||
der = pm1_e.diff(wrt, deriv_order)
|
||||
ref = ref_e.diff(wrt, deriv_order)
|
||||
delta = (der - ref).rewrite(Pow)
|
||||
assert delta.simplify() == 0
|
||||
|
||||
eulers_constant_m1 = powm1(x, 1/log(x))
|
||||
assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1
|
||||
assert eulers_constant_m1.simplify() == exp(1) - 1
|
||||
Reference in New Issue
Block a user