Commit c238e5b0 authored by Stefan Behnel's avatar Stefan Behnel

allow subtypes in right operand to override left side parent class behaviour...

allow subtypes in right operand to override left side parent class behaviour for "@" operator (Python tries "__rmatmul__()" first in this case), generally clean up code duplication while at it
parent 8bb090b0
......@@ -1398,31 +1398,30 @@ bad:
return result;
}
#define __Pyx_TryMatrixMethod(x, y, py_method_name) { \
PyObject *func = __Pyx_PyObject_GetAttrStr(x, py_method_name); \
if (func) { \
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, y); \
if (result != Py_NotImplemented) \
return result; \
Py_DECREF(result); \
} else { \
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) \
return NULL; \
PyErr_Clear(); \
} \
}
static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name) {
PyObject *func;
// FIXME: make subtype aware
int right_is_subtype = PyObject_IsSubclass((PyObject*)Py_TYPE(y), (PyObject*)Py_TYPE(x));
if (right_is_subtype) {
// to allow subtypes to override parent behaviour, try reversed operation first
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__matmul__"));
if (func) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, y);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
__Pyx_TryMatrixMethod(y, x, PYIDENT("__rmatmul__"))
}
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__"));
if (func) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, x);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
__Pyx_TryMatrixMethod(x, y, PYIDENT("__matmul__"))
if (!right_is_subtype) {
__Pyx_TryMatrixMethod(y, x, PYIDENT("__rmatmul__"))
}
PyErr_Format(PyExc_TypeError,
"unsupported operand type(s) for %.2s: '%.100s' and '%.100s'",
......@@ -1433,18 +1432,9 @@ static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const
}
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func;
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__imatmul__"));
if (func) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, y);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
__Pyx_TryMatrixMethod(x, y, PYIDENT("__imatmul__"))
return __Pyx__PyNumber_MatrixMultiply(x, y, "@=");
}
#undef __Pyx_TryMatrixMethod
#endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment