Skip to content

Commit ebe5e21

Browse files
[3.14] gh-143006: Fix and optimize mixed comparison of float and int (GH-143084) (GH-143623)
When comparing negative non-integer float and int with the same number of bits in the integer part, __neg__() in the int subclass returning not an int caused an assertion error. Now the integer is no longer negated. Also, reduced the number of temporary created Python objects. (cherry picked from commit 66bca38) Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
1 parent 18f9af2 commit ebe5e21

File tree

3 files changed

+52
-47
lines changed

3 files changed

+52
-47
lines changed

Lib/test/test_float.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -651,6 +651,24 @@ class F(float, H):
651651
value = F('nan')
652652
self.assertEqual(hash(value), object.__hash__(value))
653653

654+
def test_issue_gh143006(self):
655+
# When comparing negative non-integer float and int with the
656+
# same number of bits in the integer part, __neg__() in the
657+
# int subclass returning not an int caused an assertion error.
658+
class EvilInt(int):
659+
def __neg__(self):
660+
return ""
661+
662+
i = -1 << 50
663+
f = float(i) - 0.5
664+
i = EvilInt(i)
665+
self.assertFalse(f == i)
666+
self.assertTrue(f != i)
667+
self.assertTrue(f < i)
668+
self.assertTrue(f <= i)
669+
self.assertFalse(f > i)
670+
self.assertFalse(f >= i)
671+
654672

655673
@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
656674
class FormatFunctionsTestCase(unittest.TestCase):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix a possible assertion error when comparing negative non-integer ``float``
2+
and ``int`` with the same number of bits in the integer part.

Objects/floatobject.c

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -442,82 +442,67 @@ float_richcompare(PyObject *v, PyObject *w, int op)
442442
assert(vsign != 0); /* if vsign were 0, then since wsign is
443443
* not 0, we would have taken the
444444
* vsign != wsign branch at the start */
445-
/* We want to work with non-negative numbers. */
446-
if (vsign < 0) {
447-
/* "Multiply both sides" by -1; this also swaps the
448-
* comparator.
449-
*/
450-
i = -i;
451-
op = _Py_SwappedOp[op];
452-
}
453-
assert(i > 0.0);
454445
(void) frexp(i, &exponent);
455446
/* exponent is the # of bits in v before the radix point;
456447
* we know that nbits (the # of bits in w) > 48 at this point
457448
*/
458449
if (exponent < nbits) {
459-
i = 1.0;
460-
j = 2.0;
450+
j = i;
451+
i = 0.0;
461452
goto Compare;
462453
}
463454
if (exponent > nbits) {
464-
i = 2.0;
465-
j = 1.0;
455+
j = 0.0;
466456
goto Compare;
467457
}
468458
/* v and w have the same number of bits before the radix
469-
* point. Construct two ints that have the same comparison
470-
* outcome.
459+
* point. Construct an int from the integer part of v and
460+
* update op if necessary, so comparing two ints has the same outcome.
471461
*/
472462
{
473463
double fracpart;
474464
double intpart;
475465
PyObject *result = NULL;
476466
PyObject *vv = NULL;
477-
PyObject *ww = w;
478467

479-
if (wsign < 0) {
480-
ww = PyNumber_Negative(w);
481-
if (ww == NULL)
482-
goto Error;
468+
fracpart = modf(i, &intpart);
469+
if (fracpart != 0.0) {
470+
switch (op) {
471+
/* Non-integer float never equals to an int. */
472+
case Py_EQ:
473+
Py_RETURN_FALSE;
474+
case Py_NE:
475+
Py_RETURN_TRUE;
476+
/* For non-integer float, v <= w <=> v < w.
477+
* If v > 0: trunc(v) < v < trunc(v) + 1
478+
* v < w => trunc(v) < w
479+
* trunc(v) < w => trunc(v) + 1 <= w => v < w
480+
* If v < 0: trunc(v) - 1 < v < trunc(v)
481+
* v < w => trunc(v) - 1 < w => trunc(v) <= w
482+
* trunc(v) <= w => v < w
483+
*/
484+
case Py_LT:
485+
case Py_LE:
486+
op = vsign > 0 ? Py_LT : Py_LE;
487+
break;
488+
/* The same as above, but with opposite directions. */
489+
case Py_GT:
490+
case Py_GE:
491+
op = vsign > 0 ? Py_GE : Py_GT;
492+
break;
493+
}
483494
}
484-
else
485-
Py_INCREF(ww);
486495

487-
fracpart = modf(i, &intpart);
488496
vv = PyLong_FromDouble(intpart);
489497
if (vv == NULL)
490498
goto Error;
491499

492-
if (fracpart != 0.0) {
493-
/* Shift left, and or a 1 bit into vv
494-
* to represent the lost fraction.
495-
*/
496-
PyObject *temp;
497-
498-
temp = _PyLong_Lshift(ww, 1);
499-
if (temp == NULL)
500-
goto Error;
501-
Py_SETREF(ww, temp);
502-
503-
temp = _PyLong_Lshift(vv, 1);
504-
if (temp == NULL)
505-
goto Error;
506-
Py_SETREF(vv, temp);
507-
508-
temp = PyNumber_Or(vv, _PyLong_GetOne());
509-
if (temp == NULL)
510-
goto Error;
511-
Py_SETREF(vv, temp);
512-
}
513-
514-
r = PyObject_RichCompareBool(vv, ww, op);
500+
r = PyObject_RichCompareBool(vv, w, op);
515501
if (r < 0)
516502
goto Error;
517503
result = PyBool_FromLong(r);
518504
Error:
519505
Py_XDECREF(vv);
520-
Py_XDECREF(ww);
521506
return result;
522507
}
523508
} /* else if (PyLong_Check(w)) */

0 commit comments

Comments
 (0)