Commit c709f421 authored by Boxiang Sun's avatar Boxiang Sun

[WIP]Use new implementation based on ast package

parent 5976a7da
......@@ -19,48 +19,50 @@ code in various ways before sending it to pycodegen.
$Revision: 1.13 $
"""
from SelectCompiler import ast, parse, OP_ASSIGN, OP_DELETE, OP_APPLY
import ast
from ast import parse
# These utility functions allow us to generate AST subtrees without
# line number attributes. These trees can then be inserted into other
# trees without affecting line numbers shown in tracebacks, etc.
def rmLineno(node):
"""Strip lineno attributes from a code tree."""
if node.__dict__.has_key('lineno'):
del node.lineno
for child in node.getChildren():
if isinstance(child, ast.Node):
rmLineno(child)
for child in ast.walk(node):
if 'lineno' in child._attributes:
del child.lineno
def stmtNode(txt):
"""Make a "clean" statement node."""
node = parse(txt).node.nodes[0]
rmLineno(node)
node = parse(txt).body[0]
# TODO: Remove the line number of nodes will cause error.
# Need to figure out why.
return node
# The security checks are performed by a set of six functions that
# must be provided by the restricted environment.
_apply_name = ast.Name("_apply_")
_getattr_name = ast.Name("_getattr_")
_getitem_name = ast.Name("_getitem_")
_getiter_name = ast.Name("_getiter_")
_print_target_name = ast.Name("_print")
_write_name = ast.Name("_write_")
_inplacevar_name = ast.Name("_inplacevar_")
_apply_name = ast.Name("_apply_", ast.Load())
_getattr_name = ast.Name("_getattr_", ast.Load())
_getitem_name = ast.Name("_getitem_", ast.Load())
_getiter_name = ast.Name("_getiter_", ast.Load())
_print_target_name = ast.Name("_print", ast.Load())
_write_name = ast.Name("_write_", ast.Load())
_inplacevar_name = ast.Name("_inplacevar_", ast.Load())
# Constants.
_None_const = ast.Const(None)
_write_const = ast.Const("write")
_None_const = ast.Name('None', ast.Load())
# _write_const = ast.Name("write", ast.Load())
_printed_expr = stmtNode("_print()").expr
# What is it?
# _printed_expr = stmtNode("_print()").expr
_printed_expr = stmtNode("_print()").value
_print_target_node = stmtNode("_print = _print_()")
class FuncInfo:
print_used = False
printed_used = False
class RestrictionMutator:
class RestrictionTransformer(ast.NodeTransformer):
def __init__(self):
self.warnings = []
......@@ -108,7 +110,7 @@ class RestrictionMutator:
this underscore protection is important regardless of the
security policy. Special case: '_' is allowed.
"""
name = node.attrname
name = node.attr
if name.startswith("_") and name != "_":
# Note: "_" *is* allowed.
self.error(node, '"%s" is an invalid attribute name '
......@@ -130,7 +132,7 @@ class RestrictionMutator:
self.warnings.append(
"Doesn't print, but reads 'printed' variable.")
def visitFunction(self, node, walker):
def visit_FunctionDef(self, node):
"""Checks and mutates a function definition.
Checks the name of the function and the argument names using
......@@ -138,32 +140,39 @@ class RestrictionMutator:
beginning of the code suite.
"""
self.checkName(node, node.name)
for argname in node.argnames:
for argname in node.args.args:
if isinstance(argname, str):
self.checkName(node, argname)
else:
for name in argname:
self.checkName(node, name)
walker.visitSequence(node.defaults)
# TODO: check sequence!!!
self.checkName(node, argname.id)
# FuncDef.args.defaults is a list.
# FuncDef.args.args is a list, contains ast.Name
# FuncDef.args.kwarg is a list.
# FuncDef.args.vararg is a list.
for i, arg in enumerate(node.args.defaults):
node.args.defaults[i] = self.visit(arg)
former_funcinfo = self.funcinfo
self.funcinfo = FuncInfo()
node = walker.defaultVisitNode(node, exclude=('defaults',))
self.prepBody(node.code.nodes)
for i, item in enumerate(node.body):
node.body[i] = self.visit(item)
self.prepBody(node.body)
self.funcinfo = former_funcinfo
ast.fix_missing_locations(node)
return node
def visitLambda(self, node, walker):
def visit_Lambda(self, node):
"""Checks and mutates an anonymous function definition.
Checks the argument names using checkName(). It also calls
prepBody() to prepend code to the beginning of the code suite.
"""
for argname in node.argnames:
self.checkName(node, argname)
return walker.defaultVisitNode(node)
for arg in node.args.args:
self.checkName(node, arg.id)
return self.generic_visit(node)
def visitPrint(self, node, walker):
def visit_Print(self, node):
"""Checks and mutates a print statement.
Adds a target to all print statements. 'print foo' becomes
......@@ -178,7 +187,7 @@ class RestrictionMutator:
templates and scripts; 'write' happens to be the name of the
method that changes them.
"""
node = walker.defaultVisitNode(node)
self.generic_visit(node)
self.funcinfo.print_used = True
if node.dest is None:
node.dest = _print_target_name
......@@ -186,36 +195,44 @@ class RestrictionMutator:
# Pre-validate access to the "write" attribute.
# "print >> ob, x" becomes
# "print >> (_getattr(ob, 'write') and ob), x"
node.dest = ast.And([
ast.CallFunc(_getattr_name, [node.dest, _write_const]),
# node.dest = ast.And([
# ast.CallFunc(_getattr_name, [node.dest, _write_const]),
# node.dest])
call_node = ast.Call(_getattr_name, [node.dest, ast.Str('write')], [], None, None)
and_node = ast.And()
node.dest = ast.BoolOp(and_node, [
call_node,
node.dest])
ast.fix_missing_locations(node)
return node
visitPrintnl = visitPrint
# XXX: Does ast.AST still have Printnl???
visitPrintnl = visit_Print
def visitName(self, node, walker):
# def visitName(self, node, walker):
def visit_Name(self, node):
"""Prevents access to protected names as defined by checkName().
Also converts use of the name 'printed' to an expression.
"""
if node.name == 'printed':
if node.id == 'printed':
# Replace name lookup with an expression.
self.funcinfo.printed_used = True
return _printed_expr
self.checkName(node, node.name)
self.used_names[node.name] = True
return ast.fix_missing_locations(_printed_expr)
self.checkName(node, node.id)
self.used_names[node.id] = True
return node
def visitCallFunc(self, node, walker):
def visit_Call(self, node):
"""Checks calls with *-args and **-args.
That's a way of spelling apply(), and needs to use our safe
_apply_ instead.
"""
walked = walker.defaultVisitNode(node)
if node.star_args is None and node.dstar_args is None:
# This is not an extended function call
return walked
self.generic_visit(node)
if node.starargs is None and node.kwargs is None:
return node
# Otherwise transform foo(a, b, c, d=e, f=g, *args, **kws) into a call
# of _apply_(foo, a, b, c, d=e, f=g, *args, **kws). The interesting
# thing here is that _apply_() is defined with just *args and **kws,
......@@ -226,16 +243,16 @@ class RestrictionMutator:
# function to call), wraps args and kws in guarded accessors, then
# calls the function, returning the value.
# Transform foo(...) to _apply(foo, ...)
walked.args.insert(0, walked.node)
walked.node = _apply_name
return walked
# walked.args.insert(0, walked.node)
# walked.node = _apply_name
# walked.args.insert(0, walked.func)
node.args.insert(0, node.func)
node.func = _apply_name
# walked.func = _apply_name
return ast.fix_missing_locations(node)
def visitAssName(self, node, walker):
"""Checks a name assignment using checkName()."""
self.checkName(node, node.name)
return node
def visitFor(self, node, walker):
def visit_For(self, node):
# convert
# for x in expr:
# to
......@@ -246,106 +263,117 @@ class RestrictionMutator:
# [... for x in expr ...]
# to
# [... for x in _getiter(expr) ...]
node = walker.defaultVisitNode(node)
node.list = ast.CallFunc(_getiter_name, [node.list])
self.generic_visit(node)
node.iter = ast.Call(_getiter_name, [node.iter], [], None, None)
ast.fix_missing_locations(node)
return node
visitListCompFor = visitFor
# visitListComp = visitFor
def visit_ListComp(self, node):
self.generic_visit(node)
return node
def visitGenExprFor(self, node, walker):
# convert
# (... for x in expr ...)
def visit_comprehension(self, node):
# Also for comprehensions:
# [... for x in expr ...]
# to
# (... for x in _getiter(expr) ...)
node = walker.defaultVisitNode(node)
node.iter = ast.CallFunc(_getiter_name, [node.iter])
return node
# [... for x in _getiter(expr) ...]
if isinstance(node.target, ast.Name):
self.checkName(node, node.target.id)
def visitGetattr(self, node, walker):
"""Converts attribute access to a function call.
'foo.bar' becomes '_getattr(foo, "bar")'.
# XXX: Exception! If the target is an attribute access.
# Change it manually.
if isinstance(node.target, ast.Attribute):
self.checkAttrName(node.target)
node.target.value = ast.Call(_write_name, [node.target.value], [], None, None)
Also prevents augmented assignment of attributes, which would
be difficult to support correctly.
"""
if not isinstance(node.iter, ast.Tuple):
node.iter = ast.Call(_getiter_name, [node.iter], [], None, None)
for i, arg in enumerate(node.iter.args):
if isinstance(arg, ast.AST):
node.iter.args[i] = self.visit(arg)
node.iter = self.unpackSequence(node.iter)
for i, item in enumerate(node.ifs):
if isinstance(item, ast.AST):
node.ifs[i] = self.visit(item)
ast.fix_missing_locations(node)
return node
def visit_Attribute(self, node):
# """Converts attribute access to a function call.
#
# 'foo.bar' becomes '_getattr(foo, "bar")'.
#
# Also prevents augmented assignment of attributes, which would
# be difficult to support correctly.
# """
# assert(isinstance(node, ast.Attribute))
self.checkAttrName(node)
node = walker.defaultVisitNode(node)
if getattr(node, 'in_aug_assign', False):
# We're in an augmented assignment
# We might support this later...
self.error(node, 'Augmented assignment of '
'attributes is not allowed.')
return ast.CallFunc(_getattr_name,
[node.expr, ast.Const(node.attrname)])
node = ast.Call(_getattr_name,
[node.value, ast.Str(node.attr)], [], None, None)
ast.fix_missing_locations(node)
return node
def visitSubscript(self, node, walker):
def visit_Subscript(self, node):
"""Checks all kinds of subscripts.
This prevented in Augassgin
'foo[bar] += baz' is disallowed.
Change all 'foo[bar]' to '_getitem(foo, bar)':
'a = foo[bar, baz]' becomes 'a = _getitem(foo, (bar, baz))'.
'a = foo[bar]' becomes 'a = _getitem(foo, bar)'.
'a = foo[bar:baz]' becomes 'a = _getitem(foo, slice(bar, baz))'.
'a = foo[:baz]' becomes 'a = _getitem(foo, slice(None, baz))'.
'a = foo[bar:]' becomes 'a = _getitem(foo, slice(bar, None))'.
Not include the below:
'del foo[bar]' becomes 'del _write(foo)[bar]'.
'foo[bar] = a' becomes '_write(foo)[bar] = a'.
The _write function returns a security proxy.
"""
node = walker.defaultVisitNode(node)
if node.flags == OP_APPLY:
# Set 'subs' to the node that represents the subscript or slice.
if getattr(node, 'in_aug_assign', False):
# We're in an augmented assignment
# We might support this later...
self.error(node, 'Augmented assignment of '
'object items and slices is not allowed.')
if hasattr(node, 'subs'):
# Subscript.
subs = node.subs
if len(subs) > 1:
# example: ob[1,2]
subs = ast.Tuple(subs)
else:
# example: ob[1]
subs = subs[0]
else:
# Slice.
# example: obj[0:2]
lower = node.lower
if lower is None:
lower = _None_const
upper = node.upper
if upper is None:
upper = _None_const
subs = ast.Sliceobj([lower, upper])
return ast.CallFunc(_getitem_name, [node.expr, subs])
elif node.flags in (OP_DELETE, OP_ASSIGN):
# set or remove subscript or slice
node.expr = ast.CallFunc(_write_name, [node.expr])
# convert the 'foo[bar]' to '_getitem(foo, bar)' by default.
if isinstance(node.slice, ast.Index):
new_node = ast.copy_location(ast.Call(_getitem_name,
[
node.value,
node.slice.value
],
[], None, None), node)
ast.fix_missing_locations(new_node)
return new_node
elif isinstance(node.slice, ast.Slice):
lower = node.slice.lower
upper = node.slice.upper
step = node.slice.step
new_node = ast.copy_location(ast.Call(_getitem_name,
[
node.value,
ast.Call(ast.Name('slice', ast.Load()),
[
lower if lower else _None_const ,
upper if upper else _None_const ,
step if step else _None_const ,
], [], None, None),
],
[], None, None), node)
# return new_node
ast.fix_missing_locations(new_node)
return new_node
return node
visitSlice = visitSubscript
def visitAssAttr(self, node, walker):
"""Checks and mutates attribute assignment.
'a.b = c' becomes '_write(a).b = c'.
The _write function returns a security proxy.
"""
self.checkAttrName(node)
node = walker.defaultVisitNode(node)
node.expr = ast.CallFunc(_write_name, [node.expr])
return node
def visitExec(self, node, walker):
def visit_Exec(self, node):
self.error(node, 'Exec statements are not allowed.')
def visitYield(self, node, walker):
def visit_Yield(self, node):
self.error(node, 'Yield statements are not allowed.')
def visitClass(self, node, walker):
def visit_ClassDef(self, node):
"""Checks the name of a class using checkName().
Should classes be allowed at all? They don't cause security
......@@ -353,19 +381,96 @@ class RestrictionMutator:
code can't assign instance attributes.
"""
self.checkName(node, node.name)
return walker.defaultVisitNode(node)
return node
def visitModule(self, node, walker):
def visit_Module(self, node):
"""Adds prep code at module scope.
Zope doesn't make use of this. The body of Python scripts is
always at function scope.
"""
node = walker.defaultVisitNode(node)
self.prepBody(node.node.nodes)
self.generic_visit(node)
self.prepBody(node.body)
node.lineno = 0
node.col_offset = 0
ast.fix_missing_locations(node)
return node
def visit_Delete(self, node):
"""
'del foo[bar]' becomes 'del _write(foo)[bar]'
"""
# the foo[bar] will convert to '_getitem(foo, bar)' first
# so here need to convert the '_getitem(foo, bar)' to '_write(foo)[bar]'
# please let me know if you have a better idea. Boxiang.
for i, target in enumerate(node.targets):
if isinstance(target, ast.Subscript):
node.targets[i].value = ast.Call(_write_name, [target.value,], [], None, None)
ast.fix_missing_locations(node)
return node
def visit_With(self, node):
"""Checks and mutates the attribute access in with statement.
'with x as x.y' becomes 'with x as _write(x).y'
The _write function returns a security proxy.
"""
if isinstance(node.optional_vars, ast.Name):
self.checkName(node, node.optional_vars.id)
if isinstance(node.optional_vars, ast.Attribute):
self.checkAttrName(node.optional_vars)
node.optional_vars.value = ast.Call(_write_name, [node.optional_vars.value], [], None, None)
node.context_expr = self.visit(node.context_expr)
for item in node.body:
self.visit(item)
ast.fix_missing_locations(node)
return node
def unpackSequence(self, node):
if isinstance(node, ast.Tuple) or isinstance(node, ast.List):
for i, item in enumerate(node.elts):
node.elts[i] = self.unpackSequence(item)
node = ast.Call(_getiter_name, [node], [], None, None)
return node
def visitAugAssign(self, node, walker):
def visit_Assign(self, node):
"""Checks and mutates some assignment.
'
'a.b = c' becomes '_write(a).b = c'.
'foo[bar] = a' becomes '_write(foo)[bar] = a'
The _write function returns a security proxy.
"""
# Change the left side to '_write(a).b = c' in below.
for i, target in enumerate(node.targets):
if isinstance(target, ast.Name):
self.checkName(node, target.id)
elif isinstance(target, ast.Attribute):
self.checkAttrName(target)
node.targets[i].value = ast.Call(_write_name, [node.targets[i].value], [], None, None)
elif isinstance(target, ast.Subscript):
node.targets[i].value = ast.Call(_write_name, [node.targets[i].value], [], None, None)
node.value = self.visit(node.value)
# The purpose of this just want to call `_getiter` to generate a list from sequence.
# The check is in unpackSequence, TODO: duplicate with the previous statement?
# If the node.targets is not a tuple, do not rewrite the UNPACK_SEQUENCE, this is for no_unpack
# test in before_and_after.py
if isinstance(node.targets[0], ast.Tuple):
node.value = self.unpackSequence(node.value)
# # change the right side
#
# # For 'foo[bar] = baz'
# # elif isinstance(node.targets[0], ast.Attribute):
ast.fix_missing_locations(node)
return node
def visit_AugAssign(self, node):
"""Makes a note that augmented assignment is in use.
Note that although augmented assignment of attributes and
......@@ -375,30 +480,56 @@ class RestrictionMutator:
This could be a problem if untrusted code got access to a
mutable database object that supports augmented assignment.
"""
if node.node.__class__.__name__ == 'Name':
node = walker.defaultVisitNode(node)
newnode = ast.Assign(
[ast.AssName(node.node.name, OP_ASSIGN)],
ast.CallFunc(
_inplacevar_name,
[ast.Const(node.op),
ast.Name(node.node.name),
node.expr,
]
),
)
newnode.lineno = node.lineno
return newnode
else:
node.node.in_aug_assign = True
return walker.defaultVisitNode(node)
# XXX: This error originally defined in visitGetattr.
# But the ast.AST is different than compiler.ast.Node
# Which there has no Getatr node. The corresponding Attribute
# has nothing related with augment assign.
# So the parser will try to convert all foo.bar to '_getattr(foo, "bar")
# first, then enter this function to process augment operation.
# In this situation, we need to check ast.Call rather than ast.Attribute.
if isinstance(node.target, ast.Subscript):
self.error(node, 'Augment assignment of '
'object items and slices is not allowed.')
elif isinstance(node.target, ast.Attribute):
self.error(node, 'Augmented assignment of '
'attributes is not allowed.')
if isinstance(node.target, ast.Name):
# 'n += bar' becomes 'n = _inplace_var('+=', n, bar)'
# TODO, may contians serious problem. Do we should use ast.Name???
new_node = ast.Assign([node.target], ast.Call(_inplacevar_name, [ast.Name(node.target.id, ast.Load()), node.value], [], None, None))
if isinstance(node.op, ast.Add):
new_node.value.args.insert(0, ast.Str('+='))
elif isinstance(node.op, ast.Sub):
new_node.value.args.insert(0, ast.Str('-='))
elif isinstance(node.op, ast.Mult):
new_node.value.args.insert(0, ast.Str('*='))
elif isinstance(node.op, ast.Div):
new_node.value.args.insert(0, ast.Str('/='))
elif isinstance(node.op, ast.Mod):
new_node.value.args.insert(0, ast.Str('%='))
elif isinstance(node.op, ast.Pow):
new_node.value.args.insert(0, ast.Str('**='))
elif isinstance(node.op, ast.RShift):
new_node.value.args.insert(0, ast.Str('>>='))
elif isinstance(node.op, ast.LShift):
new_node.value.args.insert(0, ast.Str('<<='))
elif isinstance(node.op, ast.BitAnd):
new_node.value.args.insert(0, ast.Str('&='))
elif isinstance(node.op, ast.BitXor):
new_node.value.args.insert(0, ast.Str('^='))
elif isinstance(node.op, ast.BitOr):
new_node.value.args.insert(0, ast.Str('|='))
ast.fix_missing_locations(new_node)
return new_node
ast.fix_missing_locations(node)
return node
def visitImport(self, node, walker):
def visit_Import(self, node):
"""Checks names imported using checkName()."""
for name, asname in node.names:
self.checkName(node, name)
if asname:
self.checkName(node, asname)
for alias in node.names:
self.checkName(node, alias.name)
if alias.asname:
self.checkName(node, alias.asname)
return node
visitFrom = visitImport
visit_ImportFrom = visit_Import
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment