Commit bd7284c2 authored by Mark Florisson's avatar Mark Florisson

Improve fused specialization strings

parent d96dfdbb
...@@ -4443,9 +4443,13 @@ class AttributeNode(ExprNode): ...@@ -4443,9 +4443,13 @@ class AttributeNode(ExprNode):
if entry: if entry:
if obj_type.is_extension_type and entry.name == "__weakref__": if obj_type.is_extension_type and entry.name == "__weakref__":
error(self.pos, "Illegal use of special attribute __weakref__") error(self.pos, "Illegal use of special attribute __weakref__")
# methods need the normal attribute lookup
# def methods need the normal attribute lookup
# because they do not have struct entries # because they do not have struct entries
if entry.is_variable or entry.is_cmethod: # fused function go through assignment synthesis
# (foo = pycfunction(foo_func_obj)) and need to go through
# regular Python lookup as well
if (entry.is_variable and not entry.fused_cfunction) or entry.is_cmethod:
self.type = entry.type self.type = entry.type
self.member = entry.cname self.member = entry.cname
return return
...@@ -7151,7 +7155,7 @@ class CythonArrayNode(ExprNode): ...@@ -7151,7 +7155,7 @@ class CythonArrayNode(ExprNode):
else: else:
return error() return error()
if not base_type.same_as(array_dtype): if not (base_type.same_as(array_dtype) or base_type.is_void):
return error(self.operand.pos, ERR_BASE_TYPE) return error(self.operand.pos, ERR_BASE_TYPE)
elif self.operand.type.is_array and len(array_dimension_sizes) != ndim: elif self.operand.type.is_array and len(array_dimension_sizes) != ndim:
return error(self.operand.pos, return error(self.operand.pos,
......
...@@ -880,7 +880,8 @@ context = { ...@@ -880,7 +880,8 @@ context = {
memviewslice_declare_code = load_memview_c_utility( memviewslice_declare_code = load_memview_c_utility(
"MemviewSliceStruct", "MemviewSliceStruct",
proto_block='utility_code_proto_before_types', proto_block='utility_code_proto_before_types',
context=context) context=context,
requires=[])
atomic_utility = load_memview_c_utility("Atomics", context, atomic_utility = load_memview_c_utility("Atomics", context,
proto_block='utility_code_proto_before_types') proto_block='utility_code_proto_before_types')
...@@ -923,4 +924,5 @@ view_utility_whitelist = ('array', 'memoryview', 'array_cwrapper', ...@@ -923,4 +924,5 @@ view_utility_whitelist = ('array', 'memoryview', 'array_cwrapper',
'generic', 'strided', 'indirect', 'contiguous', 'generic', 'strided', 'indirect', 'contiguous',
'indirect_contiguous') 'indirect_contiguous')
memviewslice_declare_code.requires.append(view_utility_code)
copy_contents_new_utility.requires.append(view_utility_code) copy_contents_new_utility.requires.append(view_utility_code)
\ No newline at end of file
...@@ -1038,8 +1038,8 @@ class FusedTypeNode(CBaseTypeNode): ...@@ -1038,8 +1038,8 @@ class FusedTypeNode(CBaseTypeNode):
else: else:
types.append(type) types.append(type)
if len(self.types) == 1: # if len(self.types) == 1:
return types[0] # return types[0]
return PyrexTypes.FusedType(types, name=self.name) return PyrexTypes.FusedType(types, name=self.name)
...@@ -2286,7 +2286,6 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2286,7 +2286,6 @@ class FusedCFuncDefNode(StatListNode):
else: else:
node.py_func.fused_py_func = self.py_func node.py_func.fused_py_func = self.py_func
node.entry.as_variable = self.py_func.entry node.entry.as_variable = self.py_func.entry
# Copy the nodes as AnalyseDeclarationsTransform will prepend # Copy the nodes as AnalyseDeclarationsTransform will prepend
# self.py_func to self.stats, as we only want specialized # self.py_func to self.stats, as we only want specialized
# CFuncDefNodes in self.nodes # CFuncDefNodes in self.nodes
...@@ -2304,9 +2303,6 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2304,9 +2303,6 @@ class FusedCFuncDefNode(StatListNode):
fused_compound_types = PyrexTypes.unique( fused_compound_types = PyrexTypes.unique(
[arg.type for arg in self.node.args if arg.type.is_fused]) [arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types) permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types)
fused_types = [fused_type
for fused_compound_type in fused_compound_types
for fused_type in fused_compound_type.get_fused_types()]
if self.node.entry in env.pyfunc_entries: if self.node.entry in env.pyfunc_entries:
env.pyfunc_entries.remove(self.node.entry) env.pyfunc_entries.remove(self.node.entry)
...@@ -2321,7 +2317,7 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2321,7 +2317,7 @@ class FusedCFuncDefNode(StatListNode):
copied_node.analyse_declarations(env) copied_node.analyse_declarations(env)
self.create_new_local_scope(copied_node, env, fused_to_specific) self.create_new_local_scope(copied_node, env, fused_to_specific)
self.specialize_copied_def(copied_node, cname, self.node.entry, self.specialize_copied_def(copied_node, cname, self.node.entry,
fused_to_specific, fused_types) fused_to_specific, fused_compound_types)
PyrexTypes.specialize_entry(copied_node.entry, cname) PyrexTypes.specialize_entry(copied_node.entry, cname)
copied_node.entry.used = True copied_node.entry.used = True
...@@ -2430,11 +2426,9 @@ class FusedCFuncDefNode(StatListNode): ...@@ -2430,11 +2426,9 @@ class FusedCFuncDefNode(StatListNode):
"""Specialize the copy of a DefNode given the copied node, """Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry""" the specialization cname and the original DefNode entry"""
type_strings = [ type_strings = [
fused_type.specialize(f2s).typeof_name() PyrexTypes.specialization_signature_string(fused_type, f2s)
for fused_type in fused_types for fused_type in fused_types
] ]
#type_strings = [f2s[fused_type].typeof_name()
# for fused_type in fused_types]
node.specialized_signature_string = ', '.join(type_strings) node.specialized_signature_string = ', '.join(type_strings)
......
...@@ -2610,8 +2610,8 @@ class ReplaceFusedTypeChecks(VisitorTransform): ...@@ -2610,8 +2610,8 @@ class ReplaceFusedTypeChecks(VisitorTransform):
else: else:
types = PyrexTypes.get_specialized_types(type2) types = PyrexTypes.get_specialized_types(type2)
for specific_type in types: for specialized_type in types:
if type1.same_as(specific_type): if type1.same_as(specialized_type):
if op == 'in': if op == 'in':
return true_node return true_node
else: else:
......
...@@ -781,13 +781,26 @@ class BufferType(BaseType): ...@@ -781,13 +781,26 @@ class BufferType(BaseType):
return "<BufferType %r>" % self.base return "<BufferType %r>" % self.base
def __str__(self): def __str__(self):
# avoid ', ', as fused functions split the signature string on ', '
if self.cast: if self.cast:
cast_str = ',cast=True' cast_str = ',cast=True'
else: else:
cast_str = '' cast_str = ''
return "%s[%s,ndim=%d%s]" % (self.base, self.dtype, self.ndim, return "%s[%s,ndim=%d%s]" % (self.base, self.dtype, self.ndim,
cast_str) cast_str)
def assignable_from(self, other_type):
return self.same_as(other_type) or other_type.is_pyobject
def same_as(self, other_type):
return (other_type.is_buffer and
self.dtype.same_as(other_type.dtype) and
self.ndim == other_type.ndim and
self.mode == other_type.mode and
self.cast == other_type.cast)
class PyObjectType(PyrexType): class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
...@@ -2645,6 +2658,31 @@ def _get_all_specialized_permutations(fused_types, id="", f2s=()): ...@@ -2645,6 +2658,31 @@ def _get_all_specialized_permutations(fused_types, id="", f2s=()):
return result return result
def specialization_signature_string(fused_compound_type, fused_to_specific):
"""
Return the signature for a specialization of a fused type. e.g.
floating[:] ->
'float' or 'double'
cdef fused ft:
float[:]
double[:]
ft ->
'float[:]' or 'double[:]'
integral func(floating) ->
'int (*func)(float)' or ...
"""
fused_types = fused_compound_type.get_fused_types()
if len(fused_types) == 1:
fused_type = fused_types[0]
else:
fused_type = fused_compound_type
return fused_type.specialize(fused_to_specific).typeof_name()
def get_specialized_types(type): def get_specialized_types(type):
""" """
Return a list of specialized types sorted in reverse order in accordance Return a list of specialized types sorted in reverse order in accordance
...@@ -2654,10 +2692,15 @@ def get_specialized_types(type): ...@@ -2654,10 +2692,15 @@ def get_specialized_types(type):
if isinstance(type, FusedType): if isinstance(type, FusedType):
result = type.types result = type.types
for specialized_type in result:
specialized_type.specialization_string = specialized_type.typeof_name()
else: else:
result = [] result = []
for cname, f2s in get_all_specialized_permutations(type.get_fused_types()): for cname, f2s in get_all_specialized_permutations(type.get_fused_types()):
result.append(type.specialize(f2s)) specialized_type = type.specialize(f2s)
specialized_type.specialization_string = (
specialization_signature_string(type, f2s))
result.append(specialized_type)
return sorted(result) return sorted(result)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment