Commit a9b0a7b0 authored by Dylan Trotter's avatar Dylan Trotter

Support module member import. Fix #99.

Modules being imported are now detected at compile time. So it's
possible to determine whether "from x import y" is importing module y or
a member y of module x. This change uses this capability to choose
whether to import a module or a member.

This fixes https://github.com/google/grumpy/issues/99
parent 94accdb9
......@@ -672,6 +672,8 @@ class StatementVisitor(algorithm.Visitor):
imp: Import object representing an import of the form "import x.y.z" or
"from x.y import z". Expects only a single binding.
"""
# Acquire handles to the Code objects in each Go package and call
# ImportModule to initialize all modules.
parts = imp.name.split('.')
code_objs = []
for i in xrange(len(parts)):
......@@ -687,12 +689,20 @@ class StatementVisitor(algorithm.Visitor):
self.writer.write_checked_call2(
mod_slice, 'πg.ImportModule(πF, {}, {})',
util.go_str(imp.name), handles_expr)
# This method only handles simple module imports (i.e. not member
# imports) which always have a single binding.
binding = imp.bindings[0]
# Bind the imported modules or members to variables in the current scope.
for binding in imp.bindings:
self.writer.write('{} = {}[{}]'.format(
mod.name, mod_slice.expr, binding.value))
mod.name, mod_slice.expr, imp.name.count('.')))
if binding.bind_type == util.Import.MODULE:
self.block.bind_var(self.writer, binding.alias, mod.expr)
else:
# Binding a member of the imported module.
with self.block.alloc_temp() as member:
self.writer.write_checked_call2(
member, 'πg.GetAttr(πF, {}, {}, nil)',
mod.expr, self.block.root.intern(binding.value))
self.block.bind_var(self.writer, binding.alias, member.expr)
def _import_native(self, name, values):
reflect_package = self.block.root.add_native_import('reflect')
......
......@@ -297,6 +297,11 @@ class StatementVisitorTest(unittest.TestCase):
import sys
print type(sys.modules)""")))
def testImportMember(self):
self.assertEqual((0, "<type 'dict'>\n"), _GrumpRun(textwrap.dedent("""\
from sys import modules
print type(modules)""")))
def testImportConflictingPackage(self):
self.assertEqual((0, ''), _GrumpRun(textwrap.dedent("""\
import time
......@@ -307,7 +312,7 @@ class StatementVisitorTest(unittest.TestCase):
from __go__.time import Nanosecond, Second
print Nanosecond, Second""")))
def testImportGrump(self):
def testImportGrumpy(self):
self.assertEqual((0, ''), _GrumpRun(textwrap.dedent("""\
from __go__.grumpy import Assert
Assert(__frame__(), True, 'bad')""")))
......
......@@ -170,16 +170,22 @@ class ImportVisitor(algorithm.Visitor):
self.imports.append(imp)
return
# NOTE: Assume that the names being imported are all modules within a
# package. E.g. "from a.b import c" is importing the module c from package
# a.b, not some member of module b. We cannot distinguish between these
# two cases at compile time and the Google style guide forbids the latter
# so we support that use case only.
member_imp = None
for alias in node.names:
imp = self._resolve_import(node, '{}.{}'.format(node.module, alias.name))
imp.add_binding(Import.MODULE, alias.asname or alias.name,
imp.name.count('.'))
asname = alias.asname or alias.name
full_name, _ = self.path.resolve_import(
'{}.{}'.format(node.module, alias.name))
if full_name:
# Imported name is a submodule within a package, so bind that module.
imp = Import(full_name)
imp.add_binding(Import.MODULE, asname, imp.name.count('.'))
self.imports.append(imp)
else:
# A member (not a submodule) is being imported, so bind it.
if not member_imp:
member_imp = self._resolve_import(node, node.module)
self.imports.append(member_imp)
member_imp.add_binding(Import.MEMBER, asname, alias.name)
def _resolve_import(self, node, name):
full_name, _ = self.path.resolve_import(name)
......
......@@ -32,7 +32,12 @@ from grumpy.compiler import stmt
class MockPath(object):
def __init__(self, nonexistent_modules=()):
self.nonexistent_modules = nonexistent_modules
def resolve_import(self, modname):
if modname in self.nonexistent_modules:
return None, None
return modname, modname.replace('.', os.sep)
......@@ -150,6 +155,13 @@ class ImportVisitorTest(unittest.TestCase):
imp.add_binding(util.Import.MODULE, 'bar', 1)
self._assert_imports_equal(imp, self._visit_import('from foo import bar'))
def testImportFromMember(self):
imp = util.Import('foo')
imp.add_binding(util.Import.MEMBER, 'bar', 'bar')
path = MockPath(nonexistent_modules=('foo.bar',))
self._assert_imports_equal(
imp, self._visit_import('from foo import bar', path=path))
def testImportFromMultiple(self):
imp1 = util.Import('foo.bar')
imp1.add_binding(util.Import.MODULE, 'bar', 1)
......@@ -158,12 +170,28 @@ class ImportVisitorTest(unittest.TestCase):
self._assert_imports_equal(
[imp1, imp2], self._visit_import('from foo import bar, baz'))
def testImportFromMixedMembers(self):
imp1 = util.Import('foo')
imp1.add_binding(util.Import.MEMBER, 'bar', 'bar')
imp2 = util.Import('foo.baz')
imp2.add_binding(util.Import.MODULE, 'baz', 1)
path = MockPath(nonexistent_modules=('foo.bar',))
self._assert_imports_equal(
[imp1, imp2], self._visit_import('from foo import bar, baz', path=path))
def testImportFromAs(self):
imp = util.Import('foo.bar')
imp.add_binding(util.Import.MODULE, 'baz', 1)
self._assert_imports_equal(
imp, self._visit_import('from foo import bar as baz'))
def testImportFromAsMembers(self):
imp = util.Import('foo')
imp.add_binding(util.Import.MEMBER, 'baz', 'bar')
path = MockPath(nonexistent_modules=('foo.bar',))
self._assert_imports_equal(
imp, self._visit_import('from foo import bar as baz', path=path))
def testImportFromWildcardRaises(self):
self.assertRaises(util.ImportError, self._visit_import, 'from foo import *')
......@@ -190,8 +218,10 @@ class ImportVisitorTest(unittest.TestCase):
self._assert_imports_equal(
imp, self._visit_import('from __go__.fmt import Printf as foo'))
def _visit_import(self, source):
visitor = util.ImportVisitor(MockPath())
def _visit_import(self, source, path=None):
if not path:
path = MockPath()
visitor = util.ImportVisitor(path)
visitor.visit(pythonparser.parse(source).body[0])
return visitor.imports
......
......@@ -46,7 +46,7 @@ def main(args):
if not gopath:
print >> sys.stderr, 'GOPATH not set'
return 1
path = util.Path(gopath, args.script, args.modname)
path = util.Path(gopath, args.modname, args.script)
with open(args.script) as py_file:
py_contents = py_file.read()
......
......@@ -35,7 +35,7 @@ def main(args):
if not gopath:
print >> sys.stderr, 'GOPATH not set'
return 1
path = util.Path(gopath, args.script, args.modname)
path = util.Path(gopath, args.modname, args.script)
with open(args.script) as py_file:
py_contents = py_file.read()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment