Commit 3e76b40f authored by Robert Bradshaw's avatar Robert Bradshaw

Obviate the need for forward-declaring structs/unions/enums/cdef classes.

parent 6aa0d3cf
...@@ -103,7 +103,8 @@ class Context(object): ...@@ -103,7 +103,8 @@ class Context(object):
def create_pipeline(self, pxd, py=False): def create_pipeline(self, pxd, py=False):
from Visitor import PrintTree from Visitor import PrintTree
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import ForwardDeclareTypes, AnalyseDeclarationsTransform
from ParseTreeTransforms import AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
...@@ -146,6 +147,7 @@ class Context(object): ...@@ -146,6 +147,7 @@ class Context(object):
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(self), WithTransform(self),
DecoratorTransform(self), DecoratorTransform(self),
ForwardDeclareTypes(self),
AnalyseDeclarationsTransform(self), AnalyseDeclarationsTransform(self),
AutoTestDictTransform(self), AutoTestDictTransform(self),
EmbedSignature(self), EmbedSignature(self),
......
...@@ -427,20 +427,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -427,20 +427,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
for entry in module.type_entries: for entry in module.type_entries:
if entry.defined_in_pxd: if entry.defined_in_pxd:
type_entries.append(entry) type_entries.append(entry)
for entry in type_entries: self.generate_type_header_code(type_entries, code)
if not entry.in_cinclude:
#print "generate_type_header_code:", entry.name, repr(entry.type) ###
type = entry.type
if type.is_typedef: # Must test this first!
self.generate_typedef(entry, code)
elif type.is_struct_or_union:
self.generate_struct_union_definition(entry, code)
elif type.is_enum:
self.generate_enum_definition(entry, code)
elif type.is_extension_type and entry not in vtabslot_entries:
self.generate_objstruct_definition(type, code)
for entry in vtabslot_list: for entry in vtabslot_list:
self.generate_objstruct_definition(entry.type, code) # self.generate_objstruct_definition(entry.type, code)
self.generate_typeobj_predeclaration(entry, code) self.generate_typeobj_predeclaration(entry, code)
for entry in vtab_list: for entry in vtab_list:
self.generate_typeobj_predeclaration(entry, code) self.generate_typeobj_predeclaration(entry, code)
...@@ -782,17 +771,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -782,17 +771,28 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_type_header_code(self, type_entries, code): def generate_type_header_code(self, type_entries, code):
# Generate definitions of structs/unions/enums/typedefs/objstructs. # Generate definitions of structs/unions/enums/typedefs/objstructs.
#self.generate_gcc33_hack(env, code) # Is this still needed? #self.generate_gcc33_hack(env, code) # Is this still needed?
#for entry in env.type_entries: # Forward declarations
for entry in type_entries: for entry in type_entries:
if not entry.in_cinclude: if not entry.in_cinclude:
#print "generate_type_header_code:", entry.name, repr(entry.type) ### #print "generate_type_header_code:", entry.name, repr(entry.type) ###
type = entry.type type = entry.type
if type.is_typedef: # Must test this first! if type.is_typedef: # Must test this first!
self.generate_typedef(entry, code) pass
elif type.is_struct_or_union: elif type.is_struct_or_union:
self.generate_struct_union_definition(entry, code) self.generate_struct_union_predeclaration(entry, code)
elif type.is_extension_type:
self.generate_objstruct_predeclaration(type, code)
# Actual declarations
for entry in type_entries:
if not entry.in_cinclude:
#print "generate_type_header_code:", entry.name, repr(entry.type) ###
type = entry.type
if type.is_typedef: # Must test this first!
self.generate_typedef(entry, code)
elif type.is_enum: elif type.is_enum:
self.generate_enum_definition(entry, code) self.generate_enum_definition(entry, code)
elif type.is_struct_or_union:
self.generate_struct_union_definition(entry, code)
elif type.is_extension_type: elif type.is_extension_type:
self.generate_objstruct_definition(type, code) self.generate_objstruct_definition(type, code)
...@@ -822,13 +822,21 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -822,13 +822,21 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
writer.mark_pos(entry.pos) writer.mark_pos(entry.pos)
writer.putln("typedef %s;" % base_type.declaration_code(entry.cname)) writer.putln("typedef %s;" % base_type.declaration_code(entry.cname))
def sue_header_footer(self, type, kind, name): def sue_predeclaration(self, type, kind, name):
if type.typedef_flag: if type.typedef_flag:
header = "typedef %s {" % kind return "%s %s;\ntypedef %s %s %s;" % (
footer = "} %s;" % name kind, name,
kind, name, name)
else: else:
header = "%s %s {" % (kind, name) return "%s %s;" % (kind, name)
footer = "};"
def generate_struct_union_predeclaration(self, entry, code):
type = entry.type
code.putln(self.sue_predeclaration(type, type.kind, type.cname))
def sue_header_footer(self, type, kind, name):
header = "%s %s {" % (kind, name)
footer = "};"
return header, footer return header, footer
def generate_struct_union_definition(self, entry, code): def generate_struct_union_definition(self, entry, code):
...@@ -897,6 +905,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -897,6 +905,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
value_code += "," value_code += ","
code.putln(value_code) code.putln(value_code)
code.putln(footer) code.putln(footer)
if entry.type.typedef_flag:
# Not pre-declared.
code.putln("typedef enum %s %s;" % (name, name))
def generate_typeobj_predeclaration(self, entry, code): def generate_typeobj_predeclaration(self, entry, code):
code.putln("") code.putln("")
...@@ -946,6 +957,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -946,6 +957,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
type.vtabstruct_cname, type.vtabstruct_cname,
type.vtabptr_cname)) type.vtabptr_cname))
def generate_objstruct_predeclaration(self, type, code):
if not type.scope:
return
code.putln(self.sue_predeclaration(type, "struct", type.objstruct_cname))
def generate_objstruct_definition(self, type, code): def generate_objstruct_definition(self, type, code):
code.mark_pos(type.pos) code.mark_pos(type.pos)
# Generate object struct definition for an # Generate object struct definition for an
......
...@@ -979,45 +979,32 @@ class CStructOrUnionDefNode(StatNode): ...@@ -979,45 +979,32 @@ class CStructOrUnionDefNode(StatNode):
# packed boolean # packed boolean
child_attrs = ["attributes"] child_attrs = ["attributes"]
def analyse_declarations(self, env): def declare(self, env, scope=None):
scope = None if self.visibility == 'extern' and self.packed and not scope:
if self.visibility == 'extern' and self.packed:
error(self.pos, "Cannot declare extern struct as 'packed'") error(self.pos, "Cannot declare extern struct as 'packed'")
if self.attributes is not None:
scope = StructOrUnionScope(self.name)
self.entry = env.declare_struct_or_union( self.entry = env.declare_struct_or_union(
self.name, self.kind, scope, self.typedef_flag, self.pos, self.name, self.kind, scope, self.typedef_flag, self.pos,
self.cname, visibility = self.visibility, api = self.api, self.cname, visibility = self.visibility, api = self.api,
packed = self.packed) packed = self.packed)
def analyse_declarations(self, env):
scope = None
if self.attributes is not None:
scope = StructOrUnionScope(self.name)
self.declare(env, scope)
if self.attributes is not None: if self.attributes is not None:
if self.in_pxd and not env.in_cinclude: if self.in_pxd and not env.in_cinclude:
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
for attr in self.attributes: for attr in self.attributes:
attr.analyse_declarations(env, scope) attr.analyse_declarations(env, scope)
if self.visibility != 'extern': if self.visibility != 'extern':
need_typedef_indirection = False
for attr in scope.var_entries: for attr in scope.var_entries:
type = attr.type type = attr.type
while type.is_array: while type.is_array:
type = type.base_type type = type.base_type
if type == self.entry.type: if type == self.entry.type:
error(attr.pos, "Struct cannot contain itself as a member.") error(attr.pos, "Struct cannot contain itself as a member.")
if self.typedef_flag:
while type.is_ptr:
type = type.base_type
if type == self.entry.type:
need_typedef_indirection = True
if need_typedef_indirection:
# C can't handle typedef structs that refer to themselves.
struct_entry = self.entry
self.entry = env.declare_typedef(
self.name, struct_entry.type, self.pos,
cname = self.cname, visibility='ignore')
struct_entry.type.typedef_flag = False
# FIXME: this might be considered a hack ;-)
struct_entry.cname = struct_entry.type.cname = \
'_' + self.entry.type.typedef_cname
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass pass
...@@ -1037,6 +1024,15 @@ class CppClassNode(CStructOrUnionDefNode): ...@@ -1037,6 +1024,15 @@ class CppClassNode(CStructOrUnionDefNode):
# base_classes [string] # base_classes [string]
# templates [string] or None # templates [string] or None
def declare(self, env):
if self.templates is None:
template_types = None
else:
template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates]
self.entry = env.declare_cpp_class(
self.name, None, self.pos,
self.cname, base_classes = [], visibility = self.visibility, templates = template_types)
def analyse_declarations(self, env): def analyse_declarations(self, env):
scope = None scope = None
if self.attributes is not None: if self.attributes is not None:
...@@ -1078,10 +1074,12 @@ class CEnumDefNode(StatNode): ...@@ -1078,10 +1074,12 @@ class CEnumDefNode(StatNode):
child_attrs = ["items"] child_attrs = ["items"]
def declare(self, env):
self.entry = env.declare_enum(self.name, self.pos,
cname = self.cname, typedef_flag = self.typedef_flag,
visibility = self.visibility, api = self.api)
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.entry = env.declare_enum(self.name, self.pos,
cname = self.cname, typedef_flag = self.typedef_flag,
visibility = self.visibility, api = self.api)
if self.items is not None: if self.items is not None:
if self.in_pxd and not env.in_cinclude: if self.in_pxd and not env.in_cinclude:
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
...@@ -3352,18 +3350,42 @@ class CClassDefNode(ClassDefNode): ...@@ -3352,18 +3350,42 @@ class CClassDefNode(ClassDefNode):
decorators = None decorators = None
shadow = False shadow = False
def analyse_declarations(self, env): def declare(self, env):
#print "CClassDefNode.analyse_declarations:", self.class_name if self.module_name and self.visibility != 'extern':
#print "...visibility =", self.visibility module_path = self.module_name.split(".")
#print "...module_name =", self.module_name home_scope = env.find_imported_module(module_path, self.pos)
if not home_scope:
return None
else:
home_scope = env
import Buffer import Buffer
if self.buffer_defaults_node: if self.buffer_defaults_node:
buffer_defaults = Buffer.analyse_buffer_options(self.buffer_defaults_pos, self.buffer_defaults = Buffer.analyse_buffer_options(self.buffer_defaults_pos,
env, [], self.buffer_defaults_node, env, [], self.buffer_defaults_node,
need_complete=False) need_complete=False)
else: else:
buffer_defaults = None self.buffer_defaults = None
self.entry = home_scope.declare_c_class(
name = self.class_name,
pos = self.pos,
defining = 0,
implementing = 0,
module_name = self.module_name,
base_type = None,
objstruct_cname = self.objstruct_name,
typeobj_cname = self.typeobj_name,
visibility = self.visibility,
typedef_flag = self.typedef_flag,
api = self.api,
buffer_defaults = self.buffer_defaults,
shadow = self.shadow)
def analyse_declarations(self, env):
#print "CClassDefNode.analyse_declarations:", self.class_name
#print "...visibility =", self.visibility
#print "...module_name =", self.module_name
if env.in_cinclude and not self.objstruct_name: if env.in_cinclude and not self.objstruct_name:
error(self.pos, "Object struct name specification required for " error(self.pos, "Object struct name specification required for "
...@@ -3441,7 +3463,7 @@ class CClassDefNode(ClassDefNode): ...@@ -3441,7 +3463,7 @@ class CClassDefNode(ClassDefNode):
visibility = self.visibility, visibility = self.visibility,
typedef_flag = self.typedef_flag, typedef_flag = self.typedef_flag,
api = self.api, api = self.api,
buffer_defaults = buffer_defaults, buffer_defaults = self.buffer_defaults,
shadow = self.shadow) shadow = self.shadow)
if self.shadow: if self.shadow:
home_scope.lookup(self.class_name).as_variable = self.entry home_scope.lookup(self.class_name).as_variable = self.entry
......
...@@ -1263,6 +1263,44 @@ class DecoratorTransform(CythonTransform, SkipDeclarations): ...@@ -1263,6 +1263,44 @@ class DecoratorTransform(CythonTransform, SkipDeclarations):
rhs = decorator_result) rhs = decorator_result)
return [node, reassignment] return [node, reassignment]
class ForwardDeclareTypes(CythonTransform):
def visit_CompilerDirectivesNode(self, node):
env = self.module_scope
old = env.directives
env.directives = node.directives
self.visitchildren(node)
env.directives = old
return node
def visit_ModuleNode(self, node):
self.module_scope = node.scope
self.module_scope.directives = node.directives
self.visitchildren(node)
return node
def visit_CDefExternNode(self, node):
old_cinclude_flag = self.module_scope.in_cinclude
self.module_scope.in_cinclude = 1
self.visitchildren(node)
self.module_scope.in_cinclude = old_cinclude_flag
return node
def visit_CEnumDefNode(self, node):
node.declare(self.module_scope)
return node
def visit_CStructOrUnionDefNode(self, node):
if node.name not in self.module_scope.entries:
node.declare(self.module_scope)
return node
def visit_CClassDefNode(self, node):
if node.class_name not in self.module_scope.entries:
node.declare(self.module_scope)
return node
class AnalyseDeclarationsTransform(CythonTransform): class AnalyseDeclarationsTransform(CythonTransform):
......
...@@ -426,8 +426,6 @@ class Scope(object): ...@@ -426,8 +426,6 @@ class Scope(object):
if scope: if scope:
entry.type.scope = scope entry.type.scope = scope
self.type_entries.append(entry) self.type_entries.append(entry)
if not scope and not entry.type.scope:
self.check_for_illegal_incomplete_ctypedef(typedef_flag, pos)
return entry return entry
def declare_cpp_class(self, name, scope, def declare_cpp_class(self, name, scope,
...@@ -453,15 +451,26 @@ class Scope(object): ...@@ -453,15 +451,26 @@ class Scope(object):
if scope: if scope:
entry.type.scope = scope entry.type.scope = scope
self.type_entries.append(entry) self.type_entries.append(entry)
if templates is not None: if base_classes:
if entry.type.base_classes and not entry.type.base_classes == base_classes:
error(pos, "Base type does not match previous declaration")
else:
entry.type.base_classes = base_classes
if templates or entry.type.templates:
if templates != entry.type.templates:
error(pos, "Template parameters do not match previous declaration")
if templates is not None and entry.type.scope is not None:
for T in templates: for T in templates:
template_entry = entry.type.scope.declare(T.name, T.name, T, None, 'extern') template_entry = entry.type.scope.declare(T.name, T.name, T, None, 'extern')
template_entry.is_type = 1 template_entry.is_type = 1
def declare_inherited_attributes(entry, base_classes): def declare_inherited_attributes(entry, base_classes):
for base_class in base_classes: for base_class in base_classes:
declare_inherited_attributes(entry, base_class.base_classes) if base_class.scope is None:
entry.type.scope.declare_inherited_cpp_attributes(base_class.scope) error(pos, "Cannot inherit from incomplete type")
else:
declare_inherited_attributes(entry, base_class.base_classes)
entry.type.scope.declare_inherited_cpp_attributes(base_class.scope)
if entry.type.scope: if entry.type.scope:
declare_inherited_attributes(entry, base_classes) declare_inherited_attributes(entry, base_classes)
if self.is_cpp_class_scope: if self.is_cpp_class_scope:
...@@ -1171,8 +1180,6 @@ class ModuleScope(Scope): ...@@ -1171,8 +1180,6 @@ class ModuleScope(Scope):
scope.declare_inherited_c_attributes(base_type.scope) scope.declare_inherited_c_attributes(base_type.scope)
type.set_scope(scope) type.set_scope(scope)
self.type_entries.append(entry) self.type_entries.append(entry)
else:
self.check_for_illegal_incomplete_ctypedef(typedef_flag, pos)
else: else:
if defining and type.scope.defined: if defining and type.scope.defined:
error(pos, "C class '%s' already defined" % name) error(pos, "C class '%s' already defined" % name)
...@@ -1203,10 +1210,6 @@ class ModuleScope(Scope): ...@@ -1203,10 +1210,6 @@ class ModuleScope(Scope):
# #
return entry return entry
def check_for_illegal_incomplete_ctypedef(self, typedef_flag, pos):
if typedef_flag and not self.in_cinclude:
error(pos, "Forward-referenced type must use 'cdef', not 'ctypedef'")
def allocate_vtable_names(self, entry): def allocate_vtable_names(self, entry):
# If extension type has a vtable, allocate vtable struct and # If extension type has a vtable, allocate vtable struct and
# slot names for it. # slot names for it.
......
# mode: compile
ctypedef enum MyEnum:
Value1
Value2
Value3 = 100
cdef MyEnum my_enum = Value3
ctypedef struct StructA:
StructA *a
StructB *b
cdef struct StructB:
StructA *a
StructB *b
cdef class ClassA:
cdef ClassB b
ctypedef public class ClassB [ object ClassB, type TypeB ]:
cdef ClassA a
cdef StructA struct_a
cdef StructB struct_b
struct_a.a = &struct_a
struct_a.b = &struct_b
struct_b.a = &struct_a
struct_b.b = &struct_b
cdef ClassA class_a = ClassA()
cdef ClassB class_b = ClassB()
class_a.a = class_a
class_a.b = class_b
class_b.a = class_a
class_b.b = class_b
# mode: error
ctypedef struct Spam
cdef extern from *:
ctypedef struct Ham
ctypedef struct Spam:
int i
ctypedef struct Spam
_ERRORS = u"""
3:0: Forward-referenced type must use 'cdef', not 'ctypedef'
"""
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment