Commit 9b003173 authored by Stefan Behnel's avatar Stefan Behnel

reimplement SimplifyCalls transform in-place in GeneralCallNode.analyse_types()

parent f9c385e0
...@@ -4313,7 +4313,6 @@ class GeneralCallNode(CallNode): ...@@ -4313,7 +4313,6 @@ class GeneralCallNode(CallNode):
# keyword_args ExprNode or None Dict of keyword arguments # keyword_args ExprNode or None Dict of keyword arguments
type = py_object_type type = py_object_type
is_simple_call = False
subexprs = ['function', 'positional_args', 'keyword_args'] subexprs = ['function', 'positional_args', 'keyword_args']
...@@ -4344,8 +4343,10 @@ class GeneralCallNode(CallNode): ...@@ -4344,8 +4343,10 @@ class GeneralCallNode(CallNode):
self.type = error_type self.type = error_type
return self return self
if hasattr(self.function, 'entry'): if hasattr(self.function, 'entry'):
self.map_keywords_to_posargs() node = self.map_keywords_to_posargs(env)
if not self.is_simple_call: if node is not None:
return node.analyse_types(env)
else:
if self.function.entry.as_variable: if self.function.entry.as_variable:
self.function = self.function.coerce_to_pyobject(env) self.function = self.function.coerce_to_pyobject(env)
else: else:
...@@ -4370,17 +4371,17 @@ class GeneralCallNode(CallNode): ...@@ -4370,17 +4371,17 @@ class GeneralCallNode(CallNode):
self.is_temp = 1 self.is_temp = 1
return self return self
def map_keywords_to_posargs(self): def map_keywords_to_posargs(self, env):
if not isinstance(self.positional_args, TupleNode): if not isinstance(self.positional_args, TupleNode):
# has starred argument # has starred argument
return return None
if not isinstance(self.keyword_args, DictNode): if not isinstance(self.keyword_args, DictNode):
# nothing to do here # nothing to do here
return return None
function = self.function function = self.function
entry = getattr(function, 'entry', None) entry = getattr(function, 'entry', None)
if not entry or not entry.is_cfunction: if not entry or not entry.is_cfunction:
return return None
args = self.positional_args.args args = self.positional_args.args
kwargs = self.keyword_args kwargs = self.keyword_args
...@@ -4389,7 +4390,7 @@ class GeneralCallNode(CallNode): ...@@ -4389,7 +4390,7 @@ class GeneralCallNode(CallNode):
# will lead to an error elsewhere # will lead to an error elsewhere
error(self.pos, "function call got too many positional arguments, " error(self.pos, "function call got too many positional arguments, "
"expected %d, got %s" % (len(declared_args), len(args))) "expected %d, got %s" % (len(declared_args), len(args)))
return return None
matched_pos_args = set([arg.name for arg in declared_args[:len(args)]]) matched_pos_args = set([arg.name for arg in declared_args[:len(args)]])
unmatched_args = declared_args[len(args):] unmatched_args = declared_args[len(args):]
...@@ -4401,7 +4402,7 @@ class GeneralCallNode(CallNode): ...@@ -4401,7 +4402,7 @@ class GeneralCallNode(CallNode):
name = arg.key.value name = arg.key.value
if name in matched_pos_args: if name in matched_pos_args:
error(arg.pos, "keyword argument '%s' passed twice" % name) error(arg.pos, "keyword argument '%s' passed twice" % name)
return return None
if decl_arg.name == name: if decl_arg.name == name:
matched_kwargs.add(name) matched_kwargs.add(name)
args.append(arg.value) args.append(arg.value)
...@@ -4427,15 +4428,17 @@ class GeneralCallNode(CallNode): ...@@ -4427,15 +4428,17 @@ class GeneralCallNode(CallNode):
# into ordered temps if necessary # into ordered temps if necessary
if not matched_kwargs: if not matched_kwargs:
return return None
self.positional_args.args = args self.positional_args.args = args
if len(kwargs.key_value_pairs) == len(matched_kwargs): if len(kwargs.key_value_pairs) == len(matched_kwargs):
self.keyword_args = None # all keywords mapped => only positional arguments left
self.is_simple_call = True return SimpleCallNode(
else: self.pos, function=function, args=args)
kwargs.key_value_pairs = [ kwargs.key_value_pairs = [
item for item in kwargs.key_value_pairs item for item in kwargs.key_value_pairs
if item.key.value not in matched_kwargs ] if item.key.value not in matched_kwargs ]
return None
def generate_result_code(self, code): def generate_result_code(self, code):
if self.type.is_error: return if self.type.is_error: return
......
...@@ -1173,27 +1173,6 @@ class DropRefcountingTransform(Visitor.VisitorTransform): ...@@ -1173,27 +1173,6 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
return (base.name, index_val) return (base.name, index_val)
class SimplifyCalls(Visitor.EnvTransform):
"""
Replace GeneralCallNode by SimpleCallNode if possible.
"""
def visit_GeneralCallNode(self, node):
self.visitchildren(node)
if not node.is_simple_call:
return node
args = [ unwrap_coerced_node(arg)
for arg in node.positional_args.args ]
call_node = ExprNodes.SimpleCallNode(
node.pos,
function=node.function,
args=args)
call_node = call_node.analyse_types(self.current_env())
if node.type != call_node.type:
call_node = call_node.coerce_to(
node.type, self.current_env())
return call_node
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
"""Optimize some common calls to builtin types *before* the type """Optimize some common calls to builtin types *before* the type
analysis phase and *after* the declarations analysis phase. analysis phase and *after* the declarations analysis phase.
......
...@@ -141,7 +141,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -141,7 +141,7 @@ def create_pipeline(context, mode, exclude_classes=()):
from AutoDocTransforms import EmbedSignature from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
from Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls from Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
from Optimize import InlineDefNodeCalls, SimplifyCalls from Optimize import InlineDefNodeCalls
from Optimize import ConstantFolding, FinalOptimizePhase from Optimize import ConstantFolding, FinalOptimizePhase
from Optimize import DropRefcountingTransform from Optimize import DropRefcountingTransform
from Optimize import ConsolidateOverflowCheck from Optimize import ConsolidateOverflowCheck
...@@ -193,7 +193,6 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -193,7 +193,6 @@ def create_pipeline(context, mode, exclude_classes=()):
_check_c_declarations, _check_c_declarations,
InlineDefNodeCalls(context), InlineDefNodeCalls(context),
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
SimplifyCalls(context),
FindInvalidUseOfFusedTypes(context), FindInvalidUseOfFusedTypes(context),
CreateClosureClasses(context), ## After all lookups and type inference CreateClosureClasses(context), ## After all lookups and type inference
ExpandInplaceOperators(context), ExpandInplaceOperators(context),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment