Commit 53bf9d3e authored by Hardik Juneja's avatar Hardik Juneja

[erp5_joblib] Add tests and cleanup

parent f0e1af12
......@@ -27,6 +27,5 @@ def test(self, active_process_path):
tic = time.time()
with parallel_backend('CMFActivity', n_jobs=2, active_process=active_process):
clf.fit(X, y)
log("I am here", time.time()-tic)
return 'ok', sklearn.__version__, joblib.__version__, time.time() - tic
\ No newline at end of file
......@@ -45,7 +45,9 @@
<item>
<key> <string>text_content_warning_message</string> </key>
<value>
<tuple/>
<tuple>
<string>W: 5, 0: Unused log imported from Products.ERP5Type.Log (unused-import)</string>
</tuple>
</value>
</item>
<item>
......
from copy import copy
import numpy as np
from Products.ERP5Type.Log import log
from Products.CMFActivity.ActiveResult import ActiveResult
from sklearn.base import clone
from sklearn.utils import check_random_state
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.externals import joblib
from sklearn.externals.joblib.parallel import parallel_backend
def combine(all_ensembles):
final_ensemble = copy(all_ensembles[0])
final_ensemble.estimators_ = []
for ensemble in all_ensembles:
final_ensemble.estimators_ += ensemble.estimators_
return final_ensemble
def train_model(model, X, y, sample_weight=None, random_state=None):
model.set_params(random_state=random_state)
if sample_weight is not None:
model.fit(X, y, sample_weight=sample_weight)
else:
model.fit(X, y)
return model
def grow_ensemble(base_model, X, y, sample_weight=None, n_estimators=1,
n_jobs=1, random_state=None):
random_state = check_random_state(random_state)
max_seed = np.iinfo('uint32').max
random_states = random_state.randint(max_seed + 1, size=n_estimators)
results = joblib.Parallel(n_jobs=n_jobs)(
joblib.delayed(train_model)(
clone(base_model), X, y,
sample_weight=sample_weight, random_state=rs)
for rs in random_states)
return combine(results)
def test_function(self, active_process_path):
from sklearn.datasets import load_digits
digits = load_digits()
X_train, X_test, y_train, y_test = train_test_split(
digits.data, digits.target, random_state=0)
# Create an active process
active_process = self.portal_activities.unrestrictedTraverse(active_process_path)
# Use CMFActivity as a backend for joblib
with parallel_backend('CMFActivity', n_jobs=2, active_process=active_process):
final_model = grow_ensemble(RandomForestClassifier(), X_train, y_train,
n_estimators=10, n_jobs=2, random_state=42)
score = final_model.score(X_test, y_test)
# Set result value and an id to the active result and post it
result = ActiveResult(result=score)
result.sig = 123
active_process.postResult(result)
log('ok', len(final_model.estimators_))
return 'ok', len(final_model.estimators_), score
......@@ -14,7 +14,7 @@
</item>
<item>
<key> <string>default_reference</string> </key>
<value> <string>joblibFunction</string> </value>
<value> <string>joblibRandomForest</string> </value>
</item>
<item>
<key> <string>description</string> </key>
......@@ -24,7 +24,7 @@
</item>
<item>
<key> <string>id</string> </key>
<value> <string>extension.erp5.joblibFunction</string> </value>
<value> <string>extension.erp5.joblibRandomForest</string> </value>
</item>
<item>
<key> <string>portal_type</string> </key>
......@@ -45,9 +45,7 @@
<item>
<key> <string>text_content_warning_message</string> </key>
<value>
<tuple>
<string>W: 3, 0: Unused ActiveResult imported from Products.CMFActivity.ActiveResult (unused-import)</string>
</tuple>
<tuple/>
</value>
</item>
<item>
......
from sklearn.externals.joblib.parallel import parallel_backend, Parallel, delayed
from Products.ERP5Type.Log import log
from Products.CMFActivity.ActiveResult import ActiveResult
import time
from math import sqrt
def abc(num):
def sleepAndSqrt(num):
time.sleep(2)
return sqrt(num)
def test(self, active_process_path):
active_process = self.portal_activities.unrestrictedTraverse(active_process_path)
# Use CMFActivity as a backend for joblob
with parallel_backend('CMFActivity', active_process=active_process):
result = Parallel(n_jobs=2, pre_dispatch='all', timeout=30, verbose=30)(delayed(abc)(i**2) for i in range(20))
log("I am here", result)
result = Parallel(n_jobs=2, pre_dispatch='all', timeout=30, verbose=30)(delayed(sleepAndSqrt)(i**2) for i in range(5))
# Set result value and an id to the active result and post it
result = ActiveResult(result=result)
result.sig = 12345
active_process.postResult(result)
return result
\ No newline at end of file
<?xml version="1.0"?>
<ZopeData>
<record id="1" aka="AAAAAAAAAAE=">
<pickle>
<global name="Extension Component" module="erp5.portal_type"/>
</pickle>
<pickle>
<dictionary>
<item>
<key> <string>_recorded_property_dict</string> </key>
<value>
<persistent> <string encoding="base64">AAAAAAAAAAI=</string> </persistent>
</value>
</item>
<item>
<key> <string>default_reference</string> </key>
<value> <string>joblibSimpleFunction</string> </value>
</item>
<item>
<key> <string>description</string> </key>
<value>
<none/>
</value>
</item>
<item>
<key> <string>id</string> </key>
<value> <string>extension.erp5.joblibSimpleFunction</string> </value>
</item>
<item>
<key> <string>portal_type</string> </key>
<value> <string>Extension Component</string> </value>
</item>
<item>
<key> <string>sid</string> </key>
<value>
<none/>
</value>
</item>
<item>
<key> <string>text_content_error_message</string> </key>
<value>
<tuple/>
</value>
</item>
<item>
<key> <string>text_content_warning_message</string> </key>
<value>
<tuple/>
</value>
</item>
<item>
<key> <string>version</string> </key>
<value> <string>erp5</string> </value>
</item>
<item>
<key> <string>workflow_history</string> </key>
<value>
<persistent> <string encoding="base64">AAAAAAAAAAM=</string> </persistent>
</value>
</item>
</dictionary>
</pickle>
</record>
<record id="2" aka="AAAAAAAAAAI=">
<pickle>
<global name="PersistentMapping" module="Persistence.mapping"/>
</pickle>
<pickle>
<dictionary>
<item>
<key> <string>data</string> </key>
<value>
<dictionary/>
</value>
</item>
</dictionary>
</pickle>
</record>
<record id="3" aka="AAAAAAAAAAM=">
<pickle>
<global name="PersistentMapping" module="Persistence.mapping"/>
</pickle>
<pickle>
<dictionary>
<item>
<key> <string>data</string> </key>
<value>
<dictionary>
<item>
<key> <string>component_validation_workflow</string> </key>
<value>
<persistent> <string encoding="base64">AAAAAAAAAAQ=</string> </persistent>
</value>
</item>
</dictionary>
</value>
</item>
</dictionary>
</pickle>
</record>
<record id="4" aka="AAAAAAAAAAQ=">
<pickle>
<global name="WorkflowHistoryList" module="Products.ERP5Type.patches.WorkflowTool"/>
</pickle>
<pickle>
<tuple>
<none/>
<list>
<dictionary>
<item>
<key> <string>action</string> </key>
<value> <string>validate</string> </value>
</item>
<item>
<key> <string>validation_state</string> </key>
<value> <string>validated</string> </value>
</item>
</dictionary>
</list>
</tuple>
</pickle>
</record>
</ZopeData>
......@@ -4,10 +4,6 @@ if REQUEST is not None:
from Products.ERP5Type.Log import log
from Products.CMFActivity.ActiveResult import ActiveResult
log("Executing batch_function on %s" % context.getRelativeUrl())
result = batch_function()
log("Result of batch_function on %s: %s" % (context.getRelativeUrl(), result))
return ActiveResult(result=result, sig=hash)
import time
from Products.ERP5Type.Log import log
timeout = 10
active_process = context.portal_activities.newActiveProcess()
active_process.useBTree()
active_process_id = active_process.getId()
path = active_process.getPhysicalPath()
context.portal_activities.activate(activity="SQLQueue", after_method_id="Base_callSafeFunction", active_process=active_process, tag='abc').Base_joblibRandomForestFunction(path)
context.portal_activities.activate(activity="SQLQueue", after_method_id="Base_callSafeFunction", active_process=active_process).Base_joblibRandomForestFunction(path)
return path
<?xml version="1.0"?>
<ZopeData>
<record id="1" aka="AAAAAAAAAAE=">
<pickle>
<global name="PythonScript" module="Products.PythonScripts.PythonScript"/>
</pickle>
<pickle>
<dictionary>
<item>
<key> <string>Script_magic</string> </key>
<value> <int>3</int> </value>
</item>
<item>
<key> <string>_bind_names</string> </key>
<value>
<object>
<klass>
<global name="NameAssignments" module="Shared.DC.Scripts.Bindings"/>
</klass>
<tuple/>
<state>
<dictionary>
<item>
<key> <string>_asgns</string> </key>
<value>
<dictionary>
<item>
<key> <string>name_container</string> </key>
<value> <string>container</string> </value>
</item>
<item>
<key> <string>name_context</string> </key>
<value> <string>context</string> </value>
</item>
<item>
<key> <string>name_m_self</string> </key>
<value> <string>script</string> </value>
</item>
<item>
<key> <string>name_subpath</string> </key>
<value> <string>traverse_subpath</string> </value>
</item>
</dictionary>
</value>
</item>
</dictionary>
</state>
</object>
</value>
</item>
<item>
<key> <string>_params</string> </key>
<value> <string>REQUEST=None</string> </value>
</item>
<item>
<key> <string>id</string> </key>
<value> <string>Base_driverScriptRandomForest</string> </value>
</item>
</dictionary>
</pickle>
</record>
</ZopeData>
import time
active_process = context.portal_activities.newActiveProcess()
active_process.useBTree()
active_process_id = active_process.getId()
path = active_process.getPhysicalPath()
context.portal_activities.activate(activity="SQLQueue", after_method_id="Base_callSafeFunction", active_process=active_process).Base_joblibSimpleFunction(path)
return path
......@@ -54,7 +54,7 @@
</item>
<item>
<key> <string>id</string> </key>
<value> <string>Base_driverScript</string> </value>
<value> <string>Base_driverScriptSquareRoot</string> </value>
</item>
</dictionary>
</pickle>
......
......@@ -12,11 +12,11 @@
</item>
<item>
<key> <string>_module</string> </key>
<value> <string>joblibFunction</string> </value>
<value> <string>joblibSimpleFunction</string> </value>
</item>
<item>
<key> <string>id</string> </key>
<value> <string>Base_joblibFunction</string> </value>
<value> <string>Base_joblibSimpleFunction</string> </value>
</item>
<item>
<key> <string>title</string> </key>
......
##############################################################################
#
# Copyright (c) 2002-2012 Nexedi SA and Contributors. All Rights Reserved.
#
# WARNING: This program as such is intended to be used by professional
# programmers who take the whole responsibility of assessing all potential
# consequences resulting from its eventual inadequacies and bugs
# End users who are looking for a ready-to-use solution with commercial
# guarantees and support are strongly adviced to contract a Free Software
# Service Company
#
# This program is Free Software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
#
##############################################################################
from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase
class Test(ERP5TypeTestCase):
"""
Test joblib usecases with CMFActivity
"""
def getTitle(self):
return "TestJoblibUsecases"
def test_randomForest(self):
path = self.portal.Base_driverScriptRandomForest()
self.tic(1)
active_process = self.portal.portal_activities.unrestrictedTraverse(path)
result = active_process.getResult(123)
self.assertEquals(0.98444444444444446, result.result)
def test_UnderRootOfSquaresFunction(self):
path = self.portal.Base_driverScriptSquareRoot()
self.tic(1)
active_process = self.portal.portal_activities.unrestrictedTraverse(path)
result = active_process.getResult(12345)
self.assertEquals([0.0, 1.0, 2.0, 3.0, 4.0], result.result)
<?xml version="1.0"?>
<ZopeData>
<record id="1" aka="AAAAAAAAAAE=">
<pickle>
<global name="Test Component" module="erp5.portal_type"/>
</pickle>
<pickle>
<dictionary>
<item>
<key> <string>_recorded_property_dict</string> </key>
<value>
<persistent> <string encoding="base64">AAAAAAAAAAI=</string> </persistent>
</value>
</item>
<item>
<key> <string>default_reference</string> </key>
<value> <string>testJoblibActivityUsecases</string> </value>
</item>
<item>
<key> <string>description</string> </key>
<value>
<none/>
</value>
</item>
<item>
<key> <string>id</string> </key>
<value> <string>test.erp5.testJoblibActivityUsecases</string> </value>
</item>
<item>
<key> <string>portal_type</string> </key>
<value> <string>Test Component</string> </value>
</item>
<item>
<key> <string>sid</string> </key>
<value>
<none/>
</value>
</item>
<item>
<key> <string>text_content_error_message</string> </key>
<value>
<tuple/>
</value>
</item>
<item>
<key> <string>text_content_warning_message</string> </key>
<value>
<tuple/>
</value>
</item>
<item>
<key> <string>version</string> </key>
<value> <string>erp5</string> </value>
</item>
<item>
<key> <string>workflow_history</string> </key>
<value>
<persistent> <string encoding="base64">AAAAAAAAAAM=</string> </persistent>
</value>
</item>
</dictionary>
</pickle>
</record>
<record id="2" aka="AAAAAAAAAAI=">
<pickle>
<global name="PersistentMapping" module="Persistence.mapping"/>
</pickle>
<pickle>
<dictionary>
<item>
<key> <string>data</string> </key>
<value>
<dictionary/>
</value>
</item>
</dictionary>
</pickle>
</record>
<record id="3" aka="AAAAAAAAAAM=">
<pickle>
<global name="PersistentMapping" module="Persistence.mapping"/>
</pickle>
<pickle>
<dictionary>
<item>
<key> <string>data</string> </key>
<value>
<dictionary>
<item>
<key> <string>component_validation_workflow</string> </key>
<value>
<persistent> <string encoding="base64">AAAAAAAAAAQ=</string> </persistent>
</value>
</item>
</dictionary>
</value>
</item>
</dictionary>
</pickle>
</record>
<record id="4" aka="AAAAAAAAAAQ=">
<pickle>
<global name="WorkflowHistoryList" module="Products.ERP5Type.patches.WorkflowTool"/>
</pickle>
<pickle>
<tuple>
<none/>
<list>
<dictionary>
<item>
<key> <string>action</string> </key>
<value> <string>validate</string> </value>
</item>
<item>
<key> <string>validation_state</string> </key>
<value> <string>validated</string> </value>
</item>
</dictionary>
</list>
</tuple>
</pickle>
</record>
</ZopeData>
extension.erp5.joblibFunction
extension.erp5.joblibGridSearch
\ No newline at end of file
extension.erp5.joblibGridSearch
extension.erp5.joblibRandomForest
extension.erp5.joblibSimpleFunction
\ No newline at end of file
test.erp5.testJoblibActivityUsecases
\ No newline at end of file
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment