Skip to content

Commit 51a5821

Browse files
committed
refactor arithmetic power
1 parent ada4a4a commit 51a5821

File tree

4 files changed

+235
-84
lines changed

4 files changed

+235
-84
lines changed

mathics/builtin/arithfns/basic.py

Lines changed: 61 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
77
"""
88

9+
import sympy
10+
911
from mathics.builtin.arithmetic import _MPMathFunction, create_infix
1012
from mathics.builtin.base import BinaryOperator, Builtin, PrefixOperator, SympyFunction
1113
from mathics.core.atoms import (
@@ -38,7 +40,6 @@
3840
Symbol,
3941
SymbolDivide,
4042
SymbolHoldForm,
41-
SymbolNull,
4243
SymbolPower,
4344
SymbolTimes,
4445
)
@@ -49,10 +50,17 @@
4950
SymbolInfix,
5051
SymbolLeft,
5152
SymbolMinus,
53+
SymbolOverflow,
5254
SymbolPattern,
53-
SymbolSequence,
5455
)
55-
from mathics.eval.arithmetic import eval_Plus, eval_Times
56+
from mathics.eval.arithmetic import (
57+
associate_powers,
58+
eval_Exponential,
59+
eval_Plus,
60+
eval_Power_inexact,
61+
eval_Power_number,
62+
eval_Times,
63+
)
5664
from mathics.eval.nevaluator import eval_N
5765
from mathics.eval.numerify import numerify
5866

@@ -520,6 +528,8 @@ class Power(BinaryOperator, _MPMathFunction):
520528
rules = {
521529
"Power[]": "1",
522530
"Power[x_]": "x",
531+
"Power[I,-1]": "-I",
532+
"Power[-1, 1/2]": "I",
523533
}
524534

525535
summary_text = "exponentiate"
@@ -528,15 +538,15 @@ class Power(BinaryOperator, _MPMathFunction):
528538
# Remember to up sympy doc link when this is corrected
529539
sympy_name = "Pow"
530540

541+
def eval_exp(self, x, evaluation):
542+
"Power[E, x]"
543+
return eval_Exponential(x)
544+
531545
def eval_check(self, x, y, evaluation):
532546
"Power[x_, y_]"
533-
534-
# Power uses _MPMathFunction but does some error checking first
535-
if isinstance(x, Number) and x.is_zero:
536-
if isinstance(y, Number):
537-
y_err = y
538-
else:
539-
y_err = eval_N(y, evaluation)
547+
# if x is zero
548+
if x.is_zero:
549+
y_err = y if isinstance(y, Number) else eval_N(y, evaluation)
540550
if isinstance(y_err, Number):
541551
py_y = y_err.round_to_float(permit_complex=True).real
542552
if py_y > 0:
@@ -550,17 +560,47 @@ def eval_check(self, x, y, evaluation):
550560
evaluation.message(
551561
"Power", "infy", Expression(SymbolPower, x, y_err)
552562
)
553-
return SymbolComplexInfinity
554-
if isinstance(x, Complex) and x.real.is_zero:
555-
yhalf = Expression(SymbolTimes, y, RationalOneHalf)
556-
factor = self.eval(Expression(SymbolSequence, x.imag, y), evaluation)
557-
return Expression(
558-
SymbolTimes, factor, Expression(SymbolPower, IntegerM1, yhalf)
559-
)
560-
561-
result = self.eval(Expression(SymbolSequence, x, y), evaluation)
562-
if result is None or result != SymbolNull:
563-
return result
563+
return SymbolComplexInfinity
564+
565+
# If x and y are inexact numbers, use the numerical function
566+
567+
if x.is_inexact() and y.is_inexact():
568+
try:
569+
return eval_Power_inexact(x, y)
570+
except OverflowError:
571+
evaluation.message("General", "ovfl")
572+
return Expression(SymbolOverflow)
573+
574+
# Tries to associate powers a^b^c-> a^(b*c)
575+
assoc = associate_powers(x, y)
576+
if not assoc.has_form("Power", 2):
577+
return assoc
578+
579+
assoc = numerify(assoc, evaluation)
580+
x, y = assoc.elements
581+
# If x and y are numbers
582+
if isinstance(x, Number) and isinstance(y, Number):
583+
try:
584+
return eval_Power_number(x, y)
585+
except OverflowError:
586+
evaluation.message("General", "ovfl")
587+
return Expression(SymbolOverflow)
588+
589+
# if x or y are inexact, leave the expression
590+
# as it is:
591+
if x.is_inexact() or y.is_inexact():
592+
return assoc
593+
594+
# Finally, try to convert to sympy
595+
base_sp, exp_sp = x.to_sympy(), y.to_sympy()
596+
if base_sp is None or exp_sp is None:
597+
# If base or exp can not be converted to sympy,
598+
# returns the result of applying the associative
599+
# rule.
600+
return assoc
601+
602+
result = from_sympy(sympy.Pow(base_sp, exp_sp))
603+
return result.evaluate_elements(evaluation)
564604

565605

566606
class Sqrt(SympyFunction):
@@ -788,7 +828,6 @@ def inverse(item):
788828
and isinstance(item.elements[1], (Integer, Rational, Real))
789829
and item.elements[1].to_sympy() < 0
790830
): # nopep8
791-
792831
negative.append(inverse(item))
793832
elif isinstance(item, Rational):
794833
numerator = item.numerator()

mathics/eval/arithmetic.py

Lines changed: 137 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# -*- coding: utf-8 -*-
22

33
"""
4-
arithmetic-related evaluation functions.
4+
helper functions for arithmetic evaluation, which do not
5+
depends on the evaluation context. Conversions to Sympy are
6+
used just as a last resource.
57
68
Many of these do do depend on the evaluation context. Conversions to Sympy are
79
used just as a last resource.
@@ -319,48 +321,27 @@ def eval_complex_sign(n: BaseElement) -> Optional[BaseElement]:
319321
sign = eval_RealSign(expr)
320322
return sign or eval_complex_sign(expr)
321323

322-
if expr.has_form("Power", 2):
323-
base, exp = expr.elements
324-
if exp.is_zero:
325-
return Integer1
326-
if isinstance(exp, (Integer, Real, Rational)):
327-
sign = eval_Sign(base) or Expression(SymbolSign, base)
328-
return Expression(SymbolPower, sign, exp)
329-
if isinstance(exp, Complex):
330-
sign = eval_Sign(base) or Expression(SymbolSign, base)
331-
return Expression(SymbolPower, sign, exp.real)
332-
if test_arithmetic_expr(exp):
333-
sign = eval_Sign(base) or Expression(SymbolSign, base)
334-
return Expression(SymbolPower, sign, exp)
335-
return None
336-
if expr.get_head() is SymbolTimes:
337-
abs_value = eval_Abs(eval_multiply_numbers(*expr.elements))
338-
if abs_value is Integer1:
339-
return expr
340-
if abs_value is None:
341-
return None
342-
criteria = eval_add_numbers(abs_value, IntegerM1)
343-
if test_zero_arithmetic_expr(criteria, numeric=True):
344-
return expr
345-
return None
346-
if expr.get_head() is SymbolPlus:
347-
abs_value = eval_Abs(eval_add_numbers(*expr.elements))
348-
if abs_value is Integer1:
349-
return expr
350-
if abs_value is None:
351-
return None
352-
criteria = eval_add_numbers(abs_value, IntegerM1)
353-
if test_zero_arithmetic_expr(criteria, numeric=True):
354-
return expr
355-
return None
356324

357-
if test_arithmetic_expr(expr):
358-
if test_zero_arithmetic_expr(expr):
359-
return Integer0
360-
if test_positive_arithmetic_expr(expr):
361-
return Integer1
362-
if test_negative_arithmetic_expr(expr):
363-
return IntegerM1
325+
def eval_Sign_number(n: Number) -> Number:
326+
"""
327+
Evals the absolute value of a number.
328+
"""
329+
if n.is_zero:
330+
return Integer0
331+
if isinstance(n, (Integer, Rational, Real)):
332+
return Integer1 if n.value > 0 else IntegerM1
333+
if isinstance(n, Complex):
334+
abs_sq = eval_add_numbers(
335+
*(eval_multiply_numbers(x, x) for x in (n.real, n.imag))
336+
)
337+
criteria = eval_add_numbers(abs_sq, IntegerM1)
338+
if test_zero_arithmetic_expr(criteria):
339+
return n
340+
if n.is_inexact():
341+
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RealM0p5))
342+
if test_zero_arithmetic_expr(criteria, numeric=True):
343+
return n
344+
return eval_multiply_numbers(n, eval_Power_number(abs_sq, RationalMOneHalf))
364345

365346

366347
def eval_mpmath_function(
@@ -390,6 +371,31 @@ def eval_mpmath_function(
390371
return call_mpmath(mpmath_function, tuple(mpmath_args), prec)
391372

392373

374+
def eval_Exponential(exp: BaseElement) -> BaseElement:
375+
"""
376+
Eval E^exp
377+
"""
378+
# If both base and exponent are exact quantities,
379+
# use sympy.
380+
381+
if not exp.is_inexact():
382+
exp_sp = exp.to_sympy()
383+
if exp_sp is None:
384+
return None
385+
return from_sympy(sympy.Exp(exp_sp))
386+
387+
prec = exp.get_precision()
388+
if prec is not None:
389+
if exp.is_machine_precision():
390+
number = mpmath.exp(exp.to_mpmath())
391+
result = from_mpmath(number)
392+
return result
393+
else:
394+
with mpmath.workprec(prec):
395+
number = mpmath.exp(exp.to_mpmath())
396+
return from_mpmath(number, prec)
397+
398+
393399
def eval_Plus(*items: BaseElement) -> BaseElement:
394400
"evaluate Plus for general elements"
395401
numbers, items_tuple = segregate_numbers_from_sorted_list(*items)
@@ -457,6 +463,13 @@ def append_last():
457463
elements_properties=ElementsProperties(False, False, True),
458464
)
459465

466+
elements.sort()
467+
return Expression(
468+
SymbolPlus,
469+
*elements,
470+
elements_properties=ElementsProperties(False, False, True),
471+
)
472+
460473

461474
def eval_Power_number(base: Number, exp: Number) -> Optional[Number]:
462475
"""
@@ -688,8 +701,88 @@ def eval_Times(*items: BaseElement) -> BaseElement:
688701
)
689702

690703

704+
# Here I used the convention of calling eval_* to functions that can produce a new expression, or None
705+
# if the result can not be evaluated, or is trivial. For example, if we call eval_Power_number(Integer2, RationalOneHalf)
706+
# it returns ``None`` instead of ``Expression(SymbolPower, Integer2, RationalOneHalf)``.
707+
# The reason is that these functions are written to be part of replacement rules, to be applied during the evaluation process.
708+
# In that process, a rule is considered applied if produces an expression that is different from the original one, or
709+
# if the replacement function returns (Python's) ``None``.
710+
#
711+
# For example, when the expression ``Power[4, 1/2]`` is evaluated, a (Builtin) rule ``Power[base_, exp_]->eval_repl_rule(base, expr)``
712+
# is applied. If the rule matches, `repl_rule` is called with arguments ``(4, 1/2)`` and produces `2`. As `Integer2.sameQ(Power[4, 1/2])`
713+
# is False, then no new rules for `Power` are checked, and a new round of evaluation is atempted.
714+
#
715+
# On the other hand, if ``Power[3, 1/2]``, ``repl_rule`` can do two possible things: one is return ``Power[3, 1/2]``. If it does,
716+
# the rule is considered applied. Then, the evaluation method checks if `Power[3, 1/2].sameQ(Power[3, 1/2])`. In this case it is true,
717+
# and then the expression is kept as it is.
718+
# The other possibility is to return (Python's) `None`. In that case, the evaluator considers that the rule failed to be applied,
719+
# and look for another rule associated to ``Power``. To return ``None`` produces then a faster evaluation, since no ``sameQ`` call is needed,
720+
# and do not prevent that other rules are attempted.
721+
#
722+
# The bad part of using ``None`` as a return is that I would expect that ``eval`` produces always a valid Expression, so if at some point of
723+
# the code I call ``eval_Power_number(Integer3, RationalOneHalf)`` I get ``Expression(SymbolPower, Integer3, RationalOneHalf)``.
724+
#
725+
# From my point of view, it would make more sense to use the following convention:
726+
# * if the method has signature ``eval_method(...)->BaseElement:`` then use the prefix ``eval_``
727+
# * if the method has the siguature ``apply_method(...)->Optional[BaseElement]`` use the prefix ``apply_`` or maybe ``repl_``.
728+
#
729+
# In any case, let's keep the current convention.
730+
#
731+
#
732+
733+
734+
def associate_powers(expr: BaseElement, power: BaseElement = Integer1) -> BaseElement:
735+
"""
736+
base^a^b^c^...^power -> base^(a*b*c*...power)
737+
provided one of the following cases
738+
* `a`, `b`, ... `power` are all integer numbers
739+
* `a`, `b`,... are Rational/Real number with absolute value <=1,
740+
and the other powers are not integer numbers.
741+
* `a` is not a Rational/Real number, and b, c, ... power are all
742+
integer numbers.
743+
"""
744+
powers = []
745+
base = expr
746+
if power is not Integer1:
747+
powers.append(power)
748+
749+
while base.has_form("Power", 2):
750+
previous_base, outer_power = base, power
751+
base, power = base.elements
752+
if len(powers) == 0:
753+
if power is not Integer1:
754+
powers.append(power)
755+
continue
756+
if power is IntegerM1:
757+
powers.append(power)
758+
continue
759+
if isinstance(power, (Rational, Real)):
760+
if abs(power.value) < 1:
761+
powers.append(power)
762+
continue
763+
# power is not rational/real and outer_power is integer,
764+
elif isinstance(outer_power, Integer):
765+
if power is not Integer1:
766+
powers.append(power)
767+
if isinstance(power, Integer):
768+
continue
769+
else:
770+
break
771+
# in any other case, use the previous base and
772+
# exit the loop
773+
base = previous_base
774+
break
775+
776+
if len(powers) == 0:
777+
return base
778+
elif len(powers) == 1:
779+
return Expression(SymbolPower, base, powers[0])
780+
result = Expression(SymbolPower, base, Expression(SymbolTimes, *powers))
781+
return result
782+
783+
691784
def eval_add_numbers(
692-
*numbers: Number,
785+
*numbers: List[Number],
693786
) -> BaseElement:
694787
"""
695788
Add the elements in ``numbers``.
@@ -736,7 +829,7 @@ def eval_inverse_number(n: Number) -> Number:
736829
return eval_Power_number(n, IntegerM1)
737830

738831

739-
def eval_multiply_numbers(*numbers: Number) -> Number:
832+
def eval_multiply_numbers(*numbers: Number) -> BaseElement:
740833
"""
741834
Multiply the elements in ``numbers``.
742835
"""

test/builtin/arithmetic/test_basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def test_multiply(str_expr, str_expected, msg):
153153
("a b DirectedInfinity[q]", "a b (q Infinity)", ""),
154154
# Failing tests
155155
# Problem with formatting. Parenthezise are missing...
156-
# ("a b DirectedInfinity[-I]", "a b (-I Infinity)", ""),
157-
# ("a b DirectedInfinity[-3]", "a b (-Infinity)", ""),
156+
("a b DirectedInfinity[-I]", "a b (-I Infinity)", ""),
157+
("a b DirectedInfinity[-3]", "a b (-Infinity)", ""),
158158
],
159159
)
160160
def test_directed_infinity_precedence(str_expr, str_expected, msg):
@@ -197,7 +197,7 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
197197
("I^(2/3)", "(-1) ^ (1 / 3)", None),
198198
# In WMA, the next test would return ``-(-I)^(2/3)``
199199
# which is less compact and elegant...
200-
# ("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
200+
("(-I)^(2/3)", "(-1) ^ (-1 / 3)", None),
201201
("(2+3I)^3", "-46 + 9 I", None),
202202
("(1.+3. I)^.6", "1.46069 + 1.35921 I", None),
203203
("3^(1+2 I)", "3 ^ (1 + 2 I)", None),
@@ -208,15 +208,15 @@ def test_directed_infinity_precedence(str_expr, str_expected, msg):
208208
# sympy, which produces the result
209209
("(3/Pi)^(-I)", "(3 / Pi) ^ (-I)", None),
210210
# Association rules
211-
# ('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
211+
('(a^"w")^2', 'a^(2 "w")', "Integer power of a power with string exponent"),
212212
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
213213
('(a^2)^"w"', '(a ^ 2) ^ "w"', None),
214214
("(a^2)^(1/2)", "Sqrt[a ^ 2]", None),
215215
("(a^(1/2))^2", "a", None),
216216
("(a^(1/2))^2", "a", None),
217217
("(a^(3/2))^3.", "(a ^ (3 / 2)) ^ 3.", None),
218-
# ("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
219-
# ("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
218+
("(a^(1/2))^3.", "a ^ 1.5", "Power associativity rational, real"),
219+
("(a^(.3))^3.", "a ^ 0.9", "Power associativity for real powers"),
220220
("(a^(1.3))^3.", "(a ^ 1.3) ^ 3.", None),
221221
# Exponentials involving expressions
222222
("(a^(p-2 q))^3", "a ^ (3 p - 6 q)", None),

0 commit comments

Comments
 (0)