Commit fe0eeeb3 authored by da-woods's avatar da-woods Committed by GitHub

Implement PEP 572: Named/Assignment Expressions (GH-3691)

Closes https://github.com/cython/cython/issues/2636
parent cf88658e
...@@ -615,6 +615,9 @@ class ExprNode(Node): ...@@ -615,6 +615,9 @@ class ExprNode(Node):
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
error(self.pos, "Cannot assign to or delete this") error(self.pos, "Cannot assign to or delete this")
def analyse_assignment_expression_target_declaration(self, env):
error(self.pos, "Cannot use anything except a name in an assignment expression")
# ------------- Expression Analysis ---------------- # ------------- Expression Analysis ----------------
def analyse_const_expression(self, env): def analyse_const_expression(self, env):
...@@ -2083,8 +2086,17 @@ class NameNode(AtomicExprNode): ...@@ -2083,8 +2086,17 @@ class NameNode(AtomicExprNode):
return None return None
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
return self._analyse_target_declaration(env, is_assignment_expression=False)
def analyse_assignment_expression_target_declaration(self, env):
return self._analyse_target_declaration(env, is_assignment_expression=True)
def _analyse_target_declaration(self, env, is_assignment_expression):
self.is_target = True self.is_target = True
if not self.entry: if not self.entry:
if is_assignment_expression:
self.entry = env.lookup_assignment_expression_target(self.name)
else:
self.entry = env.lookup_here(self.name) self.entry = env.lookup_here(self.name)
if not self.entry and self.annotation is not None: if not self.entry and self.annotation is not None:
# name : type = ... # name : type = ...
...@@ -2096,6 +2108,9 @@ class NameNode(AtomicExprNode): ...@@ -2096,6 +2108,9 @@ class NameNode(AtomicExprNode):
type = unspecified_type type = unspecified_type
else: else:
type = py_object_type type = py_object_type
if is_assignment_expression:
self.entry = env.declare_assignment_expression_target(self.name, type, self.pos)
else:
self.entry = env.declare_var(self.name, type, self.pos) self.entry = env.declare_var(self.name, type, self.pos)
if self.entry.is_declared_generic: if self.entry.is_declared_generic:
self.result_ctype = py_object_type self.result_ctype = py_object_type
...@@ -13715,17 +13730,17 @@ class ProxyNode(CoercionNode): ...@@ -13715,17 +13730,17 @@ class ProxyNode(CoercionNode):
def __init__(self, arg): def __init__(self, arg):
super(ProxyNode, self).__init__(arg) super(ProxyNode, self).__init__(arg)
self.constant_result = arg.constant_result self.constant_result = arg.constant_result
self._proxy_type() self.update_type_and_entry()
def analyse_types(self, env): def analyse_types(self, env):
self.arg = self.arg.analyse_expressions(env) self.arg = self.arg.analyse_expressions(env)
self._proxy_type() self.update_type_and_entry()
return self return self
def infer_type(self, env): def infer_type(self, env):
return self.arg.infer_type(env) return self.arg.infer_type(env)
def _proxy_type(self): def update_type_and_entry(self):
type = getattr(self.arg, 'type', None) type = getattr(self.arg, 'type', None)
if type: if type:
self.type = type self.type = type
...@@ -13989,3 +14004,102 @@ class AnnotationNode(ExprNode): ...@@ -13989,3 +14004,102 @@ class AnnotationNode(ExprNode):
else: else:
warning(annotation.pos, "Unknown type declaration in annotation, ignoring") warning(annotation.pos, "Unknown type declaration in annotation, ignoring")
return base_type, arg_type return base_type, arg_type
class AssignmentExpressionNode(ExprNode):
"""
Also known as a named expression or the walrus operator
Arguments
lhs - NameNode - not stored directly as an attribute of the node
rhs - ExprNode
Attributes
rhs - ExprNode
assignment - SingleAssignmentNode
"""
# subexprs and child_attrs are intentionally different here, because the assignment is not an expression
subexprs = ["rhs"]
child_attrs = ["rhs", "assignment"] # This order is important for control-flow (i.e. xdecref) to be right
is_temp = False
assignment = None
clone_node = None
def __init__(self, pos, lhs, rhs, **kwds):
super(AssignmentExpressionNode, self).__init__(pos, **kwds)
self.rhs = ProxyNode(rhs)
assign_expr_rhs = CloneNode(self.rhs)
self.assignment = SingleAssignmentNode(
pos, lhs=lhs, rhs=assign_expr_rhs, is_assignment_expression=True)
@property
def type(self):
return self.rhs.type
@property
def target_name(self):
return self.assignment.lhs.name
def infer_type(self, env):
return self.rhs.infer_type(env)
def analyse_declarations(self, env):
self.assignment.analyse_declarations(env)
def analyse_types(self, env):
# we're trying to generate code that looks roughly like:
# __pyx_t_1 = rhs
# lhs = __pyx_t_1
# __pyx_t_1
# (plus any reference counting that's needed)
self.rhs = self.rhs.analyse_types(env)
if not self.rhs.arg.is_temp:
if not self.rhs.arg.is_literal:
# for anything but the simplest cases (where it can be used directly)
# we convert rhs to a temp, because CloneNode requires arg to be a temp
self.rhs.arg = self.rhs.arg.coerce_to_temp(env)
else:
# For literals we can optimize by just using the literal twice
#
# We aren't including `self.rhs.is_name` in this optimization
# because that goes wrong for assignment expressions run in
# parallel. e.g. `(a := b) + (b := a + c)`)
# This is a special case of https://github.com/cython/cython/issues/4146
# TODO - once that's fixed general revisit this code and possibly
# use coerce_to_simple
self.assignment.rhs = copy.copy(self.rhs)
# TODO - there's a missed optimization in the code generation stage
# for self.rhs.arg.is_temp: an incref/decref pair can be removed
# (but needs a general mechanism to do that)
self.assignment = self.assignment.analyse_types(env)
return self
def coerce_to(self, dst_type, env):
if dst_type == self.assignment.rhs.type:
# in this quite common case (for example, when both lhs, and self are being coerced to Python)
# we can optimize the coercion out by sharing it between
# this and the assignment
old_rhs_arg = self.rhs.arg
if isinstance(old_rhs_arg, CoerceToTempNode):
old_rhs_arg = old_rhs_arg.arg
rhs_arg = old_rhs_arg.coerce_to(dst_type, env)
if rhs_arg is not old_rhs_arg:
self.rhs.arg = rhs_arg
self.rhs.update_type_and_entry()
# clean up the old coercion node that the assignment has likely generated
if (isinstance(self.assignment.rhs, CoercionNode)
and not isinstance(self.assignment.rhs, CloneNode)):
self.assignment.rhs = self.assignment.rhs.arg
self.assignment.rhs.type = self.assignment.rhs.arg.type
return self
return super(AssignmentExpressionNode, self).coerce_to(dst_type, env)
def calculate_result_code(self):
return self.rhs.result()
def generate_result_code(self, code):
# we have to do this manually because it isn't a subexpression
self.assignment.generate_execution_code(code)
...@@ -590,7 +590,7 @@ def check_definitions(flow, compiler_directives): ...@@ -590,7 +590,7 @@ def check_definitions(flow, compiler_directives):
if (node.allow_null or entry.from_closure if (node.allow_null or entry.from_closure
or entry.is_pyclass_attr or entry.type.is_error): or entry.is_pyclass_attr or entry.type.is_error):
pass # Can be uninitialized here pass # Can be uninitialized here
elif node.cf_is_null: elif node.cf_is_null and not entry.in_closure:
if entry.error_on_uninitialized or ( if entry.error_on_uninitialized or (
Options.error_on_uninitialized and ( Options.error_on_uninitialized and (
entry.type.is_pyobject or entry.type.is_unspecified)): entry.type.is_pyobject or entry.type.is_unspecified)):
...@@ -604,10 +604,12 @@ def check_definitions(flow, compiler_directives): ...@@ -604,10 +604,12 @@ def check_definitions(flow, compiler_directives):
"local variable '%s' referenced before assignment" "local variable '%s' referenced before assignment"
% entry.name) % entry.name)
elif warn_maybe_uninitialized: elif warn_maybe_uninitialized:
msg = "local variable '%s' might be referenced before assignment" % entry.name
if entry.in_closure:
msg += " (maybe initialized inside a closure)"
messages.warning( messages.warning(
node.pos, node.pos,
"local variable '%s' might be referenced before assignment" msg)
% entry.name)
elif Unknown in node.cf_state: elif Unknown in node.cf_state:
# TODO: better cross-closure analysis to know when inner functions # TODO: better cross-closure analysis to know when inner functions
# are being called before a variable is being set, and when # are being called before a variable is being set, and when
......
...@@ -77,7 +77,7 @@ def make_lexicon(): ...@@ -77,7 +77,7 @@ def make_lexicon():
punct = Any(":,;+-*/|&<>=.%`~^?!@") punct = Any(":,;+-*/|&<>=.%`~^?!@")
diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//", diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//",
"+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=", "+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=",
"<<=", ">>=", "**=", "//=", "->", "@=", "&&", "||") "<<=", ">>=", "**=", "//=", "->", "@=", "&&", "||", ':=')
spaces = Rep1(Any(" \t\f")) spaces = Rep1(Any(" \t\f"))
escaped_newline = Str("\\\n") escaped_newline = Str("\\\n")
lineterm = Eol + Opt(Str("\n")) lineterm = Eol + Opt(Str("\n"))
......
...@@ -23,7 +23,7 @@ from . import PyrexTypes ...@@ -23,7 +23,7 @@ from . import PyrexTypes
from . import TypeSlots from . import TypeSlots
from .PyrexTypes import py_object_type, error_type from .PyrexTypes import py_object_type, error_type
from .Symtab import (ModuleScope, LocalScope, ClosureScope, PropertyScope, from .Symtab import (ModuleScope, LocalScope, ClosureScope, PropertyScope,
StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope, StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope, GeneratorExpressionScope,
CppScopedEnumScope, punycodify_name) CppScopedEnumScope, punycodify_name)
from .Code import UtilityCode from .Code import UtilityCode
from .StringEncoding import EncodedString from .StringEncoding import EncodedString
...@@ -1744,6 +1744,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1744,6 +1744,7 @@ class FuncDefNode(StatNode, BlockNode):
needs_outer_scope = False needs_outer_scope = False
pymethdef_required = False pymethdef_required = False
is_generator = False is_generator = False
is_generator_expression = False # this can be True alongside is_generator
is_coroutine = False is_coroutine = False
is_asyncgen = False is_asyncgen = False
is_generator_body = False is_generator_body = False
...@@ -1815,7 +1816,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -1815,7 +1816,8 @@ class FuncDefNode(StatNode, BlockNode):
while genv.is_py_class_scope or genv.is_c_class_scope: while genv.is_py_class_scope or genv.is_c_class_scope:
genv = genv.outer_scope genv = genv.outer_scope
if self.needs_closure: if self.needs_closure:
lenv = ClosureScope(name=self.entry.name, cls = GeneratorExpressionScope if self.is_generator_expression else ClosureScope
lenv = cls(name=self.entry.name,
outer_scope=genv, outer_scope=genv,
parent_scope=env, parent_scope=env,
scope_name=self.entry.cname) scope_name=self.entry.cname)
...@@ -5748,12 +5750,14 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5748,12 +5750,14 @@ class SingleAssignmentNode(AssignmentNode):
# rhs ExprNode Right hand side # rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs? # first bool Is this guaranteed the first assignment to lhs?
# is_overloaded_assignment bool Is this assignment done via an overloaded operator= # is_overloaded_assignment bool Is this assignment done via an overloaded operator=
# is_assignment_expression bool Internally SingleAssignmentNode is used to implement assignment expressions
# exception_check # exception_check
# exception_value # exception_value
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
first = False first = False
is_overloaded_assignment = False is_overloaded_assignment = False
is_assignment_expression = False
declaration_only = False declaration_only = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
...@@ -5837,6 +5841,9 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5837,6 +5841,9 @@ class SingleAssignmentNode(AssignmentNode):
if self.declaration_only: if self.declaration_only:
return return
else:
if self.is_assignment_expression:
self.lhs.analyse_assignment_expression_target_declaration(env)
else: else:
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
......
...@@ -183,6 +183,8 @@ class PostParse(ScopeTrackingTransform): ...@@ -183,6 +183,8 @@ class PostParse(ScopeTrackingTransform):
Note: Currently Parsing.py does a lot of interpretation and Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform reorganization that can be refactored into this transform
if a more pure Abstract Syntax Tree is wanted. if a more pure Abstract Syntax Tree is wanted.
- Some invalid uses of := assignment expressions are detected
""" """
def __init__(self, context): def __init__(self, context):
super(PostParse, self).__init__(context) super(PostParse, self).__init__(context)
...@@ -215,7 +217,9 @@ class PostParse(ScopeTrackingTransform): ...@@ -215,7 +217,9 @@ class PostParse(ScopeTrackingTransform):
node.def_node = Nodes.DefNode( node.def_node = Nodes.DefNode(
node.pos, name=node.name, doc=None, node.pos, name=node.name, doc=None,
args=[], star_arg=None, starstar_arg=None, args=[], star_arg=None, starstar_arg=None,
body=node.loop, is_async_def=collector.has_await) body=node.loop, is_async_def=collector.has_await,
is_generator_expression=True)
_AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass"))
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -226,6 +230,7 @@ class PostParse(ScopeTrackingTransform): ...@@ -226,6 +230,7 @@ class PostParse(ScopeTrackingTransform):
collector.visitchildren(node.loop) collector.visitchildren(node.loop)
if collector.has_await: if collector.has_await:
node.has_local_scope = True node.has_local_scope = True
_AssignmentExpressionChecker.do_checks(node.loop, scope_is_class=self.scope_type in ("pyclass", "cclass"))
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -378,6 +383,124 @@ class PostParse(ScopeTrackingTransform): ...@@ -378,6 +383,124 @@ class PostParse(ScopeTrackingTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class _AssignmentExpressionTargetNameFinder(TreeVisitor):
def __init__(self):
super(_AssignmentExpressionTargetNameFinder, self).__init__()
self.target_names = {}
def find_target_names(self, target):
if target.is_name:
return [target.name]
elif target.is_sequence_constructor:
names = []
for arg in target.args:
names.extend(self.find_target_names(arg))
return names
# other targets are possible, but it isn't necessary to investigate them here
return []
def visit_ForInStatNode(self, node):
self.target_names[node] = tuple(self.find_target_names(node.target))
self.visitchildren(node)
def visit_ComprehensionNode(self, node):
pass # don't recurse into nested comprehensions
def visit_LambdaNode(self, node):
pass # don't recurse into nested lambdas/generator expressions
def visit_Node(self, node):
self.visitchildren(node)
class _AssignmentExpressionChecker(TreeVisitor):
"""
Enforces rules on AssignmentExpressions within generator expressions and comprehensions
"""
def __init__(self, loop_node, scope_is_class):
super(_AssignmentExpressionChecker, self).__init__()
target_name_finder = _AssignmentExpressionTargetNameFinder()
target_name_finder.visit(loop_node)
self.target_names_dict = target_name_finder.target_names
self.in_iterator = False
self.in_nested_generator = False
self.scope_is_class = scope_is_class
self.current_target_names = ()
self.all_target_names = set()
for names in self.target_names_dict.values():
self.all_target_names.update(names)
def _reset_state(self):
old_state = (self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names)
# note: not resetting self.in_iterator here, see visit_LambdaNode() below
self.in_nested_generator = False
self.scope_is_class = False
self.current_target_names = ()
self.all_target_names = set()
return old_state
def _set_state(self, old_state):
self.in_iterator, self.in_nested_generator, self.scope_is_class, self.all_target_names, self.current_target_names = old_state
@classmethod
def do_checks(cls, loop_node, scope_is_class):
checker = cls(loop_node, scope_is_class)
checker.visit(loop_node)
def visit_ForInStatNode(self, node):
if self.in_nested_generator:
self.visitchildren(node) # once nested, don't do anything special
return
current_target_names = self.current_target_names
target_name = self.target_names_dict.get(node, None)
if target_name:
self.current_target_names += target_name
self.in_iterator = True
self.visit(node.iterator)
self.in_iterator = False
self.visitchildren(node, exclude=("iterator",))
self.current_target_names = current_target_names
def visit_AssignmentExpressionNode(self, node):
if self.in_iterator:
error(node.pos, "assignment expression cannot be used in a comprehension iterable expression")
if self.scope_is_class:
error(node.pos, "assignment expression within a comprehension cannot be used in a class body")
if node.target_name in self.current_target_names:
error(node.pos, "assignment expression cannot rebind comprehension iteration variable '%s'" %
node.target_name)
elif node.target_name in self.all_target_names:
error(node.pos, "comprehension inner loop cannot rebind assignment expression target '%s'" %
node.target_name)
def visit_LambdaNode(self, node):
# Don't reset "in_iterator" - an assignment expression in a lambda in an
# iterator is explicitly tested by the Python testcases and banned.
old_state = self._reset_state()
# the lambda node's "def_node" is not set up at this point, so we need to recurse into it explicitly.
self.visit(node.result_expr)
self._set_state(old_state)
def visit_ComprehensionNode(self, node):
in_nested_generator = self.in_nested_generator
self.in_nested_generator = True
self.visitchildren(node)
self.in_nested_generator = in_nested_generator
def visit_GeneratorExpressionNode(self, node):
in_nested_generator = self.in_nested_generator
self.in_nested_generator = True
# def_node isn't set up yet, so we need to visit the loop directly.
self.visit(node.loop)
self.in_nested_generator = in_nested_generator
def visit_Node(self, node):
self.visitchildren(node)
def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence): def eliminate_rhs_duplicates(expr_list_list, ref_node_sequence):
"""Replace rhs items by LetRefNodes if they appear more than once. """Replace rhs items by LetRefNodes if they appear more than once.
...@@ -2269,6 +2392,11 @@ if VALUE is not None: ...@@ -2269,6 +2392,11 @@ if VALUE is not None:
property.doc = entry.doc property.doc = entry.doc
return property return property
def visit_AssignmentExpressionNode(self, node):
self.visitchildren(node)
node.analyse_declarations(self.current_env())
return node
class CalculateQualifiedNamesTransform(EnvTransform): class CalculateQualifiedNamesTransform(EnvTransform):
""" """
...@@ -2806,7 +2934,8 @@ class MarkClosureVisitor(CythonTransform): ...@@ -2806,7 +2934,8 @@ class MarkClosureVisitor(CythonTransform):
star_arg=node.star_arg, starstar_arg=node.starstar_arg, star_arg=node.star_arg, starstar_arg=node.starstar_arg,
doc=node.doc, decorators=node.decorators, doc=node.doc, decorators=node.decorators,
gbody=gbody, lambda_name=node.lambda_name, gbody=gbody, lambda_name=node.lambda_name,
return_type_annotation=node.return_type_annotation) return_type_annotation=node.return_type_annotation,
is_generator_expression=node.is_generator_expression)
return coroutine return coroutine
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
......
...@@ -23,14 +23,15 @@ cdef tuple p_binop_operator(PyrexScanner s) ...@@ -23,14 +23,15 @@ cdef tuple p_binop_operator(PyrexScanner s)
cdef p_binop_expr(PyrexScanner s, ops, p_sub_expr_func p_sub_expr) cdef p_binop_expr(PyrexScanner s, ops, p_sub_expr_func p_sub_expr)
cdef p_lambdef(PyrexScanner s, bint allow_conditional=*) cdef p_lambdef(PyrexScanner s, bint allow_conditional=*)
cdef p_lambdef_nocond(PyrexScanner s) cdef p_lambdef_nocond(PyrexScanner s)
cdef p_test(PyrexScanner s) cdef p_test(PyrexScanner s, bint allow_assignment_expression=*)
cdef p_test_nocond(PyrexScanner s) cdef p_test_nocond(PyrexScanner s, bint allow_assignment_expression=*)
cdef p_walrus_test(PyrexScanner s, bint allow_assignment_expression=*)
cdef p_or_test(PyrexScanner s) cdef p_or_test(PyrexScanner s)
cdef p_rassoc_binop_expr(PyrexScanner s, unicode op, p_sub_expr_func p_subexpr) cdef p_rassoc_binop_expr(PyrexScanner s, unicode op, p_sub_expr_func p_subexpr)
cdef p_and_test(PyrexScanner s) cdef p_and_test(PyrexScanner s)
cdef p_not_test(PyrexScanner s) cdef p_not_test(PyrexScanner s)
cdef p_comparison(PyrexScanner s) cdef p_comparison(PyrexScanner s)
cdef p_test_or_starred_expr(PyrexScanner s) cdef p_test_or_starred_expr(PyrexScanner s, bint is_expression=*)
cdef p_starred_expr(PyrexScanner s) cdef p_starred_expr(PyrexScanner s)
cdef p_cascaded_cmp(PyrexScanner s) cdef p_cascaded_cmp(PyrexScanner s)
cdef p_cmp_op(PyrexScanner s) cdef p_cmp_op(PyrexScanner s)
...@@ -86,7 +87,7 @@ cdef p_simple_expr_list(PyrexScanner s, expr=*) ...@@ -86,7 +87,7 @@ cdef p_simple_expr_list(PyrexScanner s, expr=*)
cdef p_test_or_starred_expr_list(PyrexScanner s, expr=*) cdef p_test_or_starred_expr_list(PyrexScanner s, expr=*)
cdef p_testlist(PyrexScanner s) cdef p_testlist(PyrexScanner s)
cdef p_testlist_star_expr(PyrexScanner s) cdef p_testlist_star_expr(PyrexScanner s)
cdef p_testlist_comp(PyrexScanner s) cdef p_testlist_comp(PyrexScanner s, bint is_expression=*)
cdef p_genexp(PyrexScanner s, expr) cdef p_genexp(PyrexScanner s, expr)
#------------------------------------------------------- #-------------------------------------------------------
......
...@@ -120,9 +120,9 @@ def p_lambdef(s, allow_conditional=True): ...@@ -120,9 +120,9 @@ def p_lambdef(s, allow_conditional=True):
s, terminator=':', annotated=False) s, terminator=':', annotated=False)
s.expect(':') s.expect(':')
if allow_conditional: if allow_conditional:
expr = p_test(s) expr = p_test(s, allow_assignment_expression=False)
else: else:
expr = p_test_nocond(s) expr = p_test_nocond(s, allow_assignment_expression=False)
return ExprNodes.LambdaNode( return ExprNodes.LambdaNode(
pos, args = args, pos, args = args,
star_arg = star_arg, starstar_arg = starstar_arg, star_arg = star_arg, starstar_arg = starstar_arg,
...@@ -135,14 +135,16 @@ def p_lambdef_nocond(s): ...@@ -135,14 +135,16 @@ def p_lambdef_nocond(s):
#test: or_test ['if' or_test 'else' test] | lambdef #test: or_test ['if' or_test 'else' test] | lambdef
def p_test(s): def p_test(s, allow_assignment_expression=True):
if s.sy == 'lambda': if s.sy == 'lambda':
return p_lambdef(s) return p_lambdef(s)
pos = s.position() pos = s.position()
expr = p_or_test(s) expr = p_walrus_test(s, allow_assignment_expression)
if s.sy == 'if': if s.sy == 'if':
s.next() s.next()
test = p_or_test(s) # Assignment expressions are always allowed here
# even if they wouldn't be allowed in the expression as a whole.
test = p_walrus_test(s)
s.expect('else') s.expect('else')
other = p_test(s) other = p_test(s)
return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other) return ExprNodes.CondExprNode(pos, test=test, true_val=expr, false_val=other)
...@@ -151,11 +153,26 @@ def p_test(s): ...@@ -151,11 +153,26 @@ def p_test(s):
#test_nocond: or_test | lambdef_nocond #test_nocond: or_test | lambdef_nocond
def p_test_nocond(s): def p_test_nocond(s, allow_assignment_expression=True):
if s.sy == 'lambda': if s.sy == 'lambda':
return p_lambdef_nocond(s) return p_lambdef_nocond(s)
else: else:
return p_or_test(s) return p_walrus_test(s, allow_assignment_expression)
# walrurus_test: IDENT := test | or_test
def p_walrus_test(s, allow_assignment_expression=True):
lhs = p_or_test(s)
if s.sy == ':=':
position = s.position()
if not allow_assignment_expression:
s.error("invalid syntax: assignment expression not allowed in this context")
elif not lhs.is_name:
s.error("Left-hand side of assignment expression must be an identifier")
s.next()
rhs = p_test(s)
return ExprNodes.AssignmentExpressionNode(position, lhs=lhs, rhs=rhs)
return lhs
#or_test: and_test ('or' and_test)* #or_test: and_test ('or' and_test)*
...@@ -210,11 +227,11 @@ def p_comparison(s): ...@@ -210,11 +227,11 @@ def p_comparison(s):
n1.cascade = p_cascaded_cmp(s) n1.cascade = p_cascaded_cmp(s)
return n1 return n1
def p_test_or_starred_expr(s): def p_test_or_starred_expr(s, is_expression=False):
if s.sy == '*': if s.sy == '*':
return p_starred_expr(s) return p_starred_expr(s)
else: else:
return p_test(s) return p_test(s, allow_assignment_expression=is_expression)
def p_starred_expr(s): def p_starred_expr(s):
pos = s.position() pos = s.position()
...@@ -497,7 +514,7 @@ def p_call_parse_args(s, allow_genexp=True): ...@@ -497,7 +514,7 @@ def p_call_parse_args(s, allow_genexp=True):
encoded_name = s.context.intern_ustring(arg.name) encoded_name = s.context.intern_ustring(arg.name)
keyword = ExprNodes.IdentifierStringNode( keyword = ExprNodes.IdentifierStringNode(
arg.pos, value=encoded_name) arg.pos, value=encoded_name)
arg = p_test(s) arg = p_test(s, allow_assignment_expression=False)
keyword_args.append((keyword, arg)) keyword_args.append((keyword, arg))
else: else:
if keyword_args: if keyword_args:
...@@ -675,7 +692,7 @@ def p_atom(s): ...@@ -675,7 +692,7 @@ def p_atom(s):
elif s.sy == 'yield': elif s.sy == 'yield':
result = p_yield_expression(s) result = p_yield_expression(s)
else: else:
result = p_testlist_comp(s) result = p_testlist_comp(s, is_expression=True)
s.expect(')') s.expect(')')
return result return result
elif sy == '[': elif sy == '[':
...@@ -1259,7 +1276,7 @@ def p_list_maker(s): ...@@ -1259,7 +1276,7 @@ def p_list_maker(s):
s.expect(']') s.expect(']')
return ExprNodes.ListNode(pos, args=[]) return ExprNodes.ListNode(pos, args=[])
expr = p_test_or_starred_expr(s) expr = p_test_or_starred_expr(s, is_expression=True)
if s.sy in ('for', 'async'): if s.sy in ('for', 'async'):
if expr.is_starred: if expr.is_starred:
s.error("iterable unpacking cannot be used in comprehension") s.error("iterable unpacking cannot be used in comprehension")
...@@ -1459,7 +1476,7 @@ def p_simple_expr_list(s, expr=None): ...@@ -1459,7 +1476,7 @@ def p_simple_expr_list(s, expr=None):
def p_test_or_starred_expr_list(s, expr=None): def p_test_or_starred_expr_list(s, expr=None):
exprs = expr is not None and [expr] or [] exprs = expr is not None and [expr] or []
while s.sy not in expr_terminators: while s.sy not in expr_terminators:
exprs.append(p_test_or_starred_expr(s)) exprs.append(p_test_or_starred_expr(s, is_expression=(expr is not None)))
if s.sy != ',': if s.sy != ',':
break break
s.next() s.next()
...@@ -1492,9 +1509,9 @@ def p_testlist_star_expr(s): ...@@ -1492,9 +1509,9 @@ def p_testlist_star_expr(s):
# testlist_comp: (test|star_expr) ( comp_for | (',' (test|star_expr))* [','] ) # testlist_comp: (test|star_expr) ( comp_for | (',' (test|star_expr))* [','] )
def p_testlist_comp(s): def p_testlist_comp(s, is_expression=False):
pos = s.position() pos = s.position()
expr = p_test_or_starred_expr(s) expr = p_test_or_starred_expr(s, is_expression)
if s.sy == ',': if s.sy == ',':
s.next() s.next()
exprs = p_test_or_starred_expr_list(s, expr) exprs = p_test_or_starred_expr_list(s, expr)
...@@ -3073,11 +3090,11 @@ def p_c_arg_decl(s, ctx, in_pyfunc, cmethod_flag = 0, nonempty = 0, ...@@ -3073,11 +3090,11 @@ def p_c_arg_decl(s, ctx, in_pyfunc, cmethod_flag = 0, nonempty = 0,
default = ExprNodes.NoneNode(pos) default = ExprNodes.NoneNode(pos)
s.next() s.next()
elif 'inline' in ctx.modifiers: elif 'inline' in ctx.modifiers:
default = p_test(s) default = p_test(s, allow_assignment_expression=False)
else: else:
error(pos, "default values cannot be specified in pxd files, use ? or *") error(pos, "default values cannot be specified in pxd files, use ? or *")
else: else:
default = p_test(s) default = p_test(s, allow_assignment_expression=False)
return Nodes.CArgDeclNode(pos, return Nodes.CArgDeclNode(pos,
base_type = base_type, base_type = base_type,
declarator = declarator, declarator = declarator,
...@@ -3955,5 +3972,5 @@ def p_annotation(s): ...@@ -3955,5 +3972,5 @@ def p_annotation(s):
then it is not a bug. then it is not a bug.
""" """
pos = s.position() pos = s.position()
expr = p_test(s) expr = p_test(s, allow_assignment_expression=False)
return ExprNodes.AnnotationNode(pos, expr=expr) return ExprNodes.AnnotationNode(pos, expr=expr)
...@@ -331,6 +331,7 @@ class Scope(object): ...@@ -331,6 +331,7 @@ class Scope(object):
# is_py_class_scope boolean Is a Python class scope # is_py_class_scope boolean Is a Python class scope
# is_c_class_scope boolean Is an extension type scope # is_c_class_scope boolean Is an extension type scope
# is_closure_scope boolean Is a closure scope # is_closure_scope boolean Is a closure scope
# is_generator_expression_scope boolean A subset of closure scope used for generator expressions
# is_passthrough boolean Outer scope is passed directly # is_passthrough boolean Outer scope is passed directly
# is_cpp_class_scope boolean Is a C++ class scope # is_cpp_class_scope boolean Is a C++ class scope
# is_property_scope boolean Is a extension type property scope # is_property_scope boolean Is a extension type property scope
...@@ -347,6 +348,7 @@ class Scope(object): ...@@ -347,6 +348,7 @@ class Scope(object):
is_py_class_scope = 0 is_py_class_scope = 0
is_c_class_scope = 0 is_c_class_scope = 0
is_closure_scope = 0 is_closure_scope = 0
is_generator_expression_scope = 0
is_comprehension_scope = 0 is_comprehension_scope = 0
is_passthrough = 0 is_passthrough = 0
is_cpp_class_scope = 0 is_cpp_class_scope = 0
...@@ -748,6 +750,11 @@ class Scope(object): ...@@ -748,6 +750,11 @@ class Scope(object):
entry.used = 1 entry.used = 1
return entry return entry
def declare_assignment_expression_target(self, name, type, pos):
# In most cases declares the variable as normal.
# For generator expressions and comprehensions the variable is declared in their parent
return self.declare_var(name, type, pos)
def declare_builtin(self, name, pos): def declare_builtin(self, name, pos):
name = self.mangle_class_private_name(name) name = self.mangle_class_private_name(name)
return self.outer_scope.declare_builtin(name, pos) return self.outer_scope.declare_builtin(name, pos)
...@@ -974,6 +981,11 @@ class Scope(object): ...@@ -974,6 +981,11 @@ class Scope(object):
def lookup_here_unmangled(self, name): def lookup_here_unmangled(self, name):
return self.entries.get(name, None) return self.entries.get(name, None)
def lookup_assignment_expression_target(self, name):
# For most cases behaves like "lookup_here".
# However, it does look outwards for comprehension and generator expression scopes
return self.lookup_here(name)
def lookup_target(self, name): def lookup_target(self, name):
# Look up name in this scope only. Declare as Python # Look up name in this scope only. Declare as Python
# variable if not found. # variable if not found.
...@@ -1893,6 +1905,13 @@ class LocalScope(Scope): ...@@ -1893,6 +1905,13 @@ class LocalScope(Scope):
if entry is None or not entry.from_closure: if entry is None or not entry.from_closure:
error(pos, "no binding for nonlocal '%s' found" % name) error(pos, "no binding for nonlocal '%s' found" % name)
def _create_inner_entry_for_closure(self, name, entry):
entry.in_closure = True
inner_entry = InnerEntry(entry, self)
inner_entry.is_variable = True
self.entries[name] = inner_entry
return inner_entry
def lookup(self, name): def lookup(self, name):
# Look up name in this scope or an enclosing one. # Look up name in this scope or an enclosing one.
# Return None if not found. # Return None if not found.
...@@ -1907,11 +1926,7 @@ class LocalScope(Scope): ...@@ -1907,11 +1926,7 @@ class LocalScope(Scope):
raise InternalError("lookup() after scope class created.") raise InternalError("lookup() after scope class created.")
# The actual c fragment for the different scopes differs # The actual c fragment for the different scopes differs
# on the outside and inside, so we make a new entry # on the outside and inside, so we make a new entry
entry.in_closure = True return self._create_inner_entry_for_closure(name, entry)
inner_entry = InnerEntry(entry, self)
inner_entry.is_variable = True
self.entries[name] = inner_entry
return inner_entry
return entry return entry
def mangle_closure_cnames(self, outer_scope_cname): def mangle_closure_cnames(self, outer_scope_cname):
...@@ -1981,6 +1996,10 @@ class ComprehensionScope(Scope): ...@@ -1981,6 +1996,10 @@ class ComprehensionScope(Scope):
self.entries[name] = entry self.entries[name] = entry
return entry return entry
def declare_assignment_expression_target(self, name, type, pos):
# should be declared in the parent scope instead
return self.parent_scope.declare_var(name, type, pos)
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
return self.outer_scope.declare_pyfunction( return self.outer_scope.declare_pyfunction(
name, pos, allow_redefine) name, pos, allow_redefine)
...@@ -1991,6 +2010,12 @@ class ComprehensionScope(Scope): ...@@ -1991,6 +2010,12 @@ class ComprehensionScope(Scope):
def add_lambda_def(self, def_node): def add_lambda_def(self, def_node):
return self.outer_scope.add_lambda_def(def_node) return self.outer_scope.add_lambda_def(def_node)
def lookup_assignment_expression_target(self, name):
entry = self.lookup_here(name)
if not entry:
entry = self.parent_scope.lookup_assignment_expression_target(name)
return entry
class ClosureScope(LocalScope): class ClosureScope(LocalScope):
...@@ -2012,6 +2037,25 @@ class ClosureScope(LocalScope): ...@@ -2012,6 +2037,25 @@ class ClosureScope(LocalScope):
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private') return LocalScope.declare_pyfunction(self, name, pos, allow_redefine, visibility='private')
def declare_assignment_expression_target(self, name, type, pos):
return self.declare_var(name, type, pos)
class GeneratorExpressionScope(ClosureScope):
is_generator_expression_scope = True
def declare_assignment_expression_target(self, name, type, pos):
entry = self.parent_scope.declare_var(name, type, pos)
return self._create_inner_entry_for_closure(name, entry)
def lookup_assignment_expression_target(self, name):
entry = self.lookup_here(name)
if not entry:
entry = self.parent_scope.lookup_assignment_expression_target(name)
if entry:
return self._create_inner_entry_for_closure(name, entry)
return entry
class StructOrUnionScope(Scope): class StructOrUnionScope(Scope):
# Namespace of a C struct or union. # Namespace of a C struct or union.
......
# mode: run
# tag: pure3.8
# These are extra tests for the assignment expression/walrus operator/named expression that cover things
# additional to the standard Python test-suite in tests/run/test_named_expressions.pyx
import cython
import sys
@cython.test_assert_path_exists("//PythonCapiCallNode")
def optimized(x):
"""
x*2 is optimized to a PythonCapiCallNode. The test fails unless the CloneNode is kept up-to-date
(in the event that the optimization changes and test_assert_path_exists fails, the thing to do
is to find another case that's similarly optimized - the test isn't specifically interested in
multiplication)
>>> optimized(5)
10
"""
return (x:=x*2)
# FIXME: currently broken; GH-4146
# Changing x in the assignment expression should not affect the value used on the right-hand side
#def order(x):
# """
# >>> order(5)
# 15
# """
# return x+(x:=x*2)
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals1():
"""
There's a small optimization for literals to avoid creating unnecessary temps
>>> optimize_literals1()
10
"""
x = 5
return (x := 10)
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals2():
"""
There's a small optimization for literals to avoid creating unnecessary temps
Test is in __doc__ (for Py2 string formatting reasons)
"""
x = 5
return (x := u"a string")
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals3():
"""
There's a small optimization for literals to avoid creating unnecessary temps
Test is in __doc__ (for Py2 string formatting reasons)
"""
x = 5
return (x := b"a bytes")
@cython.test_fail_if_path_exists("//CloneNode")
def optimize_literals4():
"""
There's a small optimization for literals to avoid creating unnecessary temps
Test is in __doc__ (for Py2 string formatting reasons)
"""
x = 5
return (x := (u"tuple", 1, 1.0, b"stuff"))
if sys.version_info[0] != 2:
__doc__ = """
>>> optimize_literals2()
'a string'
>>> optimize_literals3()
b'a bytes'
>>> optimize_literals4()
('tuple', 1, 1.0, b'stuff')
"""
else:
__doc__ = """
>>> optimize_literals2()
u'a string'
>>> optimize_literals3()
'a bytes'
>>> optimize_literals4()
(u'tuple', 1, 1.0, 'stuff')
"""
@cython.test_fail_if_path_exists("//CoerceToPyTypeNode//AssignmentExpressionNode")
def avoid_extra_coercion(x : cython.double):
"""
The assignment expression and x are both coerced to PyObject - this should happen only once
rather than to both separately
>>> avoid_extra_coercion(5.)
5.0
"""
y : object = "I'm an object"
return (y := x)
async def async_func():
"""
DW doesn't understand async functions well enough to make it a runtime test, but it was causing
a compile-time failure at one point
"""
if variable := 1:
pass
y_global = 6
class InLambdaInClass:
"""
>>> InLambdaInClass.x1
12
>>> InLambdaInClass.x2
[12, 12]
"""
x1 = (lambda y_global: (y_global := y_global + 1) + y_global)(2) + y_global
x2 = [(lambda y_global: (y_global := y_global + 1) + y_global)(2) + y_global for _ in range(2) ]
def in_lambda_in_list_comprehension1():
"""
>>> in_lambda_in_list_comprehension1()
[[0, 2, 4, 6], [0, 2, 4, 6], [0, 2, 4, 6], [0, 2, 4, 6], [0, 2, 4, 6]]
"""
return [ (lambda x: [(x := y) + x for y in range(4)])(x) for x in range(5) ]
def in_lambda_in_list_comprehension2():
"""
>>> in_lambda_in_list_comprehension2()
[[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]]
"""
return [ (lambda z: [(x := y) + z for y in range(4)])(x) for x in range(5) ]
def in_lambda_in_generator_expression1():
"""
>>> in_lambda_in_generator_expression1()
[(0, 2, 4, 6), (0, 2, 4, 6), (0, 2, 4, 6), (0, 2, 4, 6), (0, 2, 4, 6)]
"""
return [ (lambda x: tuple((x := y) + x for y in range(4)))(x) for x in range(5) ]
def in_lambda_in_generator_expression2():
"""
>>> in_lambda_in_generator_expression2()
[(0, 1, 2, 3), (1, 2, 3, 4), (2, 3, 4, 5), (3, 4, 5, 6), (4, 5, 6, 7)]
"""
return [ (lambda z: tuple((x := y) + z for y in range(4)))(x) for x in range(5) ]
...@@ -1173,11 +1173,10 @@ non-important content ...@@ -1173,11 +1173,10 @@ non-important content
self.assertEqual(f'{0!=1}', 'True') self.assertEqual(f'{0!=1}', 'True')
self.assertEqual(f'{0<=1}', 'True') self.assertEqual(f'{0<=1}', 'True')
self.assertEqual(f'{0>=1}', 'False') self.assertEqual(f'{0>=1}', 'False')
# Walrus not implemented yet, skip self.assertEqual(f'{(x:="5")}', '5')
# self.assertEqual(f'{(x:="5")}', '5') self.assertEqual(x, '5')
# self.assertEqual(x, '5') self.assertEqual(f'{(x:=5)}', '5')
# self.assertEqual(f'{(x:=5)}', '5') self.assertEqual(x, 5)
# self.assertEqual(x, 5)
self.assertEqual(f'{"="}', '=') self.assertEqual(f'{"="}', '=')
x = 20 x = 20
...@@ -1239,13 +1238,9 @@ non-important content ...@@ -1239,13 +1238,9 @@ non-important content
# spec of '=10'. # spec of '=10'.
self.assertEqual(f'{x:=10}', ' 20') self.assertEqual(f'{x:=10}', ' 20')
# Note to anyone going to enable these: please have a look to the test
# above this one for more walrus cases to enable.
"""
# This is an assignment expression, which requires parens. # This is an assignment expression, which requires parens.
self.assertEqual(f'{(x:=10)}', '10') self.assertEqual(f'{(x:=10)}', '10')
self.assertEqual(x, 10) self.assertEqual(x, 10)
"""
def test_invalid_syntax_error_message(self): def test_invalid_syntax_error_message(self):
# with self.assertRaisesRegex(SyntaxError, "f-string: invalid syntax"): # with self.assertRaisesRegex(SyntaxError, "f-string: invalid syntax"):
......
# mode: run
# tag: pure38, no-cpp
# copied from cpython with minimal modifications (mainly exec->cython_inline, and a few exception strings)
# This is not currently run in C++ because all the cython_inline compilations fail for reasons that are unclear
# FIXME pure38 seems to be ignored
# cython: language_level=3
import os
import unittest
import cython
from Cython.Compiler.Main import CompileError
from Cython.Build.Inline import cython_inline
import sys
if cython.compiled:
class StdErrHider:
def __enter__(self):
try:
from StringIO import StringIO
except ImportError:
from io import StringIO
self.old_stderr = sys.stderr
self.new_stderr = StringIO()
sys.stderr = self.new_stderr
return self
def __exit__(self, exc_type, exc_value, traceback):
sys.stderr = self.old_stderr
@property
def stderr_contents(self):
return self.new_stderr.getvalue()
def exec(code, globals_=None, locals_=None):
if locals_ and globals_ and (locals_ is not globals_):
# a hacky attempt to treat as a class definition
code = "class Cls:\n" + "\n".join(
" " + line for line in code.split("\n"))
code += "\nreturn globals(), locals()" # so we can inspect it for changes, overriding the default cython_inline behaviour
try:
with StdErrHider() as stderr_handler:
try:
g, l = cython_inline(code, globals=globals_, locals=locals_)
finally:
err_messages = stderr_handler.stderr_contents
if globals_ is not None:
# because Cython inline bundles everything into a function some values that
# we'd expect to be in globals end up in locals. This isn't quite right but is
# as close as it's possible to get to retrieving the values
globals_.update(l)
globals_.update(g)
except CompileError as exc:
raised_message = str(exc)
if raised_message.endswith(".pyx"):
# unhelpfully Cython sometimes raises a compile error and sometimes just raises the filename
raised_message = []
for line in err_messages.split("\n"):
line = line.split(":",3)
# a usable error message with be filename:line:char: message
if len(line) == 4 and line[0].endswith(".pyx"):
raised_message.append(line[-1])
# output all the errors - we aren't worried about reproducing the exact order CPython
# emits errors in
raised_message = "; ".join(raised_message)
raise SyntaxError(raised_message) from None
if sys.version_info[0] < 3:
# some monkey patching
unittest.TestCase.assertRaisesRegex = unittest.TestCase.assertRaisesRegexp
class FakeSubTest(object):
def __init__(self, *args, **kwds):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
unittest.TestCase.subTest = FakeSubTest
class NamedExpressionInvalidTest(unittest.TestCase):
def test_named_expression_invalid_01(self):
code = """x := 0"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_02(self):
code = """x = y := 0"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_03(self):
code = """y := f(x)"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_04(self):
code = """y0 = y1 := f(x)"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_06(self):
code = """((a, b) := (1, 2))"""
# TODO Cython correctly generates an error but the message could be better
with self.assertRaisesRegex(SyntaxError, ""):
exec(code, {}, {})
def test_named_expression_invalid_07(self):
code = """def spam(a = b := 42): pass"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_08(self):
code = """def spam(a: b := 42 = 5): pass"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_09(self):
code = """spam(a=b := 'c')"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_10(self):
code = """spam(x = y := f(x))"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_11(self):
code = """spam(a=1, b := 2)"""
with self.assertRaisesRegex(SyntaxError,
"follow.* keyword arg"):
exec(code, {}, {})
def test_named_expression_invalid_12(self):
code = """spam(a=1, (b := 2))"""
with self.assertRaisesRegex(SyntaxError,
"follow.* keyword arg"):
exec(code, {}, {})
def test_named_expression_invalid_13(self):
code = """spam(a=1, (b := 2))"""
with self.assertRaisesRegex(SyntaxError,
"follow.* keyword arg"):
exec(code, {}, {})
def test_named_expression_invalid_14(self):
code = """(x := lambda: y := 1)"""
with self.assertRaisesRegex(SyntaxError, "invalid syntax"):
exec(code, {}, {})
def test_named_expression_invalid_15(self):
code = """(lambda: x := 1)"""
# TODO at the moment the error message is valid, but not the same as Python
with self.assertRaisesRegex(SyntaxError,
""):
exec(code, {}, {})
def test_named_expression_invalid_16(self):
code = "[i + 1 for i in i := [1,2]]"
# TODO at the moment the error message is valid, but not the same as Python
with self.assertRaisesRegex(SyntaxError, ""):
exec(code, {}, {})
def test_named_expression_invalid_17(self):
code = "[i := 0, j := 1 for i, j in [(1, 2), (3, 4)]]"
# TODO at the moment the error message is valid, but not the same as Python
with self.assertRaisesRegex(SyntaxError, ""):
exec(code, {}, {})
def test_named_expression_invalid_in_class_body(self):
code = """class Foo():
[(42, 1 + ((( j := i )))) for i in range(5)]
"""
with self.assertRaisesRegex(SyntaxError,
"assignment expression within a comprehension cannot be used in a class body"):
exec(code, {}, {})
def test_named_expression_invalid_rebinding_comprehension_iteration_variable(self):
cases = [
("Local reuse", 'i', "[i := 0 for i in range(5)]"),
("Nested reuse", 'j', "[[(j := 0) for i in range(5)] for j in range(5)]"),
("Reuse inner loop target", 'j', "[(j := 0) for i in range(5) for j in range(5)]"),
("Unpacking reuse", 'i', "[i := 0 for i, j in [(0, 1)]]"),
("Reuse in loop condition", 'i', "[i+1 for i in range(5) if (i := 0)]"),
("Unreachable reuse", 'i', "[False or (i:=0) for i in range(5)]"),
("Unreachable nested reuse", 'i',
"[(i, j) for i in range(5) for j in range(5) if True or (i:=10)]"),
]
for case, target, code in cases:
msg = f"assignment expression cannot rebind comprehension iteration variable '{target}'"
with self.subTest(case=case):
with self.assertRaisesRegex(SyntaxError, msg):
exec(code, {}, {})
def test_named_expression_invalid_rebinding_comprehension_inner_loop(self):
cases = [
("Inner reuse", 'j', "[i for i in range(5) if (j := 0) for j in range(5)]"),
("Inner unpacking reuse", 'j', "[i for i in range(5) if (j := 0) for j, k in [(0, 1)]]"),
]
for case, target, code in cases:
msg = f"comprehension inner loop cannot rebind assignment expression target '{target}'"
with self.subTest(case=case):
with self.assertRaisesRegex(SyntaxError, msg):
exec(code, {}) # Module scope
with self.assertRaisesRegex(SyntaxError, msg):
exec(code, {}, {}) # Class scope
with self.assertRaisesRegex(SyntaxError, msg):
exec(f"lambda: {code}", {}) # Function scope
def test_named_expression_invalid_comprehension_iterable_expression(self):
cases = [
("Top level", "[i for i in (i := range(5))]"),
("Inside tuple", "[i for i in (2, 3, i := range(5))]"),
("Inside list", "[i for i in [2, 3, i := range(5)]]"),
("Different name", "[i for i in (j := range(5))]"),
("Lambda expression", "[i for i in (lambda:(j := range(5)))()]"),
("Inner loop", "[i for i in range(5) for j in (i := range(5))]"),
("Nested comprehension", "[i for i in [j for j in (k := range(5))]]"),
("Nested comprehension condition", "[i for i in [j for j in range(5) if (j := True)]]"),
("Nested comprehension body", "[i for i in [(j := True) for j in range(5)]]"),
]
msg = "assignment expression cannot be used in a comprehension iterable expression"
for case, code in cases:
with self.subTest(case=case):
with self.assertRaisesRegex(SyntaxError, msg):
exec(code, {}) # Module scope - FIXME this test puts it in __invoke in cython_inline
with self.assertRaisesRegex(SyntaxError, msg):
exec(code, {}, {}) # Class scope
with self.assertRaisesRegex(SyntaxError, msg):
exec(f"lambda: {code}", {}) # Function scope
class NamedExpressionAssignmentTest(unittest.TestCase):
def test_named_expression_assignment_01(self):
(a := 10)
self.assertEqual(a, 10)
def test_named_expression_assignment_02(self):
a = 20
(a := a)
self.assertEqual(a, 20)
def test_named_expression_assignment_03(self):
(total := 1 + 2)
self.assertEqual(total, 3)
def test_named_expression_assignment_04(self):
(info := (1, 2, 3))
self.assertEqual(info, (1, 2, 3))
def test_named_expression_assignment_05(self):
(x := 1, 2)
self.assertEqual(x, 1)
def test_named_expression_assignment_06(self):
(z := (y := (x := 0)))
self.assertEqual(x, 0)
self.assertEqual(y, 0)
self.assertEqual(z, 0)
def test_named_expression_assignment_07(self):
(loc := (1, 2))
self.assertEqual(loc, (1, 2))
def test_named_expression_assignment_08(self):
if spam := "eggs":
self.assertEqual(spam, "eggs")
else: self.fail("variable was not assigned using named expression")
def test_named_expression_assignment_09(self):
if True and (spam := True):
self.assertTrue(spam)
else: self.fail("variable was not assigned using named expression")
def test_named_expression_assignment_10(self):
if (match := 10) == 10:
pass
else: self.fail("variable was not assigned using named expression")
def test_named_expression_assignment_11(self):
def spam(a):
return a
input_data = [1, 2, 3]
res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
def test_named_expression_assignment_12(self):
def spam(a):
return a
res = [[y := spam(x), x/y] for x in range(1, 5)]
self.assertEqual(res, [[1, 1.0], [2, 1.0], [3, 1.0], [4, 1.0]])
def test_named_expression_assignment_13(self):
length = len(lines := [1, 2])
self.assertEqual(length, 2)
self.assertEqual(lines, [1,2])
def test_named_expression_assignment_14(self):
"""
Where all variables are positive integers, and a is at least as large
as the n'th root of x, this algorithm returns the floor of the n'th
root of x (and roughly doubling the number of accurate bits per
iteration):
"""
a = 9
n = 2
x = 3
while a > (d := x // a**(n-1)):
a = ((n-1)*a + d) // n
self.assertEqual(a, 1)
def test_named_expression_assignment_15(self):
while a := False:
pass # This will not run
self.assertEqual(a, False)
def test_named_expression_assignment_16(self):
a, b = 1, 2
fib = {(c := a): (a := b) + (b := a + c) - b for __ in range(6)}
self.assertEqual(fib, {1: 2, 2: 3, 3: 5, 5: 8, 8: 13, 13: 21})
class NamedExpressionScopeTest(unittest.TestCase):
def test_named_expression_scope_01(self):
code = """def spam():
(a := 5)
print(a)"""
# FIXME for some reason the error message raised is a nonsense filename instead of "undeclared name not builtin"
# "name .* not"):
with self.assertRaisesRegex(SyntaxError if cython.compiled else NameError, ""):
exec(code, {}, {})
def test_named_expression_scope_02(self):
total = 0
partial_sums = [total := total + v for v in range(5)]
self.assertEqual(partial_sums, [0, 1, 3, 6, 10])
self.assertEqual(total, 10)
def test_named_expression_scope_03(self):
containsOne = any((lastNum := num) == 1 for num in [1, 2, 3])
self.assertTrue(containsOne)
self.assertEqual(lastNum, 1)
def test_named_expression_scope_04(self):
def spam(a):
return a
res = [[y := spam(x), x/y] for x in range(1, 5)]
self.assertEqual(y, 4)
def test_named_expression_scope_05(self):
def spam(a):
return a
input_data = [1, 2, 3]
res = [(x, y, x/y) for x in input_data if (y := spam(x)) > 0]
self.assertEqual(res, [(1, 1, 1.0), (2, 2, 1.0), (3, 3, 1.0)])
self.assertEqual(y, 3)
def test_named_expression_scope_06(self):
res = [[spam := i for i in range(3)] for j in range(2)]
self.assertEqual(res, [[0, 1, 2], [0, 1, 2]])
self.assertEqual(spam, 2)
def test_named_expression_scope_07(self):
len(lines := [1, 2])
self.assertEqual(lines, [1, 2])
def test_named_expression_scope_08(self):
def spam(a):
return a
def eggs(b):
return b * 2
res = [spam(a := eggs(b := h)) for h in range(2)]
self.assertEqual(res, [0, 2])
self.assertEqual(a, 2)
self.assertEqual(b, 1)
def test_named_expression_scope_09(self):
def spam(a):
return a
def eggs(b):
return b * 2
res = [spam(a := eggs(a := h)) for h in range(2)]
self.assertEqual(res, [0, 2])
self.assertEqual(a, 2)
def test_named_expression_scope_10(self):
res = [b := [a := 1 for i in range(2)] for j in range(2)]
self.assertEqual(res, [[1, 1], [1, 1]])
self.assertEqual(a, 1)
self.assertEqual(b, [1, 1])
def test_named_expression_scope_11(self):
res = [j := i for i in range(5)]
self.assertEqual(res, [0, 1, 2, 3, 4])
self.assertEqual(j, 4)
def test_named_expression_scope_17(self):
b = 0
res = [b := i + b for i in range(5)]
self.assertEqual(res, [0, 1, 3, 6, 10])
self.assertEqual(b, 10)
def test_named_expression_scope_18(self):
def spam(a):
return a
res = spam(b := 2)
self.assertEqual(res, 2)
self.assertEqual(b, 2)
def test_named_expression_scope_19(self):
def spam(a):
return a
res = spam((b := 2))
self.assertEqual(res, 2)
self.assertEqual(b, 2)
def test_named_expression_scope_20(self):
def spam(a):
return a
res = spam(a=(b := 2))
self.assertEqual(res, 2)
self.assertEqual(b, 2)
def test_named_expression_scope_21(self):
def spam(a, b):
return a + b
res = spam(c := 2, b=1)
self.assertEqual(res, 3)
self.assertEqual(c, 2)
def test_named_expression_scope_22(self):
def spam(a, b):
return a + b
res = spam((c := 2), b=1)
self.assertEqual(res, 3)
self.assertEqual(c, 2)
def test_named_expression_scope_23(self):
def spam(a, b):
return a + b
res = spam(b=(c := 2), a=1)
self.assertEqual(res, 3)
self.assertEqual(c, 2)
def test_named_expression_scope_24(self):
a = 10
def spam():
nonlocal a
(a := 20)
spam()
self.assertEqual(a, 20)
def test_named_expression_scope_25(self):
ns = {}
code = """a = 10
def spam():
global a
(a := 20)
spam()"""
exec(code, ns, {})
self.assertEqual(ns["a"], 20)
def test_named_expression_variable_reuse_in_comprehensions(self):
# The compiler is expected to raise syntax error for comprehension
# iteration variables, but should be fine with rebinding of other
# names (e.g. globals, nonlocals, other assignment expressions)
# The cases are all defined to produce the same expected result
# Each comprehension is checked at both function scope and module scope
rebinding = "[x := i for i in range(3) if (x := i) or not x]"
filter_ref = "[x := i for i in range(3) if x or not x]"
body_ref = "[x for i in range(3) if (x := i) or not x]"
nested_ref = "[j for i in range(3) if x or not x for j in range(3) if (x := i)][:-3]"
cases = [
("Rebind global", f"x = 1; result = {rebinding}"),
("Rebind nonlocal", f"result, x = (lambda x=1: ({rebinding}, x))()"),
("Filter global", f"x = 1; result = {filter_ref}"),
("Filter nonlocal", f"result, x = (lambda x=1: ({filter_ref}, x))()"),
("Body global", f"x = 1; result = {body_ref}"),
("Body nonlocal", f"result, x = (lambda x=1: ({body_ref}, x))()"),
("Nested global", f"x = 1; result = {nested_ref}"),
("Nested nonlocal", f"result, x = (lambda x=1: ({nested_ref}, x))()"),
]
for case, code in cases:
with self.subTest(case=case):
ns = {}
exec(code, ns)
self.assertEqual(ns["x"], 2)
self.assertEqual(ns["result"], [0, 1, 2])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment