Commit 4a25f1c1 authored by Stefan Behnel's avatar Stefan Behnel

optimise string-to-something comparisons also when we know that the result is...

optimise string-to-something comparisons also when we know that the result is boolean (the cmp functions have a proper fallback for non-strings)
parent 4227971d
...@@ -9930,11 +9930,11 @@ class CmpNode(object): ...@@ -9930,11 +9930,11 @@ class CmpNode(object):
return (container_type.is_ptr or container_type.is_array) \ return (container_type.is_ptr or container_type.is_array) \
and not container_type.is_string and not container_type.is_string
def find_special_bool_compare_function(self, env, operand1): def find_special_bool_compare_function(self, env, operand1, result_is_bool=False):
# note: currently operand1 must get coerced to a Python object if we succeed here! # note: currently operand1 must get coerced to a Python object if we succeed here!
if self.operator in ('==', '!='): if self.operator in ('==', '!='):
type1, type2 = operand1.type, self.operand2.type type1, type2 = operand1.type, self.operand2.type
if type1.is_builtin_type and type2.is_builtin_type: if result_is_bool or (type1.is_builtin_type and type2.is_builtin_type):
if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type: if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type:
self.special_bool_cmp_utility_code = UtilityCode.load_cached("UnicodeEquals", "StringTools.c") self.special_bool_cmp_utility_code = UtilityCode.load_cached("UnicodeEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals" self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals"
...@@ -10184,6 +10184,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -10184,6 +10184,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
else: else:
self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env) self.operand1 = self.operand1.coerce_to(func_type.args[0].type, env)
self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env) self.operand2 = self.operand2.coerce_to(func_type.args[1].type, env)
self.is_pycmp = False
self.type = func_type.return_type self.type = func_type.return_type
def analyse_memoryviewslice_comparison(self, env): def analyse_memoryviewslice_comparison(self, env):
...@@ -10199,6 +10200,23 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -10199,6 +10200,23 @@ class PrimaryCmpNode(ExprNode, CmpNode):
return False return False
def coerce_to_boolean(self, env):
if self.is_pycmp:
# coercing to bool => may allow for more efficient comparison code
if self.find_special_bool_compare_function(
env, self.operand1, result_is_bool=True):
self.is_pycmp = False
self.type = PyrexTypes.c_bint_type
self.is_temp = 1
if self.cascade:
operand2 = self.cascade.optimise_comparison(
self.operand2, env, result_is_bool=True)
if operand2 is not self.operand2:
self.coerced_operand2 = operand2
return self
# TODO: check if we can optimise parts of the cascade here
return ExprNode.coerce_to_boolean(self, env)
def has_python_operands(self): def has_python_operands(self):
return (self.operand1.type.is_pyobject return (self.operand1.type.is_pyobject
or self.operand2.type.is_pyobject) or self.operand2.type.is_pyobject)
...@@ -10320,12 +10338,14 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -10320,12 +10338,14 @@ class CascadedCmpNode(Node, CmpNode):
def has_python_operands(self): def has_python_operands(self):
return self.operand2.type.is_pyobject return self.operand2.type.is_pyobject
def optimise_comparison(self, operand1, env): def optimise_comparison(self, operand1, env, result_is_bool=False):
if self.find_special_bool_compare_function(env, operand1): if self.find_special_bool_compare_function(env, operand1, result_is_bool):
self.is_pycmp = False
self.type = PyrexTypes.c_bint_type
if not operand1.type.is_pyobject: if not operand1.type.is_pyobject:
operand1 = operand1.coerce_to_pyobject(env) operand1 = operand1.coerce_to_pyobject(env)
if self.cascade: if self.cascade:
operand2 = self.cascade.optimise_comparison(self.operand2, env) operand2 = self.cascade.optimise_comparison(self.operand2, env, result_is_bool)
if operand2 is not self.operand2: if operand2 is not self.operand2:
self.coerced_operand2 = operand2 self.coerced_operand2 = operand2
return operand1 return operand1
......
This diff is collapsed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment