Commit 1a5fdeeb authored by serge-sans-paille's avatar serge-sans-paille Committed by Stefan Behnel

[pythran] properly handle power operator in pythran expression

parent e35311ab
...@@ -11115,12 +11115,19 @@ class BinopNode(ExprNode): ...@@ -11115,12 +11115,19 @@ class BinopNode(ExprNode):
if self.type.is_pythran_expr: if self.type.is_pythran_expr:
code.putln("// Pythran binop") code.putln("// Pythran binop")
code.putln("__Pyx_call_destructor(%s);" % self.result()) code.putln("__Pyx_call_destructor(%s);" % self.result())
code.putln("new (&%s) decltype(%s){%s %s %s};" % ( if self.operator == '**':
self.result(), code.putln("new (&%s) decltype(%s){pythonic::numpy::functor::power{}(%s, %s)};" % (
self.result(), self.result(),
self.operand1.pythran_result(), self.result(),
self.operator, self.operand1.pythran_result(),
self.operand2.pythran_result())) self.operand2.pythran_result()))
else:
code.putln("new (&%s) decltype(%s){%s %s %s};" % (
self.result(),
self.result(),
self.operand1.pythran_result(),
self.operator,
self.operand2.pythran_result()))
elif self.operand1.type.is_pyobject: elif self.operand1.type.is_pyobject:
function = self.py_operation_function(code) function = self.py_operation_function(code)
if self.operator == '**': if self.operator == '**':
......
...@@ -60,8 +60,12 @@ def type_remove_ref(ty): ...@@ -60,8 +60,12 @@ def type_remove_ref(ty):
def pythran_binop_type(op, tA, tB): def pythran_binop_type(op, tA, tB):
return "decltype(std::declval<%s>() %s std::declval<%s>())" % ( if op == '**':
pythran_type(tA), op, pythran_type(tB)) return 'decltype(pythonic::numpy::functor::power{}(std::declval<%s>(), std::declval<%s>()))' % (
pythran_type(tA), pythran_type(tB))
else:
return "decltype(std::declval<%s>() %s std::declval<%s>())" % (
pythran_type(tA), op, pythran_type(tB))
def pythran_unaryop_type(op, type_): def pythran_unaryop_type(op, type_):
...@@ -209,6 +213,7 @@ def include_pythran_generic(env): ...@@ -209,6 +213,7 @@ def include_pythran_generic(env):
env.add_include_file("pythonic/python/core.hpp") env.add_include_file("pythonic/python/core.hpp")
env.add_include_file("pythonic/types/bool.hpp") env.add_include_file("pythonic/types/bool.hpp")
env.add_include_file("pythonic/types/ndarray.hpp") env.add_include_file("pythonic/types/ndarray.hpp")
env.add_include_file("pythonic/numpy/power.hpp")
env.add_include_file("<new>") # for placement new env.add_include_file("<new>") # for placement new
for i in (8, 16, 32, 64): for i in (8, 16, 32, 64):
......
...@@ -13,3 +13,13 @@ def trigo(np.ndarray[double, ndim=1] angles): ...@@ -13,3 +13,13 @@ def trigo(np.ndarray[double, ndim=1] angles):
array([ 1., -1., 1.]) array([ 1., -1., 1.])
""" """
return np.cos(angles) return np.cos(angles)
def power(np.ndarray[double, ndim=1] values):
"""
>>> a = np.array([0., 1., 2.])
>>> res = power(a)
>>> res[0], res[1], res[2]
(0.0, 1.0, 8.0)
"""
return values ** 3
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment