Commit 9bac9c2e authored by Robert Bradshaw's avatar Robert Bradshaw

Finish porting circular import patch from Pyrex.

parent 5f13fdb2
......@@ -3656,8 +3656,8 @@ class CImportStatNode(StatNode):
class FromCImportStatNode(StatNode):
# from ... cimport statement
#
# module_name string Qualified name of module
# imported_names [(pos, name, as_name)] Names to be imported
# module_name string Qualified name of module
# imported_names [(pos, name, as_name, kind)] Names to be imported
child_attrs = []
......@@ -3667,15 +3667,43 @@ class FromCImportStatNode(StatNode):
return
module_scope = env.find_module(self.module_name, self.pos)
env.add_imported_module(module_scope)
for pos, name, as_name in self.imported_names:
for pos, name, as_name, kind in self.imported_names:
if name == "*":
for local_name, entry in module_scope.entries.items():
env.add_imported_entry(local_name, entry, pos)
else:
entry = module_scope.find(name, pos)
entry = module_scope.lookup(name)
if entry:
if kind and not self.declaration_matches(entry, kind):
entry.redeclared(pos)
else:
if kind == 'struct' or kind == 'union':
entry = module_scope.declare_struct_or_union(name,
kind = kind, scope = None, typedef_flag = 0, pos = pos)
elif kind == 'class':
entry = module_scope.declare_c_class(name, pos = pos,
module_name = self.module_name)
else:
error(pos, "Name '%s' not declared in module '%s'"
% (name, self.module_name))
if entry:
local_name = as_name or name
env.add_imported_entry(local_name, entry, pos)
def declaration_matches(self, entry, kind):
if not entry.is_type:
return 0
type = entry.type
if kind == 'class':
if not type.is_extension_type:
return 0
else:
if not type.is_struct_or_union:
return 0
if kind <> type.kind:
return 0
return 1
def analyse_expressions(self, env):
pass
......
......@@ -959,21 +959,21 @@ def p_from_import_statement(s, first_statement = 0):
s.next()
else:
s.error("Expected 'import' or 'cimport'")
is_cimport = kind == 'cimport'
if s.sy == '*':
# s.error("'import *' not supported")
imported_names = [(s.position(), "*", None)]
imported_names = [(s.position(), "*", None, None)]
s.next()
else:
imported_names = [p_imported_name(s)]
imported_names = [p_imported_name(s, is_cimport)]
while s.sy == ',':
s.next()
imported_names.append(p_imported_name(s))
imported_names.append(p_imported_name(s, is_cimport))
dotted_name = Utils.EncodedString(dotted_name)
if dotted_name == '__future__':
if not first_statement:
s.error("from __future__ imports must occur at the beginning of the file")
else:
for (name_pos, name, as_name) in imported_names:
for (name_pos, name, as_name, kind) in imported_names:
try:
directive = getattr(Future, name)
except AttributeError:
......@@ -982,7 +982,7 @@ def p_from_import_statement(s, first_statement = 0):
s.context.future_directives.add(directive)
return Nodes.PassStatNode(pos)
elif kind == 'cimport':
for (name_pos, name, as_name) in imported_names:
for (name_pos, name, as_name, kind) in imported_names:
local_name = as_name or name
s.add_type_name(local_name)
return Nodes.FromCImportStatNode(pos,
......@@ -991,7 +991,7 @@ def p_from_import_statement(s, first_statement = 0):
else:
imported_name_strings = []
items = []
for (name_pos, name, as_name) in imported_names:
for (name_pos, name, as_name, kind) in imported_names:
encoded_name = Utils.EncodedString(name)
imported_name_strings.append(
ExprNodes.IdentifierStringNode(name_pos, value = encoded_name))
......@@ -1008,11 +1008,17 @@ def p_from_import_statement(s, first_statement = 0):
name_list = import_list),
items = items)
def p_imported_name(s):
imported_name_kinds = ('class', 'struct', 'union')
def p_imported_name(s, is_cimport):
pos = s.position()
kind = None
if is_cimport and s.systring in imported_name_kinds:
kind = s.systring
s.next()
name = p_ident(s)
as_name = p_as_name(s)
return (pos, name, as_name)
return (pos, name, as_name, kind)
def p_dotted_name(s, as_allowed):
pos = s.position()
......
......@@ -123,6 +123,9 @@ class Entry:
self.pos = pos
self.init = init
def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name)
error(self.pos, "Previous declaration is here")
class Scope:
# name string Unqualified name
......@@ -310,7 +313,8 @@ class Scope:
visibility = visibility, defining = scope is not None)
self.sue_entries.append(entry)
else:
if not (entry.is_type and entry.type.is_struct_or_union):
if not (entry.is_type and entry.type.is_struct_or_union
and entry.type.kind == kind):
warning(pos, "'%s' redeclared " % name, 0)
elif scope and entry.type.scope:
warning(pos, "'%s' already defined (ignoring second definition)" % name, 0)
......@@ -902,9 +906,9 @@ class ModuleScope(Scope):
return
self.utility_code_used.append(new_code)
def declare_c_class(self, name, pos, defining, implementing,
module_name, base_type, objstruct_cname, typeobj_cname,
visibility, typedef_flag, api):
def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None,
typeobj_cname = None, visibility = 'private', typedef_flag = 0, api = 0):
#
# Look for previous declaration as a type
#
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment