Commit bd7284c2 authored by Mark Florisson's avatar Mark Florisson

Improve fused specialization strings

parent d96dfdbb
......@@ -4443,9 +4443,13 @@ class AttributeNode(ExprNode):
if entry:
if obj_type.is_extension_type and entry.name == "__weakref__":
error(self.pos, "Illegal use of special attribute __weakref__")
# methods need the normal attribute lookup
# def methods need the normal attribute lookup
# because they do not have struct entries
if entry.is_variable or entry.is_cmethod:
# fused function go through assignment synthesis
# (foo = pycfunction(foo_func_obj)) and need to go through
# regular Python lookup as well
if (entry.is_variable and not entry.fused_cfunction) or entry.is_cmethod:
self.type = entry.type
self.member = entry.cname
return
......@@ -7151,7 +7155,7 @@ class CythonArrayNode(ExprNode):
else:
return error()
if not base_type.same_as(array_dtype):
if not (base_type.same_as(array_dtype) or base_type.is_void):
return error(self.operand.pos, ERR_BASE_TYPE)
elif self.operand.type.is_array and len(array_dimension_sizes) != ndim:
return error(self.operand.pos,
......
......@@ -880,7 +880,8 @@ context = {
memviewslice_declare_code = load_memview_c_utility(
"MemviewSliceStruct",
proto_block='utility_code_proto_before_types',
context=context)
context=context,
requires=[])
atomic_utility = load_memview_c_utility("Atomics", context,
proto_block='utility_code_proto_before_types')
......@@ -923,4 +924,5 @@ view_utility_whitelist = ('array', 'memoryview', 'array_cwrapper',
'generic', 'strided', 'indirect', 'contiguous',
'indirect_contiguous')
memviewslice_declare_code.requires.append(view_utility_code)
copy_contents_new_utility.requires.append(view_utility_code)
\ No newline at end of file
......@@ -1038,8 +1038,8 @@ class FusedTypeNode(CBaseTypeNode):
else:
types.append(type)
if len(self.types) == 1:
return types[0]
# if len(self.types) == 1:
# return types[0]
return PyrexTypes.FusedType(types, name=self.name)
......@@ -2286,7 +2286,6 @@ class FusedCFuncDefNode(StatListNode):
else:
node.py_func.fused_py_func = self.py_func
node.entry.as_variable = self.py_func.entry
# Copy the nodes as AnalyseDeclarationsTransform will prepend
# self.py_func to self.stats, as we only want specialized
# CFuncDefNodes in self.nodes
......@@ -2304,9 +2303,6 @@ class FusedCFuncDefNode(StatListNode):
fused_compound_types = PyrexTypes.unique(
[arg.type for arg in self.node.args if arg.type.is_fused])
permutations = PyrexTypes.get_all_specialized_permutations(fused_compound_types)
fused_types = [fused_type
for fused_compound_type in fused_compound_types
for fused_type in fused_compound_type.get_fused_types()]
if self.node.entry in env.pyfunc_entries:
env.pyfunc_entries.remove(self.node.entry)
......@@ -2321,7 +2317,7 @@ class FusedCFuncDefNode(StatListNode):
copied_node.analyse_declarations(env)
self.create_new_local_scope(copied_node, env, fused_to_specific)
self.specialize_copied_def(copied_node, cname, self.node.entry,
fused_to_specific, fused_types)
fused_to_specific, fused_compound_types)
PyrexTypes.specialize_entry(copied_node.entry, cname)
copied_node.entry.used = True
......@@ -2430,11 +2426,9 @@ class FusedCFuncDefNode(StatListNode):
"""Specialize the copy of a DefNode given the copied node,
the specialization cname and the original DefNode entry"""
type_strings = [
fused_type.specialize(f2s).typeof_name()
PyrexTypes.specialization_signature_string(fused_type, f2s)
for fused_type in fused_types
]
#type_strings = [f2s[fused_type].typeof_name()
# for fused_type in fused_types]
node.specialized_signature_string = ', '.join(type_strings)
......
......@@ -2610,8 +2610,8 @@ class ReplaceFusedTypeChecks(VisitorTransform):
else:
types = PyrexTypes.get_specialized_types(type2)
for specific_type in types:
if type1.same_as(specific_type):
for specialized_type in types:
if type1.same_as(specialized_type):
if op == 'in':
return true_node
else:
......
......@@ -781,13 +781,26 @@ class BufferType(BaseType):
return "<BufferType %r>" % self.base
def __str__(self):
# avoid ', ', as fused functions split the signature string on ', '
if self.cast:
cast_str = ',cast=True'
else:
cast_str = ''
return "%s[%s,ndim=%d%s]" % (self.base, self.dtype, self.ndim,
cast_str)
def assignable_from(self, other_type):
return self.same_as(other_type) or other_type.is_pyobject
def same_as(self, other_type):
return (other_type.is_buffer and
self.dtype.same_as(other_type.dtype) and
self.ndim == other_type.ndim and
self.mode == other_type.mode and
self.cast == other_type.cast)
class PyObjectType(PyrexType):
#
# Base class for all Python object types (reference-counted).
......@@ -2645,6 +2658,31 @@ def _get_all_specialized_permutations(fused_types, id="", f2s=()):
return result
def specialization_signature_string(fused_compound_type, fused_to_specific):
"""
Return the signature for a specialization of a fused type. e.g.
floating[:] ->
'float' or 'double'
cdef fused ft:
float[:]
double[:]
ft ->
'float[:]' or 'double[:]'
integral func(floating) ->
'int (*func)(float)' or ...
"""
fused_types = fused_compound_type.get_fused_types()
if len(fused_types) == 1:
fused_type = fused_types[0]
else:
fused_type = fused_compound_type
return fused_type.specialize(fused_to_specific).typeof_name()
def get_specialized_types(type):
"""
Return a list of specialized types sorted in reverse order in accordance
......@@ -2654,10 +2692,15 @@ def get_specialized_types(type):
if isinstance(type, FusedType):
result = type.types
for specialized_type in result:
specialized_type.specialization_string = specialized_type.typeof_name()
else:
result = []
for cname, f2s in get_all_specialized_permutations(type.get_fused_types()):
result.append(type.specialize(f2s))
specialized_type = type.specialize(f2s)
specialized_type.specialization_string = (
specialization_signature_string(type, f2s))
result.append(specialized_type)
return sorted(result)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment