Commit b17acad6 authored by Georg Brandl's avatar Georg Brandl

Make db modules' error classes inherit IOError.

Stop dbm from importing every dbm module when imported.
parent e81f5ef1
...@@ -48,27 +48,26 @@ class error(Exception): ...@@ -48,27 +48,26 @@ class error(Exception):
pass pass
_names = ['dbm.bsd', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb'] _names = ['dbm.bsd', 'dbm.gnu', 'dbm.ndbm', 'dbm.dumb']
_errors = [error]
_defaultmod = None _defaultmod = None
_modules = {} _modules = {}
for _name in _names: error = (error, IOError)
try:
_mod = __import__(_name, fromlist=['open'])
except ImportError:
continue
if not _defaultmod:
_defaultmod = _mod
_modules[_name] = _mod
_errors.append(_mod.error)
if not _defaultmod:
raise ImportError("no dbm clone found; tried %s" % _names)
error = tuple(_errors)
def open(file, flag = 'r', mode = 0o666): def open(file, flag = 'r', mode = 0o666):
global _defaultmod
if _defaultmod is None:
for name in _names:
try:
mod = __import__(name, fromlist=['open'])
except ImportError:
continue
if not _defaultmod:
_defaultmod = mod
_modules[name] = mod
if not _defaultmod:
raise ImportError("no dbm clone found; tried %s" % _names)
# guess the type of an existing database # guess the type of an existing database
result = whichdb(file) result = whichdb(file)
if result is None: if result is None:
...@@ -81,19 +80,14 @@ def open(file, flag = 'r', mode = 0o666): ...@@ -81,19 +80,14 @@ def open(file, flag = 'r', mode = 0o666):
elif result == "": elif result == "":
# db type cannot be determined # db type cannot be determined
raise error("db type could not be determined") raise error("db type could not be determined")
elif result not in _modules:
raise error("db type is {0}, but the module is not "
"available".format(result))
else: else:
mod = _modules[result] mod = _modules[result]
return mod.open(file, flag, mode) return mod.open(file, flag, mode)
try:
from dbm import ndbm
_dbmerror = ndbm.error
except ImportError:
ndbm = None
# just some sort of valid exception which might be raised in the ndbm test
_dbmerror = IOError
def whichdb(filename): def whichdb(filename):
"""Guess which db package to use to open a db file. """Guess which db package to use to open a db file.
...@@ -129,7 +123,7 @@ def whichdb(filename): ...@@ -129,7 +123,7 @@ def whichdb(filename):
d = ndbm.open(filename) d = ndbm.open(filename)
d.close() d.close()
return "dbm.ndbm" return "dbm.ndbm"
except (IOError, _dbmerror): except IOError:
pass pass
# Check for dumbdbm next -- this has a .dir and a .dat file # Check for dumbdbm next -- this has a .dir and a .dat file
......
...@@ -4,7 +4,8 @@ import bsddb ...@@ -4,7 +4,8 @@ import bsddb
__all__ = ["error", "open"] __all__ = ["error", "open"]
error = bsddb.error class error(bsddb.error, IOError):
pass
def open(file, flag = 'r', mode=0o666): def open(file, flag = 'r', mode=0o666):
return bsddb.hashopen(file, flag, mode) return bsddb.hashopen(file, flag, mode)
...@@ -14,11 +14,13 @@ _fname = test.support.TESTFN ...@@ -14,11 +14,13 @@ _fname = test.support.TESTFN
# setting dbm to use each in turn, and yielding that module # setting dbm to use each in turn, and yielding that module
# #
def dbm_iterator(): def dbm_iterator():
old_default = dbm._defaultmod for name in dbm._names:
for module in dbm._modules.values(): try:
dbm._defaultmod = module mod = __import__(name, fromlist=['open'])
yield module except ImportError:
dbm._defaultmod = old_default continue
dbm._modules[name] = mod
yield mod
# #
# Clean up all scratch databases we might have created during testing # Clean up all scratch databases we might have created during testing
...@@ -40,8 +42,20 @@ class AnyDBMTestCase(unittest.TestCase): ...@@ -40,8 +42,20 @@ class AnyDBMTestCase(unittest.TestCase):
'g': b'intended', 'g': b'intended',
} }
def __init__(self, *args): def init_db(self):
unittest.TestCase.__init__(self, *args) f = dbm.open(_fname, 'n')
for k in self._dict:
f[k.encode("ascii")] = self._dict[k]
f.close()
def keys_helper(self, f):
keys = sorted(k.decode("ascii") for k in f.keys())
dkeys = sorted(self._dict.keys())
self.assertEqual(keys, dkeys)
return keys
def test_error(self):
self.assert_(issubclass(self.module.error, IOError))
def test_anydbm_creation(self): def test_anydbm_creation(self):
f = dbm.open(_fname, 'c') f = dbm.open(_fname, 'c')
...@@ -83,22 +97,11 @@ class AnyDBMTestCase(unittest.TestCase): ...@@ -83,22 +97,11 @@ class AnyDBMTestCase(unittest.TestCase):
for key in self._dict: for key in self._dict:
self.assertEqual(self._dict[key], f[key.encode("ascii")]) self.assertEqual(self._dict[key], f[key.encode("ascii")])
def init_db(self):
f = dbm.open(_fname, 'n')
for k in self._dict:
f[k.encode("ascii")] = self._dict[k]
f.close()
def keys_helper(self, f):
keys = sorted(k.decode("ascii") for k in f.keys())
dkeys = sorted(self._dict.keys())
self.assertEqual(keys, dkeys)
return keys
def tearDown(self): def tearDown(self):
delete_files() delete_files()
def setUp(self): def setUp(self):
dbm._defaultmod = self.module
delete_files() delete_files()
...@@ -137,11 +140,11 @@ class WhichDBTestCase(unittest.TestCase): ...@@ -137,11 +140,11 @@ class WhichDBTestCase(unittest.TestCase):
def test_main(): def test_main():
try: classes = [WhichDBTestCase]
for module in dbm_iterator(): for mod in dbm_iterator():
test.support.run_unittest(AnyDBMTestCase, WhichDBTestCase) classes.append(type("TestCase-" + mod.__name__, (AnyDBMTestCase,),
finally: {'module': mod}))
delete_files() test.support.run_unittest(*classes)
if __name__ == "__main__": if __name__ == "__main__":
test_main() test_main()
...@@ -401,7 +401,8 @@ init_dbm(void) { ...@@ -401,7 +401,8 @@ init_dbm(void) {
return; return;
d = PyModule_GetDict(m); d = PyModule_GetDict(m);
if (DbmError == NULL) if (DbmError == NULL)
DbmError = PyErr_NewException("_dbm.error", NULL, NULL); DbmError = PyErr_NewException("_dbm.error",
PyExc_IOError, NULL);
s = PyUnicode_FromString(which_dbm); s = PyUnicode_FromString(which_dbm);
if (s != NULL) { if (s != NULL) {
PyDict_SetItemString(d, "library", s); PyDict_SetItemString(d, "library", s);
......
...@@ -523,7 +523,7 @@ init_gdbm(void) { ...@@ -523,7 +523,7 @@ init_gdbm(void) {
if (m == NULL) if (m == NULL)
return; return;
d = PyModule_GetDict(m); d = PyModule_GetDict(m);
DbmError = PyErr_NewException("_gdbm.error", NULL, NULL); DbmError = PyErr_NewException("_gdbm.error", PyExc_IOError, NULL);
if (DbmError != NULL) { if (DbmError != NULL) {
PyDict_SetItemString(d, "error", DbmError); PyDict_SetItemString(d, "error", DbmError);
s = PyUnicode_FromString(dbmmodule_open_flags); s = PyUnicode_FromString(dbmmodule_open_flags);
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment