Commit 6ec81e81 authored by Dylan Trotter's avatar Dylan Trotter Committed by GitHub

Implement relative imports. (#284)

parent 7be3a9a7
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import collections import collections
import functools
import os import os
from pythonparser import algorithm from pythonparser import algorithm
...@@ -101,10 +102,10 @@ class Importer(algorithm.Visitor): ...@@ -101,10 +102,10 @@ class Importer(algorithm.Visitor):
node.module) node.module)
raise util.ImportError(node, msg) raise util.ImportError(node, msg)
if node.module == '__future__': if not node.level and node.module == '__future__':
return [] return []
if node.module.startswith(_NATIVE_MODULE_PREFIX): if not node.level and node.module.startswith(_NATIVE_MODULE_PREFIX):
imp = Import(node.module[len(_NATIVE_MODULE_PREFIX):], is_native=True) imp = Import(node.module[len(_NATIVE_MODULE_PREFIX):], is_native=True)
for alias in node.names: for alias in node.names:
asname = alias.asname or alias.name asname = alias.asname or alias.name
...@@ -112,16 +113,29 @@ class Importer(algorithm.Visitor): ...@@ -112,16 +113,29 @@ class Importer(algorithm.Visitor):
return [imp] return [imp]
imports = [] imports = []
if not node.module:
# Import of the form 'from .. import foo, bar'. All named imports must be
# modules, not module members.
for alias in node.names:
imp = self._resolve_relative_import(node.level, node, alias.name)
imp.add_binding(Import.MODULE, alias.asname or alias.name,
imp.name.count('.'))
imports.append(imp)
return imports
member_imp = None member_imp = None
for alias in node.names: for alias in node.names:
asname = alias.asname or alias.name asname = alias.asname or alias.name
if node.level:
resolver = functools.partial(self._resolve_relative_import, node.level)
else:
resolver = self._resolve_import
try: try:
imp = self._resolve_import( imp = resolver(node, '{}.{}'.format(node.module, alias.name))
node, '{}.{}'.format(node.module, alias.name))
except util.ImportError: except util.ImportError:
# A member (not a submodule) is being imported, so bind it. # A member (not a submodule) is being imported, so bind it.
if not member_imp: if not member_imp:
member_imp = self._resolve_import(node, node.module) member_imp = resolver(node, node.module)
imports.append(member_imp) imports.append(member_imp)
member_imp.add_binding(Import.MEMBER, asname, alias.name) member_imp.add_binding(Import.MEMBER, asname, alias.name)
else: else:
...@@ -139,6 +153,20 @@ class Importer(algorithm.Visitor): ...@@ -139,6 +153,20 @@ class Importer(algorithm.Visitor):
return Import(modname) return Import(modname)
raise util.ImportError(node, 'no such module: {}'.format(modname)) raise util.ImportError(node, 'no such module: {}'.format(modname))
def _resolve_relative_import(self, level, node, modname):
if not self.package_dir:
raise util.ImportError(node, 'attempted relative import in non-package')
uplevel = level - 1
if uplevel > self.package_name.count('.'):
raise util.ImportError(
node, 'attempted relative import beyond toplevel package')
dirname = os.path.normpath(os.path.join(
self.package_dir, *(['..'] * uplevel)))
if not self._script_exists(dirname, modname):
raise util.ImportError(node, 'no such module: {}'.format(modname))
parts = self.package_name.split('.')
return Import('.'.join(parts[:len(parts)-uplevel]) + '.' + modname)
def _script_exists(self, dirname, name): def _script_exists(self, dirname, name):
prefix = os.path.join(dirname, name.replace('.', os.sep)) prefix = os.path.join(dirname, name.replace('.', os.sep))
return (os.path.isfile(prefix + '.py') or return (os.path.isfile(prefix + '.py') or
......
...@@ -36,6 +36,10 @@ class ImportVisitorTest(unittest.TestCase): ...@@ -36,6 +36,10 @@ class ImportVisitorTest(unittest.TestCase):
'foo.py': None, 'foo.py': None,
'qux.py': None, 'qux.py': None,
'bar/': { 'bar/': {
'fred/': {
'__init__.py': None,
'quux.py': None,
},
'__init__.py': None, '__init__.py': None,
'baz.py': None, 'baz.py': None,
'foo.py': None, 'foo.py': None,
...@@ -50,6 +54,12 @@ class ImportVisitorTest(unittest.TestCase): ...@@ -50,6 +54,12 @@ class ImportVisitorTest(unittest.TestCase):
self.rootdir, {'src/': {'__python__/': self._PATH_SPEC}}) self.rootdir, {'src/': {'__python__/': self._PATH_SPEC}})
foo_script = os.path.join(self.rootdir, 'foo.py') foo_script = os.path.join(self.rootdir, 'foo.py')
self.importer = imputil.Importer(self.rootdir, 'foo', foo_script, False) self.importer = imputil.Importer(self.rootdir, 'foo', foo_script, False)
bar_script = os.path.join(self.pydir, 'bar', '__init__.py')
self.bar_importer = imputil.Importer(
self.rootdir, 'bar', bar_script, False)
fred_script = os.path.join(self.pydir, 'bar', 'fred', '__init__.py')
self.fred_importer = imputil.Importer(
self.rootdir, 'bar.fred', fred_script, False)
def tearDown(self): def tearDown(self):
shutil.rmtree(self.rootdir) shutil.rmtree(self.rootdir)
...@@ -82,9 +92,7 @@ class ImportVisitorTest(unittest.TestCase): ...@@ -82,9 +92,7 @@ class ImportVisitorTest(unittest.TestCase):
def testImportPackageModuleRelative(self): def testImportPackageModuleRelative(self):
imp = imputil.Import('bar.baz') imp = imputil.Import('bar.baz')
imp.add_binding(imputil.Import.MODULE, 'baz', 1) imp.add_binding(imputil.Import.MODULE, 'baz', 1)
bar_script = os.path.join(self.pydir, 'bar', '__init__.py') got = self.bar_importer.visit(pythonparser.parse('import baz').body[0])
importer = imputil.Importer(self.rootdir, 'bar', bar_script, False)
got = importer.visit(pythonparser.parse('import baz').body[0])
self._assert_imports_equal([imp], got) self._assert_imports_equal([imp], got)
def testImportPackageModuleRelativeFromSubModule(self): def testImportPackageModuleRelativeFromSubModule(self):
...@@ -176,6 +184,58 @@ class ImportVisitorTest(unittest.TestCase): ...@@ -176,6 +184,58 @@ class ImportVisitorTest(unittest.TestCase):
imp.add_binding(imputil.Import.MEMBER, 'foo', 'Printf') imp.add_binding(imputil.Import.MEMBER, 'foo', 'Printf')
self._check_imports('from __go__.fmt import Printf as foo', [imp]) self._check_imports('from __go__.fmt import Printf as foo', [imp])
def testRelativeImportNonPackage(self):
self.assertRaises(util.ImportError, self.importer.visit,
pythonparser.parse('from . import bar').body[0])
def testRelativeImportBeyondTopLevel(self):
self.assertRaises(util.ImportError, self.bar_importer.visit,
pythonparser.parse('from .. import qux').body[0])
def testRelativeModuleNoExist(self):
self.assertRaises(util.ImportError, self.bar_importer.visit,
pythonparser.parse('from . import qux').body[0])
def testRelativeModule(self):
imp = imputil.Import('bar.foo')
imp.add_binding(imputil.Import.MODULE, 'foo', 1)
node = pythonparser.parse('from . import foo').body[0]
self._assert_imports_equal([imp], self.bar_importer.visit(node))
def testRelativeModuleFromSubModule(self):
imp = imputil.Import('bar.foo')
imp.add_binding(imputil.Import.MODULE, 'foo', 1)
baz_script = os.path.join(self.pydir, 'bar', 'baz.py')
importer = imputil.Importer(self.rootdir, 'bar.baz', baz_script, False)
node = pythonparser.parse('from . import foo').body[0]
self._assert_imports_equal([imp], importer.visit(node))
def testRelativeModuleMember(self):
imp = imputil.Import('bar.foo')
imp.add_binding(imputil.Import.MEMBER, 'qux', 'qux')
node = pythonparser.parse('from .foo import qux').body[0]
self._assert_imports_equal([imp], self.bar_importer.visit(node))
def testRelativeModuleMemberMixed(self):
imp1 = imputil.Import('bar.fred')
imp1.add_binding(imputil.Import.MEMBER, 'qux', 'qux')
imp2 = imputil.Import('bar.fred.quux')
imp2.add_binding(imputil.Import.MODULE, 'quux', 2)
node = pythonparser.parse('from .fred import qux, quux').body[0]
self._assert_imports_equal([imp1, imp2], self.bar_importer.visit(node))
def testRelativeUpLevel(self):
imp = imputil.Import('bar.foo')
imp.add_binding(imputil.Import.MODULE, 'foo', 1)
node = pythonparser.parse('from .. import foo').body[0]
self._assert_imports_equal([imp], self.fred_importer.visit(node))
def testRelativeUpLevelMember(self):
imp = imputil.Import('bar.foo')
imp.add_binding(imputil.Import.MEMBER, 'qux', 'qux')
node = pythonparser.parse('from ..foo import qux').body[0]
self._assert_imports_equal([imp], self.fred_importer.visit(node))
def _check_imports(self, stmt, want): def _check_imports(self, stmt, want):
got = self.importer.visit(pythonparser.parse(stmt).body[0]) got = self.importer.visit(pythonparser.parse(stmt).body[0])
self._assert_imports_equal(want, got) self._assert_imports_equal(want, got)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment