Commit b6e769cb authored by Jérome Perrin's avatar Jérome Perrin

bearer_token: py3

parent a94c38a4
import hashlib import hashlib
import hmac import hmac
from Products.ERP5Type.Cache import DEFAULT_CACHE_SCOPE from Products.ERP5Type.Cache import DEFAULT_CACHE_SCOPE
import six
CACHE_FACTORY_NAME = 'bearer_token_cache_factory' CACHE_FACTORY_NAME = 'bearer_token_cache_factory'
def getHMAC(self, key, body): def getHMAC(self, key, body):
# type: (bytes, bytes) -> str
digest = hmac.new(key, body, digestmod=hashlib.md5) digest = hmac.new(key, body, digestmod=hashlib.md5)
return digest.hexdigest() return digest.hexdigest()
...@@ -22,6 +24,10 @@ def _getCacheFactory(self, cache_factory_name): ...@@ -22,6 +24,10 @@ def _getCacheFactory(self, cache_factory_name):
return cache_tool.getRamCacheRoot().get(cache_factory_name) return cache_tool.getRamCacheRoot().get(cache_factory_name)
def setBearerToken(self, key, body, cache_factory_name=CACHE_FACTORY_NAME): def setBearerToken(self, key, body, cache_factory_name=CACHE_FACTORY_NAME):
# type: (str, Any, str) -> None
if not isinstance(key, six.string_types):
__traceback_info__ = key
raise TypeError("Wrong key type %s" % (type(key)))
cache_factory = _getCacheFactory(self, cache_factory_name) cache_factory = _getCacheFactory(self, cache_factory_name)
cache_duration = cache_factory.cache_duration cache_duration = cache_factory.cache_duration
for cache_plugin in cache_factory.getCachePluginList(): for cache_plugin in cache_factory.getCachePluginList():
...@@ -29,6 +35,8 @@ def setBearerToken(self, key, body, cache_factory_name=CACHE_FACTORY_NAME): ...@@ -29,6 +35,8 @@ def setBearerToken(self, key, body, cache_factory_name=CACHE_FACTORY_NAME):
body, cache_duration=cache_duration) body, cache_duration=cache_duration)
def getBearerToken(self, key, cache_factory_name=CACHE_FACTORY_NAME): def getBearerToken(self, key, cache_factory_name=CACHE_FACTORY_NAME):
# type: (str, Any, str) -> None
assert isinstance(key, six.string_types)
cache_factory = _getCacheFactory(self, cache_factory_name) cache_factory = _getCacheFactory(self, cache_factory_name)
for cache_plugin in cache_factory.getCachePluginList(): for cache_plugin in cache_factory.getCachePluginList():
cache_entry = cache_plugin.get(key, DEFAULT_CACHE_SCOPE) cache_entry = cache_plugin.get(key, DEFAULT_CACHE_SCOPE)
......
...@@ -5,9 +5,9 @@ except KeyError: ...@@ -5,9 +5,9 @@ except KeyError:
# not found # not found
return None return None
key = context.getPortalObject().portal_preferences.getPreferredBearerTokenKey() key = context.getPortalObject().portal_preferences.getPreferredBearerTokenKey().encode()
if context.Base_getHMAC(key, str(token_dict)) != token: if context.Base_getHMAC(key, str(token_dict).encode('utf-8')) != token:
# bizzare, not valid # bizzare, not valid
return None return None
......
if REQUEST is not None: if REQUEST is not None:
# mini security # mini security
return None return None
return context.getPortalObject().portal_preferences.getPreferredBearerTokenKey() return (context.getPortalObject().portal_preferences.getPreferredBearerTokenKey() or '').encode('utf-8')
...@@ -14,7 +14,7 @@ token = { ...@@ -14,7 +14,7 @@ token = {
'remote-addr': context.REQUEST.get('REMOTE_ADDR') 'remote-addr': context.REQUEST.get('REMOTE_ADDR')
} }
hmac = context.Base_getHMAC(key, str(token)) hmac = context.Base_getHMAC(key, str(token).encode('utf-8'))
context.Base_setBearerToken(hmac, token) context.Base_setBearerToken(hmac, token)
......
...@@ -108,7 +108,7 @@ class TestERP5BearerToken(ERP5TypeTestCase): ...@@ -108,7 +108,7 @@ class TestERP5BearerToken(ERP5TypeTestCase):
'remote-addr': self.portal.REQUEST.get('REMOTE_ADDR') 'remote-addr': self.portal.REQUEST.get('REMOTE_ADDR')
} }
hmac = self.portal.Base_getHMAC(self.portal.Base_getBearerTokenKey(), str( hmac = self.portal.Base_getHMAC(self.portal.Base_getBearerTokenKey(), str(
token)) token).encode('utf-8'))
self.portal.Base_setBearerToken(hmac, token) self.portal.Base_setBearerToken(hmac, token)
reference = self.getTokenCredential(self.portal.REQUEST) reference = self.getTokenCredential(self.portal.REQUEST)
self.assertEqual(reference, None) self.assertEqual(reference, None)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment