Commit 844bee64 authored by Vitja Makarov's avatar Vitja Makarov

Evaluate literal defaults only once

parent db36032e
......@@ -6132,6 +6132,8 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
nonliteral_objects.append(arg)
else:
nonliteral_other.append(arg)
else:
arg.default = DefaultLiteralArgNode(arg.pos, arg.default)
default_args.append(arg)
if nonliteral_objects or nonliteral_objects:
module_scope = env.global_scope()
......@@ -6408,13 +6410,43 @@ class CodeObjectNode(ExprNode):
))
class DefaultArgNode(ExprNode):
class DefaultLiteralArgNode(ExprNode):
# CyFunction's literal argument default value
#
# Evaluate literal only once.
subexprs = []
is_literal = True
is_temp = False
def __init__(self, pos, arg):
super(DefaultLiteralArgNode, self).__init__(pos)
self.arg = arg
self.type = self.arg.type
self.evaluated = False
def analyse_types(self, env):
pass
def generate_result_code(self, code):
pass
def generate_evaluation_code(self, code):
if not self.evaluated:
self.arg.generate_evaluation_code(code)
self.evaluated = True
def result(self):
return self.type.cast_code(self.arg.result())
class DefaultNonLiteralArgNode(ExprNode):
# CyFunction's non-literal argument default value
subexprs = []
def __init__(self, pos, arg, defaults_struct):
super(DefaultArgNode, self).__init__(pos)
super(DefaultNonLiteralArgNode, self).__init__(pos)
self.arg = arg
self.defaults_struct = defaults_struct
......@@ -6438,7 +6470,7 @@ class DefaultsTupleNode(TupleNode):
args = []
for arg in defaults:
if not arg.default.is_literal:
arg = DefaultArgNode(pos, arg, defaults_struct)
arg = DefaultNonLiteralArgNode(pos, arg, defaults_struct)
else:
arg = arg.default
args.append(arg)
......
......@@ -5,11 +5,6 @@
import sys
def get_defaults(func):
"""
>>> get_defaults(get_defaults)
>>> hasattr(get_defaults, '__defaults__') and get_defaults.__defaults__
>>> hasattr(get_defaults, 'func_defaults') and get_defaults.func_defaults
"""
if sys.version_info >= (2, 5, 0):
return func.__defaults__
return func.func_defaults
......@@ -19,11 +14,20 @@ def test_defaults_none():
>>> get_defaults(test_defaults_none)
"""
def test_defaults_literal(a=1, b=[], c={}):
def test_defaults_literal(a=1, b=(1,2,3)):
"""
>>> get_defaults(test_defaults_literal) is get_defaults(test_defaults_literal)
True
>>> get_defaults(test_defaults_literal)
(1, [], {})
(1, (1, 2, 3))
>>> a, b = get_defaults(test_defaults_literal)
>>> c, d = test_defaults_literal()
>>> a is c
True
>>> b is d
True
"""
return a, b
def test_defaults_nonliteral():
"""
......@@ -31,15 +35,27 @@ def test_defaults_nonliteral():
>>> get_defaults(f0) is get_defaults(f0) # cached
True
>>> get_defaults(f0)
(0, {})
(0, {}, (1, 2, 3))
>>> a, b = get_defaults(f0)[1:]
>>> c, d = f0(0)
>>> a is c
True
>>> b is d
True
>>> get_defaults(f1) is get_defaults(f1) # cached
True
>>> get_defaults(f1)
(0, [])
(0, [], (1, 2, 3))
>>> a, b = get_defaults(f1)[1:]
>>> c, d = f1(0)
>>> a is c
True
>>> b is d
True
"""
ret = []
for i in {}, []:
def foo(a, b=0, c=i):
pass
def foo(a, b=0, c=i, d=(1,2,3)):
return c, d
ret.append(foo)
return ret
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment