Commit 7756f651 authored by Jason R. Coombs's avatar Jason R. Coombs

Allow get_unpatched to be called to get unpatched version of a class or...

Allow get_unpatched to be called to get unpatched version of a class or function, further harmonizing the interfaces.
parent b7b9cb23
......@@ -5,6 +5,7 @@ Monkey patching of distutils.
import sys
import distutils.filelist
import platform
import types
from .py26compat import import_module
......@@ -18,7 +19,16 @@ if you think you need this functionality.
"""
def get_unpatched(cls):
def get_unpatched(item):
lookup = (
get_unpatched_class if isinstance(item, type) else
get_unpatched_function if isinstance(item, types.FunctionType) else
lambda item: None
)
return lookup(item)
def get_unpatched_class(cls):
"""Protect against re-patching the distutils if reloaded
Also ensures that no other distutils extension monkeypatched the distutils
......@@ -117,7 +127,7 @@ def patch_func(replacement, original):
setattr(target_mod, original.__name__, replacement)
def get_unpatched_func(candidate):
def get_unpatched_function(candidate):
return getattr(candidate, 'unpatched')
......
......@@ -24,7 +24,7 @@ from pkg_resources.extern.packaging.version import LegacyVersion
from setuptools.extern.six.moves import filterfalse
from .monkey import get_unpatched_func
from .monkey import get_unpatched
if platform.system() == 'Windows':
from setuptools.extern.six.moves import winreg
......@@ -87,7 +87,7 @@ def msvc9_find_vcvarsall(version):
if os.path.isfile(vcvarsall):
return vcvarsall
return get_unpatched_func(msvc9_find_vcvarsall)(version)
return get_unpatched(msvc9_find_vcvarsall)(version)
def msvc9_query_vcvarsall(ver, arch='x86', *args, **kwargs):
......@@ -120,7 +120,7 @@ def msvc9_query_vcvarsall(ver, arch='x86', *args, **kwargs):
"""
# Try to get environement from vcvarsall.bat (Classical way)
try:
orig = get_unpatched_func(msvc9_query_vcvarsall)
orig = get_unpatched(msvc9_query_vcvarsall)
return orig(ver, arch, *args, **kwargs)
except distutils.errors.DistutilsPlatformError:
# Pass error if Vcvarsall.bat is missing
......@@ -160,7 +160,7 @@ def msvc14_get_vc_env(plat_spec):
"""
# Try to get environment from vcvarsall.bat (Classical way)
try:
return get_unpatched_func(msvc14_get_vc_env)(plat_spec)
return get_unpatched(msvc14_get_vc_env)(plat_spec)
except distutils.errors.DistutilsPlatformError:
# Pass error Vcvarsall.bat is missing
pass
......@@ -183,7 +183,7 @@ def msvc14_gen_lib_options(*args, **kwargs):
import numpy as np
if LegacyVersion(np.__version__) < LegacyVersion('1.11.2'):
return np.distutils.ccompiler.gen_lib_options(*args, **kwargs)
return get_unpatched_func(msvc14_gen_lib_options)(*args, **kwargs)
return get_unpatched(msvc14_gen_lib_options)(*args, **kwargs)
def _augment_exception(exc, version, arch=''):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment