Commit 52097ca3 authored by Stefan Behnel's avatar Stefan Behnel

Give equality comparisons to integer constants a dedicated implementation that...

Give equality comparisons to integer constants a dedicated implementation that minimises branching and comparisons.
parent d556575c
...@@ -3297,6 +3297,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3297,6 +3297,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
# mixed old-/new-style division is not currently optimised for integers # mixed old-/new-style division is not currently optimised for integers
return node return node
elif abs(numval.constant_result) > 2**30: elif abs(numval.constant_result) > 2**30:
# Cut off at an integer border that is still safe for all operations.
return node return node
args = list(args) args = list(args)
...@@ -3307,7 +3308,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3307,7 +3308,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace)) args.append(ExprNodes.BoolNode(node.pos, value=inplace, constant_result=inplace))
utility_code = TempitaUtilityCode.load_cached( utility_code = TempitaUtilityCode.load_cached(
"PyFloatBinop" if is_float else "PyIntBinop", "Optimize.c", "PyFloatBinop" if is_float else "PyIntCompare" if operator in ('Eq', 'Ne') else "PyIntBinop",
"Optimize.c",
context=dict(op=operator, order=arg_order, ret_type=ret_type)) context=dict(op=operator, order=arg_order, ret_type=ret_type))
call_node = self._substitute_method_call( call_node = self._substitute_method_call(
......
...@@ -698,6 +698,90 @@ fallback: ...@@ -698,6 +698,90 @@ fallback:
} }
/////////////// PyIntCompare.proto ///////////////
{{py: c_ret_type = 'PyObject*' if ret_type.is_pyobject else 'int'}}
static CYTHON_INLINE {{c_ret_type}} __Pyx_PyInt_{{'' if ret_type.is_pyobject else 'Bool'}}{{op}}{{order}}(PyObject *op1, PyObject *op2, long intval, long inplace); /*proto*/
/////////////// PyIntCompare ///////////////
{{py: pyval, ival = ('op2', 'b') if order == 'CObj' else ('op1', 'a') }}
{{py: c_ret_type = 'PyObject*' if ret_type.is_pyobject else 'int'}}
{{py: return_true = 'Py_RETURN_TRUE' if ret_type.is_pyobject else 'return 1'}}
{{py: return_false = 'Py_RETURN_FALSE' if ret_type.is_pyobject else 'return 0'}}
{{py: slot_name = op.lower() }}
{{py: c_op = {'Eq': '==', 'Ne': '!='}[op] }}
{{py:
return_compare = (
(lambda a,b,c_op: "if ({a} {c_op} {b}) {return_true}; else {return_false};".format(
a=a, b=b, c_op=c_op, return_true=return_true, return_false=return_false))
if ret_type.is_pyobject else
(lambda a,b,c_op: "return ({a} {c_op} {b});".format(a=a, b=b, c_op=c_op))
)
}}
static CYTHON_INLINE {{c_ret_type}} __Pyx_PyInt_{{'' if ret_type.is_pyobject else 'Bool'}}{{op}}{{order}}(PyObject *op1, PyObject *op2, CYTHON_UNUSED long intval, CYTHON_UNUSED long inplace) {
if (op1 == op2) {
{{return_true if op == 'Eq' else return_false}};
}
#if PY_MAJOR_VERSION < 3
if (likely(PyInt_CheckExact({{pyval}}))) {
const long {{'a' if order == 'CObj' else 'b'}} = intval;
long {{ival}} = PyInt_AS_LONG({{pyval}});
{{return_compare('a', 'b', c_op)}}
}
#endif
#if CYTHON_USE_PYLONG_INTERNALS
if (likely(PyLong_CheckExact({{pyval}}))) {
int unequal;
unsigned long uintval;
Py_ssize_t size = Py_SIZE({{pyval}});
const digit* digits = ((PyLongObject*){{pyval}})->ob_digit;
if (intval == 0) {
// == 0 => Py_SIZE(pyval) == 0
{{return_compare('size', '0', c_op)}}
} else if (intval < 0) {
// < 0 => Py_SIZE(pyval) < 0
if (size >= 0)
{{return_false if op == 'Eq' else return_true}};
// both are negative => can use absolute values now.
intval = -intval;
size = -size;
} else {
// > 0 => Py_SIZE(pyval) > 0
if (size <= 0)
{{return_false if op == 'Eq' else return_true}};
}
// After checking that the sign is the same (and excluding 0), now compare the absolute values.
// When inlining, the C compiler should select exactly one line from this unrolled loop.
uintval = (unsigned long) intval;
if ((0));
{{for _size in range(4, 1, -1)}}
#if PyLong_SHIFT * {{_size}} < SIZEOF_LONG*8
else if (uintval >= {{_size-1}}UL * (unsigned long) PyLong_BASE)
unequal = (size != {{_size}}) || (digits[0] != (uintval & PyLong_MASK))
{{for _i in range(1, _size)}} | (digits[{{_i}}] != ((uintval >> ({{_i}} * PyLong_SHIFT)) & PyLong_MASK)){{endfor}};
#endif
{{endfor}}
else unequal = (size != 1) || (digits[0] != (uintval & PyLong_MASK));
{{return_compare('unequal', '0', c_op)}}
}
#endif
if (PyFloat_CheckExact({{pyval}})) {
const long {{'a' if order == 'CObj' else 'b'}} = intval;
double {{ival}} = PyFloat_AS_DOUBLE({{pyval}});
{{return_compare('(double)a', '(double)b', c_op)}}
}
return {{'' if ret_type.is_pyobject else '__Pyx_PyObject_IsTrueAndDecref'}}(
PyObject_RichCompare(op1, op2, Py_{{op.upper()}}));
}
/////////////// PyIntBinop.proto /////////////// /////////////// PyIntBinop.proto ///////////////
{{py: c_ret_type = 'PyObject*' if ret_type.is_pyobject else 'int'}} {{py: c_ret_type = 'PyObject*' if ret_type.is_pyobject else 'int'}}
......
...@@ -238,7 +238,10 @@ def mixed_int(obj2): ...@@ -238,7 +238,10 @@ def mixed_int(obj2):
@cython.test_assert_path_exists('//PythonCapiCallNode') @cython.test_assert_path_exists('//PythonCapiCallNode')
@cython.test_fail_if_path_exists('//IntBinopNode') @cython.test_fail_if_path_exists(
'//IntBinopNode',
'//PrimaryCmpNode',
)
def equals(obj2): def equals(obj2):
""" """
>>> equals(2) >>> equals(2)
...@@ -253,7 +256,10 @@ def equals(obj2): ...@@ -253,7 +256,10 @@ def equals(obj2):
@cython.test_assert_path_exists('//PythonCapiCallNode') @cython.test_assert_path_exists('//PythonCapiCallNode')
@cython.test_fail_if_path_exists('//IntBinopNode') @cython.test_fail_if_path_exists(
'//IntBinopNode',
'//PrimaryCmpNode',
)
def not_equals(obj2): def not_equals(obj2):
""" """
>>> not_equals(2) >>> not_equals(2)
...@@ -268,7 +274,182 @@ def not_equals(obj2): ...@@ -268,7 +274,182 @@ def not_equals(obj2):
@cython.test_assert_path_exists('//PythonCapiCallNode') @cython.test_assert_path_exists('//PythonCapiCallNode')
@cython.test_fail_if_path_exists('//IntBinopNode') @cython.test_assert_path_exists('//PrimaryCmpNode')
def equals_many(obj2):
"""
>>> equals_many(-2)
(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
>>> equals_many(0)
(True, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
>>> equals_many(1)
(False, True, False, False, False, False, False, False, False, False, False, False, False, False, False)
>>> equals_many(-1)
(False, False, True, False, False, False, False, False, False, False, False, False, False, False, False)
>>> equals_many(2**30)
(False, False, False, True, False, False, False, False, False, False, False, False, False, False, False)
>>> equals_many(-2**30)
(False, False, False, False, True, False, False, False, False, False, False, False, False, False, False)
>>> equals_many(2**30-1)
(False, False, False, False, False, True, False, False, False, False, False, False, False, False, False)
>>> equals_many(-2**30+1)
(False, False, False, False, False, False, True, False, False, False, False, False, False, False, False)
>>> equals_many(2**32)
(False, False, False, False, False, False, False, True, False, False, False, False, False, False, False)
>>> equals_many(-2**32)
(False, False, False, False, False, False, False, False, True, False, False, False, False, False, False)
>>> equals_many(2**45-1)
(False, False, False, False, False, False, False, False, False, True, False, False, False, False, False)
>>> equals_many(-2**45+1)
(False, False, False, False, False, False, False, False, False, False, True, False, False, False, False)
>>> equals_many(2**64)
(False, False, False, False, False, False, False, False, False, False, False, True, False, False, False)
>>> equals_many(-2**64)
(False, False, False, False, False, False, False, False, False, False, False, False, True, False, False)
>>> equals_many(2**64-1)
(False, False, False, False, False, False, False, False, False, False, False, False, False, True, False)
>>> equals_many(-2**64+1)
(False, False, False, False, False, False, False, False, False, False, False, False, False, False, True)
"""
cdef bint x, a, b, c, d, e, f, g, h, i, j, k, l, m, n, o
a = obj2 == 0
x = 0 == obj2
assert a == x
b = obj2 == 1
x = 1 == obj2
assert b == x
c = obj2 == -1
x = -1 == obj2
assert c == x
d = obj2 == 2**30
x = 2**30 == obj2
assert d == x
e = obj2 == -2**30
x = -2**30 == obj2
assert e == x
f = obj2 == 2**30-1
x = 2**30-1 == obj2
assert f == x
g = obj2 == -2**30+1
x = -2**30+1 == obj2
assert g == x
h = obj2 == 2**32
x = 2**32 == obj2
assert h == x
i = obj2 == -2**32
x = -2**32 == obj2
assert i == x
j = obj2 == 2**45-1
x = 2**45-1 == obj2
assert j == x
k = obj2 == -2**45+1
x = -2**45+1 == obj2
assert k == x
l = obj2 == 2**64
x = 2**64 == obj2
assert l == x
m = obj2 == -2**64
x = -2**64 == obj2
assert m == x
n = obj2 == 2**64-1
x = 2**64-1 == obj2
assert n == x
o = obj2 == -2**64+1
x = -2**64+1 == obj2
assert o == x
return (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o)
@cython.test_assert_path_exists('//PythonCapiCallNode')
@cython.test_assert_path_exists('//PrimaryCmpNode')
def not_equals_many(obj2):
"""
>>> not_equals_many(-2)
(False, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
>>> not_equals_many(0)
(True, False, False, False, False, False, False, False, False, False, False, False, False, False, False)
>>> not_equals_many(1)
(False, True, False, False, False, False, False, False, False, False, False, False, False, False, False)
>>> not_equals_many(-1)
(False, False, True, False, False, False, False, False, False, False, False, False, False, False, False)
>>> not_equals_many(2**30)
(False, False, False, True, False, False, False, False, False, False, False, False, False, False, False)
>>> not_equals_many(-2**30)
(False, False, False, False, True, False, False, False, False, False, False, False, False, False, False)
>>> not_equals_many(2**30-1)
(False, False, False, False, False, True, False, False, False, False, False, False, False, False, False)
>>> not_equals_many(-2**30+1)
(False, False, False, False, False, False, True, False, False, False, False, False, False, False, False)
>>> not_equals_many(2**32)
(False, False, False, False, False, False, False, True, False, False, False, False, False, False, False)
>>> not_equals_many(-2**32)
(False, False, False, False, False, False, False, False, True, False, False, False, False, False, False)
>>> not_equals_many(2**45-1)
(False, False, False, False, False, False, False, False, False, True, False, False, False, False, False)
>>> not_equals_many(-2**45+1)
(False, False, False, False, False, False, False, False, False, False, True, False, False, False, False)
>>> not_equals_many(2**64)
(False, False, False, False, False, False, False, False, False, False, False, True, False, False, False)
>>> not_equals_many(-2**64)
(False, False, False, False, False, False, False, False, False, False, False, False, True, False, False)
>>> not_equals_many(2**64-1)
(False, False, False, False, False, False, False, False, False, False, False, False, False, True, False)
>>> not_equals_many(-2**64+1)
(False, False, False, False, False, False, False, False, False, False, False, False, False, False, True)
"""
cdef bint a, b, c, d, e, f, g, h, i, j, k, l, m, n, o
a = obj2 != 0
x = 0 != obj2
assert a == x
b = obj2 != 1
x = 1 != obj2
assert b == x
c = obj2 != -1
x = -1 != obj2
assert c == x
d = obj2 != 2**30
x = 2**30 != obj2
assert d == x
e = obj2 != -2**30
x = -2**30 != obj2
assert e == x
f = obj2 != 2**30-1
x = 2**30-1 != obj2
assert f == x
g = obj2 != -2**30+1
x = -2**30+1 != obj2
assert g == x
h = obj2 != 2**32
x = 2**32 != obj2
assert h == x
i = obj2 != -2**32
x = -2**32 != obj2
assert i == x
j = obj2 != 2**45-1
x = 2**45-1 != obj2
assert j == x
k = obj2 != -2**45+1
x = -2**45+1 != obj2
assert k == x
l = obj2 != 2**64
x = 2**64 != obj2
assert l == x
m = obj2 != -2**64
x = -2**64 != obj2
assert m == x
n = obj2 != 2**64-1
x = 2**64-1 != obj2
assert n == x
o = obj2 != -2**64+1
x = -2**64+1 != obj2
assert o == x
return tuple(not x for x in (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o))
@cython.test_assert_path_exists('//PythonCapiCallNode')
@cython.test_fail_if_path_exists(
'//IntBinopNode',
'//PrimaryCmpNode',
)
def equals_zero(obj2): def equals_zero(obj2):
""" """
>>> equals_zero(2) >>> equals_zero(2)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment