Commit 2520038e authored by Stefan Behnel's avatar Stefan Behnel

fix #480: float() as a type cast for function return values

parent bf325529
...@@ -960,13 +960,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -960,13 +960,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
if func_arg.type == node.type: if func_arg.type == node.type:
return func_arg return func_arg
elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
return ExprNodes.CastNode(func_arg, node.type) return ExprNodes.TypecastNode(
node.pos, operand=func_arg, type=node.type)
elif function.name == 'float': elif function.name == 'float':
if func_arg.type.is_float or node.type.is_float: if func_arg.type.is_float or node.type.is_float:
if func_arg.type == node.type: if func_arg.type == node.type:
return func_arg return func_arg
elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float: elif node.type.assignable_from(func_arg.type) or func_arg.type.is_float:
return ExprNodes.CastNode(func_arg, node.type) return ExprNodes.TypecastNode(
node.pos, operand=func_arg, type=node.type)
return node return node
### dispatch to specific optimisers ### dispatch to specific optimisers
...@@ -1115,7 +1117,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -1115,7 +1117,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
if func_arg.type is PyrexTypes.c_double_type: if func_arg.type is PyrexTypes.c_double_type:
return func_arg return func_arg
elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric: elif node.type.assignable_from(func_arg.type) or func_arg.type.is_numeric:
return ExprNodes.CastNode(func_arg, node.type) return ExprNodes.TypecastNode(
node.pos, operand=func_arg, type=node.type)
return ExprNodes.PythonCapiCallNode( return ExprNodes.PythonCapiCallNode(
node.pos, "__Pyx_PyObject_AsDouble", node.pos, "__Pyx_PyObject_AsDouble",
self.PyObject_AsDouble_func_type, self.PyObject_AsDouble_func_type,
......
def f(x): def f(x):
return x return x
def float_len(x): def len_f(x):
""" """
>>> float_len([1,2,3]) >>> len_f([1,2,3])
3
"""
return len(f(x))
def float_len_f(x):
"""
>>> float_len_f([1,2,3])
3.0
"""
return float(len(f(x)))
def cast_len_f(x):
"""
>>> cast_len_f([1,2,3])
3.0 3.0
""" """
float(len(f(x))) return <double>len(f(x))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment