diff --git a/Include/internal/pycore_optimizer_types.h b/Include/internal/pycore_optimizer_types.h index 7e0dbddce2d6b8..e2323265798156 100644 --- a/Include/internal/pycore_optimizer_types.h +++ b/Include/internal/pycore_optimizer_types.h @@ -76,6 +76,8 @@ typedef struct { typedef enum { JIT_PRED_IS, JIT_PRED_IS_NOT, + JIT_PRED_EQ, + JIT_PRED_NE, } JitOptPredicateKind; typedef struct { diff --git a/Lib/test/test_capi/test_opt.py b/Lib/test/test_capi/test_opt.py index 7c33320e9f1785..e23ea81e7c4565 100644 --- a/Lib/test/test_capi/test_opt.py +++ b/Lib/test/test_capi/test_opt.py @@ -890,6 +890,138 @@ def testfunc(n): self.assertLessEqual(len(guard_nos_unicode_count), 1) self.assertIn("_COMPARE_OP_STR", uops) + def test_compare_int_eq_narrows_to_constant(self): + def f(n): + def return_1(): + return 1 + + hits = 0 + v = return_1() + for _ in range(n): + if v == 1: + if v == 1: + hits += 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # Constant narrowing allows constant folding for second comparison + self.assertLessEqual(count_ops(ex, "_COMPARE_OP_INT"), 1) + + def test_compare_int_ne_narrows_to_constant(self): + def f(n): + def return_1(): + return 1 + + hits = 0 + v = return_1() + for _ in range(n): + if v != 1: + hits += 1000 + else: + if v == 1: + hits += v + 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD * 2) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # Constant narrowing allows constant folding for second comparison + self.assertLessEqual(count_ops(ex, "_COMPARE_OP_INT"), 1) + + def test_compare_float_eq_narrows_to_constant(self): + def f(n): + def return_tenth(): + return 0.1 + + hits = 0 + v = return_tenth() + for _ in range(n): + if v == 0.1: + if v == 0.1: + hits += 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # Constant narrowing allows constant folding for second comparison + self.assertLessEqual(count_ops(ex, "_COMPARE_OP_FLOAT"), 1) + + def test_compare_float_ne_narrows_to_constant(self): + def f(n): + def return_tenth(): + return 0.1 + + hits = 0 + v = return_tenth() + for _ in range(n): + if v != 0.1: + hits += 1000 + else: + if v == 0.1: + hits += 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # Constant narrowing allows constant folding for second comparison + self.assertLessEqual(count_ops(ex, "_COMPARE_OP_FLOAT"), 1) + + def test_compare_str_eq_narrows_to_constant(self): + def f(n): + def return_hello(): + return "hello" + + hits = 0 + v = return_hello() + for _ in range(n): + if v == "hello": + if v == "hello": + hits += 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # Constant narrowing allows constant folding for second comparison + self.assertLessEqual(count_ops(ex, "_COMPARE_OP_STR"), 1) + + def test_compare_str_ne_narrows_to_constant(self): + def f(n): + def return_hello(): + return "hello" + + hits = 0 + v = return_hello() + for _ in range(n): + if v != "hello": + hits += 1000 + else: + if v == "hello": + hits += 1 + return hits + + res, ex = self._run_with_optimizer(f, TIER2_THRESHOLD) + self.assertEqual(res, TIER2_THRESHOLD) + self.assertIsNotNone(ex) + uops = get_opnames(ex) + + # Constant narrowing allows constant folding for second comparison + self.assertLessEqual(count_ops(ex, "_COMPARE_OP_STR"), 1) + @unittest.skip("gh-139109 WIP") def test_combine_stack_space_checks_sequential(self): def dummy12(x): diff --git a/Python/optimizer_analysis.c b/Python/optimizer_analysis.c index e4e259a81b510f..9e012eb3f64878 100644 --- a/Python/optimizer_analysis.c +++ b/Python/optimizer_analysis.c @@ -250,6 +250,11 @@ add_op(JitOptContext *ctx, _PyUOpInstruction *this_instr, #define sym_new_predicate _Py_uop_sym_new_predicate #define sym_apply_predicate_narrowing _Py_uop_sym_apply_predicate_narrowing +/* Comparison oparg masks */ +#define COMPARE_LT_MASK 2 +#define COMPARE_GT_MASK 4 +#define COMPARE_EQ_MASK 8 + #define JUMP_TO_LABEL(label) goto label; static int diff --git a/Python/optimizer_bytecodes.c b/Python/optimizer_bytecodes.c index 0ccc788dff962d..2fdc0854cc27e6 100644 --- a/Python/optimizer_bytecodes.c +++ b/Python/optimizer_bytecodes.c @@ -514,21 +514,51 @@ dummy_func(void) { } op(_COMPARE_OP_INT, (left, right -- res, l, r)) { - res = sym_new_type(ctx, &PyBool_Type); + int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK); + + if (cmp_mask == COMPARE_EQ_MASK) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ); + } + else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_NE); + } + else { + res = sym_new_type(ctx, &PyBool_Type); + } l = left; r = right; REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res); } op(_COMPARE_OP_FLOAT, (left, right -- res, l, r)) { - res = sym_new_type(ctx, &PyBool_Type); + int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK); + + if (cmp_mask == COMPARE_EQ_MASK) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ); + } + else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_NE); + } + else { + res = sym_new_type(ctx, &PyBool_Type); + } l = left; r = right; REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res); } op(_COMPARE_OP_STR, (left, right -- res, l, r)) { - res = sym_new_type(ctx, &PyBool_Type); + int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK); + + if (cmp_mask == COMPARE_EQ_MASK) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ); + } + else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_NE); + } + else { + res = sym_new_type(ctx, &PyBool_Type); + } l = left; r = right; REPLACE_OPCODE_IF_EVALUATES_PURE(left, right, res); diff --git a/Python/optimizer_cases.c.h b/Python/optimizer_cases.c.h index f62e15b987c0eb..f9cb376d4c2ed3 100644 --- a/Python/optimizer_cases.c.h +++ b/Python/optimizer_cases.c.h @@ -2107,7 +2107,16 @@ JitOptRef r; right = stack_pointer[-1]; left = stack_pointer[-2]; - res = sym_new_type(ctx, &PyBool_Type); + int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK); + if (cmp_mask == COMPARE_EQ_MASK) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ); + } + else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_NE); + } + else { + res = sym_new_type(ctx, &PyBool_Type); + } l = left; r = right; if ( @@ -2167,7 +2176,16 @@ JitOptRef r; right = stack_pointer[-1]; left = stack_pointer[-2]; - res = sym_new_type(ctx, &PyBool_Type); + int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK); + if (cmp_mask == COMPARE_EQ_MASK) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ); + } + else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_NE); + } + else { + res = sym_new_type(ctx, &PyBool_Type); + } l = left; r = right; if ( @@ -2231,7 +2249,16 @@ JitOptRef r; right = stack_pointer[-1]; left = stack_pointer[-2]; - res = sym_new_type(ctx, &PyBool_Type); + int cmp_mask = oparg & (COMPARE_LT_MASK | COMPARE_GT_MASK | COMPARE_EQ_MASK); + if (cmp_mask == COMPARE_EQ_MASK) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_EQ); + } + else if (cmp_mask == (COMPARE_LT_MASK | COMPARE_GT_MASK)) { + res = sym_new_predicate(ctx, left, right, JIT_PRED_NE); + } + else { + res = sym_new_type(ctx, &PyBool_Type); + } l = left; r = right; if ( diff --git a/Python/optimizer_symbols.c b/Python/optimizer_symbols.c index a9640aaa5072c5..51cf6e189f0f49 100644 --- a/Python/optimizer_symbols.c +++ b/Python/optimizer_symbols.c @@ -875,9 +875,11 @@ _Py_uop_sym_apply_predicate_narrowing(JitOptContext *ctx, JitOptRef ref, bool br bool narrow = false; switch(pred.kind) { + case JIT_PRED_EQ: case JIT_PRED_IS: narrow = branch_is_true; break; + case JIT_PRED_NE: case JIT_PRED_IS_NOT: narrow = !branch_is_true; break; @@ -1300,11 +1302,11 @@ _Py_uop_symbols_test(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(ignored)) TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (None)"); TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == Py_None, "predicate narrowing did not narrow subject to None"); - // Test narrowing subject to numerical constant + // Test narrowing subject to numerical constant from is comparison subject = _Py_uop_sym_new_unknown(ctx); PyObject *one_obj = PyLong_FromLong(1); JitOptRef const_one = _Py_uop_sym_new_const(ctx, one_obj); - if (PyJitRef_IsNull(subject) || PyJitRef_IsNull(const_one)) { + if (PyJitRef_IsNull(subject) || one_obj == NULL || PyJitRef_IsNull(const_one)) { goto fail; } ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_IS); @@ -1315,6 +1317,160 @@ _Py_uop_symbols_test(PyObject *Py_UNUSED(self), PyObject *Py_UNUSED(ignored)) TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (1)"); TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate narrowing did not narrow subject to 1"); + // Test narrowing subject to constant from EQ predicate for int + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_EQ); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (1)"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate narrowing did not narrow subject to 1"); + + // Resolving EQ predicate to False should not narrow subject for int + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_EQ); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)"); + + // Test narrowing subject to constant from NE predicate for int + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_NE); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (1)"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == one_obj, "predicate narrowing did not narrow subject to 1"); + + // Resolving NE predicate to true should not narrow subject for int + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_one, JIT_PRED_NE); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)"); + + // Test narrowing subject to constant from EQ predicate for float + subject = _Py_uop_sym_new_unknown(ctx); + PyObject *float_tenth_obj = PyFloat_FromDouble(0.1); + JitOptRef const_float_tenth = _Py_uop_sym_new_const(ctx, float_tenth_obj); + if (PyJitRef_IsNull(subject) || float_tenth_obj == NULL || PyJitRef_IsNull(const_float_tenth)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_EQ); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (float)"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == float_tenth_obj, "predicate narrowing did not narrow subject to 0.1"); + + // Resolving EQ predicate to False should not narrow subject for float + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_EQ); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)"); + + // Test narrowing subject to constant from NE predicate for float + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_NE); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (float)"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == float_tenth_obj, "predicate narrowing did not narrow subject to 0.1"); + + // Resolving NE predicate to true should not narrow subject for float + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_float_tenth, JIT_PRED_NE); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)"); + + // Test narrowing subject to constant from EQ predicate for str + subject = _Py_uop_sym_new_unknown(ctx); + PyObject *str_hello_obj = PyUnicode_FromString("hello"); + JitOptRef const_str_hello = _Py_uop_sym_new_const(ctx, str_hello_obj); + if (PyJitRef_IsNull(subject) || str_hello_obj == NULL || PyJitRef_IsNull(const_str_hello)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_EQ); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (str)"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == str_hello_obj, "predicate narrowing did not narrow subject to hello"); + + // Resolving EQ predicate to False should not narrow subject for str + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_EQ); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)"); + + // Test narrowing subject to constant from NE predicate for str + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_NE); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, false); + TEST_PREDICATE(_Py_uop_sym_is_const(ctx, subject), "predicate narrowing did not const-narrow subject (str)"); + TEST_PREDICATE(_Py_uop_sym_get_const(ctx, subject) == str_hello_obj, "predicate narrowing did not narrow subject to hello"); + + // Resolving NE predicate to true should not narrow subject for str + subject = _Py_uop_sym_new_unknown(ctx); + if (PyJitRef_IsNull(subject)) { + goto fail; + } + ref = _Py_uop_sym_new_predicate(ctx, subject, const_str_hello, JIT_PRED_NE); + if (PyJitRef_IsNull(ref)) { + goto fail; + } + _Py_uop_sym_apply_predicate_narrowing(ctx, ref, true); + TEST_PREDICATE(!_Py_uop_sym_is_const(ctx, subject), "predicate narrowing incorrectly narrowed subject (inverted/true)"); + val_big = PyNumber_Lshift(_PyLong_GetOne(), PyLong_FromLong(66)); if (val_big == NULL) { goto fail;