Commit 80ae1bd4 authored by Jérome Perrin's avatar Jérome Perrin

CMFActivity: port SQLJoblib to python3 WIP 🚧

parent 76f5a80a
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
############################################################################## ##############################################################################
import time import time
import six
import numpy as np import numpy as np
from copy import copy from copy import copy
...@@ -38,8 +39,12 @@ from sklearn.base import clone ...@@ -38,8 +39,12 @@ from sklearn.base import clone
from sklearn.utils import check_random_state from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib if six.PY2:
from sklearn.externals.joblib.parallel import parallel_backend, Parallel, delayed from sklearn.externals import joblib
from sklearn.externals.joblib.parallel import parallel_backend, Parallel, delayed
else:
import joblib
from joblib.parallel import parallel_backend, Parallel, delayed
from sklearn.datasets import load_digits from sklearn.datasets import load_digits
from sklearn.svm import SVC from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV from sklearn.model_selection import GridSearchCV
......
...@@ -40,7 +40,7 @@ class Test(ERP5TypeTestCase): ...@@ -40,7 +40,7 @@ class Test(ERP5TypeTestCase):
self.tic() self.tic()
active_process = self.portal.portal_activities.unrestrictedTraverse(path) active_process = self.portal.portal_activities.unrestrictedTraverse(path)
result = active_process.getResultList() result = active_process.getResultList()
self.assertAlmostEqual(0.98444444444444446, result[0].result) self.assertAlmostEqual(0.9, result[0].result, 0)
def test_UnderRootOfSquaresFunction(self): def test_UnderRootOfSquaresFunction(self):
path = self.portal.Base_driverScriptSquareRoot() path = self.portal.Base_driverScriptSquareRoot()
......
...@@ -77,10 +77,10 @@ CREATE TABLE %s ( ...@@ -77,10 +77,10 @@ CREATE TABLE %s (
return (tuple(m.object_path), m.method_id, m.activity_kw.get('signature'), return (tuple(m.object_path), m.method_id, m.activity_kw.get('signature'),
m.activity_kw.get('tag'), m.activity_kw.get('group_id')) m.activity_kw.get('tag'), m.activity_kw.get('group_id'))
_insert_template = ("INSERT INTO %s (uid," _insert_template = (b"INSERT INTO %s (uid,"
" path, active_process_uid, date, method_id, processing_node," b" path, active_process_uid, date, method_id, processing_node,"
" priority, group_method_id, tag, signature, serialization_tag," b" priority, group_method_id, tag, signature, serialization_tag,"
" message) VALUES\n(%s)") b" message) VALUES\n(%s)")
def prepareQueueMessageList(self, activity_tool, message_list): def prepareQueueMessageList(self, activity_tool, message_list):
db = activity_tool.getSQLConnection() db = activity_tool.getSQLConnection()
...@@ -92,9 +92,9 @@ CREATE TABLE %s ( ...@@ -92,9 +92,9 @@ CREATE TABLE %s (
if reset_uid: if reset_uid:
reset_uid = False reset_uid = False
# Overflow will result into IntegrityError. # Overflow will result into IntegrityError.
db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE)) db.query(b"SET @uid := %s" % str(getrandbits(UID_SAFE_BITSIZE)).encode())
try: try:
db.query(self._insert_template % (self.sql_table, values)) db.query(self._insert_template % (self.sql_table.encode(), values))
except MySQLdb.IntegrityError as e: except MySQLdb.IntegrityError as e:
if e.args[0] != DUP_ENTRY: if e.args[0] != DUP_ENTRY:
raise raise
...@@ -113,14 +113,14 @@ CREATE TABLE %s ( ...@@ -113,14 +113,14 @@ CREATE TABLE %s (
if m.is_registered: if m.is_registered:
active_process_uid = m.active_process_uid active_process_uid = m.active_process_uid
date = m.activity_kw.get('at_date') date = m.activity_kw.get('at_date')
row = ','.join(( row = b','.join((
b'@uid+%s' % i, b'@uid+%s' % str(i).encode(),
quote('/'.join(m.object_path)), quote('/'.join(m.object_path)),
b'NULL' if active_process_uid is None else str(active_process_uid), b'NULL' if active_process_uid is None else str(active_process_uid).encode(),
b"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)), b"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)),
quote(m.method_id), quote(m.method_id),
b'-1' if hasDependency(m) else b'0', b'-1' if hasDependency(m) else b'0',
bytes(m.activity_kw.get('priority', 1)), str(m.activity_kw.get('priority', 1)).encode(),
quote(m.getGroupId()), quote(m.getGroupId()),
quote(m.activity_kw.get('tag', '')), quote(m.activity_kw.get('tag', '')),
quote(m.activity_kw.get('signature', '')), quote(m.activity_kw.get('signature', '')),
...@@ -157,8 +157,8 @@ CREATE TABLE %s ( ...@@ -157,8 +157,8 @@ CREATE TABLE %s (
try: try:
# Select duplicates. # Select duplicates.
result = db.query(b"SELECT uid FROM message_job" result = db.query(b"SELECT uid FROM message_job"
" WHERE processing_node = 0 AND path = %s AND signature = %s" b" WHERE processing_node = 0 AND path = %s AND signature = %s"
" AND method_id = %s AND group_method_id = %s FOR UPDATE" % ( b" AND method_id = %s AND group_method_id = %s FOR UPDATE" % (
quote(path), quote(line.signature), quote(path), quote(line.signature),
quote(method_id), quote(line.group_method_id), quote(method_id), quote(line.group_method_id),
), 0)[1] ), 0)[1]
...@@ -166,10 +166,10 @@ CREATE TABLE %s ( ...@@ -166,10 +166,10 @@ CREATE TABLE %s (
if uid_list: if uid_list:
self.assignMessageList(db, processing_node, uid_list) self.assignMessageList(db, processing_node, uid_list)
else: else:
db.query("COMMIT") # XXX: useful ? db.query(b"COMMIT") # XXX: useful ?
except: except:
self._log(WARNING, 'Failed to reserve duplicates') self._log(WARNING, 'Failed to reserve duplicates')
db.query("ROLLBACK") db.query(b"ROLLBACK")
raise raise
if uid_list: if uid_list:
self._log(TRACE, 'Reserved duplicate messages: %r' % uid_list) self._log(TRACE, 'Reserved duplicate messages: %r' % uid_list)
......
...@@ -25,18 +25,27 @@ ...@@ -25,18 +25,27 @@
# #
############################################################################## ##############################################################################
from zLOG import LOG, INFO, WARNING import logging
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from Products.CMFActivity.ActivityRuntimeEnvironment import \ from Products.CMFActivity.ActivityRuntimeEnvironment import \
getActivityRuntimeEnvironment getActivityRuntimeEnvironment
import six
logger = logging.getLogger(__name__)
try: try:
if six.PY2:
from sklearn.externals.joblib import register_parallel_backend from sklearn.externals.joblib import register_parallel_backend
from sklearn.externals.joblib.parallel import ParallelBackendBase, parallel_backend from sklearn.externals.joblib.parallel import ParallelBackendBase, parallel_backend
from sklearn.externals.joblib.parallel import FallbackToBackend, SequentialBackend from sklearn.externals.joblib.parallel import FallbackToBackend, SequentialBackend
from sklearn.externals.joblib.hashing import hash as joblib_hash from sklearn.externals.joblib.hashing import hash as joblib_hash
else:
from joblib import register_parallel_backend
from joblib.parallel import ParallelBackendBase, parallel_backend
from joblib.parallel import FallbackToBackend, SequentialBackend
from joblib.hashing import hash as joblib_hash
except ImportError: except ImportError:
LOG(__name__, WARNING, "Joblib cannot be imported, support disabled") logger.warn("Joblib cannot be imported, support disabled")
else: else:
class JoblibResult(object): class JoblibResult(object):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment