Commit e7b8e43e authored by Evan Simpson's avatar Evan Simpson

Handle augmented assignment correctly.

parent f791e0f8
...@@ -84,7 +84,7 @@ ...@@ -84,7 +84,7 @@
############################################################################## ##############################################################################
from __future__ import nested_scopes from __future__ import nested_scopes
__version__='$Revision: 1.2 $'[11:-2] __version__='$Revision: 1.3 $'[11:-2]
import new import new
...@@ -105,10 +105,10 @@ for name in ('None', 'abs', 'chr', 'divmod', 'float', 'hash', 'hex', 'int', ...@@ -105,10 +105,10 @@ for name in ('None', 'abs', 'chr', 'divmod', 'float', 'hash', 'hex', 'int',
def _full_read_guard(g_attr, g_item): def _full_read_guard(g_attr, g_item):
# Nested scope abuse! # Nested scope abuse!
# The two arguments are used by class Wrapper # The arguments are used by class Wrapper
# safetype variable is used by guard() # safetype variable is used by guard()
safetype = {type(()): 1, type([]): 1, type({}): 1, type(''): 1}.has_key safetype = {type(()): 1, type([]): 1, type({}): 1, type(''): 1}.has_key
def guard(ob): def guard(ob, write=None):
# Don't bother wrapping simple types, or objects that claim to # Don't bother wrapping simple types, or objects that claim to
# handle their own read security. # handle their own read security.
if safetype(type(ob)) or getattr(ob, '_guarded_reads', 0): if safetype(type(ob)) or getattr(ob, '_guarded_reads', 0):
...@@ -123,6 +123,11 @@ def _full_read_guard(g_attr, g_item): ...@@ -123,6 +123,11 @@ def _full_read_guard(g_attr, g_item):
def __getitem__(self, i): def __getitem__(self, i):
# Must handle both item and slice access. # Must handle both item and slice access.
return g_item(ob, i) return g_item(ob, i)
# Optional, for combined read/write guard
def __setitem__(self, index, val):
write(ob)[index] = val
def __setattr__(self, attr, val):
setattr(write(ob), attr, val)
return Wrapper() return Wrapper()
return guard return guard
...@@ -143,7 +148,7 @@ def _write_wrapper(): ...@@ -143,7 +148,7 @@ def _write_wrapper():
# Required for slices with negative bounds. # Required for slices with negative bounds.
return len(self.ob) return len(self.ob)
def __init__(self, ob): def __init__(self, ob):
self.ob = ob self.__dict__['ob'] = ob
# Generate class methods # Generate class methods
d = Wrapper.__dict__ d = Wrapper.__dict__
for name, error_msg in ( for name, error_msg in (
...@@ -179,3 +184,6 @@ def guarded_delattr(object, name): ...@@ -179,3 +184,6 @@ def guarded_delattr(object, name):
delattr(full_write_guard(object), name) delattr(full_write_guard(object), name)
safe_builtins['delattr'] = guarded_delattr safe_builtins['delattr'] = guarded_delattr
...@@ -87,7 +87,7 @@ RestrictionMutator modifies a tree produced by ...@@ -87,7 +87,7 @@ RestrictionMutator modifies a tree produced by
compiler.transformer.Transformer, restricting and enhancing the compiler.transformer.Transformer, restricting and enhancing the
code in various ways before sending it to pycodegen. code in various ways before sending it to pycodegen.
''' '''
__version__='$Revision: 1.2 $'[11:-2] __version__='$Revision: 1.3 $'[11:-2]
from compiler import ast from compiler import ast
from compiler.transformer import parse from compiler.transformer import parse
...@@ -236,7 +236,12 @@ class RestrictionMutator: ...@@ -236,7 +236,12 @@ class RestrictionMutator:
def visitGetattr(self, node, walker): def visitGetattr(self, node, walker):
self.checkAttrName(node) self.checkAttrName(node)
node = walker.defaultVisitNode(node) node = walker.defaultVisitNode(node)
node.expr = ast.CallFunc(_read_guard_name, [node.expr]) expr = [node.expr]
if getattr(node, 'use_dual_guard', 0):
# We're in an augmented assignment
expr.append(_write_guard_name)
self.funcinfo._write_used = 1
node.expr = ast.CallFunc(_read_guard_name, expr)
self.funcinfo._read_used = 1 self.funcinfo._read_used = 1
return node return node
...@@ -244,7 +249,12 @@ class RestrictionMutator: ...@@ -244,7 +249,12 @@ class RestrictionMutator:
node = walker.defaultVisitNode(node) node = walker.defaultVisitNode(node)
if node.flags == OP_APPLY: if node.flags == OP_APPLY:
# get subscript or slice # get subscript or slice
node.expr = ast.CallFunc(_read_guard_name, [node.expr]) expr = [node.expr]
if getattr(node, 'use_dual_guard', 0):
# We're in an augmented assignment
expr.append(_write_guard_name)
self.funcinfo._write_used = 1
node.expr = ast.CallFunc(_read_guard_name, expr)
self.funcinfo._read_used = 1 self.funcinfo._read_used = 1
elif node.flags in (OP_DELETE, OP_ASSIGN): elif node.flags in (OP_DELETE, OP_ASSIGN):
# set or remove subscript or slice # set or remove subscript or slice
...@@ -274,6 +284,10 @@ class RestrictionMutator: ...@@ -274,6 +284,10 @@ class RestrictionMutator:
self.prepBody(node.node.nodes) self.prepBody(node.node.nodes)
return node return node
def visitAugAssign(self, node, walker):
node.node.use_dual_guard = 1
return walker.defaultVisitNode(node)
if __name__ == '__main__': if __name__ == '__main__':
# A minimal test. # A minimal test.
......
...@@ -34,20 +34,49 @@ def allowed_simple(): ...@@ -34,20 +34,49 @@ def allowed_simple():
def allowed_write(ob): def allowed_write(ob):
ob.writeable = 1 ob.writeable = 1
ob.writeable += 1
[1 for ob.writeable in 1,2]
ob['safe'] = 2 ob['safe'] = 2
ob['safe'] += 2
[1 for ob['safe'] in 1,2]
def denied_getattr(ob): def denied_getattr(ob):
ob.disallowed += 1
return ob.disallowed return ob.disallowed
def denied_setattr(ob): def denied_setattr(ob):
ob.allowed = -1 ob.allowed = -1
def denied_setattr2(ob):
ob.allowed += -1
def denied_setattr3(ob):
[1 for ob.allowed in 1,2]
def denied_getitem(ob):
ob[1]
def denied_getitem2(ob):
ob[1] += 1
def denied_setitem(ob): def denied_setitem(ob):
ob['x'] = 2 ob['x'] = 2
def denied_setitem2(ob):
ob[0] += 2
def denied_setitem3(ob):
[1 for ob['x'] in 1,2]
def denied_setslice(ob): def denied_setslice(ob):
ob[0:1] = 'a' ob[0:1] = 'a'
def denied_setslice2(ob):
ob[0:1] += 'a'
def denied_setslice3(ob):
[1 for ob[0:1] in 1,2]
def strange_attribute(): def strange_attribute():
# If a guard has attributes with names that don't start with an # If a guard has attributes with names that don't start with an
# underscore, those attributes appear to be an attribute of # underscore, those attributes appear to be an attribute of
......
...@@ -89,7 +89,7 @@ class RestrictedObject: ...@@ -89,7 +89,7 @@ class RestrictedObject:
def __getitem__(self, idx): def __getitem__(self, idx):
if idx == 'protected': if idx == 'protected':
raise AccessDenied raise AccessDenied
elif idx == 0: elif idx == 0 or idx == 'safe':
return 1 return 1
elif idx == 1: elif idx == 1:
return DisallowedObject return DisallowedObject
...@@ -114,7 +114,7 @@ class RestrictedObject: ...@@ -114,7 +114,7 @@ class RestrictedObject:
class TestGuard: class TestGuard:
'''A guard class''' '''A guard class'''
def __init__(self, _ob): def __init__(self, _ob, write=None):
self.__dict__['_ob'] = _ob self.__dict__['_ob'] = _ob
# Read guard methods # Read guard methods
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment