Commit af9d9076 authored by da-woods's avatar da-woods Committed by GitHub

Fix infinite recursion in binops code (GH-4204)

Closes https://github.com/cython/cython/issues/4172
parent 875584d1
...@@ -2247,9 +2247,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2247,9 +2247,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"right, left" if reverse else "left, right", "right, left" if reverse else "left, right",
extra_arg) extra_arg)
else: else:
return '%s_maybe_call_slot(%s, left, right %s)' % ( return '%s_maybe_call_slot(%s->tp_base, left, right %s)' % (
func_name, func_name,
'Py_TYPE(right)->tp_base' if reverse else 'Py_TYPE(left)->tp_base', scope.parent_type.typeptr_cname,
extra_arg) extra_arg)
if get_slot_method_cname(slot.left_slot.method_name) and not get_slot_method_cname(slot.right_slot.method_name): if get_slot_method_cname(slot.left_slot.method_name) and not get_slot_method_cname(slot.right_slot.method_name):
...@@ -2260,13 +2260,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -2260,13 +2260,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
slot.right_slot.method_name, slot.right_slot.method_name,
)) ))
overloads_left = int(bool(get_slot_method_cname(slot.left_slot.method_name)))
overloads_right = int(bool(get_slot_method_cname(slot.right_slot.method_name)))
code.putln( code.putln(
TempitaUtilityCode.load_as_string( TempitaUtilityCode.load_as_string(
"BinopSlot", "ExtensionTypes.c", "BinopSlot", "ExtensionTypes.c",
context={ context={
"func_name": func_name, "func_name": func_name,
"slot_name": slot.slot_name, "slot_name": slot.slot_name,
"overloads_left": int(bool(get_slot_method_cname(slot.left_slot.method_name))), "overloads_left": overloads_left,
"overloads_right": overloads_right,
"call_left": call_slot_method(slot.left_slot.method_name, reverse=False), "call_left": call_slot_method(slot.left_slot.method_name, reverse=False),
"call_right": call_slot_method(slot.right_slot.method_name, reverse=True), "call_right": call_slot_method(slot.right_slot.method_name, reverse=True),
"type_cname": scope.parent_type.typeptr_cname, "type_cname": scope.parent_type.typeptr_cname,
......
...@@ -510,7 +510,7 @@ static PyObject *{{func_name}}(PyObject *left, PyObject *right {{extra_arg_decl} ...@@ -510,7 +510,7 @@ static PyObject *{{func_name}}(PyObject *left, PyObject *right {{extra_arg_decl}
} }
if (maybe_self_is_left) { if (maybe_self_is_left) {
PyObject *res; PyObject *res;
if (maybe_self_is_right && !({{overloads_left}})) { if (maybe_self_is_right && {{overloads_right}} && !({{overloads_left}})) {
res = {{call_right}}; res = {{call_right}};
if (res != Py_NotImplemented) return res; if (res != Py_NotImplemented) return res;
Py_DECREF(res); Py_DECREF(res);
......
...@@ -9,6 +9,15 @@ class Base(object): ...@@ -9,6 +9,15 @@ class Base(object):
>>> 2 + Base() >>> 2 + Base()
'Base.__radd__(Base(), 2)' 'Base.__radd__(Base(), 2)'
>>> Base(implemented=False) + 2 #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
>>> 2 + Base(implemented=False) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
>>> Base() ** 2 >>> Base() ** 2
'Base.__pow__(Base(), 2, None)' 'Base.__pow__(Base(), 2, None)'
>>> 2 ** Base() >>> 2 ** Base()
...@@ -16,17 +25,34 @@ class Base(object): ...@@ -16,17 +25,34 @@ class Base(object):
>>> pow(Base(), 2, 100) >>> pow(Base(), 2, 100)
'Base.__pow__(Base(), 2, 100)' 'Base.__pow__(Base(), 2, 100)'
""" """
implemented: cython.bint
def __init__(self, *, implemented=True):
self.implemented = implemented
def __add__(self, other): def __add__(self, other):
return "Base.__add__(%s, %s)" % (self, other) if (<Base>self).implemented:
return "Base.__add__(%s, %s)" % (self, other)
else:
return NotImplemented
def __radd__(self, other): def __radd__(self, other):
return "Base.__radd__(%s, %s)" % (self, other) if (<Base>self).implemented:
return "Base.__radd__(%s, %s)" % (self, other)
else:
return NotImplemented
def __pow__(self, other, mod): def __pow__(self, other, mod):
return "Base.__pow__(%s, %s, %s)" % (self, other, mod) if (<Base>self).implemented:
return "Base.__pow__(%s, %s, %s)" % (self, other, mod)
else:
return NotImplemented
def __rpow__(self, other, mod): def __rpow__(self, other, mod):
return "Base.__rpow__(%s, %s, %s)" % (self, other, mod) if (<Base>self).implemented:
return "Base.__rpow__(%s, %s, %s)" % (self, other, mod)
else:
return NotImplemented
def __repr__(self): def __repr__(self):
return "%s()" % (self.__class__.__name__) return "%s()" % (self.__class__.__name__)
...@@ -44,9 +70,27 @@ class OverloadLeft(Base): ...@@ -44,9 +70,27 @@ class OverloadLeft(Base):
'OverloadLeft.__add__(OverloadLeft(), Base())' 'OverloadLeft.__add__(OverloadLeft(), Base())'
>>> Base() + OverloadLeft() >>> Base() + OverloadLeft()
'Base.__add__(Base(), OverloadLeft())' 'Base.__add__(Base(), OverloadLeft())'
>>> OverloadLeft(implemented=False) + Base(implemented=False) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
>>> Base(implemented=False) + OverloadLeft(implemented=False) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
""" """
derived_implemented: cython.bint
def __init__(self, *, implemented=True):
super().__init__(implemented=implemented)
self.derived_implemented = implemented
def __add__(self, other): def __add__(self, other):
return "OverloadLeft.__add__(%s, %s)" % (self, other) if (<OverloadLeft>self).derived_implemented:
return "OverloadLeft.__add__(%s, %s)" % (self, other)
else:
return NotImplemented
@cython.c_api_binop_methods(False) @cython.c_api_binop_methods(False)
...@@ -62,9 +106,27 @@ class OverloadRight(Base): ...@@ -62,9 +106,27 @@ class OverloadRight(Base):
'Base.__add__(OverloadRight(), Base())' 'Base.__add__(OverloadRight(), Base())'
>>> Base() + OverloadRight() >>> Base() + OverloadRight()
'OverloadRight.__radd__(OverloadRight(), Base())' 'OverloadRight.__radd__(OverloadRight(), Base())'
>>> OverloadRight(implemented=False) + Base(implemented=False) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
>>> Base(implemented=False) + OverloadRight(implemented=False) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
""" """
derived_implemented: cython.bint
def __init__(self, *, implemented=True):
super().__init__(implemented=implemented)
self.derived_implemented = implemented
def __radd__(self, other): def __radd__(self, other):
return "OverloadRight.__radd__(%s, %s)" % (self, other) if (<OverloadRight>self).derived_implemented:
return "OverloadRight.__radd__(%s, %s)" % (self, other)
else:
return NotImplemented
@cython.c_api_binop_methods(True) @cython.c_api_binop_methods(True)
@cython.cclass @cython.cclass
...@@ -79,7 +141,30 @@ class OverloadCApi(Base): ...@@ -79,7 +141,30 @@ class OverloadCApi(Base):
'OverloadCApi.__add__(OverloadCApi(), Base())' 'OverloadCApi.__add__(OverloadCApi(), Base())'
>>> Base() + OverloadCApi() >>> Base() + OverloadCApi()
'OverloadCApi.__add__(Base(), OverloadCApi())' 'OverloadCApi.__add__(Base(), OverloadCApi())'
>>> OverloadCApi(derived_implemented=False) + 2 #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
>>> 2 + OverloadCApi(derived_implemented=False) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: unsupported operand type...
""" """
derived_implemented: cython.bint
def __init__(self, *, derived_implemented=True):
super().__init__(implemented=True)
self.derived_implemented = derived_implemented
def __add__(self, other): def __add__(self, other):
return "OverloadCApi.__add__(%s, %s)" % (self, other) if isinstance(self, OverloadCApi):
derived_implemented = (<OverloadCApi>self).derived_implemented
else:
derived_implemented = (<OverloadCApi>other).derived_implemented
if derived_implemented:
return "OverloadCApi.__add__(%s, %s)" % (self, other)
else:
return NotImplemented
# TODO: Test a class that only defines the `__r...__()` methods.
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment