Commit b9a0c8a9 authored by Dylan Trotter's avatar Dylan Trotter

Move import logic into util for reuse

Import logic is currently embedded in the StatementVisitor class which
makes it difficult to reuse. This change factors that logic out into an
ImportVisitor class in util.py. The newly exposed logic can now be used
by other tools.
parent de7e668d
......@@ -30,7 +30,6 @@ from grumpy.compiler import expr_visitor
from grumpy.compiler import util
_NATIVE_MODULE_PREFIX = '__go__.'
_NATIVE_TYPE_PREFIX = 'type_'
# Partial list of known vcs for go module import
......@@ -371,58 +370,39 @@ class StatementVisitor(algorithm.Visitor):
def visit_Import(self, node):
self._write_py_context(node.lineno)
for alias in node.names:
if alias.name.startswith(_NATIVE_MODULE_PREFIX):
raise util.ParseError(
node, 'for native imports use "from __go__.xyz import ..." syntax')
with self._import(alias.name, 0) as mod:
asname = alias.asname or alias.name.split('.')[0]
self.block.bind_var(self.writer, asname, mod.expr)
for imp in util.ImportVisitor().visit(node):
self._import_and_bind(imp)
def visit_ImportFrom(self, node):
# Wildcard imports are not yet supported.
for alias in node.names:
if alias.name == '*':
msg = 'wildcard member import is not implemented: from %s import %s' % (
node.module, alias.name)
raise util.ParseError(node, msg)
self._write_py_context(node.lineno)
if node.module.startswith(_NATIVE_MODULE_PREFIX):
values = [alias.name for alias in node.names]
with self._import_native(node.module, values) as mod:
for alias in node.names:
# Strip the 'type_' prefix when populating the module. This means
# that, e.g. 'from __go__.foo import type_Bar' will populate foo with
# a member called Bar, not type_Bar (although the symbol in the
# importing module will still be type_Bar unless aliased). This bends
# the semantics of import but makes native module contents more
# sensible.
name = alias.name
if name.startswith(_NATIVE_TYPE_PREFIX):
name = name[len(_NATIVE_TYPE_PREFIX):]
with self.block.alloc_temp() as member:
self.writer.write_checked_call2(
member, 'πg.GetAttr(πF, {}, {}, nil)',
mod.expr, self.block.root.intern(name))
self.block.bind_var(
self.writer, alias.asname or alias.name, member.expr)
elif node.module == '__future__':
# At this stage all future imports are done in an initial pass (see
# visit() above), so if they are encountered here after the last valid
# __future__ then it's a syntax error.
if node.lineno > self.future_features.future_lineno:
raise util.ParseError(node, late_future)
else:
# NOTE: Assume that the names being imported are all modules within a
# package. E.g. "from a.b import c" is importing the module c from package
# a.b, not some member of module b. We cannot distinguish between these
# two cases at compile time and the Google style guide forbids the latter
# so we support that use case only.
for alias in node.names:
name = '{}.{}'.format(node.module, alias.name)
with self._import(name, name.count('.')) as mod:
asname = alias.asname or alias.name
self.block.bind_var(self.writer, asname, mod.expr)
for imp in util.ImportVisitor().visit(node):
if imp.is_native:
values = [b.value for b in imp.bindings]
with self._import_native(imp.name, values) as mod:
for binding in imp.bindings:
# Strip the 'type_' prefix when populating the module. This means
# that, e.g. 'from __go__.foo import type_Bar' will populate foo
# with a member called Bar, not type_Bar (although the symbol in
# the importing module will still be type_Bar unless aliased). This
# bends the semantics of import but makes native module contents
# more sensible.
name = binding.value
if name.startswith(_NATIVE_TYPE_PREFIX):
name = name[len(_NATIVE_TYPE_PREFIX):]
with self.block.alloc_temp() as member:
self.writer.write_checked_call2(
member, 'πg.GetAttr(πF, {}, {}, nil)',
mod.expr, self.block.root.intern(name))
self.block.bind_var(
self.writer, binding.alias, member.expr)
elif node.module == '__future__':
# At this stage all future imports are done in an initial pass (see
# visit() above), so if they are encountered here after the last valid
# __future__ then it's a syntax error.
if node.lineno > self.future_features.future_lineno:
raise util.ImportError(node, late_future)
else:
self._import_and_bind(imp)
def visit_Module(self, node):
self._visit_each(node.body)
......@@ -681,18 +661,14 @@ class StatementVisitor(algorithm.Visitor):
tmpl = 'πg.TieTarget{Target: &$temp}'
return string.Template(tmpl).substitute(temp=temp.name)
def _import(self, name, index):
"""Returns an expression for a Module object returned from ImportModule.
def _import_and_bind(self, imp):
"""Generates code that imports a module and binds it to a variable.
Args:
name: The fully qualified Python module name, e.g. foo.bar.
index: The element in the list of modules that this expression should
select. E.g. for 'foo.bar', 0 corresponds to the package foo and 1
corresponds to the module bar.
Returns:
A Go expression evaluating to an *Object (upcast from a *Module.)
imp: Import object representing an import of the form "import x.y.z" or
"from x.y import z". Expects only a single binding.
"""
parts = name.split('.')
parts = imp.name.split('.')
code_objs = []
for i in xrange(len(parts)):
package_name = '/'.join(parts[:i + 1])
......@@ -701,27 +677,33 @@ class StatementVisitor(algorithm.Visitor):
code_objs.append('{}.Code'.format(package.alias))
else:
code_objs.append('Code')
mod = self.block.alloc_temp()
with self.block.alloc_temp('[]*πg.Object') as mod_slice:
with self.block.alloc_temp() as mod, \
self.block.alloc_temp('[]*πg.Object') as mod_slice:
handles_expr = '[]*πg.Code{' + ', '.join(code_objs) + '}'
self.writer.write_checked_call2(
mod_slice, 'πg.ImportModule(πF, {}, {})',
util.go_str(name), handles_expr)
util.go_str(imp.name), handles_expr)
# This method only handles simple module imports (i.e. not member
# imports) which always have a single binding.
binding = imp.bindings[0]
if binding.value == util.Import.ROOT:
index = 0
else:
index = len(parts) - 1
self.writer.write('{} = {}[{}]'.format(mod.name, mod_slice.expr, index))
return mod
self.block.bind_var(self.writer, binding.alias, mod.expr)
def _import_native(self, name, values):
reflect_package = self.block.root.add_native_import('reflect')
import_name = name[len(_NATIVE_MODULE_PREFIX):]
# Work-around for importing go module from VCS
# TODO: support bzr|git|hg|svn from any server
package_name = None
for x in _KNOWN_VCS:
if import_name.startswith(x):
package_name = x + import_name[len(x):].replace('.', '/')
if name.startswith(x):
package_name = x + name[len(x):].replace('.', '/')
break
if not package_name:
package_name = import_name.replace('.', '/')
package_name = name.replace('.', '/')
package = self.block.root.add_native_import(package_name)
mod = self.block.alloc_temp()
......
......@@ -313,7 +313,7 @@ class StatementVisitorTest(unittest.TestCase):
def testImportNativeModuleRaises(self):
regexp = r'for native imports use "from __go__\.xyz import \.\.\." syntax'
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
'import __go__.foo')
def testImportNativeType(self):
......@@ -368,11 +368,11 @@ class StatementVisitorTest(unittest.TestCase):
def testImportWildcardMemberRaises(self):
regexp = r'wildcard member import is not implemented: from foo import *'
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
'from foo import *')
regexp = (r'wildcard member import is not '
r'implemented: from __go__.foo import *')
self.assertRaisesRegexp(util.ParseError, regexp, _ParseAndVisit,
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
'from __go__.foo import *')
def testVisitFuture(self):
......
......@@ -19,12 +19,15 @@
from __future__ import unicode_literals
import codecs
import collections
import contextlib
import cStringIO
import string
import StringIO
import textwrap
from pythonparser import algorithm
_SIMPLE_CHARS = set(string.digits + string.letters + string.punctuation + " ")
_ESCAPES = {'\t': r'\t', '\r': r'\r', '\n': r'\n', '"': r'\"', '\\': r'\\'}
......@@ -34,13 +37,96 @@ _ESCAPES = {'\t': r'\t', '\r': r'\r', '\n': r'\n', '"': r'\"', '\\': r'\\'}
# This should match the number of specializations found in tuple.go.
MAX_DIRECT_TUPLE = 6
_NATIVE_MODULE_PREFIX = '__go__.'
class ParseError(Exception):
class CompileError(Exception):
def __init__(self, node, msg):
if hasattr(node, 'lineno'):
msg = 'line {}: {}'.format(node.lineno, msg)
super(ParseError, self).__init__(msg)
super(CompileError, self).__init__(msg)
class ParseError(CompileError):
pass
class ImportError(CompileError): # pylint: disable=redefined-builtin
pass
class Import(object):
"""Represents a single module import and all its associated bindings.
Each import pertains to a single module that is imported. Thus one import
statement may produce multiple Import objects. E.g. "import foo, bar" makes
an Import object for module foo and another one for module bar.
"""
Binding = collections.namedtuple('Binding', ('bind_type', 'alias', 'value'))
MODULE = "<BindType 'module'>"
MEMBER = "<BindType 'member'>"
ROOT = "<BindValue 'root'>"
LEAF = "<BindValue 'leaf'>"
def __init__(self, name, is_native=False):
self.name = name
self.is_native = is_native
self.bindings = []
def add_binding(self, bind_type, alias, value):
self.bindings.append(Import.Binding(bind_type, alias, value))
class ImportVisitor(algorithm.Visitor):
"""Visits import nodes and produces corresponding Import objects."""
# pylint: disable=invalid-name,missing-docstring,no-init
def visit_Import(self, node):
imports = []
for alias in node.names:
if alias.name.startswith(_NATIVE_MODULE_PREFIX):
raise ImportError(
node, 'for native imports use "from __go__.xyz import ..." syntax')
imp = Import(alias.name)
if alias.asname:
imp.add_binding(Import.MODULE, alias.asname, Import.LEAF)
else:
imp.add_binding(Import.MODULE, alias.name.split('.')[-1], Import.ROOT)
imports.append(imp)
return imports
def visit_ImportFrom(self, node):
if any(a.name == '*' for a in node.names):
msg = 'wildcard member import is not implemented: from %s import *' % (
node.module)
raise ImportError(node, msg)
if node.module == '__future__':
return []
if node.module.startswith(_NATIVE_MODULE_PREFIX):
imp = Import(node.module[len(_NATIVE_MODULE_PREFIX):], is_native=True)
for alias in node.names:
asname = alias.asname or alias.name
imp.add_binding(Import.MEMBER, asname, alias.name)
return [imp]
# NOTE: Assume that the names being imported are all modules within a
# package. E.g. "from a.b import c" is importing the module c from package
# a.b, not some member of module b. We cannot distinguish between these
# two cases at compile time and the Google style guide forbids the latter
# so we support that use case only.
imports = []
for alias in node.names:
imp = Import('{}.{}'.format(node.module, alias.name))
imp.add_binding(Import.MODULE, alias.asname or alias.name, Import.LEAF)
imports.append(imp)
return imports
class Writer(object):
......
......@@ -20,11 +20,91 @@ from __future__ import unicode_literals
import unittest
import pythonparser
from grumpy.compiler import block
from grumpy.compiler import util
from grumpy.compiler import stmt
class ImportVisitorTest(unittest.TestCase):
def testImport(self):
imp = util.Import('foo')
imp.add_binding(util.Import.MODULE, 'foo', util.Import.ROOT)
self._assert_imports_equal(imp, self._visit_import('import foo'))
def testImportMultiple(self):
imp1 = util.Import('foo')
imp1.add_binding(util.Import.MODULE, 'foo', util.Import.ROOT)
imp2 = util.Import('bar')
imp2.add_binding(util.Import.MODULE, 'bar', util.Import.ROOT)
self._assert_imports_equal(
[imp1, imp2], self._visit_import('import foo, bar'))
def testImportAs(self):
imp = util.Import('foo')
imp.add_binding(util.Import.MODULE, 'bar', util.Import.LEAF)
self._assert_imports_equal(imp, self._visit_import('import foo as bar'))
def testImportNativeRaises(self):
self.assertRaises(util.ImportError, self._visit_import, 'import __go__.fmt')
def testImportFrom(self):
imp = util.Import('foo.bar')
imp.add_binding(util.Import.MODULE, 'bar', util.Import.LEAF)
self._assert_imports_equal(imp, self._visit_import('from foo import bar'))
def testImportFromMultiple(self):
imp1 = util.Import('foo.bar')
imp1.add_binding(util.Import.MODULE, 'bar', util.Import.LEAF)
imp2 = util.Import('foo.baz')
imp2.add_binding(util.Import.MODULE, 'baz', util.Import.LEAF)
self._assert_imports_equal(
[imp1, imp2], self._visit_import('from foo import bar, baz'))
def testImportFromAs(self):
imp = util.Import('foo.bar')
imp.add_binding(util.Import.MODULE, 'baz', util.Import.LEAF)
self._assert_imports_equal(
imp, self._visit_import('from foo import bar as baz'))
def testImportFromWildcardRaises(self):
self.assertRaises(util.ImportError, self._visit_import, 'from foo import *')
def testImportFromFuture(self):
result = self._visit_import('from __future__ import print_function')
self.assertEqual([], result)
def testImportFromNative(self):
imp = util.Import('fmt', is_native=True)
imp.add_binding(util.Import.MEMBER, 'Printf', 'Printf')
self._assert_imports_equal(
imp, self._visit_import('from __go__.fmt import Printf'))
def testImportFromNativeMultiple(self):
imp = util.Import('fmt', is_native=True)
imp.add_binding(util.Import.MEMBER, 'Printf', 'Printf')
imp.add_binding(util.Import.MEMBER, 'Println', 'Println')
self._assert_imports_equal(
imp, self._visit_import('from __go__.fmt import Printf, Println'))
def testImportFromNativeAs(self):
imp = util.Import('fmt', is_native=True)
imp.add_binding(util.Import.MEMBER, 'foo', 'Printf')
self._assert_imports_equal(
imp, self._visit_import('from __go__.fmt import Printf as foo'))
def _visit_import(self, source):
return util.ImportVisitor().visit(pythonparser.parse(source).body[0])
def _assert_imports_equal(self, want, got):
if isinstance(want, util.Import):
want = [want]
self.assertEqual([imp.__dict__ for imp in want],
[imp.__dict__ for imp in got])
class WriterTest(unittest.TestCase):
def testIndentBlock(self):
......
......@@ -52,7 +52,7 @@ def main(args):
# Do a pass for compiler directives from `from __future__ import *` statements
try:
future_features = stmt.visit_future(mod)
except util.ParseError as e:
except util.CompileError as e:
print >> sys.stderr, str(e)
return 2
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment