Commit 5bdfbcb5 authored by Stefan Behnel's avatar Stefan Behnel

Cast integers to their expected type in Pythran expressions to enforce template type matches.

parent 49851cae
...@@ -11042,10 +11042,11 @@ class NumBinopNode(BinopNode): ...@@ -11042,10 +11042,11 @@ class NumBinopNode(BinopNode):
self.operand2.result(), self.operand2.result(),
self.overflow_bit_node.overflow_bit) self.overflow_bit_node.overflow_bit)
elif self.type.is_cpp_class or self.infix: elif self.type.is_cpp_class or self.infix:
return "(%s %s %s)" % ( if is_pythran_expr(self.type):
self.operand1.result(), result1, result2 = self.operand1.pythran_result(), self.operand2.pythran_result()
self.operator, else:
self.operand2.result()) result1, result2 = self.operand1.result(), self.operand2.result()
return "(%s %s %s)" % (result1, self.operator, result2)
else: else:
func = self.type.binary_op(self.operator) func = self.type.binary_op(self.operator)
if func is None: if func is None:
...@@ -12385,18 +12386,19 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -12385,18 +12386,19 @@ class PrimaryCmpNode(ExprNode, CmpNode):
return self.operand1.check_const() and self.operand2.check_const() return self.operand1.check_const() and self.operand2.check_const()
def calculate_result_code(self): def calculate_result_code(self):
if self.operand1.type.is_complex: operand1, operand2 = self.operand1, self.operand2
if operand1.type.is_complex:
if self.operator == "!=": if self.operator == "!=":
negation = "!" negation = "!"
else: else:
negation = "" negation = ""
return "(%s%s(%s, %s))" % ( return "(%s%s(%s, %s))" % (
negation, negation,
self.operand1.type.binary_op('=='), operand1.type.binary_op('=='),
self.operand1.result(), operand1.result(),
self.operand2.result()) operand2.result())
elif self.is_c_string_contains(): elif self.is_c_string_contains():
if self.operand2.type is unicode_type: if operand2.type is unicode_type:
method = "__Pyx_UnicodeContainsUCS4" method = "__Pyx_UnicodeContainsUCS4"
else: else:
method = "__Pyx_BytesContains" method = "__Pyx_BytesContains"
...@@ -12407,13 +12409,15 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -12407,13 +12409,15 @@ class PrimaryCmpNode(ExprNode, CmpNode):
return "(%s%s(%s, %s))" % ( return "(%s%s(%s, %s))" % (
negation, negation,
method, method,
self.operand2.result(), operand2.result(),
self.operand1.result()) operand1.result())
else:
if is_pythran_expr(self.type):
result1, result2 = operand1.pythran_result(), operand2.pythran_result()
else: else:
result1 = self.operand1.result() result1, result2 = operand1.result(), operand2.result()
result2 = self.operand2.result()
if self.is_memslice_nonecheck: if self.is_memslice_nonecheck:
if self.operand1.type.is_memoryviewslice: if operand1.type.is_memoryviewslice:
result1 = "((PyObject *) %s.memview)" % result1 result1 = "((PyObject *) %s.memview)" % result1
else: else:
result2 = "((PyObject *) %s.memview)" % result2 result2 = "((PyObject *) %s.memview)" % result2
......
...@@ -75,7 +75,7 @@ def pythran_indexing_code(indices): ...@@ -75,7 +75,7 @@ def pythran_indexing_code(indices):
func = "slice" func = "slice"
return "pythonic::types::%s(%s)" % (func,",".join((v.pythran_result() for v in values))) return "pythonic::types::%s(%s)" % (func,",".join((v.pythran_result() for v in values)))
elif idx.type.is_int: elif idx.type.is_int:
return idx.result() return to_pythran(idx)
elif idx.type.is_pythran_expr: elif idx.type.is_pythran_expr:
return idx.pythran_result() return idx.pythran_result()
raise ValueError("unsupported indice type %s!" % str(idx.type)) raise ValueError("unsupported indice type %s!" % str(idx.type))
...@@ -85,10 +85,12 @@ def pythran_func_type(func, args): ...@@ -85,10 +85,12 @@ def pythran_func_type(func, args):
args = ",".join(("std::declval<%s>()" % pythran_type(a.type) for a in args)) args = ",".join(("std::declval<%s>()" % pythran_type(a.type) for a in args))
return "decltype(pythonic::numpy::functor::%s{}(%s))" % (func, args) return "decltype(pythonic::numpy::functor::%s{}(%s))" % (func, args)
def to_pythran(op,ptype=None): def to_pythran(op, ptype=None):
op_type = op.type op_type = op.type
if is_type(op_type,["is_pythran_expr", "is_int", "is_numeric", "is_float", if op_type.is_int:
"is_complex"]): # Make sure that integer literals always have exactly the type that the templates expect.
return op_type.cast_code(op.result())
if is_type(op_type, ["is_pythran_expr", "is_numeric", "is_float", "is_complex"]):
return op.result() return op.result()
if op.is_none: if op.is_none:
return "pythonic::__builtin__::None" return "pythonic::__builtin__::None"
...@@ -111,8 +113,7 @@ def is_pythran_supported_node_or_none(node): ...@@ -111,8 +113,7 @@ def is_pythran_supported_node_or_none(node):
def is_pythran_supported_type(type_): def is_pythran_supported_type(type_):
pythran_supported = ( pythran_supported = (
"is_pythran_expr", "is_int", "is_numeric", "is_float", "is_none", "is_pythran_expr", "is_int", "is_numeric", "is_float", "is_none", "is_complex")
"is_complex")
return is_type(type_, pythran_supported) or is_pythran_expr(type_) return is_type(type_, pythran_supported) or is_pythran_expr(type_)
def is_pythran_supported_operation_type(type_): def is_pythran_supported_operation_type(type_):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment