Commit d83f2b0e authored by Stefan Behnel's avatar Stefan Behnel

merge

parents a30df984 53045f65
...@@ -8,7 +8,6 @@ compile-time values. ...@@ -8,7 +8,6 @@ compile-time values.
from Nodes import * from Nodes import *
from ExprNodes import * from ExprNodes import *
from Visitor import BasicVisitor
from Errors import CompileError from Errors import CompileError
......
...@@ -843,33 +843,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform): ...@@ -843,33 +843,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
return replace_in(arg) return replace_in(arg)
return node return node
PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
])
PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_getattr(self, node, pos_args):
if len(pos_args) == 2:
self._inject_capi_function(
node, "PyObject_GetAttr",
self.PyObject_GetAttr2_func_type)
elif len(pos_args) == 3:
self._inject_capi_function(
node, "__Pyx_GetAttr3",
self.PyObject_GetAttr3_func_type,
Builtin.getattr3_utility_code)
else:
self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
return node
Pyx_Type_func_type = PyrexTypes.CFuncType( Pyx_Type_func_type = PyrexTypes.CFuncType(
Builtin.type_type, [ Builtin.type_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
...@@ -1160,6 +1133,35 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): ...@@ -1160,6 +1133,35 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
### builtin functions ### builtin functions
PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
])
PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
])
def _handle_simple_function_getattr(self, node, pos_args):
if len(pos_args) == 2:
return ExprNodes.PythonCapiCallNode(
node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
args = pos_args,
is_temp = node.is_temp)
elif len(pos_args) == 3:
return ExprNodes.PythonCapiCallNode(
node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
args = pos_args,
is_temp = node.is_temp,
utility_code = Builtin.getattr3_utility_code)
else:
self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
return node
Pyx_strlen_func_type = PyrexTypes.CFuncType( Pyx_strlen_func_type = PyrexTypes.CFuncType(
PyrexTypes.c_size_t_type, [ PyrexTypes.c_size_t_type, [
PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None) PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
......
cimport cython
cdef class BasicVisitor: cdef class BasicVisitor:
cdef dict dispatch_table cdef dict dispatch_table
cpdef visit(self, obj) cpdef visit(self, obj)
cpdef find_handler(self, obj)
cdef class TreeVisitor(BasicVisitor): cdef class TreeVisitor(BasicVisitor):
cdef public list access_path cdef public list access_path
cpdef visitchild(self, child, parent, attrname, idx) cpdef visitchild(self, child, parent, attrname, idx)
@cython.locals(idx=int)
cpdef dict _visitchildren(self, parent, attrs)
# cpdef visitchildren(self, parent, attrs=*) # cpdef visitchildren(self, parent, attrs=*)
cdef class VisitorTransform(TreeVisitor): cdef class VisitorTransform(TreeVisitor):
......
# cython: infer_types=True
# #
# Tree visitor and transform framework # Tree visitor and transform framework
# #
...@@ -8,7 +10,6 @@ import ExprNodes ...@@ -8,7 +10,6 @@ import ExprNodes
import Naming import Naming
import Errors import Errors
import DebugFlags import DebugFlags
from StringEncoding import EncodedString
class BasicVisitor(object): class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object.""" """A generic visitor base class which can be used for visiting any kind of object."""
...@@ -19,10 +20,15 @@ class BasicVisitor(object): ...@@ -19,10 +20,15 @@ class BasicVisitor(object):
self.dispatch_table = {} self.dispatch_table = {}
def visit(self, obj): def visit(self, obj):
cls = type(obj)
try: try:
handler_method = self.dispatch_table[cls] handler_method = self.dispatch_table[type(obj)]
except KeyError: except KeyError:
handler_method = self.find_handler(obj)
self.dispatch_table[type(obj)] = handler_method
return handler_method(obj)
def find_handler(self, obj):
cls = type(obj)
#print "Cache miss for class %s in visitor %s" % ( #print "Cache miss for class %s in visitor %s" % (
# cls.__name__, type(self).__name__) # cls.__name__, type(self).__name__)
# Must resolve, try entire hierarchy # Must resolve, try entire hierarchy
...@@ -34,7 +40,7 @@ class BasicVisitor(object): ...@@ -34,7 +40,7 @@ class BasicVisitor(object):
handler_method = getattr(self, pattern % mro_cls.__name__) handler_method = getattr(self, pattern % mro_cls.__name__)
break break
if handler_method is None: if handler_method is None:
print type(self), type(obj) print type(self), cls
if hasattr(self, 'access_path') and self.access_path: if hasattr(self, 'access_path') and self.access_path:
print self.access_path print self.access_path
if self.access_path: if self.access_path:
...@@ -42,8 +48,7 @@ class BasicVisitor(object): ...@@ -42,8 +48,7 @@ class BasicVisitor(object):
print self.access_path[-1][0].__dict__ print self.access_path[-1][0].__dict__
raise RuntimeError("Visitor does not accept object: %s" % obj) raise RuntimeError("Visitor does not accept object: %s" % obj)
#print "Caching " + cls.__name__ #print "Caching " + cls.__name__
self.dispatch_table[cls] = handler_method return handler_method
return handler_method(obj)
class TreeVisitor(BasicVisitor): class TreeVisitor(BasicVisitor):
""" """
...@@ -144,16 +149,8 @@ class TreeVisitor(BasicVisitor): ...@@ -144,16 +149,8 @@ class TreeVisitor(BasicVisitor):
stacktrace = stacktrace.tb_next stacktrace = stacktrace.tb_next
return (last_traceback, nodes) return (last_traceback, nodes)
def visitchild(self, child, parent, attrname, idx): def _raise_compiler_error(self, child, e):
self.access_path.append((parent, attrname, idx))
try:
result = self.visit(child)
except Errors.CompileError:
raise
except Exception, e:
import sys import sys
if DebugFlags.debug_no_exception_intercept:
raise
trace = [''] trace = ['']
for parent, attribute, index in self.access_path: for parent, attribute, index in self.access_path:
node = getattr(parent, attribute) node = getattr(parent, attribute)
...@@ -174,10 +171,24 @@ class TreeVisitor(BasicVisitor): ...@@ -174,10 +171,24 @@ class TreeVisitor(BasicVisitor):
raise Errors.CompilerCrash( raise Errors.CompilerCrash(
last_node.pos, self.__class__.__name__, last_node.pos, self.__class__.__name__,
u'\n'.join(trace), e, stacktrace) u'\n'.join(trace), e, stacktrace)
def visitchild(self, child, parent, attrname, idx):
self.access_path.append((parent, attrname, idx))
try:
result = self.visit(child)
except Errors.CompileError:
raise
except Exception, e:
if DebugFlags.debug_no_exception_intercept:
raise
self._raise_compiler_error(child, e)
self.access_path.pop() self.access_path.pop()
return result return result
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
return self._visitchildren(parent, attrs)
def _visitchildren(self, parent, attrs):
""" """
Visits the children of the given parent. If parent is None, returns Visits the children of the given parent. If parent is None, returns
immediately (returning None). immediately (returning None).
...@@ -223,8 +234,7 @@ class VisitorTransform(TreeVisitor): ...@@ -223,8 +234,7 @@ class VisitorTransform(TreeVisitor):
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = cython.declare(dict) result = self._visitchildren(parent, attrs)
result = TreeVisitor.visitchildren(self, parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
if not type(newnode) is list: if not type(newnode) is list:
setattr(parent, attr, newnode) setattr(parent, attr, newnode)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment