Commit d8ac4d1d authored by Serhiy Storchaka's avatar Serhiy Storchaka Committed by GitHub

bpo-31778: Make ast.literal_eval() more strict. (#4035)

Addition and subtraction of arbitrary numbers no longer allowed.
parent fbb490fd
...@@ -35,8 +35,6 @@ def parse(source, filename='<unknown>', mode='exec'): ...@@ -35,8 +35,6 @@ def parse(source, filename='<unknown>', mode='exec'):
return compile(source, filename, mode, PyCF_ONLY_AST) return compile(source, filename, mode, PyCF_ONLY_AST)
_NUM_TYPES = (int, float, complex)
def literal_eval(node_or_string): def literal_eval(node_or_string):
""" """
Safely evaluate an expression node or a string containing a Python Safely evaluate an expression node or a string containing a Python
...@@ -48,6 +46,21 @@ def literal_eval(node_or_string): ...@@ -48,6 +46,21 @@ def literal_eval(node_or_string):
node_or_string = parse(node_or_string, mode='eval') node_or_string = parse(node_or_string, mode='eval')
if isinstance(node_or_string, Expression): if isinstance(node_or_string, Expression):
node_or_string = node_or_string.body node_or_string = node_or_string.body
def _convert_num(node):
if isinstance(node, Constant):
if isinstance(node.value, (int, float, complex)):
return node.value
elif isinstance(node, Num):
return node.n
raise ValueError('malformed node or string: ' + repr(node))
def _convert_signed_num(node):
if isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
operand = _convert_num(node.operand)
if isinstance(node.op, UAdd):
return + operand
else:
return - operand
return _convert_num(node)
def _convert(node): def _convert(node):
if isinstance(node, Constant): if isinstance(node, Constant):
return node.value return node.value
...@@ -62,26 +75,19 @@ def literal_eval(node_or_string): ...@@ -62,26 +75,19 @@ def literal_eval(node_or_string):
elif isinstance(node, Set): elif isinstance(node, Set):
return set(map(_convert, node.elts)) return set(map(_convert, node.elts))
elif isinstance(node, Dict): elif isinstance(node, Dict):
return dict((_convert(k), _convert(v)) for k, v return dict(zip(map(_convert, node.keys),
in zip(node.keys, node.values)) map(_convert, node.values)))
elif isinstance(node, NameConstant): elif isinstance(node, NameConstant):
return node.value return node.value
elif isinstance(node, UnaryOp) and isinstance(node.op, (UAdd, USub)):
operand = _convert(node.operand)
if isinstance(operand, _NUM_TYPES):
if isinstance(node.op, UAdd):
return + operand
else:
return - operand
elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)): elif isinstance(node, BinOp) and isinstance(node.op, (Add, Sub)):
left = _convert(node.left) left = _convert_signed_num(node.left)
right = _convert(node.right) right = _convert_num(node.right)
if isinstance(left, _NUM_TYPES) and isinstance(right, _NUM_TYPES): if isinstance(left, (int, float)) and isinstance(right, complex):
if isinstance(node.op, Add): if isinstance(node.op, Add):
return left + right return left + right
else: else:
return left - right return left - right
raise ValueError('malformed node or string: ' + repr(node)) return _convert_signed_num(node)
return _convert(node_or_string) return _convert(node_or_string)
......
...@@ -551,14 +551,37 @@ class ASTHelpers_Test(unittest.TestCase): ...@@ -551,14 +551,37 @@ class ASTHelpers_Test(unittest.TestCase):
self.assertEqual(ast.literal_eval('{1, 2, 3}'), {1, 2, 3}) self.assertEqual(ast.literal_eval('{1, 2, 3}'), {1, 2, 3})
self.assertEqual(ast.literal_eval('b"hi"'), b"hi") self.assertEqual(ast.literal_eval('b"hi"'), b"hi")
self.assertRaises(ValueError, ast.literal_eval, 'foo()') self.assertRaises(ValueError, ast.literal_eval, 'foo()')
self.assertEqual(ast.literal_eval('6'), 6)
self.assertEqual(ast.literal_eval('+6'), 6)
self.assertEqual(ast.literal_eval('-6'), -6) self.assertEqual(ast.literal_eval('-6'), -6)
self.assertEqual(ast.literal_eval('-6j+3'), 3-6j)
self.assertEqual(ast.literal_eval('3.25'), 3.25) self.assertEqual(ast.literal_eval('3.25'), 3.25)
self.assertEqual(ast.literal_eval('+3.25'), 3.25)
def test_literal_eval_issue4907(self): self.assertEqual(ast.literal_eval('-3.25'), -3.25)
self.assertEqual(ast.literal_eval('2j'), 2j) self.assertEqual(repr(ast.literal_eval('-0.0')), '-0.0')
self.assertEqual(ast.literal_eval('10 + 2j'), 10 + 2j) self.assertRaises(ValueError, ast.literal_eval, '++6')
self.assertEqual(ast.literal_eval('1.5 - 2j'), 1.5 - 2j) self.assertRaises(ValueError, ast.literal_eval, '+True')
self.assertRaises(ValueError, ast.literal_eval, '2+3')
def test_literal_eval_complex(self):
# Issue #4907
self.assertEqual(ast.literal_eval('6j'), 6j)
self.assertEqual(ast.literal_eval('-6j'), -6j)
self.assertEqual(ast.literal_eval('6.75j'), 6.75j)
self.assertEqual(ast.literal_eval('-6.75j'), -6.75j)
self.assertEqual(ast.literal_eval('3+6j'), 3+6j)
self.assertEqual(ast.literal_eval('-3+6j'), -3+6j)
self.assertEqual(ast.literal_eval('3-6j'), 3-6j)
self.assertEqual(ast.literal_eval('-3-6j'), -3-6j)
self.assertEqual(ast.literal_eval('3.25+6.75j'), 3.25+6.75j)
self.assertEqual(ast.literal_eval('-3.25+6.75j'), -3.25+6.75j)
self.assertEqual(ast.literal_eval('3.25-6.75j'), 3.25-6.75j)
self.assertEqual(ast.literal_eval('-3.25-6.75j'), -3.25-6.75j)
self.assertEqual(ast.literal_eval('(3+6j)'), 3+6j)
self.assertRaises(ValueError, ast.literal_eval, '-6j+3')
self.assertRaises(ValueError, ast.literal_eval, '-6j+3j')
self.assertRaises(ValueError, ast.literal_eval, '3+-6j')
self.assertRaises(ValueError, ast.literal_eval, '3+(0+6j)')
self.assertRaises(ValueError, ast.literal_eval, '-(3+6j)')
def test_bad_integer(self): def test_bad_integer(self):
# issue13436: Bad error message with invalid numeric values # issue13436: Bad error message with invalid numeric values
...@@ -1077,11 +1100,11 @@ class ConstantTests(unittest.TestCase): ...@@ -1077,11 +1100,11 @@ class ConstantTests(unittest.TestCase):
ast.copy_location(new_left, binop.left) ast.copy_location(new_left, binop.left)
binop.left = new_left binop.left = new_left
new_right = ast.Constant(value=20) new_right = ast.Constant(value=20j)
ast.copy_location(new_right, binop.right) ast.copy_location(new_right, binop.right)
binop.right = new_right binop.right = new_right
self.assertEqual(ast.literal_eval(binop), 30) self.assertEqual(ast.literal_eval(binop), 10+20j)
def main(): def main():
......
...@@ -2074,7 +2074,7 @@ class TestSignatureObject(unittest.TestCase): ...@@ -2074,7 +2074,7 @@ class TestSignatureObject(unittest.TestCase):
self.assertEqual(p('f'), False) self.assertEqual(p('f'), False)
self.assertEqual(p('local'), 3) self.assertEqual(p('local'), 3)
self.assertEqual(p('sys'), sys.maxsize) self.assertEqual(p('sys'), sys.maxsize)
self.assertEqual(p('exp'), sys.maxsize - 1) self.assertNotIn('exp', signature.parameters)
test_callable(object) test_callable(object)
......
ast.literal_eval() is now more strict. Addition and subtraction of
arbitrary numbers no longer allowed.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment