Commit c7dbcf9a authored by Stefan Behnel's avatar Stefan Behnel

split builtin type call optimisations into pre and post type analysis phase

parent 21442aba
......@@ -2641,28 +2641,32 @@ class SimpleCallNode(CallNode):
class PythonCapiFunctionNode(ExprNode):
subexprs = []
def __init__(self, pos, name, func_type, utility_code = None):
def __init__(self, pos, py_name, cname, func_type, utility_code = None):
self.pos = pos
self.name = name
self.name = py_name
self.cname = cname
self.type = func_type
self.utility_code = utility_code
def analyse_types(self, env):
pass
def generate_result_code(self, code):
if self.utility_code:
code.globalstate.use_utility_code(self.utility_code)
def calculate_result_code(self):
return self.name
return self.cname
class PythonCapiCallNode(SimpleCallNode):
# Python C-API Function call (only created in transforms)
def __init__(self, pos, function_name, func_type,
utility_code = None, **kwargs):
utility_code = None, py_name=None, **kwargs):
self.type = func_type.return_type
self.result_ctype = self.type
self.function = PythonCapiFunctionNode(
pos, function_name, func_type,
pos, py_name, function_name, func_type,
utility_code = utility_code)
# call this last so that we can override the constructed
# attributes above with explicit keyword arguments if required
......
......@@ -92,7 +92,8 @@ class Context(object):
from AnalysedTreeTransforms import AutoTestDictTransform
from AutoDocTransforms import EmbedSignature
from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
from Optimize import OptimizeBuiltinCalls, ConstantFolding, FinalOptimizePhase
from Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
from Optimize import ConstantFolding, FinalOptimizePhase
from Optimize import DropRefcountingTransform
from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_declarations, check_c_declarations_pxd
......@@ -131,6 +132,7 @@ class Context(object):
AnalyseDeclarationsTransform(self),
AutoTestDictTransform(self),
EmbedSignature(self),
EarlyReplaceBuiltinCalls(self),
MarkAssignments(self),
TransformBuiltinMethods(self),
IntroduceBufferAuxiliaryVars(self),
......
......@@ -723,9 +723,190 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
return (base.name, index_val)
class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
"""Optimize some common calls to builtin types *before* the type
analysis phase and *after* the declarations analysis phase.
This transform cannot make use of any argument types, but it can
restructure the tree in a way that the type analysis phase can
respond to.
"""
# only intercept on call nodes
visit_Node = Visitor.VisitorTransform.recurse_to_children
def visit_SimpleCallNode(self, node):
self.visitchildren(node)
function = node.function
if not self._function_is_builtin_name(function):
return node
return self._dispatch_to_handler(node, function, node.args)
def visit_GeneralCallNode(self, node):
self.visitchildren(node)
function = node.function
if not self._function_is_builtin_name(function):
return node
arg_tuple = node.positional_args
if not isinstance(arg_tuple, ExprNodes.TupleNode):
return node
args = arg_tuple.args
return self._dispatch_to_handler(
node, function, args, node.keyword_args)
def _function_is_builtin_name(self, function):
if not function.is_name:
return False
entry = self.env_stack[-1].lookup(function.name)
if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
return False
return True
def _dispatch_to_handler(self, node, function, args, kwargs=None):
if kwargs is None:
handler_name = '_handle_simple_function_%s' % function.name
else:
handler_name = '_handle_general_function_%s' % function.name
handle_call = getattr(self, handler_name, None)
if handle_call is not None:
if kwargs is None:
return handle_call(node, args)
else:
return handle_call(node, args, kwargs)
return node
def _inject_capi_function(self, node, cname, func_type, utility_code=None):
node.function = ExprNodes.PythonCapiFunctionNode(
node.function.pos, node.function.name, cname, func_type,
utility_code = utility_code)
def _error_wrong_arg_count(self, function_name, node, args, expected=None):
if not expected: # None or 0
arg_str = ''
elif isinstance(expected, basestring) or expected > 1:
arg_str = '...'
elif expected == 1:
arg_str = 'x'
else:
arg_str = ''
if expected is not None:
expected_str = 'expected %s, ' % expected
else:
expected_str = ''
error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
function_name, arg_str, expected_str, len(args)))
# specific handlers for simple call nodes
def _handle_simple_function_set(self, node, pos_args):
"""Replace set([a,b,...]) by a literal set {a,b,...} and
set([ x for ... ]) by a literal { x for ... }.
"""
arg_count = len(pos_args)
if arg_count == 0:
return ExprNodes.SetNode(node.pos, args=[],
type=Builtin.set_type)
if arg_count > 1:
return node
iterable = pos_args[0]
if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
return ExprNodes.SetNode(node.pos, args=iterable.args)
elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
isinstance(iterable.target, (ExprNodes.ListNode,
ExprNodes.SetNode)):
iterable.target = ExprNodes.SetNode(node.pos, args=[])
iterable.pos = node.pos
return iterable
else:
return node
def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
"""
if len(pos_args) != 1:
return node
arg = pos_args[0]
if isinstance(arg, ExprNodes.ComprehensionNode) and \
isinstance(arg.target, (ExprNodes.ListNode,
ExprNodes.SetNode)):
append_node = arg.append
if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
len(append_node.expr.args) == 2:
key_node, value_node = append_node.expr.args
target_node = ExprNodes.DictNode(
pos=arg.target.pos, key_value_pairs=[])
new_append_node = ExprNodes.DictComprehensionAppendNode(
append_node.pos, target=target_node,
key_expr=key_node, value_expr=value_node)
arg.target = target_node
arg.type = target_node.type
replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
return replace_in(arg)
return node
PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
])
PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_getattr(self, node, pos_args):
if len(pos_args) == 2:
self._inject_capi_function(
node, "PyObject_GetAttr",
self.PyObject_GetAttr2_func_type)
elif len(pos_args) == 3:
self._inject_capi_function(
node, "__Pyx_GetAttr3",
self.PyObject_GetAttr3_func_type,
Builtin.getattr3_utility_code)
else:
self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
return node
Pyx_Type_func_type = PyrexTypes.CFuncType(
Builtin.type_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
])
def _handle_simple_function_type(self, node, pos_args):
if len(pos_args) != 1:
return node
self._inject_capi_function(
node, "__Pyx_Type",
self.Pyx_Type_func_type,
pytype_utility_code)
return node
# specific handlers for general call nodes
def _handle_general_function_dict(self, node, pos_args, kwargs):
"""Replace dict(a=b,c=d,...) by the underlying keyword dict
construction which is done anyway.
"""
if len(pos_args) > 0:
return node
if not isinstance(kwargs, ExprNodes.DictNode):
return node
if node.starstar_arg:
# we could optimize this by updating the kw dict instead
return node
return kwargs
class OptimizeBuiltinCalls(Visitor.EnvTransform):
"""Optimize some common methods calls and instantiation patterns
for builtin types.
for builtin types *after* the type analysis phase.
Running after type analysis, this transform can only perform
function replacements that do not alter the function return type
in a way that was not anticipated by the type analysis.
"""
# only intercept on call nodes
visit_Node = Visitor.VisitorTransform.recurse_to_children
......@@ -896,27 +1077,13 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
### builtin types
def _handle_general_function_dict(self, node, pos_args, kwargs):
"""Replace dict(a=b,c=d,...) by the underlying keyword dict
construction which is done anyway.
"""
if len(pos_args) > 0:
return node
if not isinstance(kwargs, ExprNodes.DictNode):
return node
if node.starstar_arg:
# we could optimize this by updating the kw dict instead
return node
return kwargs
PyDict_Copy_func_type = PyrexTypes.CFuncType(
Builtin.dict_type, [
PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
])
def _handle_simple_function_dict(self, node, pos_args):
"""Replace dict(some_dict) by PyDict_Copy(some_dict) and
dict([ (a,b) for ... ]) by a literal { a:b for ... }.
"""Replace dict(some_dict) by PyDict_Copy(some_dict).
"""
if len(pos_args) != 1:
return node
......@@ -929,48 +1096,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
args = [arg],
is_temp = node.is_temp
)
elif isinstance(arg, ExprNodes.ComprehensionNode) and \
arg.type is Builtin.list_type:
append_node = arg.append
if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
len(append_node.expr.args) == 2:
key_node, value_node = append_node.expr.args
target_node = ExprNodes.DictNode(
pos=arg.target.pos, key_value_pairs=[], is_temp=1)
new_append_node = ExprNodes.DictComprehensionAppendNode(
append_node.pos, target=target_node,
key_expr=key_node, value_expr=value_node,
is_temp=1)
arg.target = target_node
arg.type = target_node.type
replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
return replace_in(arg)
return node
def _handle_simple_function_set(self, node, pos_args):
"""Replace set([a,b,...]) by a literal set {a,b,...} and
set([ x for ... ]) by a literal { x for ... }.
"""
arg_count = len(pos_args)
if arg_count == 0:
return ExprNodes.SetNode(node.pos, args=[],
type=Builtin.set_type, is_temp=1)
if arg_count > 1:
return node
iterable = pos_args[0]
if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
return ExprNodes.SetNode(node.pos, args=iterable.args,
type=Builtin.set_type, is_temp=1)
elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
iterable.type is Builtin.list_type:
iterable.target = ExprNodes.SetNode(
node.pos, args=[], type=Builtin.set_type, is_temp=1)
iterable.type = Builtin.set_type
iterable.pos = node.pos
return iterable
else:
return node
PyList_AsTuple_func_type = PyrexTypes.CFuncType(
Builtin.tuple_type, [
PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
......@@ -998,53 +1125,6 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
### builtin functions
PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
])
PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_getattr(self, node, pos_args):
if len(pos_args) == 2:
node = ExprNodes.PythonCapiCallNode(
node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
args = pos_args,
is_temp = node.is_temp
)
elif len(pos_args) == 3:
node = ExprNodes.PythonCapiCallNode(
node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
utility_code = Builtin.getattr3_utility_code,
args = pos_args,
is_temp = node.is_temp
)
else:
self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
return node
Pyx_Type_func_type = PyrexTypes.CFuncType(
Builtin.type_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
])
def _handle_simple_function_type(self, node, pos_args):
if len(pos_args) != 1:
return node
node = ExprNodes.PythonCapiCallNode(
node.pos, "__Pyx_Type", self.Pyx_Type_func_type,
args = pos_args,
is_temp = node.is_temp,
utility_code = pytype_utility_code,
)
return node
Pyx_strlen_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_size_t_type, [
PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
......@@ -1065,10 +1145,10 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
return node
node = ExprNodes.PythonCapiCallNode(
node.pos, "strlen", self.Pyx_strlen_func_type,
args = [arg],
is_temp = node.is_temp,
utility_code = include_string_h_utility_code,
)
args = [arg],
is_temp = node.is_temp,
utility_code = include_string_h_utility_code
)
return node
### special methods
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment