Commit 480f6687 authored by Mark Florisson's avatar Mark Florisson

Fix ndarray specialization

parent 1eabb1c6
...@@ -623,7 +623,7 @@ class ExprNode(Node): ...@@ -623,7 +623,7 @@ class ExprNode(Node):
return self return self
if src_type.is_fused: if src_type.is_fused:
error(self.pos, "Type is not specific") error(self.pos, "Type is not specialized")
else: else:
error(self.pos, "Cannot coerce to a type that is not specialized") error(self.pos, "Cannot coerce to a type that is not specialized")
......
...@@ -963,6 +963,8 @@ class TemplatedTypeNode(CBaseTypeNode): ...@@ -963,6 +963,8 @@ class TemplatedTypeNode(CBaseTypeNode):
for name, value in options.items() ]) for name, value in options.items() ])
self.type = PyrexTypes.BufferType(base_type, **options) self.type = PyrexTypes.BufferType(base_type, **options)
if self.type.is_fused and env.fused_to_specific:
self.type = self.type.specialize(env.fused_to_specific)
else: else:
# Array # Array
......
...@@ -1495,12 +1495,15 @@ if VALUE is not None: ...@@ -1495,12 +1495,15 @@ if VALUE is not None:
self.fused_function = None self.fused_function = None
if node.py_func: if node.py_func:
# Create PyCFunction nodes for each specialization
node.stats.insert(0, node.py_func) node.stats.insert(0, node.py_func)
node.py_func = self.visit(node.py_func) node.py_func = self.visit(node.py_func)
pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func, pycfunc = ExprNodes.PyCFunctionNode.from_defnode(node.py_func,
True) True)
pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env)) pycfunc = ExprNodes.ProxyNode(pycfunc.coerce_to_temp(env))
node.resulting_fused_function = pycfunc node.resulting_fused_function = pycfunc
# Create assignment node for our def function
node.fused_func_assignment = self._create_assignment( node.fused_func_assignment = self._create_assignment(
node.py_func, ExprNodes.CloneNode(pycfunc), env) node.py_func, ExprNodes.CloneNode(pycfunc), env)
else: else:
......
...@@ -767,6 +767,13 @@ class BufferType(BaseType): ...@@ -767,6 +767,13 @@ class BufferType(BaseType):
def as_argument_type(self): def as_argument_type(self):
return self return self
def specialize(self, values):
dtype = self.dtype.specialize(values)
if dtype is not self.dtype:
return BufferType(self.base, dtype, self.ndim, self.mode,
self.negative_indices, self.cast)
return self
def __getattr__(self, name): def __getattr__(self, name):
return getattr(self.base, name) return getattr(self.base, name)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment