Commit 9fa7ce41 authored by Robert Bradshaw's avatar Robert Bradshaw

Default argument literals, better True/False coercion

parent 36f3f6d0
...@@ -550,6 +550,8 @@ class AtomicExprNode(ExprNode): ...@@ -550,6 +550,8 @@ class AtomicExprNode(ExprNode):
class PyConstNode(AtomicExprNode): class PyConstNode(AtomicExprNode):
# Abstract base class for constant Python values. # Abstract base class for constant Python values.
is_literal = 1
def is_simple(self): def is_simple(self):
return 1 return 1
...@@ -571,6 +573,24 @@ class NoneNode(PyConstNode): ...@@ -571,6 +573,24 @@ class NoneNode(PyConstNode):
def compile_time_value(self, denv): def compile_time_value(self, denv):
return None return None
class BoolNode(PyConstNode):
# The constant value True or False
def compile_time_value(self, denv):
return None
def calculate_result_code(self):
if self.value:
return "Py_True"
else:
return "Py_False"
def coerce_to(self, dst_type, env):
value = self.value
if dst_type.is_numeric:
return IntNode(self.pos, value=self.value).coerce_to(dst_type, env)
else:
return PyConstNode.coerce_to(self, dst_type, env)
class EllipsisNode(PyConstNode): class EllipsisNode(PyConstNode):
# '...' in a subscript list. # '...' in a subscript list.
...@@ -2148,6 +2168,7 @@ class TupleNode(SequenceNode): ...@@ -2148,6 +2168,7 @@ class TupleNode(SequenceNode):
if len(self.args) == 0: if len(self.args) == 0:
self.type = py_object_type self.type = py_object_type
self.is_temp = 0 self.is_temp = 0
self.is_literal = 1
else: else:
SequenceNode.analyse_types(self, env) SequenceNode.analyse_types(self, env)
......
...@@ -339,13 +339,15 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -339,13 +339,15 @@ class CFuncDeclaratorNode(CDeclaratorNode):
PyrexTypes.CFuncTypeArg(name, type, arg_node.pos)) PyrexTypes.CFuncTypeArg(name, type, arg_node.pos))
if arg_node.default: if arg_node.default:
self.optional_arg_count += 1 self.optional_arg_count += 1
elif self.optional_arg_count:
error(self.pos, "Non-default argument follows default argument")
if self.optional_arg_count: if self.optional_arg_count:
scope = StructOrUnionScope() scope = StructOrUnionScope()
scope.declare_var('n', PyrexTypes.c_int_type, self.pos) scope.declare_var('n', PyrexTypes.c_int_type, self.pos)
for arg in func_type_args[len(func_type_args)-self.optional_arg_count:]: for arg in func_type_args[len(func_type_args)-self.optional_arg_count:]:
scope.declare_var(arg.name, arg.type, arg.pos, allow_pyobject = 1) scope.declare_var(arg.name, arg.type, arg.pos, allow_pyobject = 1)
struct_cname = Naming.opt_arg_prefix + env.mangle(self.base.name) struct_cname = env.mangle(Naming.opt_arg_prefix, self.base.name)
self.op_args_struct = env.global_scope().declare_struct_or_union(name = struct_cname, self.op_args_struct = env.global_scope().declare_struct_or_union(name = struct_cname,
kind = 'struct', kind = 'struct',
scope = scope, scope = scope,
...@@ -396,6 +398,7 @@ class CArgDeclNode(Node): ...@@ -396,6 +398,7 @@ class CArgDeclNode(Node):
# not_none boolean Tagged with 'not None' # not_none boolean Tagged with 'not None'
# default ExprNode or None # default ExprNode or None
# default_entry Symtab.Entry Entry for the variable holding the default value # default_entry Symtab.Entry Entry for the variable holding the default value
# default_result_code string cname or code fragment for default value
# is_self_arg boolean Is the "self" arg of an extension type method # is_self_arg boolean Is the "self" arg of an extension type method
# is_kw_only boolean Is a keyword-only argument # is_kw_only boolean Is a keyword-only argument
...@@ -639,9 +642,14 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -639,9 +642,14 @@ class FuncDefNode(StatNode, BlockNode):
if not hasattr(arg, 'default_entry'): if not hasattr(arg, 'default_entry'):
arg.default.analyse_types(genv) arg.default.analyse_types(genv)
arg.default = arg.default.coerce_to(arg.type, genv) arg.default = arg.default.coerce_to(arg.type, genv)
if arg.default.is_literal:
arg.default_entry = arg.default
arg.default_result_code = arg.default.calculate_result_code()
else:
arg.default.allocate_temps(genv) arg.default.allocate_temps(genv)
arg.default_entry = genv.add_default_value(arg.type) arg.default_entry = genv.add_default_value(arg.type)
arg.default_entry.used = 1 arg.default_entry.used = 1
arg.default_result_code = arg.default_entry.cname
else: else:
error(arg.pos, error(arg.pos,
"This argument cannot have a default value") "This argument cannot have a default value")
...@@ -791,6 +799,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -791,6 +799,7 @@ class FuncDefNode(StatNode, BlockNode):
for arg in self.args: for arg in self.args:
default = arg.default default = arg.default
if default: if default:
if not default.is_literal:
default.generate_evaluation_code(code) default.generate_evaluation_code(code)
default.make_owned_reference(code) default.make_owned_reference(code)
code.putln( code.putln(
...@@ -930,7 +939,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -930,7 +939,7 @@ class CFuncDefNode(FuncDefNode):
def generate_argument_declarations(self, env, code): def generate_argument_declarations(self, env, code):
for arg in self.declarator.args: for arg in self.declarator.args:
if arg.default: if arg.default:
code.putln('%s = %s;' % (arg.type.declaration_code(arg.cname), arg.default_entry.cname)) code.putln('%s = %s;' % (arg.type.declaration_code(arg.cname), arg.default_result_code))
def generate_keyword_list(self, code): def generate_keyword_list(self, code):
pass pass
...@@ -943,7 +952,10 @@ class CFuncDefNode(FuncDefNode): ...@@ -943,7 +952,10 @@ class CFuncDefNode(FuncDefNode):
for arg in rev_args: for arg in rev_args:
if arg.default: if arg.default:
code.putln('if (%s->n > %s) {' % (Naming.optional_args_cname, i)) code.putln('if (%s->n > %s) {' % (Naming.optional_args_cname, i))
code.putln('%s = %s->%s;' % (arg.cname, Naming.optional_args_cname, arg.declarator.name)) declarator = arg.declarator
while not hasattr(declarator, 'name'):
declarator = declarator.base
code.putln('%s = %s->%s;' % (arg.cname, Naming.optional_args_cname, declarator.name))
i += 1 i += 1
for _ in range(self.type.optional_arg_count): for _ in range(self.type.optional_arg_count):
code.putln('}') code.putln('}')
...@@ -1307,7 +1319,7 @@ class DefNode(FuncDefNode): ...@@ -1307,7 +1319,7 @@ class DefNode(FuncDefNode):
code.putln( code.putln(
"%s = %s;" % ( "%s = %s;" % (
arg_entry.cname, arg_entry.cname,
arg.default_entry.cname)) arg.default_result_code))
if not default_seen: if not default_seen:
arg_formats.append("|") arg_formats.append("|")
default_seen = 1 default_seen = 1
......
...@@ -468,6 +468,10 @@ def p_atom(s): ...@@ -468,6 +468,10 @@ def p_atom(s):
s.next() s.next()
if name == "None": if name == "None":
return ExprNodes.NoneNode(pos) return ExprNodes.NoneNode(pos)
elif name == "True":
return ExprNodes.BoolNode(pos, value=1)
elif name == "False":
return ExprNodes.BoolNode(pos, value=0)
else: else:
return p_name(s, name) return p_name(s, name)
elif sy == 'NULL': elif sy == 'NULL':
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment