Commit 80ae1bd4 authored by Jérome Perrin's avatar Jérome Perrin

CMFActivity: port SQLJoblib to python3 WIP 🚧

parent 76f5a80a
......@@ -27,6 +27,7 @@
##############################################################################
import time
import six
import numpy as np
from copy import copy
......@@ -38,8 +39,12 @@ from sklearn.base import clone
from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
from sklearn.externals.joblib.parallel import parallel_backend, Parallel, delayed
if six.PY2:
from sklearn.externals import joblib
from sklearn.externals.joblib.parallel import parallel_backend, Parallel, delayed
else:
import joblib
from joblib.parallel import parallel_backend, Parallel, delayed
from sklearn.datasets import load_digits
from sklearn.svm import SVC
from sklearn.model_selection import GridSearchCV
......
......@@ -40,7 +40,7 @@ class Test(ERP5TypeTestCase):
self.tic()
active_process = self.portal.portal_activities.unrestrictedTraverse(path)
result = active_process.getResultList()
self.assertAlmostEqual(0.98444444444444446, result[0].result)
self.assertAlmostEqual(0.9, result[0].result, 0)
def test_UnderRootOfSquaresFunction(self):
path = self.portal.Base_driverScriptSquareRoot()
......
......@@ -77,10 +77,10 @@ CREATE TABLE %s (
return (tuple(m.object_path), m.method_id, m.activity_kw.get('signature'),
m.activity_kw.get('tag'), m.activity_kw.get('group_id'))
_insert_template = ("INSERT INTO %s (uid,"
" path, active_process_uid, date, method_id, processing_node,"
" priority, group_method_id, tag, signature, serialization_tag,"
" message) VALUES\n(%s)")
_insert_template = (b"INSERT INTO %s (uid,"
b" path, active_process_uid, date, method_id, processing_node,"
b" priority, group_method_id, tag, signature, serialization_tag,"
b" message) VALUES\n(%s)")
def prepareQueueMessageList(self, activity_tool, message_list):
db = activity_tool.getSQLConnection()
......@@ -92,9 +92,9 @@ CREATE TABLE %s (
if reset_uid:
reset_uid = False
# Overflow will result into IntegrityError.
db.query("SET @uid := %s" % getrandbits(UID_SAFE_BITSIZE))
db.query(b"SET @uid := %s" % str(getrandbits(UID_SAFE_BITSIZE)).encode())
try:
db.query(self._insert_template % (self.sql_table, values))
db.query(self._insert_template % (self.sql_table.encode(), values))
except MySQLdb.IntegrityError as e:
if e.args[0] != DUP_ENTRY:
raise
......@@ -113,14 +113,14 @@ CREATE TABLE %s (
if m.is_registered:
active_process_uid = m.active_process_uid
date = m.activity_kw.get('at_date')
row = ','.join((
b'@uid+%s' % i,
row = b','.join((
b'@uid+%s' % str(i).encode(),
quote('/'.join(m.object_path)),
b'NULL' if active_process_uid is None else str(active_process_uid),
b'NULL' if active_process_uid is None else str(active_process_uid).encode(),
b"UTC_TIMESTAMP(6)" if date is None else quote(render_datetime(date)),
quote(m.method_id),
b'-1' if hasDependency(m) else b'0',
bytes(m.activity_kw.get('priority', 1)),
str(m.activity_kw.get('priority', 1)).encode(),
quote(m.getGroupId()),
quote(m.activity_kw.get('tag', '')),
quote(m.activity_kw.get('signature', '')),
......@@ -157,8 +157,8 @@ CREATE TABLE %s (
try:
# Select duplicates.
result = db.query(b"SELECT uid FROM message_job"
" WHERE processing_node = 0 AND path = %s AND signature = %s"
" AND method_id = %s AND group_method_id = %s FOR UPDATE" % (
b" WHERE processing_node = 0 AND path = %s AND signature = %s"
b" AND method_id = %s AND group_method_id = %s FOR UPDATE" % (
quote(path), quote(line.signature),
quote(method_id), quote(line.group_method_id),
), 0)[1]
......@@ -166,10 +166,10 @@ CREATE TABLE %s (
if uid_list:
self.assignMessageList(db, processing_node, uid_list)
else:
db.query("COMMIT") # XXX: useful ?
db.query(b"COMMIT") # XXX: useful ?
except:
self._log(WARNING, 'Failed to reserve duplicates')
db.query("ROLLBACK")
db.query(b"ROLLBACK")
raise
if uid_list:
self._log(TRACE, 'Reserved duplicate messages: %r' % uid_list)
......
......@@ -25,18 +25,27 @@
#
##############################################################################
from zLOG import LOG, INFO, WARNING
import logging
from ZODB.POSException import ConflictError
from Products.CMFActivity.ActivityRuntimeEnvironment import \
getActivityRuntimeEnvironment
import six
logger = logging.getLogger(__name__)
try:
if six.PY2:
from sklearn.externals.joblib import register_parallel_backend
from sklearn.externals.joblib.parallel import ParallelBackendBase, parallel_backend
from sklearn.externals.joblib.parallel import FallbackToBackend, SequentialBackend
from sklearn.externals.joblib.hashing import hash as joblib_hash
else:
from joblib import register_parallel_backend
from joblib.parallel import ParallelBackendBase, parallel_backend
from joblib.parallel import FallbackToBackend, SequentialBackend
from joblib.hashing import hash as joblib_hash
except ImportError:
LOG(__name__, WARNING, "Joblib cannot be imported, support disabled")
logger.warn("Joblib cannot be imported, support disabled")
else:
class JoblibResult(object):
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment