Commit 3b9f95ef authored by Stefan Behnel's avatar Stefan Behnel

implement relative cimports and fix some general issues with relative imports

parent 19fb7521
......@@ -123,7 +123,7 @@ class Context(object):
pos = (module_name, 0, 0)
raise CompileError(pos,
"'%s' is not a valid module name" % module_name)
if "." not in module_name and relative_to:
if relative_to:
if debug_find_module:
print("...trying relative import")
scope = relative_to.lookup_submodule(module_name)
......@@ -139,7 +139,7 @@ class Context(object):
for name in module_name.split("."):
scope = scope.find_submodule(name)
if debug_find_module:
print("...scope =", scope)
print("...scope = %s" % scope)
if not scope.pxd_file_loaded:
if debug_find_module:
print("...pxd not loaded")
......@@ -149,7 +149,7 @@ class Context(object):
print("...looking for pxd file")
pxd_pathname = self.find_pxd_file(module_name, pos)
if debug_find_module:
print("......found ", pxd_pathname)
print("......found %s" % pxd_pathname)
if not pxd_pathname and need_pxd:
package_pathname = self.search_include_directories(module_name, ".py", pos)
if package_pathname and package_pathname.endswith('__init__.py'):
......@@ -162,7 +162,7 @@ class Context(object):
print("Context.find_module: Parsing %s" % pxd_pathname)
rel_path = module_name.replace('.', os.sep) + os.path.splitext(pxd_pathname)[1]
if not pxd_pathname.endswith(rel_path):
rel_path = pxd_pathname # safety measure to prevent printing incorrect paths
rel_path = pxd_pathname # safety measure to prevent printing incorrect paths
source_desc = FileSourceDescriptor(pxd_pathname, rel_path)
err, result = self.process_pxd(source_desc, scope, module_name)
if err:
......
......@@ -6864,15 +6864,22 @@ class FromCImportStatNode(StatNode):
# from ... cimport statement
#
# module_name string Qualified name of module
# relative_level int or None Relative import: number of dots before module_name
# imported_names [(pos, name, as_name, kind)] Names to be imported
child_attrs = []
module_name = None
relative_level = None
imported_names = None
def analyse_declarations(self, env):
if not env.is_module_scope:
error(self.pos, "cimport only allowed at module level")
return
module_scope = env.find_module(self.module_name, self.pos)
if self.relative_level and self.relative_level > env.qualified_name.count('.'):
error(self.pos, "relative cimport beyond main package is not allowed")
module_scope = env.find_module(self.module_name, self.pos, relative_level=self.relative_level)
module_name = module_scope.qualified_name
env.add_imported_module(module_scope)
for pos, name, as_name, kind in self.imported_names:
if name == "*":
......@@ -6886,29 +6893,27 @@ class FromCImportStatNode(StatNode):
entry.used = 1
else:
if kind == 'struct' or kind == 'union':
entry = module_scope.declare_struct_or_union(name,
kind = kind, scope = None, typedef_flag = 0, pos = pos)
entry = module_scope.declare_struct_or_union(
name, kind=kind, scope=None, typedef_flag=0, pos=pos)
elif kind == 'class':
entry = module_scope.declare_c_class(name, pos = pos,
module_name = self.module_name)
entry = module_scope.declare_c_class(name, pos=pos, module_name=module_name)
else:
submodule_scope = env.context.find_module(name, relative_to = module_scope, pos = self.pos)
submodule_scope = env.context.find_module(name, relative_to=module_scope, pos=self.pos)
if submodule_scope.parent_module is module_scope:
env.declare_module(as_name or name, submodule_scope, self.pos)
else:
error(pos, "Name '%s' not declared in module '%s'"
% (name, self.module_name))
error(pos, "Name '%s' not declared in module '%s'" % (name, module_name))
if entry:
local_name = as_name or name
env.add_imported_entry(local_name, entry, pos)
if self.module_name.startswith('cpython'): # enough for now
if self.module_name in utility_code_for_cimports:
if module_name.startswith('cpython'): # enough for now
if module_name in utility_code_for_cimports:
env.use_utility_code(UtilityCode.load_cached(
*utility_code_for_cimports[self.module_name]))
*utility_code_for_cimports[module_name]))
for _, name, _, _ in self.imported_names:
fqname = '%s.%s' % (self.module_name, name)
fqname = '%s.%s' % (module_name, name)
if fqname in utility_code_for_cimports:
env.use_utility_code(UtilityCode.load_cached(
*utility_code_for_cimports[fqname]))
......@@ -6969,7 +6974,7 @@ class FromImportStatNode(StatNode):
env.use_utility_code(UtilityCode.load_cached("ExtTypeTest", "ObjectHandling.c"))
break
else:
entry = env.lookup(target.name)
entry = env.lookup(target.name)
# check whether or not entry is already cimported
if (entry.is_type and entry.type.name == name
and hasattr(entry.type, 'module_name')):
......@@ -6978,8 +6983,8 @@ class FromImportStatNode(StatNode):
continue
try:
# cimported with relative name
module = env.find_module(self.module.module_name.value,
pos=None)
module = env.find_module(self.module.module_name.value, pos=self.pos,
relative_level=self.module.level)
if entry.type.module_name == module.qualified_name:
continue
except AttributeError:
......
......@@ -762,8 +762,8 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return node
def visit_FromCImportStatNode(self, node):
if (node.module_name == u"cython") or \
node.module_name.startswith(u"cython."):
if not node.relative_level and (
node.module_name == u"cython" or node.module_name.startswith(u"cython.")):
submodule = (node.module_name + u".")[7:]
newimp = []
......
......@@ -1319,24 +1319,19 @@ def p_from_import_statement(s, first_statement = 0):
while s.sy == '.':
level += 1
s.next()
if s.sy == 'cimport':
s.error("Relative cimport is not supported yet")
else:
level = None
if level is not None and s.sy == 'import':
if level is not None and s.sy in ('import', 'cimport'):
# we are dealing with "from .. import foo, bar"
dotted_name_pos, dotted_name = s.position(), ''
elif level is not None and s.sy == 'cimport':
# "from .. cimport"
s.error("Relative cimport is not supported yet")
else:
(dotted_name_pos, _, dotted_name, _) = \
p_dotted_name(s, as_allowed = 0)
if s.sy in ('import', 'cimport'):
kind = s.sy
s.next()
else:
if level is None and Future.absolute_import in s.context.future_directives:
level = 0
(dotted_name_pos, _, dotted_name, _) = p_dotted_name(s, as_allowed=False)
if s.sy not in ('import', 'cimport'):
s.error("Expected 'import' or 'cimport'")
kind = s.sy
s.next()
is_cimport = kind == 'cimport'
is_parenthesized = False
......@@ -1359,7 +1354,7 @@ def p_from_import_statement(s, first_statement = 0):
if dotted_name == '__future__':
if not first_statement:
s.error("from __future__ imports must occur at the beginning of the file")
elif level is not None:
elif level:
s.error("invalid syntax")
else:
for (name_pos, name, as_name, kind) in imported_names:
......@@ -1374,9 +1369,10 @@ def p_from_import_statement(s, first_statement = 0):
s.context.future_directives.add(directive)
return Nodes.PassStatNode(pos)
elif kind == 'cimport':
return Nodes.FromCImportStatNode(pos,
module_name = dotted_name,
imported_names = imported_names)
return Nodes.FromCImportStatNode(
pos, module_name=dotted_name,
relative_level=level,
imported_names=imported_names)
else:
imported_name_strings = []
items = []
......
......@@ -1081,13 +1081,19 @@ class ModuleScope(Scope):
entry.name = name
return entry
def find_module(self, module_name, pos):
def find_module(self, module_name, pos, relative_level=-1):
# Find a module in the import namespace, interpreting
# relative imports relative to this module's parent.
# Finds and parses the module's .pxd file if the module
# has not been referenced before.
return self.global_scope().context.find_module(
module_name, relative_to = self.parent_module, pos = pos)
module_scope = self.global_scope()
if relative_level is not None and relative_level > 0:
# merge current absolute module name and relative import name into qualified name
current_module = module_scope.qualified_name.split('.')
base_package = current_module[:-relative_level]
module_name = '.'.join(base_package + module_name.split('.'))
return module_scope.context.find_module(
module_name, relative_to=None if relative_level == 0 else self.parent_module, pos=pos)
def find_submodule(self, name):
# Find and return scope for a submodule of this module,
......
# mode: error
# tag: cimport
from ..relative_cimport cimport some_name
from ..cython cimport declare
_ERRORS="""
4:0: 'relative_cimport.pxd' not found
4:0: 'some_name.pxd' not found
4:0: relative cimport beyond main package is not allowed
4:32: Name 'some_name' not declared in module 'relative_cimport'
6:0: 'declare.pxd' not found
6:0: relative cimport beyond main package is not allowed
6:22: Name 'declare' not declared in module 'cython'
"""
# mode: run
# tag: cimport
PYTHON setup.py build_ext --inplace
PYTHON -c "from pkg.b import test; assert test() == (1, 2)"
PYTHON -c "from pkg.sub.c import test; assert test() == (1, 2)"
######## setup.py ########
from distutils.core import setup
from Cython.Build import cythonize
from Cython.Distutils.extension import Extension
setup(
ext_modules=cythonize('**/*.pyx'),
)
######## pkg/__init__.py ########
######## pkg/sub/__init__.py ########
######## pkg/a.pyx ########
cdef class test_pxd:
pass
######## pkg/a.pxd ########
cdef class test_pxd:
cdef public int x
cdef public int y
######## pkg/b.pyx ########
from .a cimport test_pxd
def test():
cdef test_pxd obj = test_pxd()
obj.x = 1
obj.y = 2
return (obj.x, obj.y)
######## pkg/sub/c.pyx ########
from ..a cimport test_pxd
def test():
cdef test_pxd obj = test_pxd()
obj.x = 1
obj.y = 2
return (obj.x, obj.y)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment