Commit 49ef8457 authored by Stefan Behnel's avatar Stefan Behnel

Speed up the recursion to children in Visitor.py by avoiding cpdef call...

Speed up the recursion to children in Visitor.py by avoiding cpdef call overhead (most transforms are Python classes with class dicts that require an override check).
parent 928ec4b0
...@@ -14,8 +14,8 @@ cdef class TreeVisitor: ...@@ -14,8 +14,8 @@ cdef class TreeVisitor:
cpdef visitchildren(self, parent, attrs=*) cpdef visitchildren(self, parent, attrs=*)
cdef class VisitorTransform(TreeVisitor): cdef class VisitorTransform(TreeVisitor):
cdef dict _process_children(self, parent, attrs=*)
cpdef visitchildren(self, parent, attrs=*) cpdef visitchildren(self, parent, attrs=*)
cpdef recurse_to_children(self, node)
cdef class CythonTransform(VisitorTransform): cdef class CythonTransform(VisitorTransform):
cdef public context cdef public context
......
...@@ -245,6 +245,12 @@ class VisitorTransform(TreeVisitor): ...@@ -245,6 +245,12 @@ class VisitorTransform(TreeVisitor):
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
# generic def entry point for calls from Python subclasses
return self._process_children(parent, attrs)
@cython.final
def _process_children(self, parent, attrs=None):
# fast cdef entry point for calls from Cython subclasses
result = self._visitchildren(parent, attrs) result = self._visitchildren(parent, attrs)
for attr, newnode in result.items(): for attr, newnode in result.items():
if type(newnode) is not list: if type(newnode) is not list:
...@@ -262,7 +268,7 @@ class VisitorTransform(TreeVisitor): ...@@ -262,7 +268,7 @@ class VisitorTransform(TreeVisitor):
return result return result
def recurse_to_children(self, node): def recurse_to_children(self, node):
self.visitchildren(node) self._process_children(node)
return node return node
def __call__(self, root): def __call__(self, root):
...@@ -288,14 +294,15 @@ class CythonTransform(VisitorTransform): ...@@ -288,14 +294,15 @@ class CythonTransform(VisitorTransform):
def visit_CompilerDirectivesNode(self, node): def visit_CompilerDirectivesNode(self, node):
old = self.current_directives old = self.current_directives
self.current_directives = node.directives self.current_directives = node.directives
self.visitchildren(node) self._process_children(node)
self.current_directives = old self.current_directives = old
return node return node
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self._process_children(node)
return node return node
class ScopeTrackingTransform(CythonTransform): class ScopeTrackingTransform(CythonTransform):
# Keeps track of type of scopes # Keeps track of type of scopes
#scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct' #scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct'
...@@ -304,14 +311,14 @@ class ScopeTrackingTransform(CythonTransform): ...@@ -304,14 +311,14 @@ class ScopeTrackingTransform(CythonTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.scope_type = 'module' self.scope_type = 'module'
self.scope_node = node self.scope_node = node
self.visitchildren(node) self._process_children(node)
return node return node
def visit_scope(self, node, scope_type): def visit_scope(self, node, scope_type):
prev = self.scope_type, self.scope_node prev = self.scope_type, self.scope_node
self.scope_type = scope_type self.scope_type = scope_type
self.scope_node = node self.scope_node = node
self.visitchildren(node) self._process_children(node)
self.scope_type, self.scope_node = prev self.scope_type, self.scope_node = prev
return node return node
...@@ -354,45 +361,45 @@ class EnvTransform(CythonTransform): ...@@ -354,45 +361,45 @@ class EnvTransform(CythonTransform):
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.enter_scope(node, node.local_scope) self.enter_scope(node, node.local_scope)
self.visitchildren(node) self._process_children(node)
self.exit_scope() self.exit_scope()
return node return node
def visit_GeneratorBodyDefNode(self, node): def visit_GeneratorBodyDefNode(self, node):
self.visitchildren(node) self._process_children(node)
return node return node
def visit_ClassDefNode(self, node): def visit_ClassDefNode(self, node):
self.enter_scope(node, node.scope) self.enter_scope(node, node.scope)
self.visitchildren(node) self._process_children(node)
self.exit_scope() self.exit_scope()
return node return node
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
self.enter_scope(node, node.scope) self.enter_scope(node, node.scope)
self.visitchildren(node) self._process_children(node)
self.exit_scope() self.exit_scope()
return node return node
def visit_ScopedExprNode(self, node): def visit_ScopedExprNode(self, node):
if node.expr_scope: if node.expr_scope:
self.enter_scope(node, node.expr_scope) self.enter_scope(node, node.expr_scope)
self.visitchildren(node) self._process_children(node)
self.exit_scope() self.exit_scope()
else: else:
self.visitchildren(node) self._process_children(node)
return node return node
def visit_CArgDeclNode(self, node): def visit_CArgDeclNode(self, node):
# default arguments are evaluated in the outer scope # default arguments are evaluated in the outer scope
if node.default: if node.default:
attrs = [attr for attr in node.child_attrs if attr != 'default'] attrs = [attr for attr in node.child_attrs if attr != 'default']
self.visitchildren(node, attrs) self._process_children(node, attrs)
self.enter_scope(node, self.current_env().outer_scope) self.enter_scope(node, self.current_env().outer_scope)
self.visitchildren(node, ('default',)) self.visitchildren(node, ('default',))
self.exit_scope() self.exit_scope()
else: else:
self.visitchildren(node) self._process_children(node)
return node return node
...@@ -477,7 +484,7 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -477,7 +484,7 @@ class MethodDispatcherTransform(EnvTransform):
""" """
# only visit call nodes and Python operations # only visit call nodes and Python operations
def visit_GeneralCallNode(self, node): def visit_GeneralCallNode(self, node):
self.visitchildren(node) self._process_children(node)
function = node.function function = node.function
if not function.type.is_pyobject: if not function.type.is_pyobject:
return node return node
...@@ -492,7 +499,7 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -492,7 +499,7 @@ class MethodDispatcherTransform(EnvTransform):
return self._dispatch_to_handler(node, function, args, keyword_args) return self._dispatch_to_handler(node, function, args, keyword_args)
def visit_SimpleCallNode(self, node): def visit_SimpleCallNode(self, node):
self.visitchildren(node) self._process_children(node)
function = node.function function = node.function
if function.type.is_pyobject: if function.type.is_pyobject:
arg_tuple = node.arg_tuple arg_tuple = node.arg_tuple
...@@ -506,7 +513,7 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -506,7 +513,7 @@ class MethodDispatcherTransform(EnvTransform):
def visit_PrimaryCmpNode(self, node): def visit_PrimaryCmpNode(self, node):
if node.cascade: if node.cascade:
# not currently handled below # not currently handled below
self.visitchildren(node) self._process_children(node)
return node return node
return self._visit_binop_node(node) return self._visit_binop_node(node)
...@@ -514,7 +521,7 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -514,7 +521,7 @@ class MethodDispatcherTransform(EnvTransform):
return self._visit_binop_node(node) return self._visit_binop_node(node)
def _visit_binop_node(self, node): def _visit_binop_node(self, node):
self.visitchildren(node) self._process_children(node)
# FIXME: could special case 'not_in' # FIXME: could special case 'not_in'
special_method_name = find_special_method_for_binary_operator(node.operator) special_method_name = find_special_method_for_binary_operator(node.operator)
if special_method_name: if special_method_name:
...@@ -535,7 +542,7 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -535,7 +542,7 @@ class MethodDispatcherTransform(EnvTransform):
return node return node
def visit_UnopNode(self, node): def visit_UnopNode(self, node):
self.visitchildren(node) self._process_children(node)
special_method_name = find_special_method_for_unary_operator(node.operator) special_method_name = find_special_method_for_unary_operator(node.operator)
if special_method_name: if special_method_name:
operand = node.operand operand = node.operand
...@@ -690,7 +697,7 @@ class RecursiveNodeReplacer(VisitorTransform): ...@@ -690,7 +697,7 @@ class RecursiveNodeReplacer(VisitorTransform):
return node return node
def visit_Node(self, node): def visit_Node(self, node):
self.visitchildren(node) self._process_children(node)
if node is self.orig_node: if node is self.orig_node:
return self.new_node return self.new_node
else: else:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment