Commit 2812f9b2 authored by Stefan Behnel's avatar Stefan Behnel

clean up code generation in CmpNode and fix bug #784: conflicting result types...

clean up code generation in CmpNode and fix bug #784: conflicting result types for cascaded comparisons
parent 2914af2e
......@@ -8515,6 +8515,9 @@ richcmp_constants = {
"<>": "Py_NE",
">" : "Py_GT",
">=": "Py_GE",
# the following are faked by special compare functions
"in" : "Py_EQ",
"not_in": "Py_NE",
}
class CmpNode(object):
......@@ -8522,6 +8525,7 @@ class CmpNode(object):
# and CascadedCmpNodes.
special_bool_cmp_function = None
special_bool_cmp_utility_code = None
def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable).
......@@ -8701,33 +8705,50 @@ class CmpNode(object):
and not container_type.is_string
def find_special_bool_compare_function(self, env, operand1):
# note: currently operand1 must get coerced to a Python object if we succeed here!
if self.operator in ('==', '!='):
type1, type2 = operand1.type, self.operand2.type
if type1.is_pyobject and type2.is_pyobject:
if type1 is Builtin.unicode_type or type2 is Builtin.unicode_type:
env.use_utility_code(UtilityCode.load_cached("UnicodeEquals", "StringTools.c"))
self.special_bool_cmp_utility_code = UtilityCode.load_cached("UnicodeEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyUnicode_Equals"
return True
elif type1 is Builtin.bytes_type or type2 is Builtin.bytes_type:
env.use_utility_code(UtilityCode.load_cached("BytesEquals", "StringTools.c"))
self.special_bool_cmp_utility_code = UtilityCode.load_cached("BytesEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyBytes_Equals"
return True
elif type1 is Builtin.str_type or type2 is Builtin.str_type:
env.use_utility_code(UtilityCode.load_cached("StrEquals", "StringTools.c"))
self.special_bool_cmp_utility_code = UtilityCode.load_cached("StrEquals", "StringTools.c")
self.special_bool_cmp_function = "__Pyx_PyString_Equals"
return True
elif self.operator in ('in', 'not_in'):
if self.operand2.type is Builtin.dict_type:
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
self.special_bool_cmp_utility_code = UtilityCode.load_cached("PyDictContains", "ObjectHandling.c")
self.special_bool_cmp_function = "__Pyx_PyDict_Contains"
return True
elif self.operand2.type.is_pyobject:
self.special_bool_cmp_utility_code = UtilityCode.load_cached("PySequenceContains", "ObjectHandling.c")
self.special_bool_cmp_function = "__Pyx_PySequence_Contains"
return True
return False
def generate_operation_code(self, code, result_code,
operand1, op , operand2):
if self.type.is_pyobject:
coerce_result = "__Pyx_PyBool_FromLong"
error_clause = code.error_goto_if_null
got_ref = "__Pyx_XGOTREF(%s); " % result_code
if self.special_bool_cmp_function:
code.globalstate.use_utility_code(
UtilityCode.load_cached("PyBoolOrNullFromLong", "ObjectHandling.c"))
coerce_result = "__Pyx_PyBoolOrNull_FromLong"
else:
coerce_result = "__Pyx_PyBool_FromLong"
else:
error_clause = code.error_goto_if_neg
got_ref = ""
coerce_result = ""
if 'not' in op:
negation = "!"
else:
negation = ""
if self.special_bool_cmp_function:
if operand1.type.is_pyobject:
result1 = operand1.py_result()
......@@ -8737,60 +8758,36 @@ class CmpNode(object):
result2 = operand2.py_result()
else:
result2 = operand2.result()
code.putln("%s = %s(%s, %s, %s); %s" % (
result_code,
self.special_bool_cmp_function,
result1,
result2,
richcmp_constants[op],
code.error_goto_if_neg(result_code, self.pos)))
elif op == 'in' or op == 'not_in':
code.globalstate.use_utility_code(contains_utility_code)
if self.type.is_pyobject:
coerce_result = "__Pyx_PyBoolOrNull_FromLong"
if op == 'not_in':
negation = "__Pyx_NegateNonNeg"
if operand2.type is dict_type:
method = "PyDict_Contains"
else:
method = "PySequence_Contains"
if self.type.is_pyobject:
error_clause = code.error_goto_if_null
got_ref = "__Pyx_XGOTREF(%s); " % result_code
else:
error_clause = code.error_goto_if_neg
got_ref = ""
if self.special_bool_cmp_utility_code:
code.globalstate.use_utility_code(self.special_bool_cmp_utility_code)
code.putln(
"%s = %s(%s(%s(%s, %s))); %s%s" % (
"%s = %s(%s(%s, %s, %s)); %s%s" % (
result_code,
coerce_result,
negation,
method,
operand2.py_result(),
self.special_bool_cmp_function,
result1, result2, richcmp_constants[op],
got_ref,
error_clause(result_code, self.pos)))
elif operand1.type.is_pyobject and op not in ('is', 'is_not'):
assert op not in ('in', 'not_in'), op
code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s%s" % (
result_code,
operand1.py_result(),
operand2.py_result(),
richcmp_constants[op],
got_ref,
error_clause(result_code, self.pos)))
elif (operand1.type.is_pyobject
and op not in ('is', 'is_not')):
code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s" % (
result_code,
operand1.py_result(),
operand2.py_result(),
richcmp_constants[op],
code.error_goto_if_null(result_code, self.pos)))
code.put_gotref(result_code)
elif operand1.type.is_complex:
if op == "!=":
negation = "!"
else:
negation = ""
code.putln("%s = %s(%s%s(%s, %s));" % (
result_code,
coerce_result,
negation,
op == "!=" and "!" or "",
operand1.type.unary_op('eq'),
operand1.result(),
operand2.result()))
else:
type1 = operand1.type
type2 = operand2.type
......@@ -8818,17 +8815,6 @@ class CmpNode(object):
else:
return op
contains_utility_code = UtilityCode(
proto="""
static CYTHON_INLINE int __Pyx_NegateNonNeg(int b) {
return unlikely(b < 0) ? b : !b;
}
static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
return unlikely(b < 0) ? NULL : __Pyx_PyBool_FromLong(b);
}
""")
class PrimaryCmpNode(ExprNode, CmpNode):
# Non-cascaded comparison or first comparison of
# a cascaded sequence.
......@@ -8900,12 +8886,17 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.type = PyrexTypes.c_bint_type
# Will be transformed by IterationTransform
return
elif self.find_special_bool_compare_function(env, self.operand1):
if not self.operand1.type.is_pyobject:
self.operand1 = self.operand1.coerce_to_pyobject(env)
common_type = None # if coercion needed, the method call above has already done it
self.is_pycmp = False # result is bint
else:
if self.operand2.type is dict_type:
self.operand2 = self.operand2.as_none_safe_node("'NoneType' object is not iterable")
common_type = py_object_type
self.is_pycmp = True
elif self.find_special_bool_compare_function(env, self.operand1):
if not self.operand1.type.is_pyobject:
self.operand1 = self.operand1.coerce_to_pyobject(env)
common_type = None # if coercion needed, the method call above has already done it
self.is_pycmp = False # result is bint
else:
......@@ -8918,8 +8909,8 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.coerce_operands_to(common_type, env)
if self.cascade:
self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.optimise_comparison(env, self.operand2)
self.operand2 = self.cascade.optimise_comparison(
self.operand2.coerce_to_simple(env), env)
self.cascade.coerce_cascaded_operands_to_temp(env)
if self.is_python_result():
self.type = PyrexTypes.py_object_type
......@@ -9084,10 +9075,13 @@ class CascadedCmpNode(Node, CmpNode):
def has_python_operands(self):
return self.operand2.type.is_pyobject
def optimise_comparison(self, env, operand1):
self.find_special_bool_compare_function(env, operand1)
def optimise_comparison(self, operand1, env):
if self.find_special_bool_compare_function(env, operand1):
if not operand1.type.is_pyobject:
operand1 = operand1.coerce_to_pyobject(env)
if self.cascade:
self.cascade.optimise_comparison(env, self.operand2)
self.operand2 = self.cascade.optimise_comparison(self.operand2, env)
return operand1
def coerce_operands_to_pyobjects(self, env):
self.operand2 = self.operand2.coerce_to_pyobject(env)
......
......@@ -575,3 +575,23 @@ static CYTHON_INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
#else
#define __Pyx_PyCallable_Check(obj) PyCallable_Check(obj)
#endif
/////////////// PyDictContains.proto ///////////////
static CYTHON_INLINE int __Pyx_PyDict_Contains(PyObject* item, PyObject* dict, int eq) {
int result = PyDict_Contains(dict, item);
return unlikely(result < 0) ? result : (result == (eq == Py_EQ));
}
/////////////// PySequenceContains.proto ///////////////
static CYTHON_INLINE int __Pyx_PySequence_Contains(PyObject* item, PyObject* seq, int eq) {
int result = PySequence_Contains(seq, item);
return unlikely(result < 0) ? result : (result == (eq == Py_EQ));
}
/////////////// PyBoolOrNullFromLong.proto ///////////////
static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
return unlikely(b < 0) ? NULL : __Pyx_PyBool_FromLong(b);
}
......@@ -65,7 +65,6 @@ def unicode_cascade(unicode s1, unicode s2):
"""
return s1 == s2 == u"abcdefg"
''' # NOTE: currently crashes
def unicode_cascade_untyped_end(unicode s1, unicode s2):
"""
>>> unicode_cascade_untyped_end(ustring1, ustring1)
......@@ -76,7 +75,6 @@ def unicode_cascade_untyped_end(unicode s1, unicode s2):
False
"""
return s1 == s2 == u"abcdefg" == (<object>ustring1) == ustring1
'''
# str
......@@ -135,6 +133,17 @@ def str_cascade(str s1, str s2):
"""
return s1 == s2 == "abcdefg"
def str_cascade_untyped_end(str s1, str s2):
"""
>>> str_cascade_untyped_end(string1, string1)
True
>>> str_cascade_untyped_end(string1, (string1+string2)[:len(string1)])
True
>>> str_cascade_untyped_end(string1, string2)
False
"""
return s1 == s2 == "abcdefg" == (<object>string1) == string1
# bytes
def bytes_eq(bytes s1, bytes s2):
......@@ -191,3 +200,14 @@ def bytes_cascade(bytes s1, bytes s2):
False
"""
return s1 == s2 == b"abcdefg"
def bytes_cascade_untyped_end(bytes s1, bytes s2):
"""
>>> bytes_cascade_untyped_end(bstring1, bstring1)
True
>>> bytes_cascade_untyped_end(bstring1, (bstring1+bstring2)[:len(bstring1)])
True
>>> bytes_cascade_untyped_end(bstring1, bstring2)
False
"""
return s1 == s2 == b"abcdefg" == (<object>bstring1) == bstring1
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment