Commit 4d1a5241 authored by Stefan Behnel's avatar Stefan Behnel

properly propagate string comparison optimisation into cascaded comparisons

parent 4b858e41
......@@ -8696,9 +8696,9 @@ class CmpNode(object):
return (container_type.is_ptr or container_type.is_array) \
and not container_type.is_string
def find_special_bool_compare_function(self, env):
def find_special_bool_compare_function(self, env, operand1):
if self.operator in ('==', '!='):
type1, type2 = self.operand1.type, self.operand2.type
type1, type2 = operand1.type, self.operand2.type
if type1.is_pyobject and type2.is_pyobject:
if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type:
env.use_utility_code(UtilityCode.load_cached("UnicodeEquals", "StringTools.c"))
......@@ -8901,7 +8901,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
common_type = py_object_type
self.is_pycmp = True
elif self.find_special_bool_compare_function(env):
elif self.find_special_bool_compare_function(env, self.operand1):
common_type = None # if coercion needed, the method call above has already done it
self.is_pycmp = False # result is bint
self.is_temp = True # must check for error return
......@@ -8916,6 +8916,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
if self.cascade:
self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.optimise_comparison(env, self.operand2)
self.cascade.coerce_cascaded_operands_to_temp(env)
if self.is_python_result():
self.type = PyrexTypes.py_object_type
......@@ -9079,6 +9080,11 @@ class CascadedCmpNode(Node, CmpNode):
def has_python_operands(self):
return self.operand2.type.is_pyobject
def optimise_comparison(self, env, operand1):
self.find_special_bool_compare_function(env, operand1)
if self.cascade:
self.cascade.optimise_comparison(env, self.operand2)
def coerce_operands_to_pyobjects(self, env):
self.operand2 = self.operand2.coerce_to_pyobject(env)
if self.operand2.type is dict_type and self.operator in ('in', 'not_in'):
......
bstring1 = b"abcdefg"
bstring2 = b"1234567"
string1 = "abcdefg"
string2 = "1234567"
ustring1 = u"abcdefg"
ustring2 = u"1234567"
# unicode
def unicode_eq(unicode s1, unicode s2):
"""
>>> unicode_eq(ustring1, ustring1)
True
>>> unicode_eq(ustring1+ustring2, ustring1+ustring2)
True
>>> unicode_eq(ustring1, ustring2)
False
"""
return s1 == s2
def unicode_neq(unicode s1, unicode s2):
"""
>>> unicode_neq(ustring1, ustring1)
False
>>> unicode_neq(ustring1+ustring2, ustring1+ustring2)
False
>>> unicode_neq(ustring1, ustring2)
True
"""
return s1 != s2
def unicode_literal_eq(unicode s):
"""
>>> unicode_literal_eq(ustring1)
True
>>> unicode_literal_eq((ustring1+ustring2)[:len(ustring1)])
True
>>> unicode_literal_eq(ustring2)
False
"""
return s == u"abcdefg"
def unicode_literal_neq(unicode s):
"""
>>> unicode_literal_neq(ustring1)
False
>>> unicode_literal_neq((ustring1+ustring2)[:len(ustring1)])
False
>>> unicode_literal_neq(ustring2)
True
"""
return s != u"abcdefg"
def unicode_cascade(unicode s1, unicode s2):
"""
>>> unicode_cascade(ustring1, ustring1)
True
>>> unicode_cascade(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> unicode_cascade(ustring1, ustring2)
False
"""
return s1 == s2 == u"abcdefg"
''' # NOTE: currently crashes
def unicode_cascade_untyped_end(unicode s1, unicode s2):
"""
>>> unicode_cascade_untyped_end(ustring1, ustring1)
True
>>> unicode_cascade_untyped_end(ustring1, (ustring1+ustring2)[:len(ustring1)])
True
>>> unicode_cascade_untyped_end(ustring1, ustring2)
False
"""
return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1
'''
# str
def str_eq(str s1, str s2):
"""
>>> str_eq(string1, string1)
True
>>> str_eq(string1+string2, string1+string2)
True
>>> str_eq(string1, string2)
False
"""
return s1 == s2
def str_neq(str s1, str s2):
"""
>>> str_neq(string1, string1)
False
>>> str_neq(string1+string2, string1+string2)
False
>>> str_neq(string1, string2)
True
"""
return s1 != s2
def str_literal_eq(str s):
"""
>>> str_literal_eq(string1)
True
>>> str_literal_eq((string1+string2)[:len(string1)])
True
>>> str_literal_eq(string2)
False
"""
return s == "abcdefg"
def str_literal_neq(str s):
"""
>>> str_literal_neq(string1)
False
>>> str_literal_neq((string1+string2)[:len(string1)])
False
>>> str_literal_neq(string2)
True
"""
return s != "abcdefg"
def str_cascade(str s1, str s2):
"""
>>> str_cascade(string1, string1)
True
>>> str_cascade(string1, (string1+string2)[:len(string1)])
True
>>> str_cascade(string1, string2)
False
"""
return s1 == s2 == "abcdefg"
# bytes
def bytes_eq(bytes s1, bytes s2):
"""
>>> bytes_eq(bstring1, bstring1)
True
>>> bytes_eq(bstring1+bstring2, bstring1+bstring2)
True
>>> bytes_eq(bstring1, bstring2)
False
"""
return s1 == s2
def bytes_neq(bytes s1, bytes s2):
"""
>>> bytes_neq(bstring1, bstring1)
False
>>> bytes_neq(bstring1+bstring2, bstring1+bstring2)
False
>>> bytes_neq(bstring1, bstring2)
True
"""
return s1 != s2
def bytes_literal_eq(bytes s):
"""
>>> bytes_literal_eq(bstring1)
True
>>> bytes_literal_eq((bstring1+bstring2)[:len(bstring1)])
True
>>> bytes_literal_eq(bstring2)
False
"""
return s == b"abcdefg"
def bytes_literal_neq(bytes s):
"""
>>> bytes_literal_neq(bstring1)
False
>>> bytes_literal_neq((bstring1+bstring2)[:len(bstring1)])
False
>>> bytes_literal_neq(bstring2)
True
"""
return s != b"abcdefg"
def bytes_cascade(bytes s1, bytes s2):
"""
>>> bytes_cascade(bstring1, bstring1)
True
>>> bytes_cascade(bstring1, (bstring1+bstring2)[:len(bstring1)])
True
>>> bytes_cascade(bstring1, bstring2)
False
"""
return s1 == s2 == b"abcdefg"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment