Commit ed24bfcd authored by Jason R. Coombs's avatar Jason R. Coombs

Use mock to patch msvc9compiler.Reg.

parent 56d5fbca
...@@ -11,6 +11,7 @@ import tempfile ...@@ -11,6 +11,7 @@ import tempfile
import distutils.errors import distutils.errors
import pytest import pytest
import mock
from . import contexts from . import contexts
...@@ -19,48 +20,41 @@ __import__('setuptools') ...@@ -19,48 +20,41 @@ __import__('setuptools')
pytest.importorskip("distutils.msvc9compiler") pytest.importorskip("distutils.msvc9compiler")
class MockReg:
"""Mock for distutils.msvc9compiler.Reg. We patch it
with an instance of this class that mocks out the
functions that access the registry.
"""
def __init__(self, hkey_local_machine={}, hkey_current_user={}):
self.hklm = hkey_local_machine
self.hkcu = hkey_current_user
def __enter__(self): def mock_reg(hkcu=None, hklm=None):
self.original_read_keys = distutils.msvc9compiler.Reg.read_keys """
self.original_read_values = distutils.msvc9compiler.Reg.read_values Return a mock for distutils.msvc9compiler.Reg, patched
to mock out the functions that access the registry.
"""
_winreg = getattr(distutils.msvc9compiler, '_winreg', None) _winreg = getattr(distutils.msvc9compiler, '_winreg', None)
winreg = getattr(distutils.msvc9compiler, 'winreg', _winreg) winreg = getattr(distutils.msvc9compiler, 'winreg', _winreg)
hives = { hives = {
winreg.HKEY_CURRENT_USER: self.hkcu, winreg.HKEY_CURRENT_USER: hkcu or {},
winreg.HKEY_LOCAL_MACHINE: self.hklm, winreg.HKEY_LOCAL_MACHINE: hklm or {},
} }
@classmethod
def read_keys(cls, base, key): def read_keys(cls, base, key):
"""Return list of registry keys.""" """Return list of registry keys."""
hive = hives.get(base, {}) hive = hives.get(base, {})
return [k.rpartition('\\')[2] return [
for k in hive if k.startswith(key.lower())] k.rpartition('\\')[2]
for k in hive if k.startswith(key.lower())
]
@classmethod
def read_values(cls, base, key): def read_values(cls, base, key):
"""Return dict of registry keys and values.""" """Return dict of registry keys and values."""
hive = hives.get(base, {}) hive = hives.get(base, {})
return dict((k.rpartition('\\')[2], hive[k]) return dict(
for k in hive if k.startswith(key.lower())) (k.rpartition('\\')[2], hive[k])
for k in hive if k.startswith(key.lower())
)
distutils.msvc9compiler.Reg.read_keys = classmethod(read_keys) return mock.patch.multiple(distutils.msvc9compiler.Reg,
distutils.msvc9compiler.Reg.read_values = classmethod(read_values) read_keys=read_keys, read_values=read_values)
return self
def __exit__(self, exc_type, exc_value, exc_tb):
distutils.msvc9compiler.Reg.read_keys = self.original_read_keys
distutils.msvc9compiler.Reg.read_values = self.original_read_values
class TestMSVC9Compiler: class TestMSVC9Compiler:
...@@ -75,7 +69,7 @@ class TestMSVC9Compiler: ...@@ -75,7 +69,7 @@ class TestMSVC9Compiler:
# No registry entries or environment variable means we should # No registry entries or environment variable means we should
# not find anything # not find anything
with contexts.environment(VS90COMNTOOLS=None): with contexts.environment(VS90COMNTOOLS=None):
with MockReg(): with mock_reg():
assert find_vcvarsall(9.0) is None assert find_vcvarsall(9.0) is None
expected = distutils.errors.DistutilsPlatformError expected = distutils.errors.DistutilsPlatformError
...@@ -96,29 +90,33 @@ class TestMSVC9Compiler: ...@@ -96,29 +90,33 @@ class TestMSVC9Compiler:
open(mock_vcvarsall_bat_2, 'w').close() open(mock_vcvarsall_bat_2, 'w').close()
try: try:
# Ensure we get the current user's setting first # Ensure we get the current user's setting first
with MockReg( reg = mock_reg(
hkey_current_user={key_32: mock_installdir_1}, hkcu={
hkey_local_machine={ key_32: mock_installdir_1,
},
hklm={
key_32: mock_installdir_2, key_32: mock_installdir_2,
key_64: mock_installdir_2, key_64: mock_installdir_2,
} },
): )
with reg:
assert mock_vcvarsall_bat_1 == find_vcvarsall(9.0) assert mock_vcvarsall_bat_1 == find_vcvarsall(9.0)
# Ensure we get the local machine value if it's there # Ensure we get the local machine value if it's there
with MockReg(hkey_local_machine={key_32: mock_installdir_2}): with mock_reg(hklm={key_32: mock_installdir_2}):
assert mock_vcvarsall_bat_2 == find_vcvarsall(9.0) assert mock_vcvarsall_bat_2 == find_vcvarsall(9.0)
# Ensure we prefer the 64-bit local machine key # Ensure we prefer the 64-bit local machine key
# (*not* the Wow6432Node key) # (*not* the Wow6432Node key)
with MockReg( reg = mock_reg(
hkey_local_machine={ hklm={
# This *should* only exist on 32-bit machines # This *should* only exist on 32-bit machines
key_32: mock_installdir_1, key_32: mock_installdir_1,
# This *should* only exist on 64-bit machines # This *should* only exist on 64-bit machines
key_64: mock_installdir_2, key_64: mock_installdir_2,
} }
): )
with reg:
assert mock_vcvarsall_bat_1 == find_vcvarsall(9.0) assert mock_vcvarsall_bat_1 == find_vcvarsall(9.0)
finally: finally:
shutil.rmtree(mock_installdir_1) shutil.rmtree(mock_installdir_1)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment