Skip to content

Commit f88d3c2

Browse files
authored
Convert RootSum to SymPy (#1136)
I split this out into a separate PR as I think it's going to require some more work. We already had a SymPy -> Mathics conversion for RootSum, this adds one going the other way too. However, it results in some tests failing as it means that Simplify automatically expands RootSums now, not sure if we want to add some hints to prevent that from happening.
1 parent 6212284 commit f88d3c2

File tree

6 files changed

+89
-13
lines changed

6 files changed

+89
-13
lines changed

mathics/builtin/functional/application.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,19 @@
1010

1111
from itertools import chain
1212

13+
import sympy
14+
15+
from mathics.core.atoms import Integer, Integer1
1316
from mathics.core.attributes import A_HOLD_ALL, A_N_HOLD_ALL, A_PROTECTED
14-
from mathics.core.builtin import Builtin, PostfixOperator
17+
from mathics.core.builtin import Builtin, PostfixOperator, SympyFunction
1518
from mathics.core.convert.sympy import SymbolFunction
1619
from mathics.core.evaluation import Evaluation
1720
from mathics.core.expression import Expression
18-
from mathics.core.symbols import Symbol
21+
from mathics.core.symbols import Symbol, sympy_slot_prefix
22+
from mathics.core.systemsymbols import SymbolSlot
1923

2024

21-
class Function(PostfixOperator):
25+
class Function(PostfixOperator, SympyFunction):
2226
"""
2327
<dl>
2428
<dt>'Function[$body$]'
@@ -119,9 +123,11 @@ def eval_named(self, vars, body, args, evaluation: Evaluation):
119123
# this is not included in WL, and here does not have any impact, but it is needed for
120124
# translating the function to a compiled version.
121125
var_names = (
122-
var.get_name()
123-
if isinstance(var, Symbol)
124-
else var.elements[0].get_name()
126+
(
127+
var.get_name()
128+
if isinstance(var, Symbol)
129+
else var.elements[0].get_name()
130+
)
125131
for var in vars
126132
)
127133
vars = dict(list(zip(var_names, args[: len(vars)])))
@@ -148,8 +154,17 @@ def eval_named_attr(self, vars, body, attr, args, evaluation: Evaluation):
148154
except Exception:
149155
return
150156

157+
def to_sympy(self, expr: Expression, **kwargs):
158+
if len(expr.elements) == 1:
159+
body = expr.elements[0]
160+
slot = Expression(SymbolSlot, Integer1)
161+
return sympy.Lambda(slot.to_sympy(), body.to_sympy())
162+
else:
163+
# TODO: Handle multiple and/or named arguments
164+
raise NotImplementedError
165+
151166

152-
class Slot(Builtin):
167+
class Slot(SympyFunction):
153168
"""
154169
<dl>
155170
<dt>'#$n$'
@@ -184,6 +199,10 @@ class Slot(Builtin):
184199
}
185200
summary_text = "one argument of a pure function"
186201

202+
def to_sympy(self, expr: Expression, **kwargs):
203+
index: Integer = expr.elements[0]
204+
return sympy.Symbol(f"{sympy_slot_prefix}{index.get_int_value()}")
205+
187206

188207
class SlotSequence(Builtin):
189208
"""

mathics/builtin/list/constructing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,10 +188,12 @@ class Normal(Builtin):
188188

189189
summary_text = "convert objects to normal expressions"
190190

191-
def eval_general(self, expr, evaluation: Evaluation):
191+
def eval_general(self, expr: Expression, evaluation: Evaluation):
192192
"Normal[expr_]"
193193
if isinstance(expr, Atom):
194194
return
195+
if expr.has_form("RootSum", 2):
196+
return from_sympy(expr.to_sympy().doit(roots=True))
195197
return Expression(
196198
expr.get_head(),
197199
*[Expression(SymbolNormal, element) for element in expr.elements],

mathics/builtin/numbers/calculus.py

Lines changed: 55 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,12 @@
4141
from mathics.core.convert.expression import to_expression, to_mathics_list
4242
from mathics.core.convert.function import expression_to_callable_and_args
4343
from mathics.core.convert.python import from_python
44-
from mathics.core.convert.sympy import SympyExpression, from_sympy, sympy_symbol_prefix
44+
from mathics.core.convert.sympy import (
45+
SymbolRootSum,
46+
SympyExpression,
47+
from_sympy,
48+
sympy_symbol_prefix,
49+
)
4550
from mathics.core.evaluation import Evaluation
4651
from mathics.core.expression import Expression
4752
from mathics.core.list import ListExpression
@@ -63,6 +68,7 @@
6368
SymbolConditionalExpression,
6469
SymbolD,
6570
SymbolDerivative,
71+
SymbolFunction,
6672
SymbolIndeterminate,
6773
SymbolInfinity,
6874
SymbolInfix,
@@ -76,6 +82,7 @@
7682
SymbolSeries,
7783
SymbolSeriesData,
7884
SymbolSimplify,
85+
SymbolSlot,
7986
SymbolUndefined,
8087
)
8188
from mathics.eval.makeboxes import format_element
@@ -1627,7 +1634,7 @@ class Root(SympyFunction):
16271634
16281635
Roots that can't be represented by radicals:
16291636
>> Root[#1 ^ 5 + 2 #1 + 1&, 2]
1630-
= Root[#1 ^ 5 + 2 #1 + 1&, 2]
1637+
= Root[1 + #1 ^ 5 + 2 #1&, 2]
16311638
"""
16321639

16331640
messages = {
@@ -1691,6 +1698,52 @@ def to_sympy(self, expr, **kwargs):
16911698
return None
16921699

16931700

1701+
class RootSum(SympyFunction):
1702+
"""
1703+
<url>:WMA link: https://reference.wolfram.com/language/ref/RootSum.html</url>
1704+
1705+
<dl>
1706+
<dt>'RootSum[$f$, $form$]'
1707+
<dd>sums $form[x]$ for all roots of the polynomial $f[x]$.
1708+
</dl>
1709+
1710+
>> Integrate[1/(x^5 + 11 x + 1), {x, 1, 3}]
1711+
= RootSum[-1 - 212960 #1 ^ 3 - 9680 #1 ^ 2 - 165 #1 + 41232181 #1 ^ 5&, (Log[3749971 - 3512322106304 #1 ^ 4 + 453522741 #1 + 16326568676 #1 ^ 2 + 79825502416 #1 ^ 3] - 4 Log[5]) #1&] - RootSum[-1 - 212960 #1 ^ 3 - 9680 #1 ^ 2 - 165 #1 + 41232181 #1 ^ 5&, (Log[3748721 - 3512322106304 #1 ^ 4 + 453522741 #1 + 16326568676 #1 ^ 2 + 79825502416 #1 ^ 3] - 4 Log[5]) #1&]
1712+
>> N[%, 50]
1713+
= 0.051278805184286949884270940103072421286139857550894
1714+
1715+
>> RootSum[#^5 - 11 # + 1 &, (#^2 - 1)/(#^3 - 2 # + c) &]
1716+
= (538 - 88 c + 396 c ^ 2 + 5 c ^ 3 - 5 c ^ 4) / (97 - 529 c - 53 c ^ 2 + 88 c ^ 3 + c ^ 5)
1717+
1718+
>> RootSum[#^5 - 3 # - 7 &, Sin] //N//Chop
1719+
= 0.292188
1720+
1721+
Use Normal to expand RootSum:
1722+
>> RootSum[1+#+#^2+#^3+#^4 &, Log[x + #] &]
1723+
= RootSum[1 + #1 ^ 2 + #1 ^ 3 + #1 ^ 4 + #1&, Log[x + #1]&]
1724+
>> %//Normal
1725+
= Log[-1 / 4 - Sqrt[5] / 4 - I Sqrt[5 / 8 - Sqrt[5] / 8] + x] + Log[-1 / 4 - Sqrt[5] / 4 + I Sqrt[5 / 8 - Sqrt[5] / 8] + x] + Log[-1 / 4 - I Sqrt[5 / 8 + Sqrt[5] / 8] + Sqrt[5] / 4 + x] + Log[-1 / 4 + I Sqrt[5 / 8 + Sqrt[5] / 8] + Sqrt[5] / 4 + x]
1726+
"""
1727+
1728+
summary_text = "sum polynomial roots"
1729+
1730+
def eval(self, f, form, evaluation: Evaluation): # type: ignore[override]
1731+
"RootSum[f_, form_]"
1732+
return from_sympy(Expression(SymbolRootSum, f, form).to_sympy())
1733+
1734+
def to_sympy(self, expr: Expression, **kwargs):
1735+
func = expr.elements[1]
1736+
if not isinstance(func.to_sympy(), sympy.Lambda):
1737+
# eta conversion
1738+
func = Expression(
1739+
SymbolFunction, Expression(func, Expression(SymbolSlot, Integer1))
1740+
)
1741+
1742+
poly = expr.elements[0].to_sympy()
1743+
poly_x = sympy.Symbol("poly_x")
1744+
return sympy.RootSum(poly(poly_x), func.to_sympy(), x=poly_x)
1745+
1746+
16941747
class Series(Builtin):
16951748
"""
16961749
<url>:WMA link:https://reference.wolfram.com/language/ref/Series.html</url>

mathics/core/convert/sympy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def __new__(cls, *exprs):
139139
if all(isinstance(expr, BasicSympy) for expr in exprs):
140140
# called with SymPy arguments
141141
obj = super().__new__(cls, *exprs)
142+
obj.expr = None
142143
elif len(exprs) == 1 and isinstance(exprs[0], Expression):
143144
# called with Mathics argument
144145
expr = exprs[0]
@@ -460,7 +461,7 @@ def old_from_sympy(expr) -> BaseElement:
460461
result.append(Expression(SymbolTimes, *factors))
461462
else:
462463
result.append(Integer1)
463-
return Expression(SymbolFunction, Expression(SymbolPlus, *result))
464+
return Expression(SymbolFunction, Expression(SymbolPlus, *sorted(result)))
464465
if isinstance(expr, sympy.CRootOf):
465466
try:
466467
e_root, indx = expr.args

mathics/eval/nevaluator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def eval_NValues(
9696

9797
# Special case for the Root builtin
9898
# This should be implemented as an NValue
99-
if expr.has_form("Root", 2):
99+
if expr.has_form("Root", 2) or expr.has_form("RootSum", 2):
100100
return from_sympy(sympy.N(expr.to_sympy(), d))
101101

102102
# Here we look for the NValues associated to the

mathics/eval/numbers/algebra/simplify.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@ def _default_complexity_function(x):
8282

8383
# At this point, ``complexity_function`` is a function that takes a
8484
# sympy expression and returns an integer.
85-
sympy_result = simplify(sympy_expr, measure=complexity_function)
85+
sympy_result = simplify(sympy_expr, measure=complexity_function, doit=False)
86+
sympy_result = sympy_result.doit(roots=False) # Don't expand RootSum
8687

8788
# and bring it back
8889
result = from_sympy(sympy_result).evaluate(evaluation)

0 commit comments

Comments
 (0)