Commit e3f5343f authored by Stefan Behnel's avatar Stefan Behnel

prevent fallback to absolute cimport when relative cimport is not found;...

prevent fallback to absolute cimport when relative cimport is not found; generally clean up relative cimport code
parent cdd572ef
...@@ -110,7 +110,8 @@ class Context(object): ...@@ -110,7 +110,8 @@ class Context(object):
def nonfatal_error(self, exc): def nonfatal_error(self, exc):
return Errors.report_error(exc) return Errors.report_error(exc)
def find_module(self, module_name, relative_to=None, pos=None, need_pxd=1, check_module_name=True): def find_module(self, module_name, relative_to=None, pos=None, need_pxd=1,
check_module_name=True, absolute_fallback=True):
# Finds and returns the module scope corresponding to # Finds and returns the module scope corresponding to
# the given relative or absolute module name. If this # the given relative or absolute module name. If this
# is the first time the module has been requested, finds # is the first time the module has been requested, finds
...@@ -125,25 +126,39 @@ class Context(object): ...@@ -125,25 +126,39 @@ class Context(object):
scope = None scope = None
pxd_pathname = None pxd_pathname = None
if check_module_name and not module_name_pattern.match(module_name): if relative_to:
if pos is None: if module_name:
pos = (module_name, 0, 0) # from .module import ...
raise CompileError(pos, "'%s' is not a valid module name" % module_name) qualified_name = relative_to.qualify_name(module_name)
else:
# from . import ...
qualified_name = relative_to.qualified_name
scope = relative_to
relative_to = None
else:
qualified_name = module_name
if check_module_name and not module_name_pattern.match(qualified_name):
raise CompileError(pos or (module_name, 0, 0),
"'%s' is not a valid module name" % module_name)
if relative_to: if relative_to:
if debug_find_module: if debug_find_module:
print("...trying relative import") print("...trying relative import")
scope = relative_to.lookup_submodule(module_name) scope = relative_to.lookup_submodule(module_name)
if not scope: if not scope:
qualified_name = relative_to.qualify_name(module_name)
pxd_pathname = self.find_pxd_file(qualified_name, pos) pxd_pathname = self.find_pxd_file(qualified_name, pos)
if pxd_pathname: if pxd_pathname:
scope = relative_to.find_submodule(module_name) scope = relative_to.find_submodule(module_name)
if not scope: if not scope:
if debug_find_module: if debug_find_module:
print("...trying absolute import") print("...trying absolute import")
if absolute_fallback:
qualified_name = module_name
scope = self scope = self
for name in module_name.split("."): for name in qualified_name.split("."):
scope = scope.find_submodule(name) scope = scope.find_submodule(name)
if debug_find_module: if debug_find_module:
print("...scope = %s" % scope) print("...scope = %s" % scope)
if not scope.pxd_file_loaded: if not scope.pxd_file_loaded:
...@@ -153,15 +168,15 @@ class Context(object): ...@@ -153,15 +168,15 @@ class Context(object):
if not pxd_pathname: if not pxd_pathname:
if debug_find_module: if debug_find_module:
print("...looking for pxd file") print("...looking for pxd file")
pxd_pathname = self.find_pxd_file(module_name, pos) pxd_pathname = self.find_pxd_file(qualified_name, pos)
if debug_find_module: if debug_find_module:
print("......found %s" % pxd_pathname) print("......found %s" % pxd_pathname)
if not pxd_pathname and need_pxd: if not pxd_pathname and need_pxd:
package_pathname = self.search_include_directories(module_name, ".py", pos) package_pathname = self.search_include_directories(qualified_name, ".py", pos)
if package_pathname and package_pathname.endswith('__init__.py'): if package_pathname and package_pathname.endswith('__init__.py'):
pass pass
else: else:
error(pos, "'%s.pxd' not found" % module_name.replace('.', os.sep)) error(pos, "'%s.pxd' not found" % qualified_name.replace('.', os.sep))
if pxd_pathname: if pxd_pathname:
try: try:
if debug_find_module: if debug_find_module:
...@@ -170,7 +185,7 @@ class Context(object): ...@@ -170,7 +185,7 @@ class Context(object):
if not pxd_pathname.endswith(rel_path): if not pxd_pathname.endswith(rel_path):
rel_path = pxd_pathname # safety measure to prevent printing incorrect paths rel_path = pxd_pathname # safety measure to prevent printing incorrect paths
source_desc = FileSourceDescriptor(pxd_pathname, rel_path) source_desc = FileSourceDescriptor(pxd_pathname, rel_path)
err, result = self.process_pxd(source_desc, scope, module_name) err, result = self.process_pxd(source_desc, scope, qualified_name)
if err: if err:
raise err raise err
(pxd_codenodes, pxd_scope) = result (pxd_codenodes, pxd_scope) = result
......
...@@ -7224,6 +7224,7 @@ class FromCImportStatNode(StatNode): ...@@ -7224,6 +7224,7 @@ class FromCImportStatNode(StatNode):
return return
if self.relative_level and self.relative_level > env.qualified_name.count('.'): if self.relative_level and self.relative_level > env.qualified_name.count('.'):
error(self.pos, "relative cimport beyond main package is not allowed") error(self.pos, "relative cimport beyond main package is not allowed")
return
module_scope = env.find_module(self.module_name, self.pos, relative_level=self.relative_level) module_scope = env.find_module(self.module_name, self.pos, relative_level=self.relative_level)
module_name = module_scope.qualified_name module_name = module_scope.qualified_name
env.add_imported_module(module_scope) env.add_imported_module(module_scope)
...@@ -7244,7 +7245,8 @@ class FromCImportStatNode(StatNode): ...@@ -7244,7 +7245,8 @@ class FromCImportStatNode(StatNode):
elif kind == 'class': elif kind == 'class':
entry = module_scope.declare_c_class(name, pos=pos, module_name=module_name) entry = module_scope.declare_c_class(name, pos=pos, module_name=module_name)
else: else:
submodule_scope = env.context.find_module(name, relative_to=module_scope, pos=self.pos) submodule_scope = env.context.find_module(
name, relative_to=module_scope, pos=self.pos, absolute_fallback=False)
if submodule_scope.parent_module is module_scope: if submodule_scope.parent_module is module_scope:
env.declare_module(as_name or name, submodule_scope, self.pos) env.declare_module(as_name or name, submodule_scope, self.pos)
else: else:
......
...@@ -1127,14 +1127,20 @@ class ModuleScope(Scope): ...@@ -1127,14 +1127,20 @@ class ModuleScope(Scope):
# relative imports relative to this module's parent. # relative imports relative to this module's parent.
# Finds and parses the module's .pxd file if the module # Finds and parses the module's .pxd file if the module
# has not been referenced before. # has not been referenced before.
module_scope = self.global_scope() relative_to = None
if relative_level is not None and relative_level > 0: if relative_level is not None and relative_level > 0:
# merge current absolute module name and relative import name into qualified name absolute_fallback = False
current_module = module_scope.qualified_name.split('.') # error of going beyond top-level is handled in cimport node
base_package = current_module[:-relative_level] relative_to = self
module_name = '.'.join(base_package + (module_name.split('.') if module_name else [])) while relative_level > 0 and relative_to:
relative_to = relative_to.parent_module
relative_level -= 1
else:
absolute_fallback = relative_level != 0 # might be None!
module_scope = self.global_scope()
return module_scope.context.find_module( return module_scope.context.find_module(
module_name, relative_to=None if relative_level == 0 else self.parent_module, pos=pos) module_name, relative_to=relative_to, pos=pos, absolute_fallback=absolute_fallback)
def find_submodule(self, name): def find_submodule(self, name):
# Find and return scope for a submodule of this module, # Find and return scope for a submodule of this module,
......
...@@ -34,7 +34,9 @@ class NonManglingModuleScope(Symtab.ModuleScope): ...@@ -34,7 +34,9 @@ class NonManglingModuleScope(Symtab.ModuleScope):
class CythonUtilityCodeContext(StringParseContext): class CythonUtilityCodeContext(StringParseContext):
scope = None scope = None
def find_module(self, module_name, relative_to=None, pos=None, need_pxd=True): def find_module(self, module_name, relative_to=None, pos=None, need_pxd=True, absolute_fallback=True):
if relative_to:
raise AssertionError("Relative imports not supported in utility code.")
if module_name != self.module_name: if module_name != self.module_name:
if module_name not in self.modules: if module_name not in self.modules:
raise AssertionError("Only the cython cimport is supported.") raise AssertionError("Only the cython cimport is supported.")
......
...@@ -2,16 +2,14 @@ ...@@ -2,16 +2,14 @@
# tag: cimport # tag: cimport
from ..relative_cimport cimport some_name from ..relative_cimport cimport some_name
from .e_relative_cimport cimport some_name
from ..cython cimport declare from ..cython cimport declare
from . cimport e_relative_cimport
_ERRORS=""" _ERRORS="""
4:0: 'relative_cimport.pxd' not found
4:0: 'some_name.pxd' not found
4:0: relative cimport beyond main package is not allowed 4:0: relative cimport beyond main package is not allowed
4:32: Name 'some_name' not declared in module 'relative_cimport' 5:0: relative cimport beyond main package is not allowed
6:0: 'declare.pxd' not found
6:0: relative cimport beyond main package is not allowed 6:0: relative cimport beyond main package is not allowed
6:22: Name 'declare' not declared in module 'cython' 7:0: relative cimport beyond main package is not allowed
""" """
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment