Commit d35cbf6e authored by Antoine Pitrou's avatar Antoine Pitrou

Merged revisions 68360-68361 via svnmerge from

svn+ssh://pythondev@svn.python.org/python/trunk

........
  r68360 | antoine.pitrou | 2009-01-06 19:10:47 +0100 (mar., 06 janv. 2009) | 7 lines

  Issue #1180193: When importing a module from a .pyc (or .pyo) file with
  an existing .py counterpart, override the co_filename attributes of all
  code objects if the original filename is obsolete (which can happen if the
  file has been renamed, moved, or if it is accessed through different paths).
  Patch by Ziga Seilnacht and Jean-Paul Calderone.
........
  r68361 | antoine.pitrou | 2009-01-06 19:34:08 +0100 (mar., 06 janv. 2009) | 3 lines

  Use shutil.rmtree rather than os.rmdir.
........
parent 5d1ff00b
......@@ -6,6 +6,7 @@ import sys
import py_compile
import warnings
import imp
import marshal
from test.support import unlink, TESTFN, unload, run_unittest
......@@ -230,6 +231,98 @@ class ImportTest(unittest.TestCase):
else:
self.fail("import by path didn't raise an exception")
class TestPycRewriting(unittest.TestCase):
# Test that the `co_filename` attribute on code objects always points
# to the right file, even when various things happen (e.g. both the .py
# and the .pyc file are renamed).
module_name = "unlikely_module_name"
module_source = """
import sys
code_filename = sys._getframe().f_code.co_filename
module_filename = __file__
constant = 1
def func():
pass
func_filename = func.__code__.co_filename
"""
dir_name = os.path.abspath(TESTFN)
file_name = os.path.join(dir_name, module_name) + os.extsep + "py"
compiled_name = file_name + ("c" if __debug__ else "o")
def setUp(self):
self.sys_path = sys.path[:]
self.orig_module = sys.modules.pop(self.module_name, None)
os.mkdir(self.dir_name)
with open(self.file_name, "w") as f:
f.write(self.module_source)
sys.path.insert(0, self.dir_name)
def tearDown(self):
sys.path[:] = self.sys_path
if self.orig_module is not None:
sys.modules[self.module_name] = self.orig_module
else:
del sys.modules[self.module_name]
for file_name in self.file_name, self.compiled_name:
if os.path.exists(file_name):
os.remove(file_name)
if os.path.exists(self.dir_name):
shutil.rmtree(self.dir_name)
def import_module(self):
ns = globals()
__import__(self.module_name, ns, ns)
return sys.modules[self.module_name]
def test_basics(self):
mod = self.import_module()
self.assertEqual(mod.module_filename, self.file_name)
self.assertEqual(mod.code_filename, self.file_name)
self.assertEqual(mod.func_filename, self.file_name)
del sys.modules[self.module_name]
mod = self.import_module()
self.assertEqual(mod.module_filename, self.file_name)
self.assertEqual(mod.code_filename, self.file_name)
self.assertEqual(mod.func_filename, self.file_name)
def test_incorrect_code_name(self):
py_compile.compile(self.file_name, dfile="another_module.py")
mod = self.import_module()
self.assertEqual(mod.module_filename, self.file_name)
self.assertEqual(mod.code_filename, self.file_name)
self.assertEqual(mod.func_filename, self.file_name)
def test_module_without_source(self):
target = "another_module.py"
py_compile.compile(self.file_name, dfile=target)
os.remove(self.file_name)
mod = self.import_module()
self.assertEqual(mod.module_filename, self.compiled_name)
self.assertEqual(mod.code_filename, target)
self.assertEqual(mod.func_filename, target)
def test_foreign_code(self):
py_compile.compile(self.file_name)
with open(self.compiled_name, "rb") as f:
header = f.read(8)
code = marshal.load(f)
constants = list(code.co_consts)
foreign_code = test_main.__code__
pos = constants.index(1)
constants[pos] = foreign_code
code = type(code)(code.co_argcount, code.co_kwonlyargcount,
code.co_nlocals, code.co_stacksize,
code.co_flags, code.co_code, tuple(constants),
code.co_names, code.co_varnames, code.co_filename,
code.co_name, code.co_firstlineno, code.co_lnotab,
code.co_freevars, code.co_cellvars)
with open(self.compiled_name, "wb") as f:
f.write(header)
marshal.dump(code, f)
mod = self.import_module()
self.assertEqual(mod.constant.co_filename, foreign_code.co_filename)
class PathsTests(unittest.TestCase):
SAMPLES = ('test', 'test\u00e4\u00f6\u00fc\u00df', 'test\u00e9\u00e8',
'test\u00b0\u00b3\u00b2')
......@@ -288,7 +381,7 @@ class RelativeImport(unittest.TestCase):
self.assertRaises(ValueError, check_relative)
def test_main(verbose=None):
run_unittest(ImportTest, PathsTests, RelativeImport)
run_unittest(ImportTest, TestPycRewriting, PathsTests, RelativeImport)
if __name__ == '__main__':
# test needs to be a package, so we can do relative import
......
......@@ -12,6 +12,12 @@ What's New in Python 3.1 alpha 0
Core and Builtins
-----------------
- Issue #1180193: When importing a module from a .pyc (or .pyo) file with
an existing .py counterpart, override the co_filename attributes of all
code objects if the original filename is obsolete (which can happen if the
file has been renamed, moved, or if it is accessed through different paths).
Patch by Ziga Seilnacht and Jean-Paul Calderone.
- Issue #4580: Fix slicing of memoryviews when the item size is greater than
one byte. Also fixes the meaning of len() so that it returns the number of
items, rather than the size in bytes.
......
......@@ -959,6 +959,49 @@ write_compiled_module(PyCodeObject *co, char *cpathname, struct stat *srcstat)
PySys_WriteStderr("# wrote %s\n", cpathname);
}
static void
update_code_filenames(PyCodeObject *co, PyObject *oldname, PyObject *newname)
{
PyObject *constants, *tmp;
Py_ssize_t i, n;
if (PyUnicode_Compare(co->co_filename, oldname))
return;
tmp = co->co_filename;
co->co_filename = newname;
Py_INCREF(co->co_filename);
Py_DECREF(tmp);
constants = co->co_consts;
n = PyTuple_GET_SIZE(constants);
for (i = 0; i < n; i++) {
tmp = PyTuple_GET_ITEM(constants, i);
if (PyCode_Check(tmp))
update_code_filenames((PyCodeObject *)tmp,
oldname, newname);
}
}
static int
update_compiled_module(PyCodeObject *co, char *pathname)
{
PyObject *oldname, *newname;
if (!PyUnicode_CompareWithASCIIString(co->co_filename, pathname))
return 0;
newname = PyUnicode_FromString(pathname);
if (newname == NULL)
return -1;
oldname = co->co_filename;
Py_INCREF(oldname);
update_code_filenames(co, oldname, newname);
Py_DECREF(oldname);
Py_DECREF(newname);
return 1;
}
/* Load a source module from a given file and return its module
object WITH INCREMENTED REFERENCE COUNT. If there's a matching
......@@ -999,6 +1042,8 @@ load_source_module(char *name, char *pathname, FILE *fp)
fclose(fpc);
if (co == NULL)
return NULL;
if (update_compiled_module(co, pathname) < 0)
return NULL;
if (Py_VerboseFlag)
PySys_WriteStderr("import %s # precompiled from %s\n",
name, cpathname);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment