Commit 831e3114 authored by serge-sans-paille's avatar serge-sans-paille

[pythran] properly handle power operator in pythran expression

parent 6178a049
...@@ -11115,6 +11115,13 @@ class BinopNode(ExprNode): ...@@ -11115,6 +11115,13 @@ class BinopNode(ExprNode):
if self.type.is_pythran_expr: if self.type.is_pythran_expr:
code.putln("// Pythran binop") code.putln("// Pythran binop")
code.putln("__Pyx_call_destructor(%s);" % self.result()) code.putln("__Pyx_call_destructor(%s);" % self.result())
if self.operator == '**':
code.putln("new (&%s) decltype(%s){pythonic::numpy::functor::power{}(%s, %s)};" % (
self.result(),
self.result(),
self.operand1.pythran_result(),
self.operand2.pythran_result()))
else:
code.putln("new (&%s) decltype(%s){%s %s %s};" % ( code.putln("new (&%s) decltype(%s){%s %s %s};" % (
self.result(), self.result(),
self.result(), self.result(),
......
...@@ -60,6 +60,10 @@ def type_remove_ref(ty): ...@@ -60,6 +60,10 @@ def type_remove_ref(ty):
def pythran_binop_type(op, tA, tB): def pythran_binop_type(op, tA, tB):
if op == '**':
return 'decltype(pythonic::numpy::functor::power{}(std::declval<%s>(), std::declval<%s>()))' % (
pythran_type(tA), pythran_type(tB))
else:
return "decltype(std::declval<%s>() %s std::declval<%s>())" % ( return "decltype(std::declval<%s>() %s std::declval<%s>())" % (
pythran_type(tA), op, pythran_type(tB)) pythran_type(tA), op, pythran_type(tB))
...@@ -209,6 +213,7 @@ def include_pythran_generic(env): ...@@ -209,6 +213,7 @@ def include_pythran_generic(env):
env.add_include_file("pythonic/python/core.hpp") env.add_include_file("pythonic/python/core.hpp")
env.add_include_file("pythonic/types/bool.hpp") env.add_include_file("pythonic/types/bool.hpp")
env.add_include_file("pythonic/types/ndarray.hpp") env.add_include_file("pythonic/types/ndarray.hpp")
env.add_include_file("pythonic/numpy/power.hpp")
env.add_include_file("<new>") # for placement new env.add_include_file("<new>") # for placement new
for i in (8, 16, 32, 64): for i in (8, 16, 32, 64):
......
...@@ -13,3 +13,13 @@ def trigo(np.ndarray[double, ndim=1] angles): ...@@ -13,3 +13,13 @@ def trigo(np.ndarray[double, ndim=1] angles):
array([ 1., -1., 1.]) array([ 1., -1., 1.])
""" """
return np.cos(angles) return np.cos(angles)
def power(np.ndarray[double, ndim=1] values):
"""
>>> a = np.array([0., 1., 2.])
>>> res = power(a)
>>> res[0], res[1], res[2]
(0.0, 1.0, 8.0)
"""
return values ** 3
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment