Commit 22206337 authored by Martijn Pieters's avatar Martijn Pieters

Big change

- Make DTML automatically html quote data indirectly taken from REQUEST
  which contain a '<'. Make sure (almost) all string operation preserve the
  taint on this data.

- Fix exceptions that use REQUEST data; quote the data.

- Don't let form and cookie values mask the REQUEST computed values such as
  URL0 and BASE1.
parent 3657c47f
...@@ -77,9 +77,17 @@ Zope Changes ...@@ -77,9 +77,17 @@ Zope Changes
- FileLibrary and GuestBook example applications gave anonymous - FileLibrary and GuestBook example applications gave anonymous
users the Manager proxy role when uploading files - a potential users the Manager proxy role when uploading files - a potential
vulnerability on production servers. vulnerability on production servers.
- Exceptions that use untrusted information from a REQUEST object in
the exception message now html-quote that information.
Features Added Features Added
- <dtml-var name> and &dtml.-name; will now automatically HTML-quote
unsafe data taken implictly from the REQUEST object. Data taken
explicitly from the REQUEST object is not affected, as well as any
other data not originating from REQUEST.
- ZCatalog index management ui is now integrated into ZCatalog rather - ZCatalog index management ui is now integrated into ZCatalog rather
than being a subobject managment screen with different tabs. than being a subobject managment screen with different tabs.
......
...@@ -21,6 +21,7 @@ import ExtensionClass, Acquisition ...@@ -21,6 +21,7 @@ import ExtensionClass, Acquisition
from Permission import pname from Permission import pname
from Owned import UnownableOwner from Owned import UnownableOwner
from Globals import InitializeClass from Globals import InitializeClass
from cgi import escape
class RoleManager: class RoleManager:
def manage_getPermissionMapping(self): def manage_getPermissionMapping(self):
...@@ -64,7 +65,7 @@ class RoleManager: ...@@ -64,7 +65,7 @@ class RoleManager:
raise 'Permission mapping error', ( raise 'Permission mapping error', (
"""Attempted to map a permission to a permission, %s, """Attempted to map a permission to a permission, %s,
that is not valid. This should never happen. (Waaa). that is not valid. This should never happen. (Waaa).
""" % p) """ % escape(p))
setPermissionMapping(name, wrapper, p) setPermissionMapping(name, wrapper, p)
...@@ -118,7 +119,7 @@ class PM(ExtensionClass.Base): ...@@ -118,7 +119,7 @@ class PM(ExtensionClass.Base):
# We want to make sure that any non-explicitly set methods are # We want to make sure that any non-explicitly set methods are
# private! # private!
if name.startswith('_') and name.endswith("_Permission"): return '' if name.startswith('_') and name.endswith("_Permission"): return ''
raise AttributeError, name raise AttributeError, escape(name)
PermissionMapper=PM PermissionMapper=PM
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
############################################################################## ##############################################################################
"""Access control support""" """Access control support"""
__version__='$Revision: 1.53 $'[11:-2] __version__='$Revision: 1.54 $'[11:-2]
from Globals import DTMLFile, MessageDialog, Dictionary from Globals import DTMLFile, MessageDialog, Dictionary
...@@ -20,6 +20,7 @@ from Acquisition import Implicit, Acquired, aq_get ...@@ -20,6 +20,7 @@ from Acquisition import Implicit, Acquired, aq_get
import Globals, ExtensionClass, PermissionMapping, Products import Globals, ExtensionClass, PermissionMapping, Products
from Permission import Permission from Permission import Permission
from App.Common import aq_base from App.Common import aq_base
from cgi import escape
ListType=type([]) ListType=type([])
...@@ -171,7 +172,8 @@ class RoleManager(ExtensionClass.Base, PermissionMapping.RoleManager): ...@@ -171,7 +172,8 @@ class RoleManager(ExtensionClass.Base, PermissionMapping.RoleManager):
return return
raise 'Invalid Permission', ( raise 'Invalid Permission', (
"The permission <em>%s</em> is invalid." % permission_to_manage) "The permission <em>%s</em> is invalid." %
escape(permission_to_manage))
_normal_manage_access=DTMLFile('dtml/access', globals()) _normal_manage_access=DTMLFile('dtml/access', globals())
...@@ -244,7 +246,7 @@ class RoleManager(ExtensionClass.Base, PermissionMapping.RoleManager): ...@@ -244,7 +246,7 @@ class RoleManager(ExtensionClass.Base, PermissionMapping.RoleManager):
valid_roles) valid_roles)
raise 'Invalid Permission', ( raise 'Invalid Permission', (
"The permission <em>%s</em> is invalid." % permission) "The permission <em>%s</em> is invalid." % escape(permission))
def acquiredRolesAreUsedBy(self, permission): def acquiredRolesAreUsedBy(self, permission):
"used by management screen" "used by management screen"
...@@ -256,7 +258,7 @@ class RoleManager(ExtensionClass.Base, PermissionMapping.RoleManager): ...@@ -256,7 +258,7 @@ class RoleManager(ExtensionClass.Base, PermissionMapping.RoleManager):
return type(roles) is ListType and 'CHECKED' or '' return type(roles) is ListType and 'CHECKED' or ''
raise 'Invalid Permission', ( raise 'Invalid Permission', (
"The permission <em>%s</em> is invalid." % permission) "The permission <em>%s</em> is invalid." % escape(permission))
# Local roles support # Local roles support
......
...@@ -38,6 +38,7 @@ import Globals, OFS.Folder, OFS.SimpleItem, os, Acquisition, Products ...@@ -38,6 +38,7 @@ import Globals, OFS.Folder, OFS.SimpleItem, os, Acquisition, Products
import re, zlib, Globals, cPickle, marshal, rotor import re, zlib, Globals, cPickle, marshal, rotor
import ZClasses, ZClasses.ZClass, AccessControl.Owned import ZClasses, ZClasses.ZClass, AccessControl.Owned
from urllib import quote from urllib import quote
from cgi import escape
from OFS.Folder import Folder from OFS.Folder import Folder
from Factory import Factory from Factory import Factory
...@@ -254,14 +255,14 @@ class Product(Folder, PermissionManager): ...@@ -254,14 +255,14 @@ class Product(Folder, PermissionManager):
"Product Distributions" "Product Distributions"
def __bobo_traverse__(self, REQUEST, name): def __bobo_traverse__(self, REQUEST, name):
if name[-7:] != '.tar.gz': raise 'Invalid Name', name if name[-7:] != '.tar.gz': raise 'Invalid Name', escape(name)
l=name.find('-') l=name.find('-')
id, version = name[:l], name[l+1:-7] id, version = name[:l], name[l+1:-7]
product=self.aq_parent product=self.aq_parent
if product.id==id and product.version==version: if product.id==id and product.version==version:
return Distribution(product) return Distribution(product)
raise 'Invalid version or product id', name raise 'Invalid version or product id', escape(name)
Distributions=Distributions() Distributions=Distributions()
......
...@@ -142,6 +142,8 @@ class DTMLFile(Bindings, Explicit, ClassicHTMLFile): ...@@ -142,6 +142,8 @@ class DTMLFile(Bindings, Explicit, ClassicHTMLFile):
# We're first, so get the REQUEST. # We're first, so get the REQUEST.
try: try:
req = self.aq_acquire('REQUEST') req = self.aq_acquire('REQUEST')
if hasattr(req, 'taintWrapper'):
req = req.taintWrapper()
except: pass except: pass
bound_data['REQUEST'] = req bound_data['REQUEST'] = req
ns.this = bound_data['context'] ns.this = bound_data['context']
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
"$Id: DT_String.py,v 1.49 2001/11/28 15:50:55 matt Exp $" "$Id: DT_String.py,v 1.50 2002/08/01 16:00:39 mj Exp $"
import thread,re,exceptions,os import thread,re,exceptions,os
...@@ -404,6 +404,7 @@ class String: ...@@ -404,6 +404,7 @@ class String:
# print '============================================================' # print '============================================================'
if mapping is None: mapping = {} if mapping is None: mapping = {}
if hasattr(mapping, 'taintWrapper'): mapping = mapping.taintWrapper()
if not hasattr(self,'_v_cooked'): if not hasattr(self,'_v_cooked'):
try: changed=self.__changed__() try: changed=self.__changed__()
......
...@@ -10,8 +10,8 @@ ...@@ -10,8 +10,8 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
'''$Id: DT_Util.py,v 1.86 2002/03/27 10:14:02 htrd Exp $''' '''$Id: DT_Util.py,v 1.87 2002/08/01 16:00:39 mj Exp $'''
__version__='$Revision: 1.86 $'[11:-2] __version__='$Revision: 1.87 $'[11:-2]
import re, os import re, os
from html_quote import html_quote, ustr # for import by other modules, dont remove! from html_quote import html_quote, ustr # for import by other modules, dont remove!
...@@ -67,6 +67,49 @@ if LIMITED_BUILTINS: ...@@ -67,6 +67,49 @@ if LIMITED_BUILTINS:
else: else:
d[name] = f d[name] = f
try:
# Wrap the string module so it can deal with TaintedString strings.
from ZPublisher.TaintedString import TaintedString
from types import FunctionType, BuiltinFunctionType, StringType
import string
class StringModuleWrapper:
def __getattr__(self, key):
attr = getattr(string, key)
if (isinstance(attr, FunctionType) or
isinstance(attr, BuiltinFunctionType)):
return StringFunctionWrapper(attr)
else:
return attr
class StringFunctionWrapper:
def __init__(self, method):
self._method = method
def __call__(self, *args, **kw):
tainted = 0
args = list(args)
for i in range(len(args)):
if isinstance(args[i], TaintedString):
tainted = 1
args[i] = str(args[i])
for k, v in kw.items():
if isinstance(v, TaintedString):
tainted = 1
kw[k] = str(v)
args = tuple(args)
retval = self._method(*args, **kw)
if tainted and isinstance(retval, StringType) and '<' in retval:
retval = TaintedString(retval)
return retval
d['string'] = StringModuleWrapper()
except ImportError:
# Use the string module already defined in RestrictedPython.Utilities
pass
# The functions below are meant to bind to the TemplateDict. # The functions below are meant to bind to the TemplateDict.
_marker = [] # Create a new marker object. _marker = [] # Create a new marker object.
......
...@@ -145,8 +145,8 @@ Evaluating expressions without rendering results ...@@ -145,8 +145,8 @@ Evaluating expressions without rendering results
''' # ' ''' # '
__rcs_id__='$Id: DT_Var.py,v 1.53 2002/05/21 14:41:41 andreasjung Exp $' __rcs_id__='$Id: DT_Var.py,v 1.54 2002/08/01 16:00:39 mj Exp $'
__version__='$Revision: 1.53 $'[11:-2] __version__='$Revision: 1.54 $'[11:-2]
from DT_Util import parse_params, name_param, str, ustr from DT_Util import parse_params, name_param, str, ustr
import os, string, re, sys import os, string, re, sys
...@@ -155,6 +155,7 @@ from cgi import escape ...@@ -155,6 +155,7 @@ from cgi import escape
from html_quote import html_quote # for import by other modules, dont remove! from html_quote import html_quote # for import by other modules, dont remove!
from types import StringType from types import StringType
from Acquisition import aq_base from Acquisition import aq_base
from ZPublisher.TaintedString import TaintedString
class Var: class Var:
name='var' name='var'
...@@ -232,9 +233,19 @@ class Var: ...@@ -232,9 +233,19 @@ class Var:
if hasattr(val, fmt): if hasattr(val, fmt):
val = _get(val, fmt)() val = _get(val, fmt)()
elif special_formats.has_key(fmt): elif special_formats.has_key(fmt):
val = special_formats[fmt](val, name, md) if fmt == 'html-quote' and \
isinstance(val, TaintedString):
# TaintedStrings will be quoted by default, don't
# double quote.
pass
else:
val = special_formats[fmt](val, name, md)
elif fmt=='': val='' elif fmt=='': val=''
else: val = fmt % val else:
if isinstance(val, TaintedString):
val = TaintedString(fmt % val)
else:
val = fmt % val
except: except:
t, v= sys.exc_type, sys.exc_value t, v= sys.exc_type, sys.exc_value
if hasattr(sys, 'exc_info'): t, v = sys.exc_info()[:2] if hasattr(sys, 'exc_info'): t, v = sys.exc_info()[:2]
...@@ -247,17 +258,40 @@ class Var: ...@@ -247,17 +258,40 @@ class Var:
if hasattr(val, fmt): if hasattr(val, fmt):
val = _get(val, fmt)() val = _get(val, fmt)()
elif special_formats.has_key(fmt): elif special_formats.has_key(fmt):
val = special_formats[fmt](val, name, md) if fmt == 'html-quote' and \
isinstance(val, TaintedString):
# TaintedStrings will be quoted by default, don't
# double quote.
pass
else:
val = special_formats[fmt](val, name, md)
elif fmt=='': val='' elif fmt=='': val=''
else: val = fmt % val else:
if isinstance(val, TaintedString):
val = TaintedString(fmt % val)
else:
val = fmt % val
# finally, pump it through the actual string format... # finally, pump it through the actual string format...
fmt=self.fmt fmt=self.fmt
if fmt=='s': val=ustr(val) if fmt=='s':
else: val = ('%'+self.fmt) % (val,) # Keep tainted strings as tainted strings here.
if not isinstance(val, TaintedString):
val=str(val)
else:
# Keep tainted strings as tainted strings here.
wastainted = 0
if isinstance(val, TaintedString): wastainted = 1
val = ('%'+self.fmt) % (val,)
if wastainted and '<' in val:
val = TaintedString(val)
# next, look for upper, lower, etc # next, look for upper, lower, etc
for f in self.modifiers: val=f(val) for f in self.modifiers:
if f.__name__ == 'html_quote' and isinstance(val, TaintedString):
# TaintedStrings will be quoted by default, don't double quote.
continue
val=f(val)
if have_arg('size'): if have_arg('size'):
size=args['size'] size=args['size']
...@@ -274,6 +308,9 @@ class Var: ...@@ -274,6 +308,9 @@ class Var:
else: l='...' else: l='...'
val=val+l val=val+l
if isinstance(val, TaintedString):
val = val.quoted()
return val return val
__call__=render __call__=render
...@@ -298,6 +335,9 @@ def url_quote_plus(v, name='(Unknown name)', md={}): ...@@ -298,6 +335,9 @@ def url_quote_plus(v, name='(Unknown name)', md={}):
def newline_to_br(v, name='(Unknown name)', md={}): def newline_to_br(v, name='(Unknown name)', md={}):
# Unsafe data is explicitly quoted here; we don't expect this to be HTML
# quoted later on anyway.
if isinstance(v, TaintedString): v = v.quoted()
v=str(v) v=str(v)
if v.find('\r') >= 0: v=''.join(v.split('\r')) if v.find('\r') >= 0: v=''.join(v.split('\r'))
if v.find('\n') >= 0: v='<br />\n'.join(v.split('\n')) if v.find('\n') >= 0: v='<br />\n'.join(v.split('\n'))
...@@ -368,7 +408,7 @@ def sql_quote(v, name='(Unknown name)', md={}): ...@@ -368,7 +408,7 @@ def sql_quote(v, name='(Unknown name)', md={}):
This is needed to securely insert values into sql This is needed to securely insert values into sql
string literals in templates that generate sql. string literals in templates that generate sql.
""" """
if v.find("'") >= 0: return "''".join(v.split("'")) if v.find("'") >= 0: return v.replace("'", "''")
return v return v
special_formats={ special_formats={
...@@ -389,7 +429,7 @@ special_formats={ ...@@ -389,7 +429,7 @@ special_formats={
} }
def spacify(val): def spacify(val):
if val.find('_') >= 0: val=" ".join(val.split('_')) if val.find('_') >= 0: val=val.replace('_', ' ')
return val return val
modifiers=(html_quote, url_quote, url_quote_plus, newline_to_br, modifiers=(html_quote, url_quote, url_quote_plus, newline_to_br,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
****************************************************************************/ ****************************************************************************/
static char cDocumentTemplate_module_documentation[] = static char cDocumentTemplate_module_documentation[] =
"" ""
"\n$Id: cDocumentTemplate.c,v 1.46 2002/05/09 13:19:17 htrd Exp $" "\n$Id: cDocumentTemplate.c,v 1.47 2002/08/01 16:00:39 mj Exp $"
; ;
#include "ExtensionClass.h" #include "ExtensionClass.h"
...@@ -22,7 +22,7 @@ static PyObject *py___call__, *py___roles__, *py_AUTHENTICATED_USER; ...@@ -22,7 +22,7 @@ static PyObject *py___call__, *py___roles__, *py_AUTHENTICATED_USER;
static PyObject *py_hasRole, *py__proxy_roles, *py_Unauthorized; static PyObject *py_hasRole, *py__proxy_roles, *py_Unauthorized;
static PyObject *py_Unauthorized_fmt, *py_guarded_getattr; static PyObject *py_Unauthorized_fmt, *py_guarded_getattr;
static PyObject *py__push, *py__pop, *py_aq_base, *py_renderNS; static PyObject *py__push, *py__pop, *py_aq_base, *py_renderNS;
static PyObject *py___class__, *html_quote, *ustr; static PyObject *py___class__, *html_quote, *ustr, *untaint_name;
/* ----------------------------------------------------- */ /* ----------------------------------------------------- */
...@@ -665,6 +665,7 @@ render_blocks_(PyObject *blocks, PyObject *rendered, ...@@ -665,6 +665,7 @@ render_blocks_(PyObject *blocks, PyObject *rendered,
{ {
PyObject *block, *t, *args; PyObject *block, *t, *args;
int l, i, k=0, append; int l, i, k=0, append;
int skip_html_quote = 0;
if ((l=PyList_Size(blocks)) < 0) return -1; if ((l=PyList_Size(blocks)) < 0) return -1;
for (i=0; i < l; i++) for (i=0; i < l; i++)
...@@ -689,6 +690,23 @@ render_blocks_(PyObject *blocks, PyObject *rendered, ...@@ -689,6 +690,23 @@ render_blocks_(PyObject *blocks, PyObject *rendered,
if (t == NULL) return -1; if (t == NULL) return -1;
if (! ( PyString_Check(t) || PyUnicode_Check(t) ) )
{
/* This might be a TaintedString object */
PyObject *untaintmethod = NULL;
untaintmethod = PyObject_GetAttr(t, untaint_name);
if (untaintmethod) {
/* Quote it */
UNLESS_ASSIGN(t,
PyObject_CallObject(untaintmethod, NULL)) return -1;
skip_html_quote = 1;
} else PyErr_Clear();
Py_XDECREF(untaintmethod);
}
if (! ( PyString_Check(t) || PyUnicode_Check(t) ) ) if (! ( PyString_Check(t) || PyUnicode_Check(t) ) )
{ {
args = PyTuple_New(1); args = PyTuple_New(1);
...@@ -700,9 +718,9 @@ render_blocks_(PyObject *blocks, PyObject *rendered, ...@@ -700,9 +718,9 @@ render_blocks_(PyObject *blocks, PyObject *rendered,
UNLESS(t) return -1; UNLESS(t) return -1;
} }
if (PyTuple_GET_SIZE(block) == 3) /* html_quote */ if (skip_html_quote == 0 && PyTuple_GET_SIZE(block) == 3)
/* html_quote */
{ {
int skip_html_quote;
if (PyString_Check(t)) if (PyString_Check(t))
{ {
if (strchr(PyString_AS_STRING(t), '&') || if (strchr(PyString_AS_STRING(t), '&') ||
...@@ -961,6 +979,7 @@ initcDocumentTemplate(void) ...@@ -961,6 +979,7 @@ initcDocumentTemplate(void)
UNLESS(py_isDocTemp=PyString_FromString("isDocTemp")) return; UNLESS(py_isDocTemp=PyString_FromString("isDocTemp")) return;
UNLESS(py_renderNS=PyString_FromString("__render_with_namespace__")) return; UNLESS(py_renderNS=PyString_FromString("__render_with_namespace__")) return;
UNLESS(py_blocks=PyString_FromString("blocks")) return; UNLESS(py_blocks=PyString_FromString("blocks")) return;
UNLESS(untaint_name=PyString_FromString("__untaint__")) return;
UNLESS(py_acquire=PyString_FromString("aq_acquire")) return; UNLESS(py_acquire=PyString_FromString("aq_acquire")) return;
UNLESS(py___call__=PyString_FromString("__call__")) return; UNLESS(py___call__=PyString_FromString("__call__")) return;
UNLESS(py___roles__=PyString_FromString("__roles__")) return; UNLESS(py___roles__=PyString_FromString("__roles__")) return;
......
...@@ -12,9 +12,9 @@ ...@@ -12,9 +12,9 @@
############################################################################## ##############################################################################
__doc__="""Object Manager __doc__="""Object Manager
$Id: ObjectManager.py,v 1.155 2002/07/09 15:14:51 zigg Exp $""" $Id: ObjectManager.py,v 1.156 2002/08/01 16:00:39 mj Exp $"""
__version__='$Revision: 1.155 $'[11:-2] __version__='$Revision: 1.156 $'[11:-2]
import App.Management, Acquisition, Globals, CopySupport, Products import App.Management, Acquisition, Globals, CopySupport, Products
import os, App.FactoryDispatcher, re, Products import os, App.FactoryDispatcher, re, Products
...@@ -34,6 +34,8 @@ import App.Common ...@@ -34,6 +34,8 @@ import App.Common
from AccessControl import getSecurityManager from AccessControl import getSecurityManager
from zLOG import LOG, ERROR from zLOG import LOG, ERROR
import sys,fnmatch,copy import sys,fnmatch,copy
from cgi import escape
from types import StringType, UnicodeType
import XMLExportImport import XMLExportImport
customImporters={ customImporters={
...@@ -51,11 +53,12 @@ def checkValidId(self, id, allow_dup=0): ...@@ -51,11 +53,12 @@ def checkValidId(self, id, allow_dup=0):
# check_valid_id() will be called again later with allow_dup # check_valid_id() will be called again later with allow_dup
# set to false before the object is added. # set to false before the object is added.
if not id or (type(id) != type('')): if not id or not isinstance(id, StringType):
if isinstance(id, UnicodeType): id = escape(id)
raise BadRequestException, ('Empty or invalid id specified', id) raise BadRequestException, ('Empty or invalid id specified', id)
if bad_id(id) is not None: if bad_id(id) is not None:
raise BadRequestException, ( raise BadRequestException, (
'The id "%s" contains characters illegal in URLs.' % id) 'The id "%s" contains characters illegal in URLs.' % escape(id))
if id[0]=='_': raise BadRequestException, ( if id[0]=='_': raise BadRequestException, (
'The id "%s" is invalid - it begins with an underscore.' % id) 'The id "%s" is invalid - it begins with an underscore.' % id)
if id[:3]=='aq_': raise BadRequestException, ( if id[:3]=='aq_': raise BadRequestException, (
...@@ -434,13 +437,13 @@ class ObjectManager( ...@@ -434,13 +437,13 @@ class ObjectManager(
for n in ids: for n in ids:
if n in p: if n in p:
return MessageDialog(title='Not Deletable', return MessageDialog(title='Not Deletable',
message='<EM>%s</EM> cannot be deleted.' % n, message='<EM>%s</EM> cannot be deleted.' % escape(n),
action ='./manage_main',) action ='./manage_main',)
while ids: while ids:
id=ids[-1] id=ids[-1]
v=self._getOb(id, self) v=self._getOb(id, self)
if v is self: if v is self:
raise 'BadRequest', '%s does not exist' % ids[-1] raise 'BadRequest', '%s does not exist' % escape(ids[-1])
self._delObject(id) self._delObject(id)
del ids[-1] del ids[-1]
if REQUEST is not None: if REQUEST is not None:
...@@ -511,7 +514,7 @@ class ObjectManager( ...@@ -511,7 +514,7 @@ class ObjectManager(
"""Import an object from a file""" """Import an object from a file"""
dirname, file=os.path.split(file) dirname, file=os.path.split(file)
if dirname: if dirname:
raise BadRequestException, 'Invalid file name %s' % file raise BadRequestException, 'Invalid file name %s' % escape(file)
instance_home = INSTANCE_HOME instance_home = INSTANCE_HOME
zope_home = ZOPE_HOME zope_home = ZOPE_HOME
...@@ -521,7 +524,7 @@ class ObjectManager( ...@@ -521,7 +524,7 @@ class ObjectManager(
if os.path.exists(filepath): if os.path.exists(filepath):
break break
else: else:
raise BadRequestException, 'File does not exist: %s' % file raise BadRequestException, 'File does not exist: %s' % escape(file)
self._importObjectFromFile(filepath, verify=not not REQUEST, self._importObjectFromFile(filepath, verify=not not REQUEST,
set_owner=set_owner) set_owner=set_owner)
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
############################################################################## ##############################################################################
"""Property management""" """Property management"""
__version__='$Revision: 1.44 $'[11:-2] __version__='$Revision: 1.45 $'[11:-2]
import ExtensionClass, Globals import ExtensionClass, Globals
import ZDOM import ZDOM
...@@ -21,6 +21,7 @@ from ZPublisher.Converters import type_converters ...@@ -21,6 +21,7 @@ from ZPublisher.Converters import type_converters
from Globals import DTMLFile, MessageDialog from Globals import DTMLFile, MessageDialog
from Acquisition import Implicit, aq_base from Acquisition import Implicit, aq_base
from Globals import Persistent from Globals import Persistent
from cgi import escape
...@@ -121,7 +122,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes): ...@@ -121,7 +122,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes):
def valid_property_id(self, id): def valid_property_id(self, id):
if not id or id[:1]=='_' or (id[:3]=='aq_') \ if not id or id[:1]=='_' or (id[:3]=='aq_') \
or (' ' in id) or hasattr(aq_base(self), id): or (' ' in id) or hasattr(aq_base(self), id) or escape(id) != id:
return 0 return 0
return 1 return 1
...@@ -188,7 +189,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes): ...@@ -188,7 +189,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes):
# the value to the type of the existing property. # the value to the type of the existing property.
self._wrapperCheck(value) self._wrapperCheck(value)
if not self.hasProperty(id): if not self.hasProperty(id):
raise 'Bad Request', 'The property %s does not exist' % id raise 'Bad Request', 'The property %s does not exist' % escape(id)
if type(value)==type(''): if type(value)==type(''):
proptype=self.getPropertyType(id) or 'string' proptype=self.getPropertyType(id) or 'string'
if type_converters.has_key(proptype): if type_converters.has_key(proptype):
...@@ -197,7 +198,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes): ...@@ -197,7 +198,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes):
def _delProperty(self, id): def _delProperty(self, id):
if not self.hasProperty(id): if not self.hasProperty(id):
raise ValueError, 'The property %s does not exist' % id raise ValueError, 'The property %s does not exist' % escape(id)
delattr(self,id) delattr(self,id)
self._properties=tuple(filter(lambda i, n=id: i['id'] != n, self._properties=tuple(filter(lambda i, n=id: i['id'] != n,
self._properties)) self._properties))
...@@ -281,7 +282,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes): ...@@ -281,7 +282,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes):
for name, value in props.items(): for name, value in props.items():
if self.hasProperty(name): if self.hasProperty(name):
if not 'w' in propdict[name].get('mode', 'wd'): if not 'w' in propdict[name].get('mode', 'wd'):
raise 'BadRequest', '%s cannot be changed' % name raise 'BadRequest', '%s cannot be changed' % escape(name)
self._updateProperty(name, value) self._updateProperty(name, value)
if REQUEST: if REQUEST:
message="Saved changes." message="Saved changes."
...@@ -324,7 +325,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes): ...@@ -324,7 +325,7 @@ class PropertyManager(ExtensionClass.Base, ZDOM.ElementWithAttributes):
for id in ids: for id in ids:
if not hasattr(aq_base(self), id): if not hasattr(aq_base(self), id):
raise 'BadRequest', ( raise 'BadRequest', (
'The property <em>%s</em> does not exist' % id) 'The property <em>%s</em> does not exist' % escape(id))
if (not 'd' in propdict[id].get('mode', 'wd')) or (id in nd): if (not 'd' in propdict[id].get('mode', 'wd')) or (id in nd):
return MessageDialog( return MessageDialog(
title ='Cannot delete %s' % id, title ='Cannot delete %s' % id,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
############################################################################## ##############################################################################
"""Property sheets""" """Property sheets"""
__version__='$Revision: 1.84 $'[11:-2] __version__='$Revision: 1.85 $'[11:-2]
import time, App.Management, Globals import time, App.Management, Globals
from webdav.WriteLockInterface import WriteLockInterface from webdav.WriteLockInterface import WriteLockInterface
...@@ -26,6 +26,7 @@ from Globals import Persistent ...@@ -26,6 +26,7 @@ from Globals import Persistent
from Traversable import Traversable from Traversable import Traversable
from Acquisition import aq_base from Acquisition import aq_base
from AccessControl import getSecurityManager from AccessControl import getSecurityManager
from cgi import escape
class View(App.Management.Tabs, Base): class View(App.Management.Tabs, Base):
"""A view of an object, typically used for management purposes """A view of an object, typically used for management purposes
...@@ -141,7 +142,7 @@ class PropertySheet(Traversable, Persistent, Implicit): ...@@ -141,7 +142,7 @@ class PropertySheet(Traversable, Persistent, Implicit):
def valid_property_id(self, id): def valid_property_id(self, id):
if not id or id[:1]=='_' or (id[:3]=='aq_') \ if not id or id[:1]=='_' or (id[:3]=='aq_') \
or (' ' in id): or (' ' in id) or escape(id) != id:
return 0 return 0
return 1 return 1
...@@ -180,7 +181,7 @@ class PropertySheet(Traversable, Persistent, Implicit): ...@@ -180,7 +181,7 @@ class PropertySheet(Traversable, Persistent, Implicit):
# systems. # systems.
self._wrapperCheck(value) self._wrapperCheck(value)
if not self.valid_property_id(id): if not self.valid_property_id(id):
raise 'Bad Request', 'Invalid property id, %s.' % id raise 'Bad Request', 'Invalid property id, %s.' % escape(id)
if not self.property_extensible_schema__(): if not self.property_extensible_schema__():
raise 'Bad Request', ( raise 'Bad Request', (
...@@ -190,7 +191,8 @@ class PropertySheet(Traversable, Persistent, Implicit): ...@@ -190,7 +191,8 @@ class PropertySheet(Traversable, Persistent, Implicit):
if hasattr(aq_base(self),id): if hasattr(aq_base(self),id):
if not (id=='title' and not self.__dict__.has_key(id)): if not (id=='title' and not self.__dict__.has_key(id)):
raise 'Bad Request', ( raise 'Bad Request', (
'Invalid property id, <em>%s</em>. It is in use.' % id) 'Invalid property id, <em>%s</em>. It is in use.' %
escape(id))
if meta is None: meta={} if meta is None: meta={}
prop={'id':id, 'type':type, 'meta':meta} prop={'id':id, 'type':type, 'meta':meta}
pself._properties=pself._properties+(prop,) pself._properties=pself._properties+(prop,)
...@@ -211,10 +213,10 @@ class PropertySheet(Traversable, Persistent, Implicit): ...@@ -211,10 +213,10 @@ class PropertySheet(Traversable, Persistent, Implicit):
# it will used to _replace_ the properties meta data. # it will used to _replace_ the properties meta data.
self._wrapperCheck(value) self._wrapperCheck(value)
if not self.hasProperty(id): if not self.hasProperty(id):
raise 'Bad Request', 'The property %s does not exist.' % id raise 'Bad Request', 'The property %s does not exist.' % escape(id)
propinfo=self.propertyInfo(id) propinfo=self.propertyInfo(id)
if not 'w' in propinfo.get('mode', 'wd'): if not 'w' in propinfo.get('mode', 'wd'):
raise 'Bad Request', '%s cannot be changed.' % id raise 'Bad Request', '%s cannot be changed.' % escape(id)
if type(value)==type(''): if type(value)==type(''):
proptype=propinfo.get('type', 'string') proptype=propinfo.get('type', 'string')
if type_converters.has_key(proptype): if type_converters.has_key(proptype):
...@@ -232,13 +234,13 @@ class PropertySheet(Traversable, Persistent, Implicit): ...@@ -232,13 +234,13 @@ class PropertySheet(Traversable, Persistent, Implicit):
# Delete the property with the given id. If a property with the # Delete the property with the given id. If a property with the
# given id does not exist, a ValueError is raised. # given id does not exist, a ValueError is raised.
if not self.hasProperty(id): if not self.hasProperty(id):
raise 'Bad Request', 'The property %s does not exist.' % id raise 'Bad Request', 'The property %s does not exist.' % escape(id)
vself=self.v_self() vself=self.v_self()
if hasattr(vself, '_reserved_names'): if hasattr(vself, '_reserved_names'):
nd=vself._reserved_names nd=vself._reserved_names
else: nd=() else: nd=()
if (not 'd' in self.propertyInfo(id).get('mode', 'wd')) or (id in nd): if (not 'd' in self.propertyInfo(id).get('mode', 'wd')) or (id in nd):
raise 'Bad Request', '%s cannot be deleted.' % id raise 'Bad Request', '%s cannot be deleted.' % escape(id)
delattr(vself, id) delattr(vself, id)
pself=self.p_self() pself=self.p_self()
pself._properties=tuple(filter(lambda i, n=id: i['id'] != n, pself._properties=tuple(filter(lambda i, n=id: i['id'] != n,
...@@ -262,7 +264,7 @@ class PropertySheet(Traversable, Persistent, Implicit): ...@@ -262,7 +264,7 @@ class PropertySheet(Traversable, Persistent, Implicit):
# Return a mapping containing property meta-data # Return a mapping containing property meta-data
for p in self._propertyMap(): for p in self._propertyMap():
if p['id']==id: return p if p['id']==id: return p
raise ValueError, 'The property %s does not exist.' % id raise ValueError, 'The property %s does not exist.' % escape(id)
def _propertyMap(self): def _propertyMap(self):
# Return a tuple of mappings, giving meta-data for properties. # Return a tuple of mappings, giving meta-data for properties.
...@@ -418,7 +420,7 @@ class PropertySheet(Traversable, Persistent, Implicit): ...@@ -418,7 +420,7 @@ class PropertySheet(Traversable, Persistent, Implicit):
for name, value in props.items(): for name, value in props.items():
if self.hasProperty(name): if self.hasProperty(name):
if not 'w' in propdict[name].get('mode', 'wd'): if not 'w' in propdict[name].get('mode', 'wd'):
raise 'BadRequest', '%s cannot be changed' % name raise 'BadRequest', '%s cannot be changed' % escape(name)
self._updateProperty(name, value) self._updateProperty(name, value)
if REQUEST is not None: if REQUEST is not None:
return MessageDialog( return MessageDialog(
...@@ -487,13 +489,13 @@ class DAVProperties(Virtual, PropertySheet, View): ...@@ -487,13 +489,13 @@ class DAVProperties(Virtual, PropertySheet, View):
return getattr(self, method)() return getattr(self, method)()
def _setProperty(self, id, value, type='string', meta=None): def _setProperty(self, id, value, type='string', meta=None):
raise ValueError, '%s cannot be set.' % id raise ValueError, '%s cannot be set.' % escape(id)
def _updateProperty(self, id, value): def _updateProperty(self, id, value):
raise ValueError, '%s cannot be updated.' % id raise ValueError, '%s cannot be updated.' % escape(id)
def _delProperty(self, id): def _delProperty(self, id):
raise ValueError, '%s cannot be deleted.' % id raise ValueError, '%s cannot be deleted.' % escape(id)
def _propertyMap(self): def _propertyMap(self):
# Only use getlastmodified if returns a value # Only use getlastmodified if returns a value
......
...@@ -18,6 +18,7 @@ from Acquisition import Acquired ...@@ -18,6 +18,7 @@ from Acquisition import Acquired
import Persistence import Persistence
from thread import allocate_lock from thread import allocate_lock
from zLOG import LOG, WARNING from zLOG import LOG, WARNING
from cgi import escape
broken_klasses={} broken_klasses={}
broken_klasses_lock = allocate_lock() broken_klasses_lock = allocate_lock()
...@@ -42,7 +43,7 @@ class BrokenClass(Acquisition.Explicit, SimpleItem.Item, ...@@ -42,7 +43,7 @@ class BrokenClass(Acquisition.Explicit, SimpleItem.Item,
def __getattr__(self, name): def __getattr__(self, name):
if name[:3]=='_p_': if name[:3]=='_p_':
return BrokenClass.inheritedAttribute('__getattr__')(self, name) return BrokenClass.inheritedAttribute('__getattr__')(self, name)
raise AttributeError, name raise AttributeError, escape(name)
manage=manage_main=Globals.DTMLFile('dtml/brokenEdit',globals()) manage=manage_main=Globals.DTMLFile('dtml/brokenEdit',globals())
manage_workspace=manage manage_workspace=manage
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
"""Text Index """Text Index
""" """
__version__ = '$Revision: 1.32 $'[11:-2] __version__ = '$Revision: 1.33 $'[11:-2]
import re import re
...@@ -35,6 +35,7 @@ from BTrees.IIBTree import difference, weightedIntersection ...@@ -35,6 +35,7 @@ from BTrees.IIBTree import difference, weightedIntersection
from Lexicon import Lexicon from Lexicon import Lexicon
from types import * from types import *
from cgi import escape
class Op: class Op:
def __init__(self, name): def __init__(self, name):
...@@ -482,7 +483,7 @@ class TextIndex(Persistent, Implicit, SimpleItem): ...@@ -482,7 +483,7 @@ class TextIndex(Persistent, Implicit, SimpleItem):
query_operator = operator_dict.get(qop) query_operator = operator_dict.get(qop)
if query_operator is None: if query_operator is None:
raise exceptions.RuntimeError, ("Invalid operator '%s' " raise exceptions.RuntimeError, ("Invalid operator '%s' "
"for a TextIndex" % qop) "for a TextIndex" % escape(qop))
r = None r = None
for key in record.keys: for key in record.keys:
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
"""Simple column indices""" """Simple column indices"""
__version__='$Revision: 1.12 $'[11:-2] __version__='$Revision: 1.13 $'[11:-2]
from Globals import Persistent from Globals import Persistent
from Acquisition import Implicit from Acquisition import Implicit
...@@ -30,6 +30,7 @@ import BTrees.Length ...@@ -30,6 +30,7 @@ import BTrees.Length
from Products.PluginIndexes.common.util import parseIndexRequest from Products.PluginIndexes.common.util import parseIndexRequest
import sys import sys
from cgi import escape
_marker = [] _marker = []
...@@ -330,7 +331,7 @@ class UnIndex(Persistent, Implicit, SimpleItem): ...@@ -330,7 +331,7 @@ class UnIndex(Persistent, Implicit, SimpleItem):
# experimental code for specifing the operator # experimental code for specifing the operator
operator = record.get('operator',self.useOperator) operator = record.get('operator',self.useOperator)
if not operator in self.operators : if not operator in self.operators :
raise RuntimeError,"operator not valid: %s" % operator raise RuntimeError,"operator not valid: %s" % escape(operator)
# depending on the operator we use intersection or union # depending on the operator we use intersection or union
if operator=="or": set_func = union if operator=="or": set_func = union
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
# #
############################################################################ ############################################################################
__version__='$Revision: 1.10 $'[11:-2] __version__='$Revision: 1.11 $'[11:-2]
import Globals import Globals
from Persistence import Persistent from Persistence import Persistent
from ZODB import TimeStamp from ZODB import TimeStamp
...@@ -26,12 +26,13 @@ import SessionInterfaces ...@@ -26,12 +26,13 @@ import SessionInterfaces
from SessionPermissions import * from SessionPermissions import *
from common import DEBUG from common import DEBUG
import os, time, random, string, binascii, sys, re import os, time, random, string, binascii, sys, re
from cgi import escape
b64_trans = string.maketrans('+/', '-.') b64_trans = string.maketrans('+/', '-.')
b64_untrans = string.maketrans('-.', '+/') b64_untrans = string.maketrans('-.', '+/')
badidnamecharsin = re.compile('[\?&;, ]').search badidnamecharsin = re.compile('[\?&;,<> ]').search
badcookiecharsin = re.compile('[;, ]').search badcookiecharsin = re.compile('[;,<>& ]').search
twodotsin = re.compile('(\w*\.){2,}').search twodotsin = re.compile('(\w*\.){2,}').search
_marker = [] _marker = []
...@@ -119,7 +120,8 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs): ...@@ -119,7 +120,8 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs):
# somebody screwed with the REQUEST instance during # somebody screwed with the REQUEST instance during
# this request. # this request.
raise BrowserIdManagerErr, ( raise BrowserIdManagerErr, (
'Ill-formed browserid in REQUEST.browser_id_: %s' % bid 'Ill-formed browserid in REQUEST.browser_id_: %s' %
escape(bid)
) )
return bid return bid
# fall through & ck id namespaces if bid is not in request. # fall through & ck id namespaces if bid is not in request.
...@@ -235,7 +237,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs): ...@@ -235,7 +237,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs):
def setBrowserIdName(self, k): def setBrowserIdName(self, k):
""" sets browser id name string """ """ sets browser id name string """
if not (type(k) is type('') and k and not badidnamecharsin(k)): if not (type(k) is type('') and k and not badidnamecharsin(k)):
raise BrowserIdManagerErr, 'Bad id name string %s' % repr(k) raise BrowserIdManagerErr, 'Bad id name string %s' % escape(repr(k))
self.browserid_name = k self.browserid_name = k
security.declareProtected(ACCESS_CONTENTS_PERM, 'getBrowserIdName') security.declareProtected(ACCESS_CONTENTS_PERM, 'getBrowserIdName')
...@@ -309,7 +311,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs): ...@@ -309,7 +311,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs):
def setCookiePath(self, path=''): def setCookiePath(self, path=''):
""" sets cookie 'path' element for id cookie """ """ sets cookie 'path' element for id cookie """
if not (type(path) is type('') and not badcookiecharsin(path)): if not (type(path) is type('') and not badcookiecharsin(path)):
raise BrowserIdManagerErr, 'Bad cookie path %s' % repr(path) raise BrowserIdManagerErr, 'Bad cookie path %s' % escape(repr(path))
self.cookie_path = path self.cookie_path = path
security.declareProtected(ACCESS_CONTENTS_PERM, 'getCookiePath') security.declareProtected(ACCESS_CONTENTS_PERM, 'getCookiePath')
...@@ -323,7 +325,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs): ...@@ -323,7 +325,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs):
if type(days) not in (type(1), type(1.0)): if type(days) not in (type(1), type(1.0)):
raise BrowserIdManagerErr,( raise BrowserIdManagerErr,(
'Bad cookie lifetime in days %s (requires integer value)' 'Bad cookie lifetime in days %s (requires integer value)'
% repr(days) % escape(repr(days))
) )
self.cookie_life_days = int(days) self.cookie_life_days = int(days)
...@@ -337,7 +339,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs): ...@@ -337,7 +339,7 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs):
""" sets cookie 'domain' element for id cookie """ """ sets cookie 'domain' element for id cookie """
if type(domain) is not type(''): if type(domain) is not type(''):
raise BrowserIdManagerErr, ( raise BrowserIdManagerErr, (
'Cookie domain must be string: %s' % repr(domain) 'Cookie domain must be string: %s' % escape(repr(domain))
) )
if not domain: if not domain:
self.cookie_domain = '' self.cookie_domain = ''
...@@ -346,11 +348,11 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs): ...@@ -346,11 +348,11 @@ class BrowserIdManager(Item, Persistent, Implicit, RoleManager, Owned, Tabs):
raise BrowserIdManagerErr, ( raise BrowserIdManagerErr, (
'Cookie domain must contain at least two dots (e.g. ' 'Cookie domain must contain at least two dots (e.g. '
'".zope.org" or "www.zope.org") or it must be left blank. : ' '".zope.org" or "www.zope.org") or it must be left blank. : '
'%s' % `domain` '%s' % escape(`domain`)
) )
if badcookiecharsin(domain): if badcookiecharsin(domain):
raise BrowserIdManagerErr, ( raise BrowserIdManagerErr, (
'Bad characters in cookie domain %s' % `domain` 'Bad characters in cookie domain %s' % escape(`domain`)
) )
self.cookie_domain = domain self.cookie_domain = domain
......
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
""" """
Transient Object Container Class ('timeslice'-based design). Transient Object Container Class ('timeslice'-based design).
$Id: Transience.py,v 1.25 2002/06/21 01:51:43 chrism Exp $ $Id: Transience.py,v 1.26 2002/08/01 16:00:41 mj Exp $
""" """
__version__='$Revision: 1.25 $'[11:-2] __version__='$Revision: 1.26 $'[11:-2]
import Globals import Globals
from Globals import HTMLFile from Globals import HTMLFile
...@@ -42,6 +42,7 @@ from TransientObject import TransientObject ...@@ -42,6 +42,7 @@ from TransientObject import TransientObject
import thread import thread
import ThreadLock import ThreadLock
import Queue import Queue
from cgi import escape
_marker = [] _marker = []
...@@ -324,14 +325,14 @@ class TransientObjectContainer(SimpleItem): ...@@ -324,14 +325,14 @@ class TransientObjectContainer(SimpleItem):
def _setTimeout(self, timeout_mins): def _setTimeout(self, timeout_mins):
if type(timeout_mins) is not type(1): if type(timeout_mins) is not type(1):
raise TypeError, (timeout_mins, "Must be integer") raise TypeError, (escape(`timeout_mins`), "Must be integer")
self._timeout_secs = t_secs = timeout_mins * 60 self._timeout_secs = t_secs = timeout_mins * 60
# timeout_slices == fewest number of timeslices that's >= t_secs # timeout_slices == fewest number of timeslices that's >= t_secs
self._timeout_slices=int(math.ceil(float(t_secs)/self._period)) self._timeout_slices=int(math.ceil(float(t_secs)/self._period))
def _setLimit(self, limit): def _setLimit(self, limit):
if type(limit) is not type(1): if type(limit) is not type(1):
raise TypeError, (limit, "Must be integer") raise TypeError, (escape(`limit`), "Must be integer")
self._limit = limit self._limit = limit
security.declareProtected(MGMT_SCREEN_PERM, 'nudge') security.declareProtected(MGMT_SCREEN_PERM, 'nudge')
......
...@@ -10,10 +10,11 @@ ...@@ -10,10 +10,11 @@
# FOR A PARTICULAR PURPOSE # FOR A PARTICULAR PURPOSE
# #
############################################################################## ##############################################################################
__version__='$Revision: 1.16 $'[11:-2] __version__='$Revision: 1.17 $'[11:-2]
import re import re
from types import ListType, TupleType, UnicodeType from types import ListType, TupleType, UnicodeType
from cgi import escape
def field2string(v): def field2string(v):
if hasattr(v,'read'): return v.read() if hasattr(v,'read'): return v.read()
...@@ -53,7 +54,7 @@ def field2int(v): ...@@ -53,7 +54,7 @@ def field2int(v):
try: return int(v) try: return int(v)
except ValueError: except ValueError:
raise ValueError, ( raise ValueError, (
"An integer was expected in the value '%s'" % v "An integer was expected in the value '%s'" % escape(v)
) )
raise ValueError, 'Empty entry when <strong>integer</strong> expected' raise ValueError, 'Empty entry when <strong>integer</strong> expected'
...@@ -65,7 +66,8 @@ def field2float(v): ...@@ -65,7 +66,8 @@ def field2float(v):
try: return float(v) try: return float(v)
except ValueError: except ValueError:
raise ValueError, ( raise ValueError, (
"A floating-point number was expected in the value '%s'" % v "A floating-point number was expected in the value '%s'" %
escape(v)
) )
raise ValueError, ( raise ValueError, (
'Empty entry when <strong>floating-point number</strong> expected') 'Empty entry when <strong>floating-point number</strong> expected')
...@@ -81,7 +83,7 @@ def field2long(v): ...@@ -81,7 +83,7 @@ def field2long(v):
try: return long(v) try: return long(v)
except ValueError: except ValueError:
raise ValueError, ( raise ValueError, (
"A long integer was expected in the value '%s'" % v "A long integer was expected in the value '%s'" % escape(v)
) )
raise ValueError, 'Empty entry when <strong>integer</strong> expected' raise ValueError, 'Empty entry when <strong>integer</strong> expected'
...@@ -100,7 +102,11 @@ def field2lines(v): ...@@ -100,7 +102,11 @@ def field2lines(v):
def field2date(v): def field2date(v):
from DateTime import DateTime from DateTime import DateTime
v = field2string(v) v = field2string(v)
return DateTime(v) try:
v = DateTime(v)
except DateTime.SyntaxError, e:
raise DateTime.SyntaxError, escape(e)
return v
def field2boolean(v): def field2boolean(v):
return not not v return not not v
......
...@@ -11,14 +11,16 @@ ...@@ -11,14 +11,16 @@
# #
############################################################################## ##############################################################################
__version__='$Revision: 1.74 $'[11:-2] __version__='$Revision: 1.75 $'[11:-2]
import re, sys, os, urllib, time, random, cgi, codecs import re, sys, os, urllib, time, random, cgi, codecs
from BaseRequest import BaseRequest from BaseRequest import BaseRequest
from HTTPResponse import HTTPResponse from HTTPResponse import HTTPResponse
from cgi import FieldStorage, escape from cgi import FieldStorage, escape
from urllib import quote, unquote, splittype, splitport from urllib import quote, unquote, splittype, splitport
from copy import deepcopy
from Converters import get_converter from Converters import get_converter
from TaintedString import TaintedString
from maybe_lock import allocate_lock from maybe_lock import allocate_lock
xmlrpc=None # Placeholder for module that we'll import if we have to. xmlrpc=None # Placeholder for module that we'll import if we have to.
...@@ -241,6 +243,7 @@ class HTTPRequest(BaseRequest): ...@@ -241,6 +243,7 @@ class HTTPRequest(BaseRequest):
self.response=response self.response=response
other=self.other={'RESPONSE': response} other=self.other={'RESPONSE': response}
self.form={} self.form={}
self.taintedform={}
self.steps=[] self.steps=[]
self._steps=[] self._steps=[]
self._lazies={} self._lazies={}
...@@ -306,13 +309,22 @@ class HTTPRequest(BaseRequest): ...@@ -306,13 +309,22 @@ class HTTPRequest(BaseRequest):
# vars with the same name - they are more like default values # vars with the same name - they are more like default values
# for names not otherwise specified in the form. # for names not otherwise specified in the form.
cookies={} cookies={}
taintedcookies={}
k=get_env('HTTP_COOKIE','') k=get_env('HTTP_COOKIE','')
if k: if k:
parse_cookie(k, cookies) parse_cookie(k, cookies)
for k,item in cookies.items(): for k, v in cookies.items():
if not other.has_key(k): istainted = 0
other[k]=item if '<' in k:
k = TaintedString(k)
istainted = 1
if '<' in v:
v = TaintedString(v)
istainted = 1
if istainted:
taintedcookies[k] = v
self.cookies=cookies self.cookies=cookies
self.taintedcookies = taintedcookies
def processInputs( def processInputs(
self, self,
...@@ -343,6 +355,7 @@ class HTTPRequest(BaseRequest): ...@@ -343,6 +355,7 @@ class HTTPRequest(BaseRequest):
form=self.form form=self.form
other=self.other other=self.other
taintedform=self.taintedform
meth=None meth=None
fs=FieldStorage(fp=fp,environ=environ,keep_blank_values=1) fs=FieldStorage(fp=fp,environ=environ,keep_blank_values=1)
...@@ -367,10 +380,12 @@ class HTTPRequest(BaseRequest): ...@@ -367,10 +380,12 @@ class HTTPRequest(BaseRequest):
lt=type([]) lt=type([])
CGI_name=isCGI_NAME CGI_name=isCGI_NAME
defaults={} defaults={}
tainteddefaults={}
converter=seqf=None converter=seqf=None
for item in fslist: for item in fslist:
isFileUpload = 0
key=item.name key=item.name
if (hasattr(item,'file') and hasattr(item,'filename') if (hasattr(item,'file') and hasattr(item,'filename')
and hasattr(item,'headers')): and hasattr(item,'headers')):
...@@ -380,11 +395,15 @@ class HTTPRequest(BaseRequest): ...@@ -380,11 +395,15 @@ class HTTPRequest(BaseRequest):
# or 'content-type' in map(lower, item.headers.keys()) # or 'content-type' in map(lower, item.headers.keys())
)): )):
item=FileUpload(item) item=FileUpload(item)
isFileUpload = 1
else: else:
item=item.value item=item.value
flags=0 flags=0
character_encoding = '' character_encoding = ''
# Variables for potentially unsafe values.
tainted = None
converter_type = None
# Loop through the different types and set # Loop through the different types and set
# the appropriate flags # the appropriate flags
...@@ -408,6 +427,7 @@ class HTTPRequest(BaseRequest): ...@@ -408,6 +427,7 @@ class HTTPRequest(BaseRequest):
if c is not None: if c is not None:
converter=c converter=c
converter_type = type_name
flags=flags|CONVERTED flags=flags|CONVERTED
elif type_name == 'list': elif type_name == 'list':
seqf=list seqf=list
...@@ -446,6 +466,11 @@ class HTTPRequest(BaseRequest): ...@@ -446,6 +466,11 @@ class HTTPRequest(BaseRequest):
# Filter out special names from form: # Filter out special names from form:
if CGI_name(key) or key[:5]=='HTTP_': continue if CGI_name(key) or key[:5]=='HTTP_': continue
# If the key is tainted, mark it so as well.
tainted_key = key
if '<' in key:
tainted_key = TaintedString(key)
if flags: if flags:
# skip over empty fields # skip over empty fields
...@@ -455,6 +480,17 @@ class HTTPRequest(BaseRequest): ...@@ -455,6 +480,17 @@ class HTTPRequest(BaseRequest):
if flags&REC: if flags&REC:
key=key.split(".") key=key.split(".")
key, attr=".".join(key[:-1]), key[-1] key, attr=".".join(key[:-1]), key[-1]
# Update the tainted_key if necessary
tainted_key = key
if '<' in key:
tainted_key = TaintedString(key)
# Attributes cannot hold a <.
if '<' in attr:
raise ValueError(
"%s is not a valid record attribute name" %
escape(attr))
# defer conversion # defer conversion
if flags&CONVERTED: if flags&CONVERTED:
...@@ -470,6 +506,23 @@ class HTTPRequest(BaseRequest): ...@@ -470,6 +506,23 @@ class HTTPRequest(BaseRequest):
item = converter(item.encode('latin1')) item = converter(item.encode('latin1'))
else: else:
item=converter(item) item=converter(item)
# Flag potentially unsafe values
if converter_type in ('string', 'required', 'text',
'ustring', 'utext'):
if not isFileUpload and '<' in item:
tainted = TaintedString(item)
elif converter_type in ('tokens', 'lines',
'utokens', 'ulines'):
is_tainted = 0
tainted = item[:]
for i in range(len(tainted)):
if '<' in tainted[i]:
is_tainted = 1
tainted[i] = TaintedString(tainted[i])
if not is_tainted:
tainted = None
except: except:
if (not item and not (flags&DEFAULT) and if (not item and not (flags&DEFAULT) and
defaults.has_key(key)): defaults.has_key(key)):
...@@ -478,14 +531,31 @@ class HTTPRequest(BaseRequest): ...@@ -478,14 +531,31 @@ class HTTPRequest(BaseRequest):
item=getattr(item,attr) item=getattr(item,attr)
if flags&RECORDS: if flags&RECORDS:
item = getattr(item[-1], attr) item = getattr(item[-1], attr)
if tainteddefaults.has_key(tainted_key):
tainted = tainteddefaults[tainted_key]
if flags&RECORD:
tainted = getattr(tainted, attr)
if flags&RECORDS:
tainted = getattr(tainted[-1], attr)
else: else:
raise raise
elif not isFileUpload and '<' in item:
# Flag potentially unsafe values
tainted = TaintedString(item)
# If the key is tainted, we need to store stuff in the
# tainted dict as well, even if the value is safe.
if '<' in tainted_key and tainted is None:
tainted = item
#Determine which dictionary to use #Determine which dictionary to use
if flags&DEFAULT: if flags&DEFAULT:
mapping_object = defaults mapping_object = defaults
tainted_mapping = tainteddefaults
else: else:
mapping_object = form mapping_object = form
tainted_mapping = taintedform
#Insert in dictionary #Insert in dictionary
if mapping_object.has_key(key): if mapping_object.has_key(key):
...@@ -494,6 +564,47 @@ class HTTPRequest(BaseRequest): ...@@ -494,6 +564,47 @@ class HTTPRequest(BaseRequest):
#in the list. reclist is mutable. #in the list. reclist is mutable.
reclist = mapping_object[key] reclist = mapping_object[key]
x = reclist[-1] x = reclist[-1]
if tainted:
# Store a tainted copy as well
if not tainted_mapping.has_key(tainted_key):
tainted_mapping[tainted_key] = deepcopy(
reclist)
treclist = tainted_mapping[tainted_key]
lastrecord = treclist[-1]
if not hasattr(lastrecord, attr):
if flags&SEQUENCE: tainted = [tainted]
setattr(lastrecord, attr, tainted)
else:
if flags&SEQUENCE:
getattr(lastrecord,
attr).append(tainted)
else:
newrec = record()
setattr(newrec, attr, tainted)
treclist.append(newrec)
elif tainted_mapping.has_key(tainted_key):
# If we already put a tainted value into this
# recordset, we need to make sure the whole
# recordset is built.
treclist = tainted_mapping[tainted_key]
lastrecord = treclist[-1]
copyitem = item
if not hasattr(lastrecord, attr):
if flags&SEQUENCE: copyitem = [copyitem]
setattr(lastrecord, attr, copyitem)
else:
if flags&SEQUENCE:
getattr(lastrecord,
attr).append(copyitem)
else:
newrec = record()
setattr(newrec, attr, copyitem)
treclist.append(newrec)
if not hasattr(x,attr): if not hasattr(x,attr):
#If the attribute does not #If the attribute does not
#exist, setit #exist, setit
...@@ -529,9 +640,57 @@ class HTTPRequest(BaseRequest): ...@@ -529,9 +640,57 @@ class HTTPRequest(BaseRequest):
# it is not a sequence so # it is not a sequence so
# set the attribute # set the attribute
setattr(b,attr,item) setattr(b,attr,item)
# Store a tainted copy as well if necessary
if tainted:
if not tainted_mapping.has_key(tainted_key):
tainted_mapping[tainted_key] = deepcopy(
mapping_object[key])
b = tainted_mapping[tainted_key]
if flags&SEQUENCE:
seq = getattr(b, attr, [])
seq.append(tainted)
setattr(b, attr, seq)
else:
setattr(b, attr, tainted)
elif tainted_mapping.has_key(tainted_key):
# If we already put a tainted value into this
# record, we need to make sure the whole record
# is built.
b = tainted_mapping[tainted_key]
if flags&SEQUENCE:
seq = getattr(b, attr, [])
seq.append(item)
setattr(b, attr, seq)
else:
setattr(b, attr, item)
else: else:
# it is not a record or list of records # it is not a record or list of records
found=mapping_object[key] found=mapping_object[key]
if tainted:
# Store a tainted version if necessary
if not tainted_mapping.has_key(tainted_key):
copied = deepcopy(found)
if isinstance(copied, lt):
tainted_mapping[tainted_key] = copied
else:
tainted_mapping[tainted_key] = [copied]
tainted_mapping[tainted_key].append(tainted)
elif tainted_mapping.has_key(tainted_key):
# We may already have encountered a tainted
# value for this key, and the tainted_mapping
# needs to hold all the values.
tfound = tainted_mapping[tainted_key]
if isinstance(tfound, lt):
tainted_mapping[tainted_key].append(item)
else:
tainted_mapping[tainted_key] = [tfound,
item]
if type(found) is lt: if type(found) is lt:
found.append(item) found.append(item)
else: else:
...@@ -546,25 +705,70 @@ class HTTPRequest(BaseRequest): ...@@ -546,25 +705,70 @@ class HTTPRequest(BaseRequest):
if flags&SEQUENCE: item=[item] if flags&SEQUENCE: item=[item]
setattr(a,attr,item) setattr(a,attr,item)
mapping_object[key]=[a] mapping_object[key]=[a]
if tainted:
# Store a tainted copy if necessary
a = record()
if flags&SEQUENCE: tainted = [tainted]
setattr(a, attr, tainted)
tainted_mapping[tainted_key] = [a]
elif flags&RECORD: elif flags&RECORD:
# Create a new record, set its attribute # Create a new record, set its attribute
# and put it in the dictionary # and put it in the dictionary
if flags&SEQUENCE: item=[item] if flags&SEQUENCE: item=[item]
r = mapping_object[key]=record() r = mapping_object[key]=record()
setattr(r,attr,item) setattr(r,attr,item)
if tainted:
# Store a tainted copy if necessary
if flags&SEQUENCE: tainted = [tainted]
r = tainted_mapping[tainted_key] = record()
setattr(r, attr, tainted)
else: else:
# it is not a record or list of records # it is not a record or list of records
if flags&SEQUENCE: item=[item] if flags&SEQUENCE: item=[item]
mapping_object[key]=item mapping_object[key]=item
if tainted:
# Store a tainted copy if necessary
if flags&SEQUENCE: tainted = [tainted]
tainted_mapping[tainted_key] = tainted
else: else:
# This branch is for case when no type was specified. # This branch is for case when no type was specified.
mapping_object = form mapping_object = form
if not isFileUpload and '<' in item:
tainted = TaintedString(item)
elif '<' in key:
tainted = item
#Insert in dictionary #Insert in dictionary
if mapping_object.has_key(key): if mapping_object.has_key(key):
# it is not a record or list of records # it is not a record or list of records
found=mapping_object[key] found=mapping_object[key]
if tainted:
# Store a tainted version if necessary
if not taintedform.has_key(tainted_key):
copied = deepcopy(found)
if isinstance(copied, lt):
taintedform[tainted_key] = copied
else:
taintedform[tainted_key] = [copied]
taintedform[tainted_key].append(tainted)
elif taintedform.has_key(tainted_key):
# We may already have encountered a tainted value
# for this key, and the taintedform needs to hold
# all the values.
tfound = taintedform[tainted_key]
if isinstance(tfound, lt):
taintedform[tainted_key].append(item)
else:
taintedform[tainted_key] = [tfound, item]
if type(found) is lt: if type(found) is lt:
found.append(item) found.append(item)
else: else:
...@@ -572,20 +776,53 @@ class HTTPRequest(BaseRequest): ...@@ -572,20 +776,53 @@ class HTTPRequest(BaseRequest):
mapping_object[key]=found mapping_object[key]=found
else: else:
mapping_object[key]=item mapping_object[key]=item
if tainted:
taintedform[tainted_key] = tainted
#insert defaults into form dictionary #insert defaults into form dictionary
if defaults: if defaults:
for key, value in defaults.items(): for key, value in defaults.items():
tainted_key = key
if '<' in key: tainted_key = TaintedString(key)
if not form.has_key(key): if not form.has_key(key):
# if the form does not have the key, # if the form does not have the key,
# set the default # set the default
form[key]=value form[key]=value
if tainteddefaults.has_key(tainted_key):
taintedform[tainted_key] = \
tainteddefaults[tainted_key]
else: else:
#The form has the key #The form has the key
tdefault = tainteddefaults.get(tainted_key, value)
if isinstance(value, record): if isinstance(value, record):
# if the key is mapped to a record, get the # if the key is mapped to a record, get the
# record # record
r = form[key] r = form[key]
# First deal with tainted defaults.
if taintedform.has_key(tainted_key):
tainted = taintedform[tainted_key]
for k, v in tdefault.__dict__.items():
if not hasattr(tainted, k):
setattr(tainted, k, v)
elif tainteddefaults.has_key(tainted_key):
# Find out if any of the tainted default
# attributes needs to be copied over.
missesdefault = 0
for k, v in tdefault.__dict__.items():
if not hasattr(r, k):
missesdefault = 1
break
if missesdefault:
tainted = deepcopy(r)
for k, v in tdefault.__dict__.items():
if not hasattr(tainted, k):
setattr(tainted, k, v)
taintedform[tainted_key] = tainted
for k, v in value.__dict__.items(): for k, v in value.__dict__.items():
# loop through the attributes and value # loop through the attributes and value
# in the default dictionary # in the default dictionary
...@@ -594,12 +831,61 @@ class HTTPRequest(BaseRequest): ...@@ -594,12 +831,61 @@ class HTTPRequest(BaseRequest):
# the attribute, set it to the default # the attribute, set it to the default
setattr(r,k,v) setattr(r,k,v)
form[key] = r form[key] = r
elif isinstance(value, lt): elif isinstance(value, lt):
# the default value is a list # the default value is a list
l = form[key] l = form[key]
if not isinstance(l, lt): if not isinstance(l, lt):
l = [l] l = [l]
# First deal with tainted copies
if taintedform.has_key(tainted_key):
tainted = taintedform[tainted_key]
if not isinstance(tainted, lt):
tainted = [tainted]
for defitem in tdefault:
if isinstance(defitem, record):
for k, v in defitem.__dict__.items():
for origitem in tainted:
if not hasattr(origitem, k):
setattr(origitem, k, v)
else:
if not defitem in tainted:
tainted.append(defitem)
taintedform[tainted_key] = tainted
elif tainteddefaults.has_key(tainted_key):
missesdefault = 0
for defitem in tdefault:
if isinstance(defitem, record):
try:
for k, v in \
defitem.__dict__.items():
for origitem in l:
if not hasattr(origitem, k):
missesdefault = 1
raise "Break"
except "Break":
break
else:
if not defitem in l:
missesdefault = 1
break
if missesdefault:
tainted = deepcopy(l)
for defitem in tdefault:
if isinstance(defitem, record):
for k, v in defitem.__dict__.items():
for origitem in tainted:
if not hasattr(origitem, k):
setattr(origitem, k, v)
else:
if not defitem in tainted:
tainted.append(defitem)
taintedform[tainted_key] = tainted
for x in value: for x in value:
# for each x in the list
if isinstance(x, record): if isinstance(x, record):
# if the x is a record # if the x is a record
for k, v in x.__dict__.items(): for k, v in x.__dict__.items():
...@@ -643,6 +929,8 @@ class HTTPRequest(BaseRequest): ...@@ -643,6 +929,8 @@ class HTTPRequest(BaseRequest):
attr = new attr = new
if form.has_key(k): if form.has_key(k):
# If the form has the split key get its value # If the form has the split key get its value
tainted_split_key = k
if '<' in k: tainted_split_key = TaintedString(k)
item =form[k] item =form[k]
if isinstance(item, record): if isinstance(item, record):
# if the value is mapped to a record, check if it # if the value is mapped to a record, check if it
...@@ -660,16 +948,34 @@ class HTTPRequest(BaseRequest): ...@@ -660,16 +948,34 @@ class HTTPRequest(BaseRequest):
# convert it to a tuple and set it # convert it to a tuple and set it
value=tuple(getattr(x,attr)) value=tuple(getattr(x,attr))
setattr(x,attr,value) setattr(x,attr,value)
# Do the same for the tainted counterpart
if taintedform.has_key(tainted_split_key):
tainted = taintedform[tainted_split_key]
if isinstance(item, record):
seq = tuple(getattr(tainted, attr))
setattr(tainted, attr, seq)
else:
for trec in tainted:
if hasattr(trec, attr):
seq = getattr(trec, attr)
seq = tuple(seq)
setattr(trec, attr, seq)
else: else:
# the form does not have the split key # the form does not have the split key
tainted_key = key
if '<' in key: tainted_key = TaintedString(key)
if form.has_key(key): if form.has_key(key):
# if it has the original key, get the item # if it has the original key, get the item
# convert it to a tuple # convert it to a tuple
item=form[key] item=form[key]
item=tuple(form[key]) item=tuple(form[key])
form[key]=item form[key]=item
if taintedform.has_key(tainted_key):
tainted = tuple(taintedform[tainted_key])
taintedform[tainted_key] = tainted
other.update(form)
if meth: if meth:
if environ.has_key('PATH_INFO'): if environ.has_key('PATH_INFO'):
path=environ['PATH_INFO'] path=environ['PATH_INFO']
...@@ -754,7 +1060,7 @@ class HTTPRequest(BaseRequest): ...@@ -754,7 +1060,7 @@ class HTTPRequest(BaseRequest):
name='HTTP_%s' % name name='HTTP_%s' % name
return environ.get(name, default) return environ.get(name, default)
def get(self, key, default=None, def get(self, key, default=None, returnTaints=0,
URLmatch=re.compile('URL(PATH)?([0-9]+)$').match, URLmatch=re.compile('URL(PATH)?([0-9]+)$').match,
BASEmatch=re.compile('BASE(PATH)?([0-9]+)$').match, BASEmatch=re.compile('BASE(PATH)?([0-9]+)$').match,
): ):
...@@ -841,16 +1147,42 @@ class HTTPRequest(BaseRequest): ...@@ -841,16 +1147,42 @@ class HTTPRequest(BaseRequest):
del self._lazies[key] del self._lazies[key]
return v return v
# Return tainted data first (marked as suspect)
if returnTaints:
v = self.taintedform.get(key, _marker)
if v is not _marker:
other[key] = v
return v
# Untrusted data *after* trusted data
v = self.form.get(key, _marker)
if v is not _marker:
other[key] = v
return v
# Return tainted data first (marked as suspect)
if returnTaints:
v = self.taintedcookies.get(key, _marker)
if v is not _marker:
other[key] = v
return v
# Untrusted data *after* trusted data
v = self.cookies.get(key, _marker)
if v is not _marker:
other[key] = v
return v
return default return default
def __getitem__(self, key, default=_marker): def __getitem__(self, key, default=_marker, returnTaints=0):
v = self.get(key, default) v = self.get(key, default, returnTaints=returnTaints)
if v is _marker: if v is _marker:
raise KeyError, key raise KeyError, key
return v return v
def __getattr__(self, key, default=_marker): def __getattr__(self, key, default=_marker, returnTaints=0):
v = self.get(key, default) v = self.get(key, default, returnTaints=returnTaints)
if v is _marker: if v is _marker:
raise AttributeError, key raise AttributeError, key
return v return v
...@@ -858,12 +1190,12 @@ class HTTPRequest(BaseRequest): ...@@ -858,12 +1190,12 @@ class HTTPRequest(BaseRequest):
def set_lazy(self, key, callable): def set_lazy(self, key, callable):
self._lazies[key] = callable self._lazies[key] = callable
def has_key(self, key): def has_key(self, key, returnTaints=0):
try: self[key] try: self.__getitem__(key, returnTaints=returnTaints)
except: return 0 except: return 0
else: return 1 else: return 1
def keys(self): def keys(self, returnTaints=0):
keys = {} keys = {}
keys.update(self.common) keys.update(self.common)
keys.update(self._lazies) keys.update(self._lazies)
...@@ -885,6 +1217,10 @@ class HTTPRequest(BaseRequest): ...@@ -885,6 +1217,10 @@ class HTTPRequest(BaseRequest):
if not self.has_key(key): break if not self.has_key(key): break
keys.update(self.other) keys.update(self.other)
keys.update(self.cookies)
if returnTaints: keys.update(self.taintedcookies)
keys.update(self.form)
if returnTaints: keys.update(self.taintedform)
keys=keys.keys() keys=keys.keys()
keys.sort() keys.sort()
...@@ -966,6 +1302,32 @@ class HTTPRequest(BaseRequest): ...@@ -966,6 +1302,32 @@ class HTTPRequest(BaseRequest):
base64.decodestring(auth.split()[-1]).split(':') base64.decodestring(auth.split()[-1]).split(':')
return name, password return name, password
def taintWrapper(self):
return TaintRequestWrapper(self)
class TaintRequestWrapper:
def __init__(self, req):
self._req = req
def __getattr__(self, key):
if key in ('get', '__getitem__', '__getattr__', 'has_key', 'keys'):
return TaintMethodWrapper(getattr(self._req, key))
if not key in self._req.keys():
item = getattr(self._req, key, _marker)
if item is not _marker:
return item
return self._req.__getattr__(key, returnTaints=1)
class TaintMethodWrapper:
def __init__(self, method):
self._method = method
def __call__(self, *args, **kw):
kw['returnTaints'] = 1
return self._method(*args, **kw)
def has_codec(x): def has_codec(x):
try: try:
...@@ -1122,6 +1484,12 @@ class record: ...@@ -1122,6 +1484,12 @@ class record:
return '{%s}' % ', '.join( return '{%s}' % ', '.join(
map(lambda item: "'%s': %s" % (item[0], repr(item[1])), L1)) map(lambda item: "'%s': %s" % (item[0], repr(item[1])), L1))
def __cmp__(self, other):
return (cmp(type(self), type(other)) or
cmp(self.__class__, other.__class__) or
cmp(self.__dict__.items(), other.__dict__.items()))
# Flags # Flags
SEQUENCE=1 SEQUENCE=1
DEFAULT=2 DEFAULT=2
......
##############################################################################
#
# Copyright (c) 2001 Zope Corporation and Contributors. All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
__version__='$Revision: 1.1 $'[11:-2]
from cgi import escape
# TaintedStrings hold potentially dangerous untrusted data; anything that could
# possibly hold HTML is considered dangerous. DTML code will use the quoted
# value of this tring, and raised exceptions in Zope will use the repr()
# conversion.
class TaintedString:
def __init__(self, value):
self._value = value
def __str__(self):
return self._value
def __repr__(self):
return repr(self.quoted())
def __cmp__(self, o):
return cmp(self._value, o)
def __hash__(self):
return hash(self._value)
def __len__(self):
return len(self._value)
def __getitem__(self, index):
v = self._value[index]
if '<' in v:
v = self.__class__(v)
return v
def __getslice__(self, i, j):
i = max(i, 0)
j = max(j, 0)
v = self._value[i:j]
if '<' in v:
v = self.__class__(v)
return v
def __add__(self, o):
return self.__class__(self._value + o)
def __radd__(self, o):
return self.__class__(o + self._value)
def __mul__(self, o):
return self.__class__(self._value * o)
def __rmul__(self, o):
return self.__class__(o * self._value)
def __mod__(self, o):
return self.__class__(self._value % o)
def __int__(self):
return int(self._value)
def __float__(self):
return float(self._value)
def __long__(self):
return long(self._value)
def __getstate__(self):
# If an object tries to store a TaintedString, it obviously wasn't aware
# that it was playing with untrusted data. Complain acordingly.
raise SystemError("A TaintedString cannot be pickled. Code that "
"caused this TaintedString to be stored should be more careful "
"with untrusted data from the REQUEST.")
def __getattr__(self, a):
# for string methods support other than those defined below
return getattr(self._value, a)
# Python 2.2 only.
def decode(self, *args):
return self.__class__(self._value.decode(*args))
def encode(self, *args):
return self.__class__(self._value.encode(*args))
def expandtabs(self, *args):
return self.__class__(self._value.expandtabs(*args))
def replace(self, *args):
v = self._value.replace(*args)
if '<' in v:
v = self.__class__(v)
return v
def split(self, *args):
r = self._value.split(*args)
return map(lambda v, c=self.__class__: '<' in v and c(v) or v, r)
def splitlines(self, *args):
r = self._value.splitlines(*args)
return map(lambda v, c=self.__class__: '<' in v and c(v) or v, r)
def translate(self, *args):
v = self._value.translate(*args)
if '<' in v:
v = self.__class__(v)
return v
def quoted(self):
return escape(self._value, 1)
# As called by cDocumentTemplate
__untaint__ = quoted
def createSimpleWrapper(func):
return lambda s, f=func: s.__class__(getattr(s._value, f)())
def createOneArgWrapper(func):
return lambda s, a, f=func: s.__class__(getattr(s._value, f)(a))
simpleWrappedMethods = \
"capitalize lower lstrip rstrip strip swapcase title upper".split()
oneArgWrappedMethods = "center join ljust rjust".split()
for f in simpleWrappedMethods:
setattr(TaintedString, f, createSimpleWrapper(f))
for f in oneArgWrappedMethods:
setattr(TaintedString, f, createOneArgWrapper(f))
...@@ -36,6 +36,73 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -36,6 +36,73 @@ class ProcessInputsTests(unittest.TestCase):
req.processInputs() req.processInputs()
return req return req
def _noTaintedValues(self, req):
self.failIf(req.taintedform.keys())
def _valueIsOrHoldsTainted(self, val):
# Recursively searches a structure for a TaintedString and returns 1
# when one is found.
# Also raises an Assertion if a string which *should* have been
# tainted is found, or when a tainted string is not deemed dangerous.
from types import ListType, TupleType, StringType, UnicodeType
from ZPublisher.HTTPRequest import record
from ZPublisher.TaintedString import TaintedString
retval = 0
if isinstance(val, TaintedString):
self.failIf(not '<' in val,
"%r is not dangerous, no taint required." % val)
retval = 1
elif isinstance(val, record):
for attr, value in val.__dict__.items():
rval = self._valueIsOrHoldsTainted(attr)
if rval: retval = 1
rval = self._valueIsOrHoldsTainted(value)
if rval: retval = 1
elif type(val) in (ListType, TupleType):
for entry in val:
rval = self._valueIsOrHoldsTainted(entry)
if rval: retval = 1
elif type(val) in (StringType, UnicodeType):
self.failIf('<' in val,
"'%s' is dangerous and should have been tainted." % val)
return retval
def _noFormValuesInOther(self, req):
for key in req.taintedform.keys():
self.failIf(req.other.has_key(key),
'REQUEST.other should not hold tainted values at first!')
for key in req.form.keys():
self.failIf(req.other.has_key(key),
'REQUEST.other should not hold form values at first!')
def _onlyTaintedformHoldsTaintedStrings(self, req):
for key, val in req.taintedform.items():
self.assert_(self._valueIsOrHoldsTainted(key) or
self._valueIsOrHoldsTainted(val),
'Tainted form holds item %s that is not tainted' % key)
for key, val in req.form.items():
if req.taintedform.has_key(key):
continue
self.failIf(self._valueIsOrHoldsTainted(key) or
self._valueIsOrHoldsTainted(val),
'Normal form holds item %s that is tainted' % key)
def _taintedKeysAlsoInForm(self, req):
for key in req.taintedform.keys():
self.assert_(req.form.has_key(key),
"Found tainted %s not in form" % key)
self.assertEquals(req.form[key], req.taintedform[key],
"Key %s not correctly reproduced in tainted; expected %r, "
"got %r" % (key, req.form[key], req.taintedform[key]))
def testNoMarshalling(self): def testNoMarshalling(self):
inputs = ( inputs = (
('foo', 'bar'), ('spam', 'eggs'), ('foo', 'bar'), ('spam', 'eggs'),
...@@ -43,6 +110,7 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -43,6 +110,7 @@ class ProcessInputsTests(unittest.TestCase):
('spacey key', 'val'), ('key', 'spacey val'), ('spacey key', 'val'), ('key', 'spacey val'),
('multi', '1'), ('multi', '2')) ('multi', '1'), ('multi', '2'))
req = self._processInputs(inputs) req = self._processInputs(inputs)
self._noFormValuesInOther(req)
formkeys = list(req.form.keys()) formkeys = list(req.form.keys())
formkeys.sort() formkeys.sort()
...@@ -53,6 +121,9 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -53,6 +121,9 @@ class ProcessInputsTests(unittest.TestCase):
self.assertEquals(req['spacey key'], 'val') self.assertEquals(req['spacey key'], 'val')
self.assertEquals(req['key'], 'spacey val') self.assertEquals(req['key'], 'spacey val')
self._noTaintedValues(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testSimpleMarshalling(self): def testSimpleMarshalling(self):
from DateTime import DateTime from DateTime import DateTime
...@@ -64,6 +135,7 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -64,6 +135,7 @@ class ProcessInputsTests(unittest.TestCase):
('multiline:lines', 'one\ntwo'), ('multiline:lines', 'one\ntwo'),
('morewords:text', 'one\ntwo\n')) ('morewords:text', 'one\ntwo\n'))
req = self._processInputs(inputs) req = self._processInputs(inputs)
self._noFormValuesInOther(req)
formkeys = list(req.form.keys()) formkeys = list(req.form.keys())
formkeys.sort() formkeys.sort()
...@@ -80,6 +152,9 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -80,6 +152,9 @@ class ProcessInputsTests(unittest.TestCase):
self.assertEquals(req['num'], 42) self.assertEquals(req['num'], 42)
self.assertEquals(req['words'], 'Some words') self.assertEquals(req['words'], 'Some words')
self._noTaintedValues(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testUnicodeConversions(self): def testUnicodeConversions(self):
inputs = (('ustring:ustring:utf8', 'test\xc2\xae'), inputs = (('ustring:ustring:utf8', 'test\xc2\xae'),
('utext:utext:utf8', 'test\xc2\xae\ntest\xc2\xae\n'), ('utext:utext:utf8', 'test\xc2\xae\ntest\xc2\xae\n'),
...@@ -88,6 +163,7 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -88,6 +163,7 @@ class ProcessInputsTests(unittest.TestCase):
('nouconverter:string:utf8', 'test\xc2\xae')) ('nouconverter:string:utf8', 'test\xc2\xae'))
req = self._processInputs(inputs) req = self._processInputs(inputs)
self._noFormValuesInOther(req)
formkeys = list(req.form.keys()) formkeys = list(req.form.keys())
formkeys.sort() formkeys.sort()
...@@ -99,8 +175,12 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -99,8 +175,12 @@ class ProcessInputsTests(unittest.TestCase):
self.assertEquals(req['utokens'], [u'test\u00AE', u'test\u00AE']) self.assertEquals(req['utokens'], [u'test\u00AE', u'test\u00AE'])
self.assertEquals(req['ulines'], [u'test\u00AE', u'test\u00AE']) self.assertEquals(req['ulines'], [u'test\u00AE', u'test\u00AE'])
# expect a latin1 encoded version
self.assertEquals(req['nouconverter'], 'test\xae') self.assertEquals(req['nouconverter'], 'test\xae')
self._noTaintedValues(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testSimpleContainers(self): def testSimpleContainers(self):
inputs = ( inputs = (
('oneitem:list', 'one'), ('oneitem:list', 'one'),
...@@ -111,6 +191,7 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -111,6 +191,7 @@ class ProcessInputsTests(unittest.TestCase):
('setrec.foo:records', 'foo'), ('setrec.bar:records', 'bar'), ('setrec.foo:records', 'foo'), ('setrec.bar:records', 'bar'),
('setrec.foo:records', 'spam'), ('setrec.bar:records', 'eggs')) ('setrec.foo:records', 'spam'), ('setrec.bar:records', 'eggs'))
req = self._processInputs(inputs) req = self._processInputs(inputs)
self._noFormValuesInOther(req)
formkeys = list(req.form.keys()) formkeys = list(req.form.keys())
formkeys.sort() formkeys.sort()
...@@ -129,6 +210,9 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -129,6 +210,9 @@ class ProcessInputsTests(unittest.TestCase):
self.assertEquals(req['setrec'][1].foo, 'spam') self.assertEquals(req['setrec'][1].foo, 'spam')
self.assertEquals(req['setrec'][1].bar, 'eggs') self.assertEquals(req['setrec'][1].bar, 'eggs')
self._noTaintedValues(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testMarshallIntoSequences(self): def testMarshallIntoSequences(self):
inputs = ( inputs = (
('ilist:int:list', '1'), ('ilist:int:list', '2'), ('ilist:int:list', '1'), ('ilist:int:list', '2'),
...@@ -137,6 +221,7 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -137,6 +221,7 @@ class ProcessInputsTests(unittest.TestCase):
('ftuple:tuple:float', '1.2'), ('ftuple:tuple:float', '1.2'),
('tlist:tokens:list', 'one two'), ('tlist:list:tokens', '3 4')) ('tlist:tokens:list', 'one two'), ('tlist:list:tokens', '3 4'))
req = self._processInputs(inputs) req = self._processInputs(inputs)
self._noFormValuesInOther(req)
formkeys = list(req.form.keys()) formkeys = list(req.form.keys())
formkeys.sort() formkeys.sort()
...@@ -146,6 +231,9 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -146,6 +231,9 @@ class ProcessInputsTests(unittest.TestCase):
self.assertEquals(req['ftuple'], (1.0, 1.1, 1.2)) self.assertEquals(req['ftuple'], (1.0, 1.1, 1.2))
self.assertEquals(req['tlist'], [['one', 'two'], ['3', '4']]) self.assertEquals(req['tlist'], [['one', 'two'], ['3', '4']])
self._noTaintedValues(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testRecordsWithSequences(self): def testRecordsWithSequences(self):
inputs = ( inputs = (
('onerec.name:record', 'foo'), ('onerec.name:record', 'foo'),
...@@ -164,6 +252,7 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -164,6 +252,7 @@ class ProcessInputsTests(unittest.TestCase):
('setrec.ituple:tuple:int:records', '1'), ('setrec.ituple:tuple:int:records', '1'),
('setrec.ituple:tuple:int:records', '2')) ('setrec.ituple:tuple:int:records', '2'))
req = self._processInputs(inputs) req = self._processInputs(inputs)
self._noFormValuesInOther(req)
formkeys = list(req.form.keys()) formkeys = list(req.form.keys())
formkeys.sort() formkeys.sort()
...@@ -182,6 +271,9 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -182,6 +271,9 @@ class ProcessInputsTests(unittest.TestCase):
self.assertEquals(req['setrec'][i].ilist, [1, 2]) self.assertEquals(req['setrec'][i].ilist, [1, 2])
self.assertEquals(req['setrec'][i].ituple, (1, 2)) self.assertEquals(req['setrec'][i].ituple, (1, 2))
self._noTaintedValues(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testDefaults(self): def testDefaults(self):
inputs = ( inputs = (
('foo:default:int', '5'), ('foo:default:int', '5'),
...@@ -208,6 +300,7 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -208,6 +300,7 @@ class ProcessInputsTests(unittest.TestCase):
('setrec.foo:records', 'ham'), ('setrec.foo:records', 'ham'),
) )
req = self._processInputs(inputs) req = self._processInputs(inputs)
self._noFormValuesInOther(req)
formkeys = list(req.form.keys()) formkeys = list(req.form.keys())
formkeys.sort() formkeys.sort()
...@@ -227,6 +320,240 @@ class ProcessInputsTests(unittest.TestCase): ...@@ -227,6 +320,240 @@ class ProcessInputsTests(unittest.TestCase):
self.assertEquals(req['setrec'][1].spam, 'eggs') self.assertEquals(req['setrec'][1].spam, 'eggs')
self.assertEquals(req['setrec'][1].foo, 'ham') self.assertEquals(req['setrec'][1].foo, 'ham')
self._noTaintedValues(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testNoMarshallingWithTaints(self):
inputs = (
('foo', 'bar'), ('spam', 'eggs'),
('number', '1'),
('tainted', '<tainted value>'),
('<tainted key>', 'value'),
('spacey key', 'val'), ('key', 'spacey val'),
('multi', '1'), ('multi', '2'))
req = self._processInputs(inputs)
self._noFormValuesInOther(req)
taintedformkeys = list(req.taintedform.keys())
taintedformkeys.sort()
self.assertEquals(taintedformkeys, ['<tainted key>', 'tainted'])
self._taintedKeysAlsoInForm(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testSimpleMarshallingWithTaints(self):
inputs = (
('foo', 'bar'), ('spam', 'eggs'),
('number', '1'),
('tainted', '<tainted value>'), ('<tainted key>', 'value'),
('spacey key', 'val'), ('key', 'spacey val'),
('multi', '1'), ('multi', '2'))
req = self._processInputs(inputs)
self._noFormValuesInOther(req)
taintedformkeys = list(req.taintedform.keys())
taintedformkeys.sort()
self.assertEquals(taintedformkeys, ['<tainted key>', 'tainted'])
self._taintedKeysAlsoInForm(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testUnicodeWithTaints(self):
inputs = (('tustring:ustring:utf8', '<test\xc2\xae>'),
('tutext:utext:utf8', '<test\xc2\xae>\n<test\xc2\xae\n>'),
('tinitutokens:utokens:utf8', '<test\xc2\xae> test\xc2\xae'),
('tinitulines:ulines:utf8', '<test\xc2\xae>\ntest\xc2\xae'),
('tdeferutokens:utokens:utf8', 'test\xc2\xae <test\xc2\xae>'),
('tdeferulines:ulines:utf8', 'test\xc2\xae\n<test\xc2\xae>'),
('tnouconverter:string:utf8', '<test\xc2\xae>'))
req = self._processInputs(inputs)
self._noFormValuesInOther(req)
taintedformkeys = list(req.taintedform.keys())
taintedformkeys.sort()
self.assertEquals(taintedformkeys, ['tdeferulines', 'tdeferutokens',
'tinitulines', 'tinitutokens', 'tnouconverter', 'tustring',
'tutext'])
self._taintedKeysAlsoInForm(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testSimpleContainersWithTaints(self):
from types import ListType, TupleType
from ZPublisher.HTTPRequest import record
inputs = (
('toneitem:list', '<one>'),
('<tkeyoneitem>:list', 'one'),
('tinitalist:list', '<one>'), ('tinitalist:list', 'two'),
('tdeferalist:list', 'one'), ('tdeferalist:list', '<two>'),
('toneitemtuple:tuple', '<one>'),
('tinitatuple:tuple', '<one>'), ('tinitatuple:tuple', 'two'),
('tdeferatuple:tuple', 'one'), ('tdeferatuple:tuple', '<two>'),
('tinitonerec.foo:record', '<foo>'),
('tinitonerec.bar:record', 'bar'),
('tdeferonerec.foo:record', 'foo'),
('tdeferonerec.bar:record', '<bar>'),
('tinitinitsetrec.foo:records', '<foo>'),
('tinitinitsetrec.bar:records', 'bar'),
('tinitinitsetrec.foo:records', 'spam'),
('tinitinitsetrec.bar:records', 'eggs'),
('tinitdefersetrec.foo:records', 'foo'),
('tinitdefersetrec.bar:records', '<bar>'),
('tinitdefersetrec.foo:records', 'spam'),
('tinitdefersetrec.bar:records', 'eggs'),
('tdeferinitsetrec.foo:records', 'foo'),
('tdeferinitsetrec.bar:records', 'bar'),
('tdeferinitsetrec.foo:records', '<spam>'),
('tdeferinitsetrec.bar:records', 'eggs'),
('tdeferdefersetrec.foo:records', 'foo'),
('tdeferdefersetrec.bar:records', 'bar'),
('tdeferdefersetrec.foo:records', 'spam'),
('tdeferdefersetrec.bar:records', '<eggs>'))
req = self._processInputs(inputs)
self._noFormValuesInOther(req)
taintedformkeys = list(req.taintedform.keys())
taintedformkeys.sort()
self.assertEquals(taintedformkeys, ['<tkeyoneitem>', 'tdeferalist',
'tdeferatuple', 'tdeferdefersetrec', 'tdeferinitsetrec',
'tdeferonerec', 'tinitalist', 'tinitatuple', 'tinitdefersetrec',
'tinitinitsetrec', 'tinitonerec', 'toneitem', 'toneitemtuple'])
self._taintedKeysAlsoInForm(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testRecordsWithSequencesAndTainted(self):
inputs = (
('tinitonerec.tokens:tokens:record', '<one> two'),
('tdeferonerec.tokens:tokens:record', 'one <two>'),
('tinitsetrec.name:records', 'first'),
('tinitsetrec.ilist:list:records', '<1>'),
('tinitsetrec.ilist:list:records', '2'),
('tinitsetrec.ituple:tuple:int:records', '1'),
('tinitsetrec.ituple:tuple:int:records', '2'),
('tinitsetrec.name:records', 'second'),
('tinitsetrec.ilist:list:records', '1'),
('tinitsetrec.ilist:list:records', '2'),
('tinitsetrec.ituple:tuple:int:records', '1'),
('tinitsetrec.ituple:tuple:int:records', '2'),
('tdeferfirstsetrec.name:records', 'first'),
('tdeferfirstsetrec.ilist:list:records', '1'),
('tdeferfirstsetrec.ilist:list:records', '<2>'),
('tdeferfirstsetrec.ituple:tuple:int:records', '1'),
('tdeferfirstsetrec.ituple:tuple:int:records', '2'),
('tdeferfirstsetrec.name:records', 'second'),
('tdeferfirstsetrec.ilist:list:records', '1'),
('tdeferfirstsetrec.ilist:list:records', '2'),
('tdeferfirstsetrec.ituple:tuple:int:records', '1'),
('tdeferfirstsetrec.ituple:tuple:int:records', '2'),
('tdefersecondsetrec.name:records', 'first'),
('tdefersecondsetrec.ilist:list:records', '1'),
('tdefersecondsetrec.ilist:list:records', '2'),
('tdefersecondsetrec.ituple:tuple:int:records', '1'),
('tdefersecondsetrec.ituple:tuple:int:records', '2'),
('tdefersecondsetrec.name:records', 'second'),
('tdefersecondsetrec.ilist:list:records', '1'),
('tdefersecondsetrec.ilist:list:records', '<2>'),
('tdefersecondsetrec.ituple:tuple:int:records', '1'),
('tdefersecondsetrec.ituple:tuple:int:records', '2'),
)
req = self._processInputs(inputs)
self._noFormValuesInOther(req)
taintedformkeys = list(req.taintedform.keys())
taintedformkeys.sort()
self.assertEquals(taintedformkeys, ['tdeferfirstsetrec', 'tdeferonerec',
'tdefersecondsetrec', 'tinitonerec', 'tinitsetrec'])
self._taintedKeysAlsoInForm(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testDefaultsWithTaints(self):
inputs = (
('tfoo:default', '<5>'),
('doesnnotapply:default', '<4>'),
('doesnnotapply', '4'),
('tinitlist:default', '3'),
('tinitlist:default', '4'),
('tinitlist:default', '5'),
('tinitlist', '<1>'),
('tinitlist', '2'),
('tdeferlist:default', '3'),
('tdeferlist:default', '<4>'),
('tdeferlist:default', '5'),
('tdeferlist', '1'),
('tdeferlist', '2'),
('tinitbar.spam:record:default', 'eggs'),
('tinitbar.foo:record:default', 'foo'),
('tinitbar.foo:record', '<baz>'),
('tdeferbar.spam:record:default', '<eggs>'),
('tdeferbar.foo:record:default', 'foo'),
('tdeferbar.foo:record', 'baz'),
('rdoesnotapply.spam:record:default', '<eggs>'),
('rdoesnotapply.spam:record', 'eggs'),
('tinitsetrec.spam:records:default', 'eggs'),
('tinitsetrec.foo:records:default', 'foo'),
('tinitsetrec.foo:records', '<baz>'),
('tinitsetrec.foo:records', 'ham'),
('tdefersetrec.spam:records:default', '<eggs>'),
('tdefersetrec.foo:records:default', 'foo'),
('tdefersetrec.foo:records', 'baz'),
('tdefersetrec.foo:records', 'ham'),
('srdoesnotapply.foo:records:default', '<eggs>'),
('srdoesnotapply.foo:records', 'baz'),
('srdoesnotapply.foo:records', 'ham'))
req = self._processInputs(inputs)
self._noFormValuesInOther(req)
taintedformkeys = list(req.taintedform.keys())
taintedformkeys.sort()
self.assertEquals(taintedformkeys, ['tdeferbar', 'tdeferlist',
'tdefersetrec', 'tfoo', 'tinitbar', 'tinitlist', 'tinitsetrec'])
self._taintedKeysAlsoInForm(req)
self._onlyTaintedformHoldsTaintedStrings(req)
def testTaintedAttributeRaises(self):
input = ('taintedattr.here<be<taint:record', 'value',)
self.assertRaises(ValueError, self._processInputs, input)
def testNoTaintedExceptions(self):
# Feed tainted garbage to the conversion methods, and any exception
# returned should be HTML safe
from ZPublisher.Converters import type_converters
from DateTime import DateTime
for type, convert in type_converters.items():
try:
convert('<html garbage>')
except Exception, e:
self.failIf('<' in e.args,
'%s converter does not quote unsafe value!' % type)
except DateTime.SyntaxError, e:
self.failIf('<' in e,
'%s converter does not quote unsafe value!' % type)
def test_suite(): def test_suite():
suite = unittest.TestSuite() suite = unittest.TestSuite()
......
##############################################################################
#
# Copyright (c) 2001 Zope Corporation and Contributors. All Rights Reserved.
#
# This software is subject to the provisions of the Zope Public License,
# Version 2.0 (ZPL). A copy of the ZPL should accompany this distribution.
# THIS SOFTWARE IS PROVIDED "AS IS" AND ANY AND ALL EXPRESS OR IMPLIED
# WARRANTIES ARE DISCLAIMED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF TITLE, MERCHANTABILITY, AGAINST INFRINGEMENT, AND FITNESS
# FOR A PARTICULAR PURPOSE
#
##############################################################################
import unittest
class TestTaintedString(unittest.TestCase):
def setUp(self):
self.unquoted = '<test attr="&">'
self.quoted = '&lt;test attr=&quot;&amp;&quot;&gt;'
self.tainted = self._getClass()(self.unquoted)
def _getClass(self):
from ZPublisher.TaintedString import TaintedString
return TaintedString
def testStr(self):
self.assertEquals(str(self.tainted), self.unquoted)
def testRepr(self):
self.assertEquals(repr(self.tainted), repr(self.quoted))
def testCmp(self):
self.assertEquals(cmp(self.tainted, self.unquoted), 0)
self.assertEquals(cmp(self.tainted, 'a'), -1)
self.assertEquals(cmp(self.tainted, '.'), 1)
def testHash(self):
hash = {}
hash[self.tainted] = self.quoted
hash[self.unquoted] = self.unquoted
self.assertEquals(hash[self.tainted], self.unquoted)
def testLen(self):
self.assertEquals(len(self.tainted), len(self.unquoted))
def testGetItem(self):
self.assert_(isinstance(self.tainted[0], self._getClass()))
self.assertEquals(self.tainted[0], '<')
self.failIf(isinstance(self.tainted[-1], self._getClass()))
self.assertEquals(self.tainted[-1], '>')
def testGetSlice(self):
self.assert_(isinstance(self.tainted[0:1], self._getClass()))
self.assertEquals(self.tainted[0:1], '<')
self.failIf(isinstance(self.tainted[1:], self._getClass()))
self.assertEquals(self.tainted[1:], self.unquoted[1:])
def testConcat(self):
self.assert_(isinstance(self.tainted + 'test', self._getClass()))
self.assertEquals(self.tainted + 'test', self.unquoted + 'test')
self.assert_(isinstance('test' + self.tainted, self._getClass()))
self.assertEquals('test' + self.tainted, 'test' + self.unquoted)
def testMultiply(self):
self.assert_(isinstance(2 * self.tainted, self._getClass()))
self.assertEquals(2 * self.tainted, 2 * self.unquoted)
self.assert_(isinstance(self.tainted * 2, self._getClass()))
self.assertEquals(self.tainted * 2, self.unquoted * 2)
def testInterpolate(self):
tainted = self._getClass()('<%s>')
self.assert_(isinstance(tainted % 'foo', self._getClass()))
self.assertEquals(tainted % 'foo', '<foo>')
tainted = self._getClass()('<%s attr="%s">')
self.assert_(isinstance(tainted % ('foo', 'bar'), self._getClass()))
self.assertEquals(tainted % ('foo', 'bar'), '<foo attr="bar">')
def testStringMethods(self):
simple = "capitalize isalpha isdigit islower isspace istitle isupper" \
" lower lstrip rstrip strip swapcase upper".split()
returnsTainted = "capitalize lower lstrip rstrip strip swapcase upper"
returnsTainted = returnsTainted.split()
unquoted = '\tThis is a test '
tainted = self._getClass()(unquoted)
for f in simple:
v = getattr(tainted, f)()
self.assertEquals(v, getattr(unquoted, f)())
if f in returnsTainted:
self.assert_(isinstance(v, self._getClass()))
else:
self.failIf(isinstance(v, self._getClass()))
justify = "center ljust rjust".split()
for f in justify:
v = getattr(tainted, f)(30)
self.assertEquals(v, getattr(unquoted, f)(30))
self.assert_(isinstance(v, self._getClass()))
searches = "find index rfind rindex endswith startswith".split()
searchraises = "index rindex".split()
for f in searches:
v = getattr(tainted, f)('test')
self.assertEquals(v, getattr(unquoted, f)('test'))
if f in searchraises:
self.assertRaises(ValueError, getattr(tainted, f), 'nada')
self.assertEquals(tainted.count('test', 1, -1),
unquoted.count('test', 1, -1))
self.assertEquals(tainted.encode(), unquoted.encode())
self.assert_(isinstance(tainted.encode(), self._getClass()))
self.assertEquals(tainted.expandtabs(10),
unquoted.expandtabs(10))
self.assert_(isinstance(tainted.expandtabs(), self._getClass()))
self.assertEquals(tainted.replace('test', 'spam'),
unquoted.replace('test', 'spam'))
self.assert_(isinstance(tainted.replace('test', '<'), self._getClass()))
self.failIf(isinstance(tainted.replace('test', 'spam'),
self._getClass()))
self.assertEquals(tainted.split(), unquoted.split())
for part in self._getClass()('< < <').split():
self.assert_(isinstance(part, self._getClass()))
for part in tainted.split():
self.failIf(isinstance(part, self._getClass()))
multiline = 'test\n<tainted>'
lines = self._getClass()(multiline).split()
self.assertEquals(lines, multiline.split())
self.assert_(isinstance(lines[1], self._getClass()))
self.failIf(isinstance(lines[0], self._getClass()))
transtable = ''.join(map(chr, range(256)))
self.assertEquals(tainted.translate(transtable),
unquoted.translate(transtable))
self.assert_(isinstance(self._getClass()('<').translate(transtable),
self._getClass()))
self.failIf(isinstance(self._getClass()('<').translate(transtable, '<'),
self._getClass()))
def testQuoted(self):
self.assertEquals(self.tainted.quoted(), self.quoted)
def test_suite():
suite = unittest.TestSuite()
suite.addTest(unittest.makeSuite(TestTaintedString, 'test'))
return suite
def main():
unittest.TextTestRunner().run(test_suite())
def debug():
test_suite().debug()
def pdebug():
import pdb
pdb.run('debug()')
if __name__=='__main__':
if len(sys.argv) > 1:
globals()[sys.argv[1]]()
else:
main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment