Commit f0298152 authored by Stefan Behnel's avatar Stefan Behnel

enable MethodDispatcherTransform() to dispatch also on special method calls triggered by operators

parent 23350e5a
...@@ -29,7 +29,10 @@ cdef class EnvTransform(CythonTransform): ...@@ -29,7 +29,10 @@ cdef class EnvTransform(CythonTransform):
cdef class MethodDispatcherTransform(EnvTransform): cdef class MethodDispatcherTransform(EnvTransform):
cdef _find_handler(self, match_name, bint has_kwargs) cdef _find_handler(self, match_name, bint has_kwargs)
cdef _dispatch_to_handler(self, node, function, arg_list, kwargs=*) cdef _dispatch_to_handler(self, node, function, arg_list, kwargs)
cdef _dispatch_to_method_handler(self, attr_name, self_arg,
is_unbound_method, type_name,
node, arg_list, kwargs)
cdef class RecursiveNodeReplacer(VisitorTransform): cdef class RecursiveNodeReplacer(VisitorTransform):
cdef public orig_node cdef public orig_node
......
...@@ -420,11 +420,43 @@ class NodeRefCleanupMixin(object): ...@@ -420,11 +420,43 @@ class NodeRefCleanupMixin(object):
return replacement return replacement
find_special_method_for_binary_operator = {
'<': '__lt__',
'<=': '__le__',
'==': '__eq__',
'!=': '__ne__',
'>=': '__ge__',
'>': '__gt__',
'+': '__add__',
'&': '__and__',
'/': '__truediv__',
'//': '__floordiv__',
'<<': '__lshift__',
'%': '__mod__',
'*': '__mul__',
'|': '__or__',
'**': '__pow__',
'>>': '__rshift__',
'-': '__sub__',
'^': '__xor__',
'in': '__contains__',
}.get
find_special_method_for_unary_operator = {
'not': '__not__',
'~': '__inv__',
'-': '__neg__',
'+': '__pos__',
}.get
class MethodDispatcherTransform(EnvTransform): class MethodDispatcherTransform(EnvTransform):
""" """
Base class for transformations that want to intercept on specific Base class for transformations that want to intercept on specific
builtin functions or methods of builtin types. Must run after builtin functions or methods of builtin types, including special
declaration analysis when entries were assigned. methods triggered by Python operators. Must run after declaration
analysis when entries were assigned.
Naming pattern for handler methods is as follows: Naming pattern for handler methods is as follows:
...@@ -432,7 +464,7 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -432,7 +464,7 @@ class MethodDispatcherTransform(EnvTransform):
* builtin methods: _handle_(general|simple|any)_method_TYPENAME_METHODNAME * builtin methods: _handle_(general|simple|any)_method_TYPENAME_METHODNAME
""" """
# only visit call nodes # only visit call nodes and Python operations
def visit_GeneralCallNode(self, node): def visit_GeneralCallNode(self, node):
self.visitchildren(node) self.visitchildren(node)
function = node.function function = node.function
...@@ -446,8 +478,7 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -446,8 +478,7 @@ class MethodDispatcherTransform(EnvTransform):
# can't handle **kwargs # can't handle **kwargs
return node return node
args = arg_tuple.args args = arg_tuple.args
return self._dispatch_to_handler( return self._dispatch_to_handler(node, function, args, keyword_args)
node, function, args, keyword_args)
def visit_SimpleCallNode(self, node): def visit_SimpleCallNode(self, node):
self.visitchildren(node) self.visitchildren(node)
...@@ -459,8 +490,40 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -459,8 +490,40 @@ class MethodDispatcherTransform(EnvTransform):
args = arg_tuple.args args = arg_tuple.args
else: else:
args = node.args args = node.args
return self._dispatch_to_handler( return self._dispatch_to_handler(node, function, args, None)
node, function, args)
def visit_BinopNode(self, node):
self.visitchildren(node)
# FIXME: could special case 'not_in'
special_method_name = find_special_method_for_binary_operator(node.operator)
if special_method_name:
operand1, operand2 = node.operand1, node.operand2
if special_method_name == '__contains__':
operand1, operand2 = operand2, operand1
obj_type = operand1.type
if obj_type.is_builtin_type:
type_name = obj_type.name
else:
type_name = "object" # safety measure
node = self._dispatch_to_method_handler(
special_method_name, None, False, type_name,
node, [operand1, operand2], None)
return node
def visit_UnopNode(self, node):
self.visitchildren(node)
special_method_name = find_special_method_for_unary_operator(node.operator)
if special_method_name:
operand = node.operand
obj_type = operand.type
if obj_type.is_builtin_type:
type_name = obj_type.name
else:
type_name = "object" # safety measure
node = self._dispatch_to_method_handler(
special_method_name, None, False, type_name,
node, [operand], None)
return node
### dispatch to specific handlers ### dispatch to specific handlers
...@@ -500,31 +563,38 @@ class MethodDispatcherTransform(EnvTransform): ...@@ -500,31 +563,38 @@ class MethodDispatcherTransform(EnvTransform):
arg_list and arg_list[0].type.is_pyobject): arg_list and arg_list[0].type.is_pyobject):
# calling an unbound method like 'list.append(L,x)' # calling an unbound method like 'list.append(L,x)'
# (ignoring 'type.mro()' here ...) # (ignoring 'type.mro()' here ...)
type_name = function.obj.name type_name = self_arg.name
self_arg = None self_arg = None
is_unbound_method = True is_unbound_method = True
else: else:
type_name = obj_type.name type_name = obj_type.name
else: else:
type_name = "object" # safety measure type_name = "object" # safety measure
method_handler = self._find_handler( return self._dispatch_to_method_handler(
"method_%s_%s" % (type_name, attr_name), kwargs) attr_name, self_arg, is_unbound_method, type_name,
if method_handler is None: node, arg_list, kwargs)
if attr_name in TypeSlots.method_name_to_slot\
or attr_name == '__new__':
method_handler = self._find_handler(
"slot%s" % attr_name, kwargs)
if method_handler is None:
return node
if self_arg is not None:
arg_list = [self_arg] + list(arg_list)
if kwargs:
return method_handler(node, arg_list, kwargs, is_unbound_method)
else:
return method_handler(node, arg_list, is_unbound_method)
else: else:
return node return node
def _dispatch_to_method_handler(self, attr_name, self_arg,
is_unbound_method, type_name,
node, arg_list, kwargs):
method_handler = self._find_handler(
"method_%s_%s" % (type_name, attr_name), kwargs)
if method_handler is None:
if (attr_name in TypeSlots.method_name_to_slot
or attr_name == '__new__'):
method_handler = self._find_handler(
"slot%s" % attr_name, kwargs)
if method_handler is None:
return node
if self_arg is not None:
arg_list = [self_arg] + list(arg_list)
if kwargs:
return method_handler(node, arg_list, kwargs, is_unbound_method)
else:
return method_handler(node, arg_list, is_unbound_method)
class RecursiveNodeReplacer(VisitorTransform): class RecursiveNodeReplacer(VisitorTransform):
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment