Commit 685dbe0d authored by Robert Bradshaw's avatar Robert Bradshaw

Use PyObject_RichCompare rather than PyObject_Cmp

This is what the interpreter does, and allows one to
get at the actual object (rather than just its truth
value).
parent ebbac140
...@@ -2836,6 +2836,16 @@ class CondExprNode(ExprNode): ...@@ -2836,6 +2836,16 @@ class CondExprNode(ExprNode):
code.putln("}") code.putln("}")
self.test.generate_disposal_code(code) self.test.generate_disposal_code(code)
richcmp_constants = {
"<" : "Py_LT",
"<=": "Py_LE",
"==": "Py_EQ",
"!=": "Py_NE",
"<>": "Py_NE",
">" : "Py_GT",
">=": "Py_GE",
}
class CmpNode: class CmpNode:
# Mixin class containing code common to PrimaryCmpNodes # Mixin class containing code common to PrimaryCmpNodes
# and CascadedCmpNodes. # and CascadedCmpNodes.
...@@ -2845,6 +2855,10 @@ class CmpNode: ...@@ -2845,6 +2855,10 @@ class CmpNode:
or (self.cascade and self.cascade.is_python_comparison()) or (self.cascade and self.cascade.is_python_comparison())
or self.operator in ('in', 'not_in')) or self.operator in ('in', 'not_in'))
def is_python_result(self):
return ((self.has_python_operands() and self.operator not in ('is', 'is_not', 'in', 'not_in'))
or (self.cascade and self.cascade.is_python_result()))
def check_types(self, env, operand1, op, operand2): def check_types(self, env, operand1, op, operand2):
if not self.types_okay(operand1, op, operand2): if not self.types_okay(operand1, op, operand2):
error(self.pos, "Invalid types for '%s' (%s, %s)" % error(self.pos, "Invalid types for '%s' (%s, %s)" %
...@@ -2871,30 +2885,33 @@ class CmpNode: ...@@ -2871,30 +2885,33 @@ class CmpNode:
def generate_operation_code(self, code, result_code, def generate_operation_code(self, code, result_code,
operand1, op , operand2): operand1, op , operand2):
if self.type is PyrexTypes.py_object_type:
coerce_result = "__Pyx_PyBool_FromLong"
else:
coerce_result = ""
if 'not' in op: negation = "!"
else: negation = ""
if op == 'in' or op == 'not_in': if op == 'in' or op == 'not_in':
code.putln( code.putln(
"%s = PySequence_Contains(%s, %s); %s" % ( "%s = %s(%sPySequence_Contains(%s, %s)); %s" % (
result_code, result_code,
coerce_result,
negation,
operand2.py_result(), operand2.py_result(),
operand1.py_result(), operand1.py_result(),
code.error_goto_if_neg(result_code, self.pos))) code.error_goto_if_neg(result_code, self.pos)))
if op == 'not_in':
code.putln(
"%s = !%s;" % (
result_code, result_code))
elif (operand1.type.is_pyobject elif (operand1.type.is_pyobject
and op not in ('is', 'is_not')): and op not in ('is', 'is_not')):
code.put_error_if_neg(self.pos, code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s" % (
"PyObject_Cmp(%s, %s, &%s)" % ( result_code,
operand1.py_result(), operand1.py_result(),
operand2.py_result(), operand2.py_result(),
result_code)) richcmp_constants[op],
code.putln( code.error_goto_if_null(result_code, self.pos)))
"%s = %s %s 0;" % (
result_code, result_code, op))
else: else:
code.putln("%s = %s %s %s;" % ( code.putln("%s = %s(%s %s %s);" % (
result_code, result_code,
coerce_result,
operand1.result_code, operand1.result_code,
self.c_operator(op), self.c_operator(op),
operand2.result_code)) operand2.result_code))
...@@ -2937,7 +2954,14 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -2937,7 +2954,14 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.operand2 = self.operand2.coerce_to_simple(env) self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.coerce_cascaded_operands_to_temp(env) self.cascade.coerce_cascaded_operands_to_temp(env)
self.check_operand_types(env) self.check_operand_types(env)
self.type = PyrexTypes.c_bint_type if self.is_python_result():
self.type = PyrexTypes.py_object_type
else:
self.type = PyrexTypes.c_bint_type
cdr = self.cascade
while cdr:
cdr.type = self.type
cdr = cdr.cascade
if self.is_pycmp or self.cascade: if self.is_pycmp or self.cascade:
self.is_temp = 1 self.is_temp = 1
...@@ -3048,7 +3072,10 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -3048,7 +3072,10 @@ class CascadedCmpNode(Node, CmpNode):
self.cascade.release_subexpr_temps(env) self.cascade.release_subexpr_temps(env)
def generate_evaluation_code(self, code, result, operand1): def generate_evaluation_code(self, code, result, operand1):
code.putln("if (%s) {" % result) if self.type.is_pyobject:
code.putln("if (__Pyx_PyObject_IsTrue(%s)) {" % result)
else:
code.putln("if (%s) {" % result)
self.operand2.generate_evaluation_code(code) self.operand2.generate_evaluation_code(code)
self.generate_operation_code(code, result, self.generate_operation_code(code, result,
operand1, self.operator, self.operand2) operand1, self.operator, self.operand2)
...@@ -3242,7 +3269,7 @@ class CoerceToBooleanNode(CoercionNode): ...@@ -3242,7 +3269,7 @@ class CoerceToBooleanNode(CoercionNode):
def generate_result_code(self, code): def generate_result_code(self, code):
if self.arg.type.is_pyobject: if self.arg.type.is_pyobject:
code.putln( code.putln(
"%s = PyObject_IsTrue(%s); %s" % ( "%s = __Pyx_PyObject_IsTrue(%s); %s" % (
self.result_code, self.result_code,
self.arg.py_result(), self.arg.py_result(),
code.error_goto_if_neg(self.result_code, self.pos))) code.error_goto_if_neg(self.result_code, self.pos)))
......
...@@ -346,8 +346,6 @@ class CIntType(CNumericType): ...@@ -346,8 +346,6 @@ class CIntType(CNumericType):
class CBIntType(CIntType): class CBIntType(CIntType):
# TODO: this should be a macro "(__ ? Py_True : Py_False)"
# and no error checking should be needed (just an incref).
to_py_function = "__Pyx_PyBool_FromLong" to_py_function = "__Pyx_PyBool_FromLong"
from_py_function = "__Pyx_PyObject_IsTrue" from_py_function = "__Pyx_PyObject_IsTrue"
exception_check = 0 exception_check = 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment