Commit d61de7f1 authored by Benjamin Peterson's avatar Benjamin Peterson

Merged revisions 66653-66654 via svnmerge from

svn+ssh://pythondev@svn.python.org/python/trunk

................
  r66653 | benjamin.peterson | 2008-09-27 16:09:10 -0500 (Sat, 27 Sep 2008) | 49 lines

  Merged revisions 66511,66548-66549,66644,66646-66652 via svnmerge from
  svn+ssh://pythondev@svn.python.org/sandbox/trunk/2to3/lib2to3

  ........
    r66511 | benjamin.peterson | 2008-09-18 21:49:27 -0500 (Thu, 18 Sep 2008) | 1 line

    remove a  useless if __name__ == '__main__'
  ........
    r66548 | benjamin.peterson | 2008-09-21 21:14:14 -0500 (Sun, 21 Sep 2008) | 1 line

    avoid the perils of mutable default arguments
  ........
    r66549 | benjamin.peterson | 2008-09-21 21:26:11 -0500 (Sun, 21 Sep 2008) | 1 line

    some places in RefactoringTool should raise an error instead of logging it
  ........
    r66644 | benjamin.peterson | 2008-09-27 10:45:10 -0500 (Sat, 27 Sep 2008) | 1 line

    fix doctest refactoring
  ........
    r66646 | benjamin.peterson | 2008-09-27 11:40:13 -0500 (Sat, 27 Sep 2008) | 1 line

    don't print to stdout when 2to3 is used as a library
  ........
    r66647 | benjamin.peterson | 2008-09-27 12:28:28 -0500 (Sat, 27 Sep 2008) | 1 line

    let fixer modules and classes have different prefixes
  ........
    r66648 | benjamin.peterson | 2008-09-27 14:02:13 -0500 (Sat, 27 Sep 2008) | 1 line

    raise errors when 2to3 is used as a library
  ........
    r66649 | benjamin.peterson | 2008-09-27 14:03:38 -0500 (Sat, 27 Sep 2008) | 1 line

    fix docstring
  ........
    r66650 | benjamin.peterson | 2008-09-27 14:22:21 -0500 (Sat, 27 Sep 2008) | 1 line

    make use of enumerate
  ........
    r66651 | benjamin.peterson | 2008-09-27 14:24:13 -0500 (Sat, 27 Sep 2008) | 1 line

    revert last revision; it breaks things
  ........
    r66652 | benjamin.peterson | 2008-09-27 16:03:06 -0500 (Sat, 27 Sep 2008) | 1 line

    add tests for lib2to3.refactor
  ........
................
  r66654 | benjamin.peterson | 2008-09-27 16:12:20 -0500 (Sat, 27 Sep 2008) | 1 line

  enable refactor tests
................
parent 027951f1
...@@ -10,6 +10,20 @@ import optparse ...@@ -10,6 +10,20 @@ import optparse
from . import refactor from . import refactor
class StdoutRefactoringTool(refactor.RefactoringTool):
"""
Prints output to stdout.
"""
def log_error(self, msg, *args, **kwargs):
self.errors.append((msg, args, kwargs))
self.logger.error(msg, *args, **kwargs)
def print_output(self, lines):
for line in lines:
print(line)
def main(fixer_pkg, args=None): def main(fixer_pkg, args=None):
"""Main program. """Main program.
...@@ -68,7 +82,7 @@ def main(fixer_pkg, args=None): ...@@ -68,7 +82,7 @@ def main(fixer_pkg, args=None):
fixer_names = avail_names if "all" in options.fix else explicit fixer_names = avail_names if "all" in options.fix else explicit
else: else:
fixer_names = avail_names fixer_names = avail_names
rt = refactor.RefactoringTool(fixer_names, rt_opts, explicit=explicit) rt = StdoutRefactoringTool(fixer_names, rt_opts, explicit=explicit)
# Refactor all files and directories passed as arguments # Refactor all files and directories passed as arguments
if not rt.errors: if not rt.errors:
...@@ -80,7 +94,3 @@ def main(fixer_pkg, args=None): ...@@ -80,7 +94,3 @@ def main(fixer_pkg, args=None):
# Return error status (0 if rt.errors is zero) # Return error status (0 if rt.errors is zero)
return int(bool(rt.errors)) return int(bool(rt.errors))
if __name__ == "__main__":
sys.exit(main())
...@@ -90,11 +90,18 @@ def get_fixers_from_package(pkg_name): ...@@ -90,11 +90,18 @@ def get_fixers_from_package(pkg_name):
for fix_name in get_all_fix_names(pkg_name, False)] for fix_name in get_all_fix_names(pkg_name, False)]
class FixerError(Exception):
"""A fixer could not be loaded."""
class RefactoringTool(object): class RefactoringTool(object):
_default_options = {"print_function": False} _default_options = {"print_function": False}
def __init__(self, fixer_names, options=None, explicit=[]): CLASS_PREFIX = "Fix" # The prefix for fixer classes
FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
def __init__(self, fixer_names, options=None, explicit=None):
"""Initializer. """Initializer.
Args: Args:
...@@ -103,7 +110,7 @@ class RefactoringTool(object): ...@@ -103,7 +110,7 @@ class RefactoringTool(object):
explicit: a list of fixers to run even if they are explicit. explicit: a list of fixers to run even if they are explicit.
""" """
self.fixers = fixer_names self.fixers = fixer_names
self.explicit = explicit self.explicit = explicit or []
self.options = self._default_options.copy() self.options = self._default_options.copy()
if options is not None: if options is not None:
self.options.update(options) self.options.update(options)
...@@ -134,29 +141,17 @@ class RefactoringTool(object): ...@@ -134,29 +141,17 @@ class RefactoringTool(object):
pre_order_fixers = [] pre_order_fixers = []
post_order_fixers = [] post_order_fixers = []
for fix_mod_path in self.fixers: for fix_mod_path in self.fixers:
try:
mod = __import__(fix_mod_path, {}, {}, ["*"]) mod = __import__(fix_mod_path, {}, {}, ["*"])
except ImportError:
self.log_error("Can't load transformation module %s",
fix_mod_path)
continue
fix_name = fix_mod_path.rsplit(".", 1)[-1] fix_name = fix_mod_path.rsplit(".", 1)[-1]
if fix_name.startswith("fix_"): if fix_name.startswith(self.FILE_PREFIX):
fix_name = fix_name[4:] fix_name = fix_name[len(self.FILE_PREFIX):]
parts = fix_name.split("_") parts = fix_name.split("_")
class_name = "Fix" + "".join([p.title() for p in parts]) class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts])
try: try:
fix_class = getattr(mod, class_name) fix_class = getattr(mod, class_name)
except AttributeError: except AttributeError:
self.log_error("Can't find %s.%s", raise FixerError("Can't find %s.%s" % (fix_name, class_name))
fix_name, class_name)
continue
try:
fixer = fix_class(self.options, self.fixer_log) fixer = fix_class(self.options, self.fixer_log)
except Exception as err:
self.log_error("Can't instantiate fixes.fix_%s.%s()",
fix_name, class_name, exc_info=True)
continue
if fixer.explicit and self.explicit is not True and \ if fixer.explicit and self.explicit is not True and \
fix_mod_path not in self.explicit: fix_mod_path not in self.explicit:
self.log_message("Skipping implicit fixer: %s", fix_name) self.log_message("Skipping implicit fixer: %s", fix_name)
...@@ -168,7 +163,7 @@ class RefactoringTool(object): ...@@ -168,7 +163,7 @@ class RefactoringTool(object):
elif fixer.order == "post": elif fixer.order == "post":
post_order_fixers.append(fixer) post_order_fixers.append(fixer)
else: else:
raise ValueError("Illegal fixer order: %r" % fixer.order) raise FixerError("Illegal fixer order: %r" % fixer.order)
key_func = operator.attrgetter("run_order") key_func = operator.attrgetter("run_order")
pre_order_fixers.sort(key=key_func) pre_order_fixers.sort(key=key_func)
...@@ -176,9 +171,8 @@ class RefactoringTool(object): ...@@ -176,9 +171,8 @@ class RefactoringTool(object):
return (pre_order_fixers, post_order_fixers) return (pre_order_fixers, post_order_fixers)
def log_error(self, msg, *args, **kwds): def log_error(self, msg, *args, **kwds):
"""Increments error count and log a message.""" """Called when an error occurs."""
self.errors.append((msg, args, kwds)) raise
self.logger.error(msg, *args, **kwds)
def log_message(self, msg, *args): def log_message(self, msg, *args):
"""Hook to log a message.""" """Hook to log a message."""
...@@ -191,13 +185,17 @@ class RefactoringTool(object): ...@@ -191,13 +185,17 @@ class RefactoringTool(object):
msg = msg % args msg = msg % args
self.logger.debug(msg) self.logger.debug(msg)
def print_output(self, lines):
"""Called with lines of output to give to the user."""
pass
def refactor(self, items, write=False, doctests_only=False): def refactor(self, items, write=False, doctests_only=False):
"""Refactor a list of files and directories.""" """Refactor a list of files and directories."""
for dir_or_file in items: for dir_or_file in items:
if os.path.isdir(dir_or_file): if os.path.isdir(dir_or_file):
self.refactor_dir(dir_or_file, write) self.refactor_dir(dir_or_file, write, doctests_only)
else: else:
self.refactor_file(dir_or_file, write) self.refactor_file(dir_or_file, write, doctests_only)
def refactor_dir(self, dir_name, write=False, doctests_only=False): def refactor_dir(self, dir_name, write=False, doctests_only=False):
"""Descends down a directory and refactor every Python file found. """Descends down a directory and refactor every Python file found.
...@@ -348,12 +346,11 @@ class RefactoringTool(object): ...@@ -348,12 +346,11 @@ class RefactoringTool(object):
if old_text == new_text: if old_text == new_text:
self.log_debug("No changes to %s", filename) self.log_debug("No changes to %s", filename)
return return
diff_texts(old_text, new_text, filename) self.print_output(diff_texts(old_text, new_text, filename))
if not write:
self.log_debug("Not writing changes to %s", filename)
return
if write: if write:
self.write_file(new_text, filename, old_text) self.write_file(new_text, filename, old_text)
else:
self.log_debug("Not writing changes to %s", filename)
def write_file(self, new_text, filename, old_text=None): def write_file(self, new_text, filename, old_text=None):
"""Writes a string to a file. """Writes a string to a file.
...@@ -528,10 +525,9 @@ class RefactoringTool(object): ...@@ -528,10 +525,9 @@ class RefactoringTool(object):
def diff_texts(a, b, filename): def diff_texts(a, b, filename):
"""Prints a unified diff of two strings.""" """Return a unified diff of two strings."""
a = a.splitlines() a = a.splitlines()
b = b.splitlines() b = b.splitlines()
for line in difflib.unified_diff(a, b, filename, filename, return difflib.unified_diff(a, b, filename, filename,
"(original)", "(refactored)", "(original)", "(refactored)",
lineterm=""): lineterm="")
print(line)
from lib2to3.fixer_base import BaseFix
class FixBadOrder(BaseFix):
order = "crazy"
from lib2to3.fixer_base import BaseFix
class FixExplicit(BaseFix):
explicit = True
def match(self): return False
from lib2to3.fixer_base import BaseFix
class FixFirst(BaseFix):
run_order = 1
def match(self, node): return False
from lib2to3.fixer_base import BaseFix
class FixLast(BaseFix):
run_order = 10
def match(self, node): return False
from lib2to3.fixer_base import BaseFix
from lib2to3.fixer_util import Name
class FixParrot(BaseFix):
"""
Change functions named 'parrot' to 'cheese'.
"""
PATTERN = """funcdef < 'def' name='parrot' any* >"""
def transform(self, node, results):
name = results["name"]
name.replace(Name("cheese", name.get_prefix()))
from lib2to3.fixer_base import BaseFix
class FixPreorder(BaseFix):
order = "pre"
def match(self, node): return False
# This is empty so trying to fetch the fixer class gives an AttributeError
"""
Unit tests for refactor.py.
"""
import sys
import os
import operator
import io
import tempfile
import unittest
from lib2to3 import refactor, pygram, fixer_base
from . import support
FIXER_DIR = os.path.join(os.path.dirname(__file__), "data/fixers")
sys.path.append(FIXER_DIR)
try:
_DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes")
finally:
sys.path.pop()
class TestRefactoringTool(unittest.TestCase):
def setUp(self):
sys.path.append(FIXER_DIR)
def tearDown(self):
sys.path.pop()
def check_instances(self, instances, classes):
for inst, cls in zip(instances, classes):
if not isinstance(inst, cls):
self.fail("%s are not instances of %s" % instances, classes)
def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None):
return refactor.RefactoringTool(fixers, options, explicit)
def test_print_function_option(self):
gram = pygram.python_grammar
save = gram.keywords["print"]
try:
rt = self.rt({"print_function" : True})
self.assertRaises(KeyError, operator.itemgetter("print"),
gram.keywords)
finally:
gram.keywords["print"] = save
def test_fixer_loading_helpers(self):
contents = ["explicit", "first", "last", "parrot", "preorder"]
non_prefixed = refactor.get_all_fix_names("myfixes")
prefixed = refactor.get_all_fix_names("myfixes", False)
full_names = refactor.get_fixers_from_package("myfixes")
self.assertEqual(prefixed, ["fix_" + name for name in contents])
self.assertEqual(non_prefixed, contents)
self.assertEqual(full_names,
["myfixes.fix_" + name for name in contents])
def test_get_headnode_dict(self):
class NoneFix(fixer_base.BaseFix):
PATTERN = None
class FileInputFix(fixer_base.BaseFix):
PATTERN = "file_input< any * >"
no_head = NoneFix({}, [])
with_head = FileInputFix({}, [])
d = refactor.get_headnode_dict([no_head, with_head])
expected = {None: [no_head],
pygram.python_symbols.file_input : [with_head]}
self.assertEqual(d, expected)
def test_fixer_loading(self):
from myfixes.fix_first import FixFirst
from myfixes.fix_last import FixLast
from myfixes.fix_parrot import FixParrot
from myfixes.fix_preorder import FixPreorder
rt = self.rt()
pre, post = rt.get_fixers()
self.check_instances(pre, [FixPreorder])
self.check_instances(post, [FixFirst, FixParrot, FixLast])
def test_naughty_fixers(self):
self.assertRaises(ImportError, self.rt, fixers=["not_here"])
self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"])
self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"])
def test_refactor_string(self):
rt = self.rt()
input = "def parrot(): pass\n\n"
tree = rt.refactor_string(input, "<test>")
self.assertNotEqual(str(tree), input)
input = "def f(): pass\n\n"
tree = rt.refactor_string(input, "<test>")
self.assertEqual(str(tree), input)
def test_refactor_stdin(self):
class MyRT(refactor.RefactoringTool):
def print_output(self, lines):
diff_lines.extend(lines)
diff_lines = []
rt = MyRT(_DEFAULT_FIXERS)
save = sys.stdin
sys.stdin = io.StringIO("def parrot(): pass\n\n")
try:
rt.refactor_stdin()
finally:
sys.stdin = save
expected = """--- <stdin> (original)
+++ <stdin> (refactored)
@@ -1,2 +1,2 @@
-def parrot(): pass
+def cheese(): pass""".splitlines()
self.assertEqual(diff_lines[:-1], expected)
def test_refactor_file(self):
test_file = os.path.join(FIXER_DIR, "parrot_example.py")
backup = test_file + ".bak"
old_contents = open(test_file, "r").read()
rt = self.rt()
rt.refactor_file(test_file)
self.assertEqual(old_contents, open(test_file, "r").read())
rt.refactor_file(test_file, True)
try:
self.assertNotEqual(old_contents, open(test_file, "r").read())
self.assertTrue(os.path.exists(backup))
self.assertEqual(old_contents, open(backup, "r").read())
finally:
open(test_file, "w").write(old_contents)
try:
os.unlink(backup)
except OSError:
pass
def test_refactor_docstring(self):
rt = self.rt()
def example():
"""
>>> example()
42
"""
out = rt.refactor_docstring(example.__doc__, "<test>")
self.assertEqual(out, example.__doc__)
def parrot():
"""
>>> def parrot():
... return 43
"""
out = rt.refactor_docstring(parrot.__doc__, "<test>")
self.assertNotEqual(out, parrot.__doc__)
def test_explicit(self):
from myfixes.fix_explicit import FixExplicit
rt = self.rt(fixers=["myfixes.fix_explicit"])
self.assertEqual(len(rt.post_order), 0)
rt = self.rt(explicit=["myfixes.fix_explicit"])
for fix in rt.post_order[None]:
if isinstance(fix, FixExplicit):
break
else:
self.fail("explicit fixer not loaded")
# Skipping test_parser and test_all_fixers # Skipping test_parser and test_all_fixers
# because of running # because of running
from lib2to3.tests import test_fixers, test_pytree, test_util from lib2to3.tests import test_fixers, test_pytree, test_util, test_refactor
import unittest import unittest
from test.support import run_unittest from test.support import run_unittest
def suite(): def suite():
tests = unittest.TestSuite() tests = unittest.TestSuite()
loader = unittest.TestLoader() loader = unittest.TestLoader()
for m in (test_fixers,test_pytree,test_util): for m in (test_fixers,test_pytree,test_util, test_refactor):
tests.addTests(loader.loadTestsFromModule(m)) tests.addTests(loader.loadTestsFromModule(m))
return tests return tests
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment