Commit 538dd366 authored by Julien Muchembled's avatar Julien Muchembled

CategoryTool._setCategoryMemberShip: bugfixes and cleanup

- duplicated default categories were dropped when the list of categories was
  passed without their base categories
- accept any iterable for 'category_list'
parent 257306cd
No related merge requests found
......@@ -30,7 +30,7 @@
"""\
ERP portal_categories tool.
"""
from collections import deque
from OFS.Folder import Folder
from Products.CMFCore.utils import UniqueObject
from Products.ERP5Type.Globals import InitializeClass, DTMLFile
......@@ -633,84 +633,54 @@ class CategoryTool( UniqueObject, Folder, Base ):
category_list = (category_list, )
elif category_list is None:
category_list = ()
elif isinstance(category_list, (list, tuple)):
pass
else:
__traceback_info__ = (base_category_list, category_list)
raise TypeError('Category must be of string, tuple of string '
'or list of string type.')
if isinstance(base_category_list, str):
base_category_list = (base_category_list, )
# Build the ckecked_permission filter
if checked_permission is not None:
checkPermission = self.portal_membership.checkPermission
def permissionFilter(obj):
if checkPermission(checked_permission, obj):
return 0
else:
return 1
new_category_list = []
default_dict = {}
spec_len = len(spec)
new_category_list = deque()
default_base_category_set = set()
default_category_set = set()
for path in self._getCategoryList(context):
my_base_id = self.getBaseCategoryId(path)
if my_base_id not in base_category_list:
# Keep each membership which is not in the
# specified list of base_category ids
new_category_list.append(path)
else:
keep_it = 0
if spec_len != 0 or (checked_permission is not None):
if my_base_id in base_category_list:
if spec or checked_permission is not None:
obj = self.unrestrictedTraverse(path, None)
if obj is not None:
if spec_len != 0:
# If spec is (), then we should keep nothing
# Everything will be replaced
# If spec is not (), Only keep this if not in our spec
my_type = obj.portal_type
keep_it = (my_type not in spec)
if (not keep_it) and (checked_permission is not None):
keep_it = permissionFilter(obj)
if keep_it:
new_category_list.append(path)
elif keep_default:
# We must remember the default value
# for each replaced category
if not default_dict.has_key(my_base_id):
default_dict[my_base_id] = path
# We now create a list of default category values
default_new_category_list = []
for path in default_dict.values():
if base or len(base_category_list) > 1:
if path in category_list:
default_new_category_list.append(path)
# If spec is (), then we should keep nothing
# Everything will be replaced
# If spec is not (), Only keep this if not in our spec
if (spec and obj.portal_type not in spec) or not (
checked_permission is None or
checkPermission(checked_permission, obj)):
new_category_list.append(path)
continue
# We must remember the default value for each replaced category
if keep_default and my_base_id not in default_base_category_set:
default_base_category_set.add(my_base_id)
default_category_set.add(path)
else:
if path[len(base_category_list[0])+1:] in category_list:
default_new_category_list.append(path)
# Keep each membership which is not in the
# specified list of base_category ids
new_category_list.append(path)
# Before we append new category values (except default values)
# We must make sure however that multiple links are possible
default_path_found = {}
base = '' if base or len(base_category_list) > 1 \
else base_category_list[0] + '/'
for path in category_list:
if not path in ('', None):
if base or len(base_category_list) > 1:
# Only keep path which are member of base_category_list
if self.getBaseCategoryId(path) in base_category_list:
if path not in default_new_category_list or default_path_found.has_key(path):
default_path_found[path] = 1
new_category_list.append(path)
if path not in ('', None):
if base:
path = base + path
elif self.getBaseCategoryId(path) not in base_category_list:
continue
if path in default_category_set:
default_category_set.remove(path)
new_category_list.appendleft(path)
else:
new_path = '%s/%s' % (base_category_list[0], path)
if new_path not in default_new_category_list:
new_category_list.append(new_path)
# LOG("CategoryTool, setCategoryMembership", 0 ,
# 'new_category_list: %s' % str(new_category_list))
# LOG("CategoryTool, setCategoryMembership", 0 ,
# 'default_new_category_list: %s' % str(default_new_category_list))
self._setCategoryList(context, tuple(default_new_category_list + new_category_list))
new_category_list.append(path)
self._setCategoryList(context, new_category_list)
security.declareProtected( Permissions.AccessContentsInformation, 'setDefaultCategoryMembership' )
......
......@@ -1075,6 +1075,27 @@ class TestCMFCategory(ERP5TypeTestCase):
# Check indexation
self.tic()
def test_setCategoryMemberShip(self):
person = self.getPersonModule().newContent(portal_type='Person')
category_tool = self.getCategoriesTool()
bc = category_tool.newContent()
bc.newContent('a')
bc.newContent('b')
base = (bc.id + '/').__add__
def get(*args, **kw):
return category_tool.getCategoryMembershipList(person, *args, **kw)
def _set(*args, **kw):
return category_tool._setCategoryMembership(person, *args, **kw)
_set(bc.id, list('aa'))
self.assertEqual(get(bc.id), list('aa'))
_set(bc.id, list('baa'))
self.assertEqual(get(bc.id), list('aba'))
_set(bc.id, map(base, 'bb'), 1)
self.assertEqual(get(bc.id), list('bb'))
_set(bc.id, map(base, 'abb'), 1)
self.assertEqual(get(bc.id), list('bab'))
_set(bc.id, ())
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestCMFCategory))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment