Commit 538dd366 authored by Julien Muchembled's avatar Julien Muchembled

CategoryTool._setCategoryMemberShip: bugfixes and cleanup

- duplicated default categories were dropped when the list of categories was
  passed without their base categories
- accept any iterable for 'category_list'
parent 257306cd
......@@ -30,7 +30,7 @@
ERP portal_categories tool.
from collections import deque
from OFS.Folder import Folder
from Products.CMFCore.utils import UniqueObject
from Products.ERP5Type.Globals import InitializeClass, DTMLFile
......@@ -633,84 +633,54 @@ class CategoryTool( UniqueObject, Folder, Base ):
category_list = (category_list, )
elif category_list is None:
category_list = ()
elif isinstance(category_list, (list, tuple)):
__traceback_info__ = (base_category_list, category_list)
raise TypeError('Category must be of string, tuple of string '
'or list of string type.')
if isinstance(base_category_list, str):
base_category_list = (base_category_list, )
# Build the ckecked_permission filter
if checked_permission is not None:
checkPermission = self.portal_membership.checkPermission
def permissionFilter(obj):
if checkPermission(checked_permission, obj):
return 0
return 1
new_category_list = []
default_dict = {}
spec_len = len(spec)
new_category_list = deque()
default_base_category_set = set()
default_category_set = set()
for path in self._getCategoryList(context):
my_base_id = self.getBaseCategoryId(path)
if my_base_id not in base_category_list:
# Keep each membership which is not in the
# specified list of base_category ids
keep_it = 0
if spec_len != 0 or (checked_permission is not None):
if my_base_id in base_category_list:
if spec or checked_permission is not None:
obj = self.unrestrictedTraverse(path, None)
if obj is not None:
if spec_len != 0:
# If spec is (), then we should keep nothing
# Everything will be replaced
# If spec is not (), Only keep this if not in our spec
my_type = obj.portal_type
keep_it = (my_type not in spec)
if (not keep_it) and (checked_permission is not None):
keep_it = permissionFilter(obj)
if keep_it:
if (spec and obj.portal_type not in spec) or not (
checked_permission is None or
checkPermission(checked_permission, obj)):
elif keep_default:
# We must remember the default value
# for each replaced category
if not default_dict.has_key(my_base_id):
default_dict[my_base_id] = path
# We now create a list of default category values
default_new_category_list = []
for path in default_dict.values():
if base or len(base_category_list) > 1:
if path in category_list:
# We must remember the default value for each replaced category
if keep_default and my_base_id not in default_base_category_set:
if path[len(base_category_list[0])+1:] in category_list:
# Keep each membership which is not in the
# specified list of base_category ids
# Before we append new category values (except default values)
# We must make sure however that multiple links are possible
default_path_found = {}
base = '' if base or len(base_category_list) > 1 \
else base_category_list[0] + '/'
for path in category_list:
if not path in ('', None):
if base or len(base_category_list) > 1:
# Only keep path which are member of base_category_list
if self.getBaseCategoryId(path) in base_category_list:
if path not in default_new_category_list or default_path_found.has_key(path):
default_path_found[path] = 1
if path not in ('', None):
if base:
path = base + path
elif self.getBaseCategoryId(path) not in base_category_list:
if path in default_category_set:
new_path = '%s/%s' % (base_category_list[0], path)
if new_path not in default_new_category_list:
# LOG("CategoryTool, setCategoryMembership", 0 ,
# 'new_category_list: %s' % str(new_category_list))
# LOG("CategoryTool, setCategoryMembership", 0 ,
# 'default_new_category_list: %s' % str(default_new_category_list))
self._setCategoryList(context, tuple(default_new_category_list + new_category_list))
self._setCategoryList(context, new_category_list)
security.declareProtected( Permissions.AccessContentsInformation, 'setDefaultCategoryMembership' )
......@@ -1075,6 +1075,27 @@ class TestCMFCategory(ERP5TypeTestCase):
# Check indexation
def test_setCategoryMemberShip(self):
person = self.getPersonModule().newContent(portal_type='Person')
category_tool = self.getCategoriesTool()
bc = category_tool.newContent()
base = ( + '/').__add__
def get(*args, **kw):
return category_tool.getCategoryMembershipList(person, *args, **kw)
def _set(*args, **kw):
return category_tool._setCategoryMembership(person, *args, **kw)
_set(, list('aa'))
self.assertEqual(get(, list('aa'))
_set(, list('baa'))
self.assertEqual(get(, list('aba'))
_set(, map(base, 'bb'), 1)
self.assertEqual(get(, list('bb'))
_set(, map(base, 'abb'), 1)
self.assertEqual(get(, list('bab'))
_set(, ())
def test_suite():
suite = unittest.TestSuite()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment