Commit 2812f9b2 authored by Stefan Behnel's avatar Stefan Behnel

clean up code generation in CmpNode and fix bug #784: conflicting result types...

clean up code generation in CmpNode and fix bug #784: conflicting result types for cascaded comparisons
parent 2914af2e
...@@ -8515,6 +8515,9 @@ richcmp_constants = { ...@@ -8515,6 +8515,9 @@ richcmp_constants = {
"<>": "Py_NE", "<>": "Py_NE",
">" : "Py_GT", ">" : "Py_GT",
">=": "Py_GE", ">=": "Py_GE",
# the following are faked by special compare functions
"in" : "Py_EQ",
"not_in": "Py_NE",
} }
class CmpNode(object): class CmpNode(object):
...@@ -8522,6 +8525,7 @@ class CmpNode(object): ...@@ -8522,6 +8525,7 @@ class CmpNode(object):
# and CascadedCmpNodes. # and CascadedCmpNodes.
special_bool_cmp_function = None special_bool_cmp_function = None
special_bool_cmp_utility_code = None
def infer_type(self, env): def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable). # TODO: Actually implement this (after merging with -unstable).
...@@ -8701,33 +8705,50 @@ class CmpNode(object): ...@@ -8701,33 +8705,50 @@ class CmpNode(object):
and not container_type.is_string and not container_type.is_string
def find_special_bool_compare_function(self, env, operand1): def find_special_bool_compare_function(self, env, operand1):
# note: currently operand1 must get coerced to a Python object if we succeed here!
if self.operator in ('==', '!='): if self.operator in ('==', '!='):
type1, type2 = operand1.type, self.operand2.type type1, type2 = operand1.type, self.operand2.type
if type1.is_pyobject and type2.is_pyobject: if type1.is_pyobject and type2.is_pyobject:
if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type: if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type:
env.use_utility_code(UtilityCode.load_cached("UnicodeEquals", "StringTools.c")) self.special_bool_cmp_utility_code = UtilityCode.load_cached("UnicodeEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals" self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals"
return True return True
elif type1 is Builtin.bytes_type or type2 is Builtin.bytes_type: elif type1 is Builtin.bytes_type or type2 is Builtin.bytes_type:
env.use_utility_code(UtilityCode.load_cached("BytesEquals", "StringTools.c")) self.special_bool_cmp_utility_code = UtilityCode.load_cached("BytesEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyBytes_Equals" self.special_bool_cmp_function = "__Pyx_PyBytes_Equals"
return True return True
elif type1 is Builtin.str_type or type2 is Builtin.str_type: elif type1 is Builtin.str_type or type2 is Builtin.str_type:
env.use_utility_code(UtilityCode.load_cached("StrEquals", "StringTools.c")) self.special_bool_cmp_utility_code = UtilityCode.load_cached("StrEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyString_Equals" self.special_bool_cmp_function = "__Pyx_PyString_Equals"
return True return True
elif self.operator in ('in', 'not_in'):
if self.operand2.type is Builtin.dict_type:
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
self.special_bool_cmp_utility_code = UtilityCode.load_cached("PyDictContains", "ObjectHandling.c")
self.special_bool_cmp_function = "__Pyx_PyDict_Contains"
return True
elif self.operand2.type.is_pyobject:
self.special_bool_cmp_utility_code = UtilityCode.load_cached("PySequenceContains", "ObjectHandling.c")
self.special_bool_cmp_function = "__Pyx_PySequence_Contains"
return True
return False return False
def generate_operation_code(self, code, result_code, def generate_operation_code(self, code, result_code,
operand1, op , operand2): operand1, op , operand2):
if self.type.is_pyobject: if self.type.is_pyobject:
error_clause = code.error_goto_if_null
got_ref = "__Pyx_XGOTREF(%s); " % result_code
if self.special_bool_cmp_function:
code.globalstate.use_utility_code(
UtilityCode.load_cached("PyBoolOrNullFromLong", "ObjectHandling.c"))
coerce_result = "__Pyx_PyBoolOrNull_FromLong"
else:
coerce_result = "__Pyx_PyBool_FromLong" coerce_result = "__Pyx_PyBool_FromLong"
else: else:
error_clause = code.error_goto_if_neg
got_ref = ""
coerce_result = "" coerce_result = ""
if 'not' in op:
negation = "!"
else:
negation = ""
if self.special_bool_cmp_function: if self.special_bool_cmp_function:
if operand1.type.is_pyobject: if operand1.type.is_pyobject:
result1 = operand1.py_result() result1 = operand1.py_result()
...@@ -8737,60 +8758,36 @@ class CmpNode(object): ...@@ -8737,60 +8758,36 @@ class CmpNode(object):
result2 = operand2.py_result() result2 = operand2.py_result()
else: else:
result2 = operand2.result() result2 = operand2.result()
code.putln("%s = %s(%s, %s, %s); %s" % ( if self.special_bool_cmp_utility_code:
result_code, code.globalstate.use_utility_code(self.special_bool_cmp_utility_code)
self.special_bool_cmp_function,
result1,
result2,
richcmp_constants[op],
code.error_goto_if_neg(result_code, self.pos)))
elif op == 'in' or op == 'not_in':
code.globalstate.use_utility_code(contains_utility_code)
if self.type.is_pyobject:
coerce_result = "__Pyx_PyBoolOrNull_FromLong"
if op == 'not_in':
negation = "__Pyx_NegateNonNeg"
if operand2.type is dict_type:
method = "PyDict_Contains"
else:
method = "PySequence_Contains"
if self.type.is_pyobject:
error_clause = code.error_goto_if_null
got_ref = "__Pyx_XGOTREF(%s); " % result_code
else:
error_clause = code.error_goto_if_neg
got_ref = ""
code.putln( code.putln(
"%s = %s(%s(%s(%s, %s))); %s%s" % ( "%s = %s(%s(%s, %s, %s)); %s%s" % (
result_code, result_code,
coerce_result, coerce_result,
negation, self.special_bool_cmp_function,
method, result1, result2, richcmp_constants[op],
operand2.py_result(),
operand1.py_result(),
got_ref, got_ref,
error_clause(result_code, self.pos))) error_clause(result_code, self.pos)))
elif (operand1.type.is_pyobject
and op not in ('is', 'is_not')): elif operand1.type.is_pyobject and op not in ('is', 'is_not'):
code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s" % ( assert op not in ('in', 'not_in'), op
code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s%s" % (
result_code, result_code,
operand1.py_result(), operand1.py_result(),
operand2.py_result(), operand2.py_result(),
richcmp_constants[op], richcmp_constants[op],
code.error_goto_if_null(result_code, self.pos))) got_ref,
code.put_gotref(result_code) error_clause(result_code, self.pos)))
elif operand1.type.is_complex: elif operand1.type.is_complex:
if op == "!=":
negation = "!"
else:
negation = ""
code.putln("%s = %s(%s%s(%s, %s));" % ( code.putln("%s = %s(%s%s(%s, %s));" % (
result_code, result_code,
coerce_result, coerce_result,
negation, op == "!=" and "!" or "",
operand1.type.unary_op('eq'), operand1.type.unary_op('eq'),
operand1.result(), operand1.result(),
operand2.result())) operand2.result()))
else: else:
type1 = operand1.type type1 = operand1.type
type2 = operand2.type type2 = operand2.type
...@@ -8818,17 +8815,6 @@ class CmpNode(object): ...@@ -8818,17 +8815,6 @@ class CmpNode(object):
else: else:
return op return op
contains_utility_code = UtilityCode(
proto="""
static CYTHON_INLINE int __Pyx_NegateNonNeg(int b) {
return unlikely(b < 0) ? b : !b;
}
static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
return unlikely(b < 0) ? NULL : __Pyx_PyBool_FromLong(b);
}
""")
class PrimaryCmpNode(ExprNode, CmpNode): class PrimaryCmpNode(ExprNode, CmpNode):
# Non-cascaded comparison or first comparison of # Non-cascaded comparison or first comparison of
# a cascaded sequence. # a cascaded sequence.
...@@ -8900,12 +8886,17 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -8900,12 +8886,17 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.type = PyrexTypes.c_bint_type self.type = PyrexTypes.c_bint_type
# Will be transformed by IterationTransform # Will be transformed by IterationTransform
return return
elif self.find_special_bool_compare_function(env, self.operand1):
if not self.operand1.type.is_pyobject:
self.operand1 = self.operand1.coerce_to_pyobject(env)
common_type = None # if coercion needed, the method call above has already done it
self.is_pycmp = False # result is bint
else: else:
if self.operand2.type is dict_type:
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
common_type = py_object_type common_type = py_object_type
self.is_pycmp = True self.is_pycmp = True
elif self.find_special_bool_compare_function(env, self.operand1): elif self.find_special_bool_compare_function(env, self.operand1):
if not self.operand1.type.is_pyobject:
self.operand1 = self.operand1.coerce_to_pyobject(env)
common_type = None # if coercion needed, the method call above has already done it common_type = None # if coercion needed, the method call above has already done it
self.is_pycmp = False # result is bint self.is_pycmp = False # result is bint
else: else:
...@@ -8918,8 +8909,8 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -8918,8 +8909,8 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.coerce_operands_to(common_type, env) self.coerce_operands_to(common_type, env)
if self.cascade: if self.cascade:
self.operand2 = self.operand2.coerce_to_simple(env) self.operand2 = self.cascade.optimise_comparison(
self.cascade.optimise_comparison(env, self.operand2) self.operand2.coerce_to_simple(env), env)
self.cascade.coerce_cascaded_operands_to_temp(env) self.cascade.coerce_cascaded_operands_to_temp(env)
if self.is_python_result(): if self.is_python_result():
self.type = PyrexTypes.py_object_type self.type = PyrexTypes.py_object_type
...@@ -9084,10 +9075,13 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -9084,10 +9075,13 @@ class CascadedCmpNode(Node, CmpNode):
def has_python_operands(self): def has_python_operands(self):
return self.operand2.type.is_pyobject return self.operand2.type.is_pyobject
def optimise_comparison(self, env, operand1): def optimise_comparison(self, operand1, env):
self.find_special_bool_compare_function(env, operand1) if self.find_special_bool_compare_function(env, operand1):
if not operand1.type.is_pyobject:
operand1 = operand1.coerce_to_pyobject(env)
if self.cascade: if self.cascade:
self.cascade.optimise_comparison(env, self.operand2) self.operand2 = self.cascade.optimise_comparison(self.operand2, env)
return operand1
def coerce_operands_to_pyobjects(self, env): def coerce_operands_to_pyobjects(self, env):
self.operand2 = self.operand2.coerce_to_pyobject(env) self.operand2 = self.operand2.coerce_to_pyobject(env)
......
...@@ -575,3 +575,23 @@ static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { ...@@ -575,3 +575,23 @@ static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
#else #else
#define __Pyx_PyCallable_Check(obj) PyCallable_Check(obj) #define __Pyx_PyCallable_Check(obj) PyCallable_Check(obj)
#endif #endif
/////////////// PyDictContains.proto ///////////////
static CYTHON_INLINE int __Pyx_PyDict_Contains(PyObject* item, PyObject* dict, int eq) {
int result = PyDict_Contains(dict, item);
return unlikely(result < 0) ? result : (result == (eq == Py_EQ));
}
/////////////// PySequenceContains.proto ///////////////
static CYTHON_INLINE int __Pyx_PySequence_Contains(PyObject* item, PyObject* seq, int eq) {
int result = PySequence_Contains(seq, item);
return unlikely(result < 0) ? result : (result == (eq == Py_EQ));
}
/////////////// PyBoolOrNullFromLong.proto ///////////////
static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
return unlikely(b < 0) ? NULL : __Pyx_PyBool_FromLong(b);
}
...@@ -65,7 +65,6 @@ def unicode_cascade(unicode s1, unicode s2): ...@@ -65,7 +65,6 @@ def unicode_cascade(unicode s1, unicode s2):
""" """
return s1 == s2 == u"abcdefg" return s1 == s2 == u"abcdefg"
''' # NOTE: currently crashes
def unicode_cascade_untyped_end(unicode s1, unicode s2): def unicode_cascade_untyped_end(unicode s1, unicode s2):
""" """
>>> unicode_cascade_untyped_end(ustring1, ustring1) >>> unicode_cascade_untyped_end(ustring1, ustring1)
...@@ -76,7 +75,6 @@ def unicode_cascade_untyped_end(unicode s1, unicode s2): ...@@ -76,7 +75,6 @@ def unicode_cascade_untyped_end(unicode s1, unicode s2):
False False
""" """
return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1 return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1
'''
# str # str
...@@ -135,6 +133,17 @@ def str_cascade(str s1, str s2): ...@@ -135,6 +133,17 @@ def str_cascade(str s1, str s2):
""" """
return s1 == s2 == "abcdefg" return s1 == s2 == "abcdefg"
def str_cascade_untyped_end(str s1, str s2):
"""
>>> str_cascade_untyped_end(string1, string1)
True
>>> str_cascade_untyped_end(string1, (string1+string2)[:len(string1)])
True
>>> str_cascade_untyped_end(string1, string2)
False
"""
return s1 == s2 == "abcdefg" == (<object>string1) == string1
# bytes # bytes
def bytes_eq(bytes s1, bytes s2): def bytes_eq(bytes s1, bytes s2):
...@@ -191,3 +200,14 @@ def bytes_cascade(bytes s1, bytes s2): ...@@ -191,3 +200,14 @@ def bytes_cascade(bytes s1, bytes s2):
False False
""" """
return s1 == s2 == b"abcdefg" return s1 == s2 == b"abcdefg"
def bytes_cascade_untyped_end(bytes s1, bytes s2):
"""
>>> bytes_cascade_untyped_end(bstring1, bstring1)
True
>>> bytes_cascade_untyped_end(bstring1, (bstring1+bstring2)[:len(bstring1)])
True
>>> bytes_cascade_untyped_end(bstring1, bstring2)
False
"""
return s1 == s2 == b"abcdefg" == (<object>bstring1) == bstring1
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment