Commit 480f6687 authored by Mark Florisson's avatar Mark Florisson

Fix ndarray specialization

parent 1eabb1c6
......@@ -623,7 +623,7 @@ class ExprNode(Node):
return self
if src_type.is_fused:
error(self.pos, "Type is not specific")
error(self.pos, "Type is not specialized")
else:
error(self.pos, "Cannot coerce to a type that is not specialized")
......
......@@ -963,6 +963,8 @@ class TemplatedTypeNode(CBaseTypeNode):
for name, value in options.items() ])
self.type = PyrexTypes.BufferType(base_type, **options)
if self.type.is_fused and env.fused_to_specific:
self.type = self.type.specialize(env.fused_to_specific)
else:
# Array
......
......@@ -1495,12 +1495,15 @@ if VALUE is not None:
self.fused_function = None
if node.py_func:
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc
# Create assignment node for our def function
node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env)
else:
......
......@@ -767,6 +767,13 @@ class BufferType(BaseType):
def as_argument_type(self):
return self
def specialize(self, values):
dtype = self.dtype.specialize(values)
if dtype is not self.dtype:
return BufferType(self.base, dtype, self.ndim, self.mode,
self.negative_indices, self.cast)
return self
def __getattr__(self, name):
return getattr(self.base, name)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment