Commit 70cb2e18 authored by Jérome Perrin's avatar Jérome Perrin

joblib: python3

parent df160681
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
############################################################################## ##############################################################################
import time import time
import six
import numpy as np import numpy as np
from copy import copy from copy import copy
...@@ -38,8 +39,12 @@ from sklearn.base import clone ...@@ -38,8 +39,12 @@ from sklearn.base import clone
from sklearn.utils import check_random_state from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib if six.PY2:
from sklearn.externals.joblib.parallel import parallel_backend, Parallel, delayed from sklearn.externals import joblib
from sklearn.externals.joblib.parallel import parallel_backend, Parallel, delayed
else:
import joblib
from joblib.parallel import parallel_backend, Parallel, delayed
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
from sklearn.svm import SVC from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
......
...@@ -40,7 +40,7 @@ class Test(ERP5TypeTestCase): ...@@ -40,7 +40,7 @@ class Test(ERP5TypeTestCase):
self.tic() self.tic()
active_process = self.portal.portal_activities.unrestrictedTraverse(path) active_process = self.portal.portal_activities.unrestrictedTraverse(path)
result = active_process.getResultList() result = active_process.getResultList()
self.assertAlmostEqual(0.98444444444444446, result[0].result) self.assertAlmostEqual(0.9, result[0].result, 0)
def test_UnderRootOfSquaresFunction(self): def test_UnderRootOfSquaresFunction(self):
path = self.portal.Base_driverScriptSquareRoot() path = self.portal.Base_driverScriptSquareRoot()
......
...@@ -25,18 +25,27 @@ ...@@ -25,18 +25,27 @@
# #
############################################################################## ##############################################################################
from zLOG import LOG, INFO, WARNING import logging
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from Products.CMFActivity.ActivityRuntimeEnvironment import \ from Products.CMFActivity.ActivityRuntimeEnvironment import \
getActivityRuntimeEnvironment getActivityRuntimeEnvironment
import six
logger = logging.getLogger(__name__)
try: try:
from sklearn.externals.joblib import register_parallel_backend if six.PY2:
from sklearn.externals.joblib.parallel import ParallelBackendBase, parallel_backend from sklearn.externals.joblib import register_parallel_backend
from sklearn.externals.joblib.parallel import FallbackToBackend, SequentialBackend from sklearn.externals.joblib.parallel import ParallelBackendBase, parallel_backend
from sklearn.externals.joblib.hashing import hash as joblib_hash from sklearn.externals.joblib.parallel import FallbackToBackend, SequentialBackend
from sklearn.externals.joblib.hashing import hash as joblib_hash
else:
from joblib import register_parallel_backend
from joblib.parallel import ParallelBackendBase, parallel_backend
from joblib.parallel import FallbackToBackend, SequentialBackend
from joblib.hashing import hash as joblib_hash
except ImportError: except ImportError:
LOG(__name__, WARNING, "Joblib cannot be imported, support disabled") logger.warn("Joblib cannot be imported, support disabled")
else: else:
class JoblibResult(object): class JoblibResult(object):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment