Skip to content

Commit 618a424

Browse files
[3.13] gh-143006: Fix and optimize mixed comparison of float and int (GH-143084) (GH-143624)
When comparing negative non-integer float and int with the same number of bits in the integer part, __neg__() in the int subclass returning not an int caused an assertion error. Now the integer is no longer negated. Also, reduced the number of temporary created Python objects. (cherry picked from commit 66bca38)
1 parent 57c56b0 commit 618a424

File tree

3 files changed

+52
-47
lines changed

3 files changed

+52
-47
lines changed

Lib/test/test_float.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,24 @@ class F(float, H):
618618
value = F('nan')
619619
self.assertEqual(hash(value), object.__hash__(value))
620620

621+
def test_issue_gh143006(self):
622+
# When comparing negative non-integer float and int with the
623+
# same number of bits in the integer part, __neg__() in the
624+
# int subclass returning not an int caused an assertion error.
625+
class EvilInt(int):
626+
def __neg__(self):
627+
return ""
628+
629+
i = -1 << 50
630+
f = float(i) - 0.5
631+
i = EvilInt(i)
632+
self.assertFalse(f == i)
633+
self.assertTrue(f != i)
634+
self.assertTrue(f < i)
635+
self.assertTrue(f <= i)
636+
self.assertFalse(f > i)
637+
self.assertFalse(f >= i)
638+
621639

622640
@unittest.skipUnless(hasattr(float, "__getformat__"), "requires __getformat__")
623641
class FormatFunctionsTestCase(unittest.TestCase):
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
Fix a possible assertion error when comparing negative non-integer ``float``
2+
and ``int`` with the same number of bits in the integer part.

Objects/floatobject.c

Lines changed: 32 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -469,82 +469,67 @@ float_richcompare(PyObject *v, PyObject *w, int op)
469469
assert(vsign != 0); /* if vsign were 0, then since wsign is
470470
* not 0, we would have taken the
471471
* vsign != wsign branch at the start */
472-
/* We want to work with non-negative numbers. */
473-
if (vsign < 0) {
474-
/* "Multiply both sides" by -1; this also swaps the
475-
* comparator.
476-
*/
477-
i = -i;
478-
op = _Py_SwappedOp[op];
479-
}
480-
assert(i > 0.0);
481472
(void) frexp(i, &exponent);
482473
/* exponent is the # of bits in v before the radix point;
483474
* we know that nbits (the # of bits in w) > 48 at this point
484475
*/
485476
if (exponent < 0 || (size_t)exponent < nbits) {
486-
i = 1.0;
487-
j = 2.0;
477+
j = i;
478+
i = 0.0;
488479
goto Compare;
489480
}
490481
if ((size_t)exponent > nbits) {
491-
i = 2.0;
492-
j = 1.0;
482+
j = 0.0;
493483
goto Compare;
494484
}
495485
/* v and w have the same number of bits before the radix
496-
* point. Construct two ints that have the same comparison
497-
* outcome.
486+
* point. Construct an int from the integer part of v and
487+
* update op if necessary, so comparing two ints has the same outcome.
498488
*/
499489
{
500490
double fracpart;
501491
double intpart;
502492
PyObject *result = NULL;
503493
PyObject *vv = NULL;
504-
PyObject *ww = w;
505494

506-
if (wsign < 0) {
507-
ww = PyNumber_Negative(w);
508-
if (ww == NULL)
509-
goto Error;
495+
fracpart = modf(i, &intpart);
496+
if (fracpart != 0.0) {
497+
switch (op) {
498+
/* Non-integer float never equals to an int. */
499+
case Py_EQ:
500+
Py_RETURN_FALSE;
501+
case Py_NE:
502+
Py_RETURN_TRUE;
503+
/* For non-integer float, v <= w <=> v < w.
504+
* If v > 0: trunc(v) < v < trunc(v) + 1
505+
* v < w => trunc(v) < w
506+
* trunc(v) < w => trunc(v) + 1 <= w => v < w
507+
* If v < 0: trunc(v) - 1 < v < trunc(v)
508+
* v < w => trunc(v) - 1 < w => trunc(v) <= w
509+
* trunc(v) <= w => v < w
510+
*/
511+
case Py_LT:
512+
case Py_LE:
513+
op = vsign > 0 ? Py_LT : Py_LE;
514+
break;
515+
/* The same as above, but with opposite directions. */
516+
case Py_GT:
517+
case Py_GE:
518+
op = vsign > 0 ? Py_GE : Py_GT;
519+
break;
520+
}
510521
}
511-
else
512-
Py_INCREF(ww);
513522

514-
fracpart = modf(i, &intpart);
515523
vv = PyLong_FromDouble(intpart);
516524
if (vv == NULL)
517525
goto Error;
518526

519-
if (fracpart != 0.0) {
520-
/* Shift left, and or a 1 bit into vv
521-
* to represent the lost fraction.
522-
*/
523-
PyObject *temp;
524-
525-
temp = _PyLong_Lshift(ww, 1);
526-
if (temp == NULL)
527-
goto Error;
528-
Py_SETREF(ww, temp);
529-
530-
temp = _PyLong_Lshift(vv, 1);
531-
if (temp == NULL)
532-
goto Error;
533-
Py_SETREF(vv, temp);
534-
535-
temp = PyNumber_Or(vv, _PyLong_GetOne());
536-
if (temp == NULL)
537-
goto Error;
538-
Py_SETREF(vv, temp);
539-
}
540-
541-
r = PyObject_RichCompareBool(vv, ww, op);
527+
r = PyObject_RichCompareBool(vv, w, op);
542528
if (r < 0)
543529
goto Error;
544530
result = PyBool_FromLong(r);
545531
Error:
546532
Py_XDECREF(vv);
547-
Py_XDECREF(ww);
548533
return result;
549534
}
550535
} /* else if (PyLong_Check(w)) */

0 commit comments

Comments
 (0)