Commit 887248bd authored by Dylan Trotter's avatar Dylan Trotter Committed by Dylan Trotter

Simplify future features a little bit

parent dd626b10
...@@ -109,8 +109,9 @@ class ImportVisitor(algorithm.Visitor): ...@@ -109,8 +109,9 @@ class ImportVisitor(algorithm.Visitor):
# pylint: disable=invalid-name,missing-docstring,no-init # pylint: disable=invalid-name,missing-docstring,no-init
def __init__(self, path): def __init__(self, path, future_node=None):
self.path = path self.path = path
self.future_node = future_node
self.imports = [] self.imports = []
def visit_Import(self, node): def visit_Import(self, node):
...@@ -134,6 +135,8 @@ class ImportVisitor(algorithm.Visitor): ...@@ -134,6 +135,8 @@ class ImportVisitor(algorithm.Visitor):
raise util.ImportError(node, msg) raise util.ImportError(node, msg)
if node.module == '__future__': if node.module == '__future__':
if node != self.future_node:
raise util.LateFutureError(node)
return return
if node.module.startswith(_NATIVE_MODULE_PREFIX): if node.module.startswith(_NATIVE_MODULE_PREFIX):
...@@ -168,88 +171,60 @@ class ImportVisitor(algorithm.Visitor): ...@@ -168,88 +171,60 @@ class ImportVisitor(algorithm.Visitor):
return Import(full_name) return Import(full_name)
# Parser flags, set on 'from __future__ import *', see parser_flags on _FUTURE_FEATURES = (
# StatementVisitor below. Note these have the same values as CPython. 'absolute_import',
FUTURE_DIVISION = 0x2000 'division',
FUTURE_ABSOLUTE_IMPORT = 0x4000 'print_function',
FUTURE_PRINT_FUNCTION = 0x10000 'unicode_literals',
FUTURE_UNICODE_LITERALS = 0x20000 )
# Names for future features in 'from __future__ import *'. Map from name in the _IMPLEMENTED_FUTURE_FEATURES = ('print_function',)
# import statement to a tuple of the flag for parser, and whether we've (grumpy)
# implemented the feature yet.
future_features = {
"division": (FUTURE_DIVISION, False),
"absolute_import": (FUTURE_ABSOLUTE_IMPORT, False),
"print_function": (FUTURE_PRINT_FUNCTION, True),
"unicode_literals": (FUTURE_UNICODE_LITERALS, False),
}
# These future features are already in the language proper as of 2.6, so # These future features are already in the language proper as of 2.6, so
# importing them via __future__ has no effect. # importing them via __future__ has no effect.
redundant_future_features = ["generators", "with_statement", "nested_scopes"] _REDUNDANT_FUTURE_FEATURES = ('generators', 'with_statement', 'nested_scopes')
class FutureFeatures(object):
def __init__(self):
for name in _FUTURE_FEATURES:
setattr(self, name, False)
def import_from_future(node): def _make_future_features(node):
"""Processes a future import statement, returning set of flags it defines.""" """Processes a future import statement, returning set of flags it defines."""
assert isinstance(node, ast.ImportFrom) assert isinstance(node, ast.ImportFrom)
assert node.module == '__future__' assert node.module == '__future__'
flags = 0 features = FutureFeatures()
for alias in node.names: for alias in node.names:
name = alias.name name = alias.name
if name in future_features: if name in _FUTURE_FEATURES:
flag, implemented = future_features[name] if name not in _IMPLEMENTED_FUTURE_FEATURES:
if not implemented:
msg = 'future feature {} not yet implemented by grumpy'.format(name) msg = 'future feature {} not yet implemented by grumpy'.format(name)
raise util.ParseError(node, msg) raise util.ParseError(node, msg)
flags |= flag setattr(features, name, True)
elif name == 'braces': elif name == 'braces':
raise util.ParseError(node, 'not a chance') raise util.ParseError(node, 'not a chance')
elif name not in redundant_future_features: elif name not in _REDUNDANT_FUTURE_FEATURES:
msg = 'future feature {} is not defined'.format(name) msg = 'future feature {} is not defined'.format(name)
raise util.ParseError(node, msg) raise util.ParseError(node, msg)
return flags return features
class FutureFeatures(object):
def __init__(self):
self.parser_flags = 0
self.future_lineno = 0
def visit_future(node): def parse_future_features(mod):
"""Accumulates a set of compiler flags for the compiler __future__ imports. """Accumulates a set of flags for the compiler __future__ imports."""
assert isinstance(mod, ast.Module)
Returns an instance of FutureFeatures which encapsulates the flags and the
line number of the last valid future import parsed. A downstream parser can
use the latter to detect invalid future imports that appear too late in the
file.
"""
# If this is the module node, do an initial pass through the module body's
# statements to detect future imports and process their directives (i.e.,
# set compiler flags), and detect ones that don't appear at the beginning of
# the file. The only things that can proceed a future statement are other
# future statements and/or a doc string.
assert isinstance(node, ast.Module)
ff = FutureFeatures()
done = False
found_docstring = False found_docstring = False
for node in node.body: for node in mod.body:
if isinstance(node, ast.ImportFrom): if isinstance(node, ast.ImportFrom):
modname = node.module if node.module == '__future__':
if modname == '__future__': return node, _make_future_features(node)
if done: break
raise util.LateFutureError(node)
ff.parser_flags |= import_from_future(node)
ff.future_lineno = node.lineno
else:
done = True
elif isinstance(node, ast.Expr) and not found_docstring: elif isinstance(node, ast.Expr) and not found_docstring:
e = node.value if not isinstance(node.value, ast.Str):
if not isinstance(e, ast.Str): # pylint: disable=simplifiable-if-statement break
done = True
else:
found_docstring = True found_docstring = True
else: else:
done = True break
return ff return None, FutureFeatures()
...@@ -232,31 +232,75 @@ class ImportVisitorTest(unittest.TestCase): ...@@ -232,31 +232,75 @@ class ImportVisitorTest(unittest.TestCase):
[imp.__dict__ for imp in got]) [imp.__dict__ for imp in got])
class VisitFutureTest(unittest.TestCase): class MakeFutureFeaturesTest(unittest.TestCase):
def testImportFromFuture(self):
print_function_features = imputil.FutureFeatures()
print_function_features.print_function = True
testcases = [
('from __future__ import print_function',
print_function_features),
('from __future__ import generators', imputil.FutureFeatures()),
('from __future__ import generators, print_function',
print_function_features),
]
for tc in testcases:
source, want = tc
mod = pythonparser.parse(textwrap.dedent(source))
node = mod.body[0]
got = imputil._make_future_features(node) # pylint: disable=protected-access
self.assertEqual(want.__dict__, got.__dict__)
def testImportFromFutureParseError(self):
testcases = [
# NOTE: move this group to testImportFromFuture as they are implemented
# by grumpy
('from __future__ import absolute_import',
r'future feature \w+ not yet implemented'),
('from __future__ import division',
r'future feature \w+ not yet implemented'),
('from __future__ import unicode_literals',
r'future feature \w+ not yet implemented'),
('from __future__ import braces', 'not a chance'),
('from __future__ import nonexistant_feature',
r'future feature \w+ is not defined'),
]
for tc in testcases:
source, want_regexp = tc
mod = pythonparser.parse(source)
node = mod.body[0]
self.assertRaisesRegexp(util.ParseError, want_regexp,
imputil._make_future_features, node) # pylint: disable=protected-access
class ParseFutureFeaturesTest(unittest.TestCase):
def testVisitFuture(self): def testVisitFuture(self):
print_function_features = imputil.FutureFeatures()
print_function_features.print_function = True
testcases = [ testcases = [
('from __future__ import print_function', ('from __future__ import print_function',
imputil.FUTURE_PRINT_FUNCTION, 1), print_function_features),
("""\ ("""\
"module docstring" "module docstring"
from __future__ import print_function from __future__ import print_function
""", imputil.FUTURE_PRINT_FUNCTION, 3), """, print_function_features),
("""\ ("""\
"module docstring" "module docstring"
from __future__ import print_function, with_statement from __future__ import print_function, with_statement
from __future__ import nested_scopes from __future__ import nested_scopes
""", imputil.FUTURE_PRINT_FUNCTION, 4), """, print_function_features),
] ]
for tc in testcases: for tc in testcases:
source, flags, lineno = tc source, want = tc
mod = pythonparser.parse(textwrap.dedent(source)) mod = pythonparser.parse(textwrap.dedent(source))
future_features = imputil.visit_future(mod) _, got = imputil.parse_future_features(mod)
self.assertEqual(future_features.parser_flags, flags) self.assertEqual(want.__dict__, got.__dict__)
self.assertEqual(future_features.future_lineno, lineno)
def testVisitFutureLate(self): def testVisitFutureLate(self):
testcases = [ testcases = [
...@@ -274,4 +318,5 @@ class VisitFutureTest(unittest.TestCase): ...@@ -274,4 +318,5 @@ class VisitFutureTest(unittest.TestCase):
for source in testcases: for source in testcases:
mod = pythonparser.parse(textwrap.dedent(source)) mod = pythonparser.parse(textwrap.dedent(source))
self.assertRaises(util.LateFutureError, imputil.visit_future, mod) self.assertRaises(util.LateFutureError,
imputil.parse_future_features, mod)
...@@ -49,10 +49,9 @@ class StatementVisitor(algorithm.Visitor): ...@@ -49,10 +49,9 @@ class StatementVisitor(algorithm.Visitor):
# pylint: disable=invalid-name,missing-docstring # pylint: disable=invalid-name,missing-docstring
def __init__(self, block_): def __init__(self, block_, future_node=None):
self.block = block_ self.block = block_
self.future_features = (self.block.root.future_features or self.future_node = future_node
imputil.FutureFeatures())
self.writer = util.Writer() self.writer = util.Writer()
self.expr_visitor = expr_visitor.ExprVisitor(self) self.expr_visitor = expr_visitor.ExprVisitor(self)
...@@ -108,7 +107,7 @@ class StatementVisitor(algorithm.Visitor): ...@@ -108,7 +107,7 @@ class StatementVisitor(algorithm.Visitor):
if v.type == block.Var.TYPE_GLOBAL} if v.type == block.Var.TYPE_GLOBAL}
# Visit all the statements inside body of the class definition. # Visit all the statements inside body of the class definition.
body_visitor = StatementVisitor(block.ClassBlock( body_visitor = StatementVisitor(block.ClassBlock(
self.block, node.name, global_vars)) self.block, node.name, global_vars), self.future_node)
# Indent so that the function body is aligned with the goto labels. # Indent so that the function body is aligned with the goto labels.
with body_visitor.writer.indent_block(): with body_visitor.writer.indent_block():
body_visitor._visit_each(node.body) # pylint: disable=protected-access body_visitor._visit_each(node.body) # pylint: disable=protected-access
...@@ -293,7 +292,7 @@ class StatementVisitor(algorithm.Visitor): ...@@ -293,7 +292,7 @@ class StatementVisitor(algorithm.Visitor):
def visit_ImportFrom(self, node): def visit_ImportFrom(self, node):
self._write_py_context(node.lineno) self._write_py_context(node.lineno)
visitor = imputil.ImportVisitor(self.block.root.path) visitor = imputil.ImportVisitor(self.block.root.path, self.future_node)
visitor.visit(node) visitor.visit(node)
for imp in visitor.imports: for imp in visitor.imports:
if imp.is_native: if imp.is_native:
...@@ -315,13 +314,7 @@ class StatementVisitor(algorithm.Visitor): ...@@ -315,13 +314,7 @@ class StatementVisitor(algorithm.Visitor):
mod.expr, self.block.root.intern(name)) mod.expr, self.block.root.intern(name))
self.block.bind_var( self.block.bind_var(
self.writer, binding.alias, member.expr) self.writer, binding.alias, member.expr)
elif node.module == '__future__': elif node.module != '__future__':
# At this stage all future imports are done in an initial pass (see
# visit() above), so if they are encountered here after the last valid
# __future__ then it's a syntax error.
if node.lineno > self.future_features.future_lineno:
raise util.LateFutureError(node)
else:
self._import_and_bind(imp) self._import_and_bind(imp)
def visit_Module(self, node): def visit_Module(self, node):
...@@ -331,7 +324,7 @@ class StatementVisitor(algorithm.Visitor): ...@@ -331,7 +324,7 @@ class StatementVisitor(algorithm.Visitor):
self._write_py_context(node.lineno) self._write_py_context(node.lineno)
def visit_Print(self, node): def visit_Print(self, node):
if self.future_features.parser_flags & imputil.FUTURE_PRINT_FUNCTION: if self.block.root.future_features.print_function:
raise util.ParseError(node, 'syntax error (print is not a keyword)') raise util.ParseError(node, 'syntax error (print is not a keyword)')
self._write_py_context(node.lineno) self._write_py_context(node.lineno)
with self.block.alloc_temp('[]*πg.Object') as args: with self.block.alloc_temp('[]*πg.Object') as args:
...@@ -549,7 +542,7 @@ class StatementVisitor(algorithm.Visitor): ...@@ -549,7 +542,7 @@ class StatementVisitor(algorithm.Visitor):
func_visitor.visit(child) func_visitor.visit(child)
func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars, func_block = block.FunctionBlock(self.block, node.name, func_visitor.vars,
func_visitor.is_generator) func_visitor.is_generator)
visitor = StatementVisitor(func_block) visitor = StatementVisitor(func_block, self.future_node)
# Indent so that the function body is aligned with the goto labels. # Indent so that the function body is aligned with the goto labels.
with visitor.writer.indent_block(): with visitor.writer.indent_block():
visitor._visit_each(node.body) # pylint: disable=protected-access visitor._visit_each(node.body) # pylint: disable=protected-access
......
...@@ -335,46 +335,6 @@ class StatementVisitorTest(unittest.TestCase): ...@@ -335,46 +335,6 @@ class StatementVisitorTest(unittest.TestCase):
print '123' print '123'
print 'foo', 'bar'"""))) print 'foo', 'bar'""")))
def testImportFromFuture(self):
testcases = [
('from __future__ import print_function',
imputil.FUTURE_PRINT_FUNCTION),
('from __future__ import generators', 0),
('from __future__ import generators, print_function',
imputil.FUTURE_PRINT_FUNCTION),
]
for i, tc in enumerate(testcases):
source, want_flags = tc
mod = pythonparser.parse(textwrap.dedent(source))
node = mod.body[0]
got = imputil.import_from_future(node)
msg = '#{}: want {}, got {}'.format(i, want_flags, got)
self.assertEqual(want_flags, got, msg=msg)
def testImportFromFutureParseError(self):
testcases = [
# NOTE: move this group to testImportFromFuture as they are implemented
# by grumpy
('from __future__ import absolute_import',
r'future feature \w+ not yet implemented'),
('from __future__ import division',
r'future feature \w+ not yet implemented'),
('from __future__ import unicode_literals',
r'future feature \w+ not yet implemented'),
('from __future__ import braces', 'not a chance'),
('from __future__ import nonexistant_feature',
r'future feature \w+ is not defined'),
]
for tc in testcases:
source, want_regexp = tc
mod = pythonparser.parse(source)
node = mod.body[0]
self.assertRaisesRegexp(util.ParseError, want_regexp,
imputil.import_from_future, node)
def testImportWildcardMemberRaises(self): def testImportWildcardMemberRaises(self):
regexp = r'wildcard member import is not implemented: from foo import *' regexp = r'wildcard member import is not implemented: from foo import *'
self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit, self.assertRaisesRegexp(util.ImportError, regexp, _ParseAndVisit,
...@@ -573,7 +533,7 @@ def _MakeModuleBlock(): ...@@ -573,7 +533,7 @@ def _MakeModuleBlock():
def _ParseAndVisit(source): def _ParseAndVisit(source):
mod = pythonparser.parse(source) mod = pythonparser.parse(source)
future_features = imputil.visit_future(mod) _, future_features = imputil.parse_future_features(mod)
b = block.ModuleBlock(imputil_test.MockPath(), '__main__', b = block.ModuleBlock(imputil_test.MockPath(), '__main__',
'<test>', source, future_features) '<test>', source, future_features)
visitor = stmt.StatementVisitor(b) visitor = stmt.StatementVisitor(b)
......
...@@ -47,7 +47,6 @@ def main(args): ...@@ -47,7 +47,6 @@ def main(args):
if not gopath: if not gopath:
print >> sys.stderr, 'GOPATH not set' print >> sys.stderr, 'GOPATH not set'
return 1 return 1
path = imputil.Path(gopath, args.modname, args.script)
with open(args.script) as py_file: with open(args.script) as py_file:
py_contents = py_file.read() py_contents = py_file.read()
...@@ -60,16 +59,17 @@ def main(args): ...@@ -60,16 +59,17 @@ def main(args):
# Do a pass for compiler directives from `from __future__ import *` statements # Do a pass for compiler directives from `from __future__ import *` statements
try: try:
future_features = imputil.visit_future(mod) future_node, future_features = imputil.parse_future_features(mod)
except util.CompileError as e: except util.CompileError as e:
print >> sys.stderr, str(e) print >> sys.stderr, str(e)
return 2 return 2
path = imputil.Path(gopath, args.modname, args.script)
full_package_name = args.modname.replace('.', '/') full_package_name = args.modname.replace('.', '/')
mod_block = block.ModuleBlock(path, full_package_name, args.script, mod_block = block.ModuleBlock(path, full_package_name, args.script,
py_contents, future_features) py_contents, future_features)
mod_block.add_native_import('grumpy') mod_block.add_native_import('grumpy')
visitor = stmt.StatementVisitor(mod_block) visitor = stmt.StatementVisitor(mod_block, future_node)
# Indent so that the module body is aligned with the goto labels. # Indent so that the module body is aligned with the goto labels.
with visitor.writer.indent_block(): with visitor.writer.indent_block():
try: try:
......
...@@ -47,7 +47,13 @@ def main(args): ...@@ -47,7 +47,13 @@ def main(args):
e.filename, e.lineno, e.text) e.filename, e.lineno, e.text)
return 2 return 2
visitor = imputil.ImportVisitor(path) try:
future_node, _ = imputil.parse_future_features(mod)
except util.CompileError as e:
print >> sys.stderr, str(e)
return 2
visitor = imputil.ImportVisitor(path, future_node)
try: try:
visitor.visit(mod) visitor.visit(mod)
except util.CompileError as e: except util.CompileError as e:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment