Commit d83f2b0e authored by Stefan Behnel's avatar Stefan Behnel

merge

parents a30df984 53045f65
...@@ -8,7 +8,6 @@ compile-time values. ...@@ -8,7 +8,6 @@ compile-time values.
from Nodes import * from Nodes import *
from ExprNodes import * from ExprNodes import *
from Visitor import BasicVisitor
from Errors import CompileError from Errors import CompileError
......
...@@ -843,33 +843,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -843,33 +843,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return replace_in(arg) return replace_in(arg)
return node return node
PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
])
PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_getattr(self, node, pos_args):
if len(pos_args) == 2:
self._inject_capi_function(
node, "PyObject_GetAttr",
self.PyObject_GetAttr2_func_type)
elif len(pos_args) == 3:
self._inject_capi_function(
node, "__Pyx_GetAttr3",
self.PyObject_GetAttr3_func_type,
Builtin.getattr3_utility_code)
else:
self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
return node
Pyx_Type_func_type = PyrexTypes.CFuncType( Pyx_Type_func_type = PyrexTypes.CFuncType(
Builtin.type_type, [ Builtin.type_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
...@@ -1160,6 +1133,35 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -1160,6 +1133,35 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
### builtin functions ### builtin functions
PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
])
PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_getattr(self, node, pos_args):
if len(pos_args) == 2:
return ExprNodes.PythonCapiCallNode(
node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
args = pos_args,
is_temp = node.is_temp)
elif len(pos_args) == 3:
return ExprNodes.PythonCapiCallNode(
node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
args = pos_args,
is_temp = node.is_temp,
utility_code = Builtin.getattr3_utility_code)
else:
self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
return node
Pyx_strlen_func_type = PyrexTypes.CFuncType( Pyx_strlen_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_size_t_type, [ PyrexTypes.c_size_t_type, [
PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None) PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
......
cimport cython
cdef class BasicVisitor: cdef class BasicVisitor:
cdef dict dispatch_table cdef dict dispatch_table
cpdef visit(self, obj) cpdef visit(self, obj)
cpdef find_handler(self, obj)
cdef class TreeVisitor(BasicVisitor): cdef class TreeVisitor(BasicVisitor):
cdef public list access_path cdef public list access_path
cpdef visitchild(self, child, parent, attrname, idx) cpdef visitchild(self, child, parent, attrname, idx)
@cython.locals(idx=int)
cpdef dict _visitchildren(self, parent, attrs)
# cpdef visitchildren(self, parent, attrs=*) # cpdef visitchildren(self, parent, attrs=*)
cdef class VisitorTransform(TreeVisitor): cdef class VisitorTransform(TreeVisitor):
......
# cython: infer_types=True
# #
# Tree visitor and transform framework # Tree visitor and transform framework
# #
...@@ -8,7 +10,6 @@ import ExprNodes ...@@ -8,7 +10,6 @@ import ExprNodes
import Naming import Naming
import Errors import Errors
import DebugFlags import DebugFlags
from StringEncoding import EncodedString
class BasicVisitor(object): class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object.""" """A generic visitor base class which can be used for visiting any kind of object."""
...@@ -19,32 +20,36 @@ class BasicVisitor(object): ...@@ -19,32 +20,36 @@ class BasicVisitor(object):
self.dispatch_table = {} self.dispatch_table = {}
def visit(self, obj): def visit(self, obj):
cls = type(obj)
try: try:
handler_method = self.dispatch_table[cls] handler_method = self.dispatch_table[type(obj)]
except KeyError: except KeyError:
#print "Cache miss for class %s in visitor %s" % ( handler_method = self.find_handler(obj)
# cls.__name__, type(self).__name__) self.dispatch_table[type(obj)] = handler_method
# Must resolve, try entire hierarchy
pattern = "visit_%s"
mro = inspect.getmro(cls)
handler_method = None
for mro_cls in mro:
if hasattr(self, pattern % mro_cls.__name__):
handler_method = getattr(self, pattern % mro_cls.__name__)
break
if handler_method is None:
print type(self), type(obj)
if hasattr(self, 'access_path') and self.access_path:
print self.access_path
if self.access_path:
print self.access_path[-1][0].pos
print self.access_path[-1][0].__dict__
raise RuntimeError("Visitor does not accept object: %s" % obj)
#print "Caching " + cls.__name__
self.dispatch_table[cls] = handler_method
return handler_method(obj) return handler_method(obj)
def find_handler(self, obj):
cls = type(obj)
#print "Cache miss for class %s in visitor %s" % (
# cls.__name__, type(self).__name__)
# Must resolve, try entire hierarchy
pattern = "visit_%s"
mro = inspect.getmro(cls)
handler_method = None
for mro_cls in mro:
if hasattr(self, pattern % mro_cls.__name__):
handler_method = getattr(self, pattern % mro_cls.__name__)
break
if handler_method is None:
print type(self), cls
if hasattr(self, 'access_path') and self.access_path:
print self.access_path
if self.access_path:
print self.access_path[-1][0].pos
print self.access_path[-1][0].__dict__
raise RuntimeError("Visitor does not accept object: %s" % obj)
#print "Caching " + cls.__name__
return handler_method
class TreeVisitor(BasicVisitor): class TreeVisitor(BasicVisitor):
""" """
Base class for writing visitors for a Cython tree, contains utilities for Base class for writing visitors for a Cython tree, contains utilities for
...@@ -144,6 +149,29 @@ class TreeVisitor(BasicVisitor): ...@@ -144,6 +149,29 @@ class TreeVisitor(BasicVisitor):
stacktrace = stacktrace.tb_next stacktrace = stacktrace.tb_next
return (last_traceback, nodes) return (last_traceback, nodes)
def _raise_compiler_error(self, child, e):
import sys
trace = ['']
for parent, attribute, index in self.access_path:
node = getattr(parent, attribute)
if index is None:
index = ''
else:
node = node[index]
index = u'[%d]' % index
trace.append(u'%s.%s%s = %s' % (
parent.__class__.__name__, attribute, index,
self.dump_node(node)))
stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
last_node = child
for node, method_name, pos in called_nodes:
last_node = node
trace.append(u"File '%s', line %d, in %s: %s" % (
pos[0], pos[1], method_name, self.dump_node(node)))
raise Errors.CompilerCrash(
last_node.pos, self.__class__.__name__,
u'\n'.join(trace), e, stacktrace)
def visitchild(self, child, parent, attrname, idx): def visitchild(self, child, parent, attrname, idx):
self.access_path.append((parent, attrname, idx)) self.access_path.append((parent, attrname, idx))
try: try:
...@@ -151,33 +179,16 @@ class TreeVisitor(BasicVisitor): ...@@ -151,33 +179,16 @@ class TreeVisitor(BasicVisitor):
except Errors.CompileError: except Errors.CompileError:
raise raise
except Exception, e: except Exception, e:
import sys
if DebugFlags.debug_no_exception_intercept: if DebugFlags.debug_no_exception_intercept:
raise raise
trace = [''] self._raise_compiler_error(child, e)
for parent, attribute, index in self.access_path:
node = getattr(parent, attribute)
if index is None:
index = ''
else:
node = node[index]
index = u'[%d]' % index
trace.append(u'%s.%s%s = %s' % (
parent.__class__.__name__, attribute, index,
self.dump_node(node)))
stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
last_node = child
for node, method_name, pos in called_nodes:
last_node = node
trace.append(u"File '%s', line %d, in %s: %s" % (
pos[0], pos[1], method_name, self.dump_node(node)))
raise Errors.CompilerCrash(
last_node.pos, self.__class__.__name__,
u'\n'.join(trace), e, stacktrace)
self.access_path.pop() self.access_path.pop()
return result return result
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
return self._visitchildren(parent, attrs)
def _visitchildren(self, parent, attrs):
""" """
Visits the children of the given parent. If parent is None, returns Visits the children of the given parent. If parent is None, returns
immediately (returning None). immediately (returning None).
...@@ -223,8 +234,7 @@ class VisitorTransform(TreeVisitor): ...@@ -223,8 +234,7 @@ class VisitorTransform(TreeVisitor):
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = cython.declare(dict) result = self._visitchildren(parent, attrs)
result = TreeVisitor.visitchildren(self, parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
if not type(newnode) is list: if not type(newnode) is list:
setattr(parent, attr, newnode) setattr(parent, attr, newnode)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment