Commit b75d2942 authored by Stefan Behnel's avatar Stefan Behnel

Generate "tp_richcompare" slot when "__eq__" and/or its friends are defined...

Generate "tp_richcompare" slot when "__eq__" and/or its friends are defined but "__richcmp__" is not.
Closes #690.
parent 0ee04208
...@@ -27,6 +27,10 @@ Features added ...@@ -27,6 +27,10 @@ Features added
types. This can be disabled with the directive ``annotation_typing=False``. types. This can be disabled with the directive ``annotation_typing=False``.
(Github issue #1850) (Github issue #1850)
* Extension types (also in pure Python mode) can implement the normal special methods
``__eq__``, ``__lt__`` etc. for comparisons instead of the low-level ``__richcmp__``
method. (Github issue #690)
* New decorator ``@cython.exceptval(x=None, check=False)`` that makes the signature * New decorator ``@cython.exceptval(x=None, check=False)`` that makes the signature
declarations ``except x``, ``except? x`` and ``except *`` available to pure Python declarations ``except x``, ``except? x`` and ``except *`` available to pure Python
code. Original patch by Antonio Cuni. (Github issue #1653) code. Original patch by Antonio Cuni. (Github issue #1653)
......
...@@ -1223,6 +1223,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1223,6 +1223,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_descr_set_function(scope, code) self.generate_descr_set_function(scope, code)
if scope.defines_any(["__dict__"]): if scope.defines_any(["__dict__"]):
self.generate_dict_getter_function(scope, code) self.generate_dict_getter_function(scope, code)
if scope.defines_any(TypeSlots.richcmp_special_methods):
self.generate_richcmp_function(scope, code)
self.generate_property_accessors(scope, code) self.generate_property_accessors(scope, code)
self.generate_method_table(scope, code) self.generate_method_table(scope, code)
self.generate_getset_table(scope, code) self.generate_getset_table(scope, code)
...@@ -1790,6 +1792,73 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1790,6 +1792,73 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln( code.putln(
"}") "}")
def generate_richcmp_function(self, scope, code):
if scope.lookup_here("__richcmp__"):
# user implemented, nothing to do
return
# otherwise, we have to generate it from the Python special methods
richcmp_cfunc = scope.mangle_internal("tp_richcompare")
code.putln("")
code.putln("static PyObject *%s(PyObject *o1, PyObject *o2, int op) {" % richcmp_cfunc)
code.putln("switch (op) {")
class_scopes = []
cls = scope.parent_type
while cls is not None and not cls.entry.visibility == 'extern':
class_scopes.append(cls.scope)
cls = cls.scope.parent_type.base_type
assert scope in class_scopes
extern_parent = None
if cls and cls.entry.visibility == 'extern':
# need to call up into base classes as we may not know all implemented comparison methods
extern_parent = cls if cls.typeptr_cname else scope.parent_type.base_type
eq_entry = None
has_ne = False
for cmp_method in TypeSlots.richcmp_special_methods:
for class_scope in class_scopes:
entry = class_scope.lookup_here(cmp_method)
if entry is not None:
break
else:
continue
cmp_type = cmp_method.strip('_').upper() # e.g. "__eq__" -> EQ
code.putln("case Py_%s: {" % cmp_type)
if cmp_method == '__eq__':
eq_entry = entry
code.putln("if (o1 == o2) return __Pyx_NewRef(Py_True);")
elif cmp_method == '__ne__':
has_ne = True
code.putln("if (o1 == o2) return __Pyx_NewRef(Py_False);")
code.putln("return %s(o1, o2);" % entry.func_cname)
code.putln("}")
if eq_entry and not has_ne and not extern_parent:
code.putln("case Py_NE: {")
code.putln("PyObject *ret;")
code.putln("if (o1 == o2) return __Pyx_NewRef(Py_False);")
code.putln("ret = %s(o1, o2);" % eq_entry.func_cname)
code.putln("if (likely(ret && ret != Py_NotImplemented)) {")
code.putln("int b = __Pyx_PyObject_IsTrue(ret); Py_DECREF(ret);")
code.putln("if (unlikely(b < 0)) return NULL;")
code.putln("ret = (b) ? Py_False : Py_True;")
code.putln("Py_INCREF(ret);")
code.putln("}")
code.putln("return ret;")
code.putln("}")
code.putln("default: {")
if extern_parent and extern_parent.typeptr_cname:
code.putln("if (likely(%s->tp_richcompare)) return %s->tp_richcompare(o1, o2, op);" % (
extern_parent.typeptr_cname, extern_parent.typeptr_cname))
code.putln("return __Pyx_NewRef(Py_NotImplemented);")
code.putln("}")
code.putln("}") # switch
code.putln("}")
def generate_getattro_function(self, scope, code): def generate_getattro_function(self, scope, code):
# First try to get the attribute using __getattribute__, if defined, or # First try to get the attribute using __getattribute__, if defined, or
# PyObject_GenericGetAttr. # PyObject_GenericGetAttr.
......
...@@ -18,9 +18,9 @@ from .StringEncoding import EncodedString ...@@ -18,9 +18,9 @@ from .StringEncoding import EncodedString
from . import Options, Naming from . import Options, Naming
from . import PyrexTypes from . import PyrexTypes
from .PyrexTypes import py_object_type, unspecified_type from .PyrexTypes import py_object_type, unspecified_type
from .TypeSlots import \ from .TypeSlots import (
pyfunction_signature, pymethod_signature, \ pyfunction_signature, pymethod_signature, richcmp_special_methods,
get_special_method_signature, get_property_accessor_signature get_special_method_signature, get_property_accessor_signature)
from . import Code from . import Code
...@@ -2060,8 +2060,13 @@ class CClassScope(ClassScope): ...@@ -2060,8 +2060,13 @@ class CClassScope(ClassScope):
def declare_pyfunction(self, name, pos, allow_redefine=False): def declare_pyfunction(self, name, pos, allow_redefine=False):
# Add an entry for a method. # Add an entry for a method.
if name in ('__eq__', '__ne__', '__lt__', '__gt__', '__le__', '__ge__'): if name in richcmp_special_methods:
error(pos, "Special method %s must be implemented via __richcmp__" % name) if self.lookup_here('__richcmp__'):
error(pos, "Cannot define both % and __richcmp__" % name)
elif name == '__richcmp__':
for n in richcmp_special_methods:
if self.lookup_here(n):
error(pos, "Cannot define both % and __richcmp__" % n)
if name == "__new__": if name == "__new__":
error(pos, "__new__ method of extension type will change semantics " error(pos, "__new__ method of extension type will change semantics "
"in a future version of Pyrex and Cython. Use __cinit__ instead.") "in a future version of Pyrex and Cython. Use __cinit__ instead.")
......
...@@ -12,6 +12,8 @@ from .Errors import error ...@@ -12,6 +12,8 @@ from .Errors import error
invisible = ['__cinit__', '__dealloc__', '__richcmp__', invisible = ['__cinit__', '__dealloc__', '__richcmp__',
'__nonzero__', '__bool__'] '__nonzero__', '__bool__']
richcmp_special_methods = ['__eq__', '__ne__', '__lt__', '__gt__', '__le__', '__ge__']
class Signature(object): class Signature(object):
# Method slot signature descriptor. # Method slot signature descriptor.
...@@ -400,6 +402,17 @@ class SyntheticSlot(InternalMethodSlot): ...@@ -400,6 +402,17 @@ class SyntheticSlot(InternalMethodSlot):
return self.default_value return self.default_value
class RichcmpSlot(SlotDescriptor):
def slot_code(self, scope):
entry = scope.lookup_here("__richcmp__")
if entry and entry.func_cname:
return entry.func_cname
elif scope.defines_any(richcmp_special_methods):
return scope.mangle_internal(self.slot_name)
else:
return "0"
class TypeFlagsSlot(SlotDescriptor): class TypeFlagsSlot(SlotDescriptor):
# Descriptor for the type flags slot. # Descriptor for the type flags slot.
...@@ -823,8 +836,7 @@ slot_table = ( ...@@ -823,8 +836,7 @@ slot_table = (
GCDependentSlot("tp_traverse"), GCDependentSlot("tp_traverse"),
GCClearReferencesSlot("tp_clear"), GCClearReferencesSlot("tp_clear"),
# Later -- synthesize a method to split into separate ops? RichcmpSlot("tp_richcompare", inherited=False), # Py3 checks for __hash__
MethodSlot(richcmpfunc, "tp_richcompare", "__richcmp__", inherited=False), # Py3 checks for __hash__
EmptySlot("tp_weaklistoffset"), EmptySlot("tp_weaklistoffset"),
......
...@@ -268,26 +268,29 @@ Arithmetic Methods ...@@ -268,26 +268,29 @@ Arithmetic Methods
Rich Comparisons Rich Comparisons
================ ================
.. note:: There are no separate methods for individual rich comparison operations. * Starting with Cython 0.27, the Python special methods ``__eq__``, ``__lt__``, etc. can be implemented.
In previous versions, ``__richcmp__`` was the only way to implement rich comparisons.
* A single special method called ``__richcmp__()`` replaces all the individual rich compare, special method types. * A single special method called ``__richcmp__()`` can be used to implement all the individual
* ``__richcmp__()`` takes an integer argument, indicating which operation is to be performed as shown in the table below. rich compare, special method types.
* ``__richcmp__()`` takes an integer argument, indicating which operation is to be performed
+-----+-----+ as shown in the table below.
| < | 0 |
+-----+-----+ +-----+-----+-------+
| == | 2 | | < | 0 | Py_LT |
+-----+-----+ +-----+-----+-------+
| > | 4 | | == | 2 | Py_EQ |
+-----+-----+ +-----+-----+-------+
| <= | 1 | | > | 4 | Py_GT |
+-----+-----+ +-----+-----+-------+
| != | 3 | | <= | 1 | Py_LE |
+-----+-----+ +-----+-----+-------+
| >= | 5 | | != | 3 | Py_NE |
+-----+-----+ +-----+-----+-------+
| >= | 5 | Py_GE |
+-----+-----+-------+
The named constants can be cimported from the ``cpython.object`` module.
They should generally be preferred over plain integers to improve readabilty.
The ``__next__()`` Method The ``__next__()`` Method
......
...@@ -25,8 +25,6 @@ General ...@@ -25,8 +25,6 @@ General
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __cmp__ |x, y | int | 3-way comparison | | __cmp__ |x, y | int | 3-way comparison |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __richcmp__ |x, y, int op | object | Rich comparison (no direct Python equivalent) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __str__ |self | object | str(self) | | __str__ |self | object | str(self) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __repr__ |self | object | repr(self) | | __repr__ |self | object | repr(self) |
...@@ -46,6 +44,25 @@ General ...@@ -46,6 +44,25 @@ General
| __delattr__ |self, name | | Delete attribute | | __delattr__ |self, name | | Delete attribute |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
Rich comparison operators
^^^^^^^^^^^^^^^^^^^^^^^^^
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __richcmp__ |x, y, int op | object | Rich comparison (no direct Python equivalent) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __eq__ |x, y | object | x == y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __ne__ |x, y | object | x != y (falls back to ``__eq__`` if not available) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __lt__ |x, y | object | x < y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __gt__ |x, y | object | x > y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __le__ |x, y | object | x <= y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __ge__ |x, y | object | x >= y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
Arithmetic operators Arithmetic operators
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
......
...@@ -127,10 +127,9 @@ take `self` as the first argument. ...@@ -127,10 +127,9 @@ take `self` as the first argument.
Rich comparisons Rich comparisons
----------------- -----------------
There are no separate methods for the individual rich comparison operations Starting with Cython 0.27, the Python special methods :meth:``__eq__``, :meth:``__lt__``, etc.
(:meth:`__eq__`, :meth:`__le__`, etc.) Instead there is a single method can be implemented. In previous versions, :meth:``__richcmp__`` was the only way to implement
:meth:`__richcmp__` which takes an integer indicating which operation is to be rich comparisons. It takes an integer indicating which operation is to be performed, as follows:
performed, as follows:
+-----+-----+-------+ +-----+-----+-------+
| < | 0 | Py_LT | | < | 0 | Py_LT |
...@@ -182,8 +181,6 @@ General ...@@ -182,8 +181,6 @@ General
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __cmp__ |x, y | int | 3-way comparison | | __cmp__ |x, y | int | 3-way comparison |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __richcmp__ |x, y, int op | object | Rich comparison (no direct Python equivalent) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __str__ |self | object | str(self) | | __str__ |self | object | str(self) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __repr__ |self | object | repr(self) | | __repr__ |self | object | repr(self) |
...@@ -203,6 +200,25 @@ General ...@@ -203,6 +200,25 @@ General
| __delattr__ |self, name | | Delete attribute | | __delattr__ |self, name | | Delete attribute |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+ +-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
Rich comparison operators
^^^^^^^^^^^^^^^^^^^^^^^^^
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __richcmp__ |x, y, int op | object | Rich comparison (no direct Python equivalent) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __eq__ |x, y | object | x == y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __ne__ |x, y | object | x != y (falls back to ``__eq__`` if not available) |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __lt__ |x, y | object | x < y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __gt__ |x, y | object | x > y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __le__ |x, y | object | x <= y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
| __ge__ |x, y | object | x >= y |
+-----------------------+---------------------------------------+-------------+-----------------------------------------------------+
Arithmetic operators Arithmetic operators
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
......
# mode: run
import cython
compiled = cython.compiled
import sys
IS_PY2 = sys.version_info[0] == 2
@cython.cclass
class X(object):
x = cython.declare(cython.int, visibility="public")
def __init__(self, x):
self.x = x
def __repr__(self):
return "<%d>" % self.x
@cython.cclass
class ClassEq(X):
"""
>>> a = ClassEq(1)
>>> b = ClassEq(2)
>>> c = ClassEq(1)
>>> a == a
True
>>> a != a
False
>>> a == b
False
>>> a != b
True
>>> a == c
True
>>> if IS_PY2 and not compiled: a is c
... else: a != c
False
>>> b == c
False
>>> b != c
True
>>> c == a
True
>>> if IS_PY2 and not compiled: c is a
... else: c != a
False
>>> b == a
False
>>> b != a
True
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a < b
Traceback (most recent call last):
TypeError...
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a > b
Traceback (most recent call last):
TypeError...
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a <= b
Traceback (most recent call last):
TypeError...
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a >= b
Traceback (most recent call last):
TypeError...
"""
def __eq__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x == other.x
return NotImplemented
@cython.cclass
class ClassEqNe(ClassEq):
"""
>>> a = ClassEqNe(1)
>>> b = ClassEqNe(2)
>>> c = ClassEqNe(1)
>>> a == a
True
>>> a != a
False
>>> a == b
False
>>> a != b
True
>>> a == c
True
>>> a != c
False
>>> b == c
False
>>> b != c
True
>>> c == a
True
>>> c != a
False
>>> b == a
False
>>> b != a
True
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a < b
Traceback (most recent call last):
TypeError...
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a > b
Traceback (most recent call last):
TypeError...
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a <= b
Traceback (most recent call last):
TypeError...
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a >= b
Traceback (most recent call last):
TypeError...
"""
def __ne__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x != other.x
return NotImplemented
@cython.cclass
class ClassEqNeGe(ClassEqNe):
"""
>>> a = ClassEqNeGe(1)
>>> b = ClassEqNeGe(2)
>>> c = ClassEqNeGe(1)
>>> a == a
True
>>> a != a
False
>>> a >= a
True
>>> a <= a
True
>>> a == b
False
>>> a != b
True
>>> a >= b
False
>>> b <= a
False
>>> a == c
True
>>> a != c
False
>>> a >= c
True
>>> c <= a
True
>>> b == c
False
>>> b != c
True
>>> b >= c
True
>>> c <= b
True
>>> c == a
True
>>> c != a
False
>>> c >= a
True
>>> a <= c
True
>>> b == a
False
>>> b != a
True
>>> b >= a
True
>>> a <= b
True
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a < b
Traceback (most recent call last):
TypeError...
>>> if IS_PY2: raise TypeError # doctest: +ELLIPSIS
... else: a > b
Traceback (most recent call last):
TypeError...
"""
def __ge__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x >= other.x
return NotImplemented
@cython.cclass
class ClassLe(X):
"""
>>> a = ClassLe(1)
>>> b = ClassLe(2)
>>> c = ClassLe(1)
>>> a <= b
True
>>> b >= a
True
>>> b <= a
False
>>> a >= b
False
>>> a <= c
True
>>> c >= a
True
>>> c <= a
True
>>> a >= c
True
>>> b <= c
False
>>> c >= b
False
>>> c <= b
True
>>> b >= c
True
"""
def __le__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x <= other.x
return NotImplemented
@cython.cclass
class ClassLt(X):
"""
>>> a = ClassLt(1)
>>> b = ClassLt(2)
>>> c = ClassLt(1)
>>> a < b
True
>>> b > a
True
>>> b < a
False
>>> a > b
False
>>> a < c
False
>>> c > a
False
>>> c < a
False
>>> a > c
False
>>> b < c
False
>>> c > b
False
>>> c < b
True
>>> b > c
True
>>> sorted([a, b, c])
[<1>, <1>, <2>]
>>> sorted([b, a, c])
[<1>, <1>, <2>]
"""
def __lt__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x < other.x
return NotImplemented
@cython.cclass
class ClassLtGtInherited(X):
"""
>>> a = ClassLtGtInherited(1)
>>> b = ClassLtGtInherited(2)
>>> c = ClassLtGtInherited(1)
>>> a < b
True
>>> b > a
True
>>> b < a
False
>>> a > b
False
>>> a < c
False
>>> c > a
False
>>> c < a
False
>>> a > c
False
>>> b < c
False
>>> c > b
False
>>> c < b
True
>>> b > c
True
>>> sorted([a, b, c])
[<1>, <1>, <2>]
>>> sorted([b, a, c])
[<1>, <1>, <2>]
"""
def __gt__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x > other.x
return NotImplemented
@cython.cclass
class ClassLtGt(X):
"""
>>> a = ClassLtGt(1)
>>> b = ClassLtGt(2)
>>> c = ClassLtGt(1)
>>> a < b
True
>>> b > a
True
>>> b < a
False
>>> a > b
False
>>> a < c
False
>>> c > a
False
>>> c < a
False
>>> a > c
False
>>> b < c
False
>>> c > b
False
>>> c < b
True
>>> b > c
True
>>> sorted([a, b, c])
[<1>, <1>, <2>]
>>> sorted([b, a, c])
[<1>, <1>, <2>]
"""
def __lt__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x < other.x
return NotImplemented
def __gt__(self, other):
if isinstance(self, X):
if isinstance(other, X):
return self.x > other.x
return NotImplemented
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment