Commit e7b8e43e authored by Evan Simpson's avatar Evan Simpson

Handle augmented assignment correctly.

parent f791e0f8
......@@ -84,7 +84,7 @@
##############################################################################
from __future__ import nested_scopes
__version__='$Revision: 1.2 $'[11:-2]
__version__='$Revision: 1.3 $'[11:-2]
import new
......@@ -105,10 +105,10 @@ for name in ('None', 'abs', 'chr', 'divmod', 'float', 'hash', 'hex', 'int',
def _full_read_guard(g_attr, g_item):
# Nested scope abuse!
# The two arguments are used by class Wrapper
# The arguments are used by class Wrapper
# safetype variable is used by guard()
safetype = {type(()): 1, type([]): 1, type({}): 1, type(''): 1}.has_key
def guard(ob):
def guard(ob, write=None):
# Don't bother wrapping simple types, or objects that claim to
# handle their own read security.
if safetype(type(ob)) or getattr(ob, '_guarded_reads', 0):
......@@ -122,7 +122,12 @@ def _full_read_guard(g_attr, g_item):
return g_attr(ob, name)
def __getitem__(self, i):
# Must handle both item and slice access.
return g_item(ob, i)
return g_item(ob, i)
# Optional, for combined read/write guard
def __setitem__(self, index, val):
write(ob)[index] = val
def __setattr__(self, attr, val):
setattr(write(ob), attr, val)
return Wrapper()
return guard
......@@ -143,7 +148,7 @@ def _write_wrapper():
# Required for slices with negative bounds.
return len(self.ob)
def __init__(self, ob):
self.ob = ob
self.__dict__['ob'] = ob
# Generate class methods
d = Wrapper.__dict__
for name, error_msg in (
......@@ -179,3 +184,6 @@ def guarded_delattr(object, name):
delattr(full_write_guard(object), name)
safe_builtins['delattr'] = guarded_delattr
......@@ -87,7 +87,7 @@ RestrictionMutator modifies a tree produced by
compiler.transformer.Transformer, restricting and enhancing the
code in various ways before sending it to pycodegen.
'''
__version__='$Revision: 1.2 $'[11:-2]
__version__='$Revision: 1.3 $'[11:-2]
from compiler import ast
from compiler.transformer import parse
......@@ -236,7 +236,12 @@ class RestrictionMutator:
def visitGetattr(self, node, walker):
self.checkAttrName(node)
node = walker.defaultVisitNode(node)
node.expr = ast.CallFunc(_read_guard_name, [node.expr])
expr = [node.expr]
if getattr(node, 'use_dual_guard', 0):
# We're in an augmented assignment
expr.append(_write_guard_name)
self.funcinfo._write_used = 1
node.expr = ast.CallFunc(_read_guard_name, expr)
self.funcinfo._read_used = 1
return node
......@@ -244,7 +249,12 @@ class RestrictionMutator:
node = walker.defaultVisitNode(node)
if node.flags == OP_APPLY:
# get subscript or slice
node.expr = ast.CallFunc(_read_guard_name, [node.expr])
expr = [node.expr]
if getattr(node, 'use_dual_guard', 0):
# We're in an augmented assignment
expr.append(_write_guard_name)
self.funcinfo._write_used = 1
node.expr = ast.CallFunc(_read_guard_name, expr)
self.funcinfo._read_used = 1
elif node.flags in (OP_DELETE, OP_ASSIGN):
# set or remove subscript or slice
......@@ -274,6 +284,10 @@ class RestrictionMutator:
self.prepBody(node.node.nodes)
return node
def visitAugAssign(self, node, walker):
node.node.use_dual_guard = 1
return walker.defaultVisitNode(node)
if __name__ == '__main__':
# A minimal test.
......
......@@ -34,20 +34,49 @@ def allowed_simple():
def allowed_write(ob):
ob.writeable = 1
ob.writeable += 1
[1 for ob.writeable in 1,2]
ob['safe'] = 2
ob['safe'] += 2
[1 for ob['safe'] in 1,2]
def denied_getattr(ob):
ob.disallowed += 1
return ob.disallowed
def denied_setattr(ob):
ob.allowed = -1
def denied_setattr2(ob):
ob.allowed += -1
def denied_setattr3(ob):
[1 for ob.allowed in 1,2]
def denied_getitem(ob):
ob[1]
def denied_getitem2(ob):
ob[1] += 1
def denied_setitem(ob):
ob['x'] = 2
def denied_setitem2(ob):
ob[0] += 2
def denied_setitem3(ob):
[1 for ob['x'] in 1,2]
def denied_setslice(ob):
ob[0:1] = 'a'
def denied_setslice2(ob):
ob[0:1] += 'a'
def denied_setslice3(ob):
[1 for ob[0:1] in 1,2]
def strange_attribute():
# If a guard has attributes with names that don't start with an
# underscore, those attributes appear to be an attribute of
......
......@@ -89,7 +89,7 @@ class RestrictedObject:
def __getitem__(self, idx):
if idx == 'protected':
raise AccessDenied
elif idx == 0:
elif idx == 0 or idx == 'safe':
return 1
elif idx == 1:
return DisallowedObject
......@@ -114,7 +114,7 @@ class RestrictedObject:
class TestGuard:
'''A guard class'''
def __init__(self, _ob):
def __init__(self, _ob, write=None):
self.__dict__['_ob'] = _ob
# Read guard methods
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment