Commit 844bee64 authored by Vitja Makarov's avatar Vitja Makarov

Evaluate literal defaults only once

parent db36032e
...@@ -6132,6 +6132,8 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6132,6 +6132,8 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
nonliteral_objects.append(arg) nonliteral_objects.append(arg)
else: else:
nonliteral_other.append(arg) nonliteral_other.append(arg)
else:
arg.default = DefaultLiteralArgNode(arg.pos, arg.default)
default_args.append(arg) default_args.append(arg)
if nonliteral_objects or nonliteral_objects: if nonliteral_objects or nonliteral_objects:
module_scope = env.global_scope() module_scope = env.global_scope()
...@@ -6408,13 +6410,43 @@ class CodeObjectNode(ExprNode): ...@@ -6408,13 +6410,43 @@ class CodeObjectNode(ExprNode):
)) ))
class DefaultArgNode(ExprNode): class DefaultLiteralArgNode(ExprNode):
# CyFunction's literal argument default value
#
# Evaluate literal only once.
subexprs = []
is_literal = True
is_temp = False
def __init__(self, pos, arg):
super(DefaultLiteralArgNode, self).__init__(pos)
self.arg = arg
self.type = self.arg.type
self.evaluated = False
def analyse_types(self, env):
pass
def generate_result_code(self, code):
pass
def generate_evaluation_code(self, code):
if not self.evaluated:
self.arg.generate_evaluation_code(code)
self.evaluated = True
def result(self):
return self.type.cast_code(self.arg.result())
class DefaultNonLiteralArgNode(ExprNode):
# CyFunction's non-literal argument default value # CyFunction's non-literal argument default value
subexprs = [] subexprs = []
def __init__(self, pos, arg, defaults_struct): def __init__(self, pos, arg, defaults_struct):
super(DefaultArgNode, self).__init__(pos) super(DefaultNonLiteralArgNode, self).__init__(pos)
self.arg = arg self.arg = arg
self.defaults_struct = defaults_struct self.defaults_struct = defaults_struct
...@@ -6438,7 +6470,7 @@ class DefaultsTupleNode(TupleNode): ...@@ -6438,7 +6470,7 @@ class DefaultsTupleNode(TupleNode):
args = [] args = []
for arg in defaults: for arg in defaults:
if not arg.default.is_literal: if not arg.default.is_literal:
arg = DefaultArgNode(pos, arg, defaults_struct) arg = DefaultNonLiteralArgNode(pos, arg, defaults_struct)
else: else:
arg = arg.default arg = arg.default
args.append(arg) args.append(arg)
......
...@@ -5,11 +5,6 @@ ...@@ -5,11 +5,6 @@
import sys import sys
def get_defaults(func): def get_defaults(func):
"""
>>> get_defaults(get_defaults)
>>> hasattr(get_defaults, '__defaults__') and get_defaults.__defaults__
>>> hasattr(get_defaults, 'func_defaults') and get_defaults.func_defaults
"""
if sys.version_info >= (2, 5, 0): if sys.version_info >= (2, 5, 0):
return func.__defaults__ return func.__defaults__
return func.func_defaults return func.func_defaults
...@@ -19,11 +14,20 @@ def test_defaults_none(): ...@@ -19,11 +14,20 @@ def test_defaults_none():
>>> get_defaults(test_defaults_none) >>> get_defaults(test_defaults_none)
""" """
def test_defaults_literal(a=1, b=[], c={}): def test_defaults_literal(a=1, b=(1,2,3)):
""" """
>>> get_defaults(test_defaults_literal) is get_defaults(test_defaults_literal)
True
>>> get_defaults(test_defaults_literal) >>> get_defaults(test_defaults_literal)
(1, [], {}) (1, (1, 2, 3))
>>> a, b = get_defaults(test_defaults_literal)
>>> c, d = test_defaults_literal()
>>> a is c
True
>>> b is d
True
""" """
return a, b
def test_defaults_nonliteral(): def test_defaults_nonliteral():
""" """
...@@ -31,15 +35,27 @@ def test_defaults_nonliteral(): ...@@ -31,15 +35,27 @@ def test_defaults_nonliteral():
>>> get_defaults(f0) is get_defaults(f0) # cached >>> get_defaults(f0) is get_defaults(f0) # cached
True True
>>> get_defaults(f0) >>> get_defaults(f0)
(0, {}) (0, {}, (1, 2, 3))
>>> a, b = get_defaults(f0)[1:]
>>> c, d = f0(0)
>>> a is c
True
>>> b is d
True
>>> get_defaults(f1) is get_defaults(f1) # cached >>> get_defaults(f1) is get_defaults(f1) # cached
True True
>>> get_defaults(f1) >>> get_defaults(f1)
(0, []) (0, [], (1, 2, 3))
>>> a, b = get_defaults(f1)[1:]
>>> c, d = f1(0)
>>> a is c
True
>>> b is d
True
""" """
ret = [] ret = []
for i in {}, []: for i in {}, []:
def foo(a, b=0, c=i): def foo(a, b=0, c=i, d=(1,2,3)):
pass return c, d
ret.append(foo) ret.append(foo)
return ret return ret
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment