Commit 992219fc authored by Stefan Behnel's avatar Stefan Behnel

speed up tree visitor somewhat by moving code out of the critical methods

parent d6bc4573
cimport cython
cdef class BasicVisitor: cdef class BasicVisitor:
cdef dict dispatch_table cdef dict dispatch_table
cpdef visit(self, obj) cpdef visit(self, obj)
cpdef find_handler(self, obj)
cdef class TreeVisitor(BasicVisitor): cdef class TreeVisitor(BasicVisitor):
cdef public list access_path cdef public list access_path
cpdef visitchild(self, child, parent, attrname, idx) cpdef visitchild(self, child, parent, attrname, idx)
@cython.locals(idx=int)
cpdef dict _visitchildren(self, parent, attrs)
# cpdef visitchildren(self, parent, attrs=*) # cpdef visitchildren(self, parent, attrs=*)
cdef class VisitorTransform(TreeVisitor): cdef class VisitorTransform(TreeVisitor):
......
# cython: infer_types=True
# #
# Tree visitor and transform framework # Tree visitor and transform framework
# #
...@@ -19,32 +21,36 @@ class BasicVisitor(object): ...@@ -19,32 +21,36 @@ class BasicVisitor(object):
self.dispatch_table = {} self.dispatch_table = {}
def visit(self, obj): def visit(self, obj):
cls = type(obj)
try: try:
handler_method = self.dispatch_table[cls] handler_method = self.dispatch_table[type(obj)]
except KeyError: except KeyError:
#print "Cache miss for class %s in visitor %s" % ( handler_method = self.find_handler(obj)
# cls.__name__, type(self).__name__) self.dispatch_table[type(obj)] = handler_method
# Must resolve, try entire hierarchy
pattern = "visit_%s"
mro = inspect.getmro(cls)
handler_method = None
for mro_cls in mro:
if hasattr(self, pattern % mro_cls.__name__):
handler_method = getattr(self, pattern % mro_cls.__name__)
break
if handler_method is None:
print type(self), type(obj)
if hasattr(self, 'access_path') and self.access_path:
print self.access_path
if self.access_path:
print self.access_path[-1][0].pos
print self.access_path[-1][0].__dict__
raise RuntimeError("Visitor does not accept object: %s" % obj)
#print "Caching " + cls.__name__
self.dispatch_table[cls] = handler_method
return handler_method(obj) return handler_method(obj)
def find_handler(self, obj):
cls = type(obj)
#print "Cache miss for class %s in visitor %s" % (
# cls.__name__, type(self).__name__)
# Must resolve, try entire hierarchy
pattern = "visit_%s"
mro = inspect.getmro(cls)
handler_method = None
for mro_cls in mro:
if hasattr(self, pattern % mro_cls.__name__):
handler_method = getattr(self, pattern % mro_cls.__name__)
break
if handler_method is None:
print type(self), cls
if hasattr(self, 'access_path') and self.access_path:
print self.access_path
if self.access_path:
print self.access_path[-1][0].pos
print self.access_path[-1][0].__dict__
raise RuntimeError("Visitor does not accept object: %s" % obj)
#print "Caching " + cls.__name__
return handler_method
class TreeVisitor(BasicVisitor): class TreeVisitor(BasicVisitor):
""" """
Base class for writing visitors for a Cython tree, contains utilities for Base class for writing visitors for a Cython tree, contains utilities for
...@@ -144,6 +150,29 @@ class TreeVisitor(BasicVisitor): ...@@ -144,6 +150,29 @@ class TreeVisitor(BasicVisitor):
stacktrace = stacktrace.tb_next stacktrace = stacktrace.tb_next
return (last_traceback, nodes) return (last_traceback, nodes)
def _raise_compiler_error(self, child, e):
import sys
trace = ['']
for parent, attribute, index in self.access_path:
node = getattr(parent, attribute)
if index is None:
index = ''
else:
node = node[index]
index = u'[%d]' % index
trace.append(u'%s.%s%s = %s' % (
parent.__class__.__name__, attribute, index,
self.dump_node(node)))
stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
last_node = child
for node, method_name, pos in called_nodes:
last_node = node
trace.append(u"File '%s', line %d, in %s: %s" % (
pos[0], pos[1], method_name, self.dump_node(node)))
raise Errors.CompilerCrash(
last_node.pos, self.__class__.__name__,
u'\n'.join(trace), e, stacktrace)
def visitchild(self, child, parent, attrname, idx): def visitchild(self, child, parent, attrname, idx):
self.access_path.append((parent, attrname, idx)) self.access_path.append((parent, attrname, idx))
try: try:
...@@ -151,33 +180,16 @@ class TreeVisitor(BasicVisitor): ...@@ -151,33 +180,16 @@ class TreeVisitor(BasicVisitor):
except Errors.CompileError: except Errors.CompileError:
raise raise
except Exception, e: except Exception, e:
import sys
if DebugFlags.debug_no_exception_intercept: if DebugFlags.debug_no_exception_intercept:
raise raise
trace = [''] self._raise_compiler_error(child, e)
for parent, attribute, index in self.access_path:
node = getattr(parent, attribute)
if index is None:
index = ''
else:
node = node[index]
index = u'[%d]' % index
trace.append(u'%s.%s%s = %s' % (
parent.__class__.__name__, attribute, index,
self.dump_node(node)))
stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
last_node = child
for node, method_name, pos in called_nodes:
last_node = node
trace.append(u"File '%s', line %d, in %s: %s" % (
pos[0], pos[1], method_name, self.dump_node(node)))
raise Errors.CompilerCrash(
last_node.pos, self.__class__.__name__,
u'\n'.join(trace), e, stacktrace)
self.access_path.pop() self.access_path.pop()
return result return result
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
return self._visitchildren(parent, attrs)
def _visitchildren(self, parent, attrs):
""" """
Visits the children of the given parent. If parent is None, returns Visits the children of the given parent. If parent is None, returns
immediately (returning None). immediately (returning None).
...@@ -223,8 +235,7 @@ class VisitorTransform(TreeVisitor): ...@@ -223,8 +235,7 @@ class VisitorTransform(TreeVisitor):
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = cython.declare(dict) result = self._visitchildren(parent, attrs)
result = TreeVisitor.visitchildren(self, parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
if not type(newnode) is list: if not type(newnode) is list:
setattr(parent, attr, newnode) setattr(parent, attr, newnode)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment