Commit 4a25f1c1 authored by Stefan Behnel's avatar Stefan Behnel

optimise string-to-something comparisons also when we know that the result is...

optimise string-to-something comparisons also when we know that the result is boolean (the cmp functions have a proper fallback for non-strings)
parent 4227971d
......@@ -9930,11 +9930,11 @@ class CmpNode(object):
return (container_type.is_ptr or container_type.is_array) \
and not container_type.is_string
def find_special_bool_compare_function(self, env, operand1):
def find_special_bool_compare_function(self, env, operand1, result_is_bool=False):
# note: currently operand1 must get coerced to a Python object if we succeed here!
if self.operator in ('==', '!='):
type1, type2 = operand1.type, self.operand2.type
if type1.is_builtin_type and type2.is_builtin_type:
if result_is_bool or (type1.is_builtin_type and type2.is_builtin_type):
if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type:
self.special_bool_cmp_utility_code = UtilityCode.load_cached("UnicodeEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals"
......@@ -10184,6 +10184,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
else:
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.is_pycmp = False
self.type = func_type.return_type
def analyse_memoryviewslice_comparison(self, env):
......@@ -10199,6 +10200,23 @@ class PrimaryCmpNode(ExprNode, CmpNode):
return False
def coerce_to_boolean(self, env):
if self.is_pycmp:
# coercing to bool => may allow for more efficient comparison code
if self.find_special_bool_compare_function(
env, self.operand1, result_is_bool=True):
self.is_pycmp = False
self.type = PyrexTypes.c_bint_type
self.is_temp = 1
if self.cascade:
operand2 = self.cascade.optimise_comparison(
self.operand2, env, result_is_bool=True)
if operand2 is not self.operand2:
self.coerced_operand2 = operand2
return self
# TODO: check if we can optimise parts of the cascade here
return ExprNode.coerce_to_boolean(self, env)
def has_python_operands(self):
return (self.operand1.type.is_pyobject
or self.operand2.type.is_pyobject)
......@@ -10320,12 +10338,14 @@ class CascadedCmpNode(Node, CmpNode):
def has_python_operands(self):
return self.operand2.type.is_pyobject
def optimise_comparison(self, operand1, env):
if self.find_special_bool_compare_function(env, operand1):
def optimise_comparison(self, operand1, env, result_is_bool=False):
if self.find_special_bool_compare_function(env, operand1, result_is_bool):
self.is_pycmp = False
self.type = PyrexTypes.c_bint_type
if not operand1.type.is_pyobject:
operand1 = operand1.coerce_to_pyobject(env)
if self.cascade:
operand2 = self.cascade.optimise_comparison(self.operand2, env)
operand2 = self.cascade.optimise_comparison(self.operand2, env, result_is_bool)
if operand2 is not self.operand2:
self.coerced_operand2 = operand2
return operand1
......
......@@ -15,6 +15,10 @@ ustring2 = u"1234567"
# unicode
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def unicode_eq(unicode s1, unicode s2):
"""
>>> unicode_eq(ustring1, ustring1)
......@@ -26,6 +30,10 @@ def unicode_eq(unicode s1, unicode s2):
"""
return s1 == s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def unicode_neq(unicode s1, unicode s2):
"""
>>> unicode_neq(ustring1, ustring1)
......@@ -37,6 +45,10 @@ def unicode_neq(unicode s1, unicode s2):
"""
return s1 != s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def unicode_literal_eq(unicode s):
"""
>>> unicode_literal_eq(ustring1)
......@@ -48,6 +60,10 @@ def unicode_literal_eq(unicode s):
"""
return s == u"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def unicode_literal_neq(unicode s):
"""
>>> unicode_literal_neq(ustring1)
......@@ -59,6 +75,15 @@ def unicode_literal_neq(unicode s):
"""
return s != u"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
"//CascadedCmpNode"
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = True]",
)
def unicode_cascade(unicode s1, unicode s2):
"""
>>> unicode_cascade(ustring1, ustring1)
......@@ -70,6 +95,10 @@ def unicode_cascade(unicode s1, unicode s2):
"""
return s1 == s2 == u"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def unicode_cascade_untyped_end(unicode s1, unicode s2):
"""
>>> unicode_cascade_untyped_end(ustring1, ustring1)
......@@ -81,8 +110,31 @@ def unicode_cascade_untyped_end(unicode s1, unicode s2):
"""
return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def unicode_cascade_untyped_end_bool(unicode s1, unicode s2):
"""
>>> unicode_cascade_untyped_end_bool(ustring1, ustring1)
True
>>> unicode_cascade_untyped_end_bool(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> unicode_cascade_untyped_end_bool(ustring1, ustring2)
False
"""
if s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1:
return True
else:
return False
# str
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def str_eq(str s1, str s2):
"""
>>> str_eq(string1, string1)
......@@ -94,6 +146,10 @@ def str_eq(str s1, str s2):
"""
return s1 == s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def str_neq(str s1, str s2):
"""
>>> str_neq(string1, string1)
......@@ -105,6 +161,10 @@ def str_neq(str s1, str s2):
"""
return s1 != s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def str_literal_eq(str s):
"""
>>> str_literal_eq(string1)
......@@ -116,6 +176,10 @@ def str_literal_eq(str s):
"""
return s == "abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def str_literal_neq(str s):
"""
>>> str_literal_neq(string1)
......@@ -127,6 +191,14 @@ def str_literal_neq(str s):
"""
return s != "abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = True]",
)
def str_cascade(str s1, str s2):
"""
>>> str_cascade(string1, string1)
......@@ -138,6 +210,10 @@ def str_cascade(str s1, str s2):
"""
return s1 == s2 == "abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def str_cascade_untyped_end(str s1, str s2):
"""
>>> str_cascade_untyped_end(string1, string1)
......@@ -151,6 +227,10 @@ def str_cascade_untyped_end(str s1, str s2):
# bytes
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def bytes_eq(bytes s1, bytes s2):
"""
>>> bytes_eq(bstring1, bstring1)
......@@ -162,6 +242,10 @@ def bytes_eq(bytes s1, bytes s2):
"""
return s1 == s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def bytes_neq(bytes s1, bytes s2):
"""
>>> bytes_neq(bstring1, bstring1)
......@@ -173,6 +257,10 @@ def bytes_neq(bytes s1, bytes s2):
"""
return s1 != s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def bytes_literal_eq(bytes s):
"""
>>> bytes_literal_eq(bstring1)
......@@ -184,6 +272,10 @@ def bytes_literal_eq(bytes s):
"""
return s == b"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def bytes_literal_neq(bytes s):
"""
>>> bytes_literal_neq(bstring1)
......@@ -195,6 +287,14 @@ def bytes_literal_neq(bytes s):
"""
return s != b"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = True]",
)
def bytes_cascade(bytes s1, bytes s2):
"""
>>> bytes_cascade(bstring1, bstring1)
......@@ -206,6 +306,10 @@ def bytes_cascade(bytes s1, bytes s2):
"""
return s1 == s2 == b"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def bytes_cascade_untyped_end(bytes s1, bytes s2):
"""
>>> bytes_cascade_untyped_end(bstring1, bstring1)
......@@ -218,6 +322,288 @@ def bytes_cascade_untyped_end(bytes s1, bytes s2):
return s1 == s2 == b"abcdefg" == (<object>bstring1) == bstring1
# basestring
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def basestring_eq(basestring s1, basestring s2):
"""
>>> basestring_eq(string1, string1)
True
>>> basestring_eq(string1, ustring1)
True
>>> basestring_eq(string1+string2, string1+string2)
True
>>> basestring_eq(string1+ustring2, ustring1+string2)
True
>>> basestring_eq(string1, string2)
False
>>> basestring_eq(string1, ustring2)
False
>>> basestring_eq(ustring1, string2)
False
"""
return s1 == s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def basestring_neq(basestring s1, basestring s2):
"""
>>> basestring_neq(string1, string1)
False
>>> basestring_neq(string1+string2, string1+string2)
False
>>> basestring_neq(string1+ustring2, ustring1+string2)
False
>>> basestring_neq(string1, string2)
True
>>> basestring_neq(string1, ustring2)
True
>>> basestring_neq(ustring1, string2)
True
"""
return s1 != s2
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def basestring_str_literal_eq(basestring s):
"""
>>> basestring_str_literal_eq(string1)
True
>>> basestring_str_literal_eq((string1+string2)[:len(string1)])
True
>>> basestring_str_literal_eq(string2)
False
"""
return s == "abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def basestring_unicode_literal_eq(basestring s):
"""
>>> basestring_unicode_literal_eq(string1)
True
>>> basestring_unicode_literal_eq((string1+string2)[:len(string1)])
True
>>> basestring_unicode_literal_eq(string2)
False
"""
return s == u"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def basestring_str_literal_neq(basestring s):
"""
>>> basestring_str_literal_neq(string1)
False
>>> basestring_str_literal_neq((string1+string2)[:len(string1)])
False
>>> basestring_str_literal_neq(string2)
True
"""
return s != "abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def basestring_unicode_literal_neq(basestring s):
"""
>>> basestring_unicode_literal_neq(string1)
False
>>> basestring_unicode_literal_neq((string1+string2)[:len(string1)])
False
>>> basestring_unicode_literal_neq(string2)
True
"""
return s != u"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
"//CascadedCmpNode[@is_pycmp = False]",
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = True]",
)
def basestring_cascade_str(basestring s1, basestring s2):
"""
>>> basestring_cascade_str(string1, string1)
True
>>> basestring_cascade_str(string1, (string1+string2)[:len(string1)])
True
>>> basestring_cascade_str(string1, string2)
False
"""
return s1 == s2 == "abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
"//CascadedCmpNode[@is_pycmp = False]",
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = True]",
)
def basestring_cascade_unicode(basestring s1, basestring s2):
"""
>>> basestring_cascade_unicode(string1, string1)
True
>>> basestring_cascade_unicode(ustring1, string1)
True
>>> basestring_cascade_unicode(string1, ustring1)
True
>>> basestring_cascade_unicode(string1, (string1+string2)[:len(string1)])
True
>>> basestring_cascade_unicode(string1, string2)
False
>>> basestring_cascade_unicode(ustring1, string2)
False
>>> basestring_cascade_unicode(string1, ustring2)
False
"""
return s1 == s2 == u"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def basestring_cascade_untyped_end(basestring s1, basestring s2):
"""
>>> basestring_cascade_untyped_end(string1, string1)
True
>>> basestring_cascade_untyped_end(string1, (string1+string2)[:len(string1)])
True
>>> basestring_cascade_untyped_end(string1, string2)
False
"""
return s1 == s2 == "abcdefg" == (<object>string1) == string1
# untyped/literal comparison
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def untyped_unicode_literal_eq_bool(s):
"""
>>> untyped_unicode_literal_eq_bool(string1)
True
>>> untyped_unicode_literal_eq_bool(ustring1)
True
>>> untyped_unicode_literal_eq_bool((string1+string2)[:len(string1)])
True
>>> untyped_unicode_literal_eq_bool(string2)
False
>>> untyped_unicode_literal_eq_bool(ustring2)
False
"""
return True if s == u"abcdefg" else False
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def untyped_str_literal_eq_bool(s):
"""
>>> untyped_str_literal_eq_bool(string1)
True
>>> untyped_str_literal_eq_bool(ustring1)
True
>>> untyped_str_literal_eq_bool((string1+string2)[:len(string1)])
True
>>> untyped_str_literal_eq_bool(string2)
False
>>> untyped_str_literal_eq_bool(ustring2)
False
"""
return True if s == "abcdefg" else False
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = True]",
"//CascadedCmpNode",
"//CascadedCmpNode[@is_pycmp = False]",
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def untyped_unicode_cascade(s1, unicode s2):
"""
>>> untyped_unicode_cascade(ustring1, ustring1)
True
>>> untyped_unicode_cascade(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> untyped_unicode_cascade(ustring1, ustring2)
False
"""
return s1 == s2 == u"abcdefg"
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = False]",
"//CascadedCmpNode",
"//CascadedCmpNode[@is_pycmp = False]",
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = True]",
)
def untyped_unicode_cascade_bool(s1, unicode s2):
"""
>>> untyped_unicode_cascade_bool(ustring1, ustring1)
True
>>> untyped_unicode_cascade_bool(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> untyped_unicode_cascade_bool(ustring1, ustring2)
False
"""
return True if s1 == s2 == u"abcdefg" else False
@cython.test_assert_path_exists(
"//PrimaryCmpNode",
"//PrimaryCmpNode[@is_pycmp = True]",
"//CascadedCmpNode",
# "//CascadedCmpNode[@is_pycmp = False]",
)
@cython.test_fail_if_path_exists(
"//CascadedCmpNode[@is_pycmp = True]",
"//PrimaryCmpNode[@is_pycmp = False]",
)
def untyped_untyped_unicode_cascade_bool(s1, s2):
"""
>>> untyped_untyped_unicode_cascade_bool(ustring1, ustring1)
True
>>> untyped_untyped_unicode_cascade_bool(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> untyped_untyped_unicode_cascade_bool(ustring1, ustring2)
False
>>> untyped_untyped_unicode_cascade_bool(string1, string2)
False
>>> untyped_untyped_unicode_cascade_bool(1, 2)
False
>>> untyped_untyped_unicode_cascade_bool(1, 1)
False
"""
return True if s1 == s2 == u"abcdefg" else False
# bytes/str comparison
@cython.test_assert_path_exists(
'//CondExprNode',
'//CondExprNode//PrimaryCmpNode',
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment