Commit c238e5b0 authored by Stefan Behnel's avatar Stefan Behnel

allow subtypes in right operand to override left side parent class behaviour...

allow subtypes in right operand to override left side parent class behaviour for "@" operator (Python tries "__rmatmul__()" first in this case), generally clean up code duplication while at it
parent 8bb090b0
...@@ -1398,31 +1398,30 @@ bad: ...@@ -1398,31 +1398,30 @@ bad:
return result; return result;
} }
#define __Pyx_TryMatrixMethod(x, y, py_method_name) { \
PyObject *func = __Pyx_PyObject_GetAttrStr(x, py_method_name); \
if (func) { \
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, y); \
if (result != Py_NotImplemented) \
return result; \
Py_DECREF(result); \
} else { \
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) \
return NULL; \
PyErr_Clear(); \
} \
}
static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name) { static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name) {
PyObject *func; int right_is_subtype = PyObject_IsSubclass((PyObject*)Py_TYPE(y), (PyObject*)Py_TYPE(x));
// FIXME: make subtype aware if (right_is_subtype) {
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types // to allow subtypes to override parent behaviour, try reversed operation first
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__matmul__")); // see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
if (func) { __Pyx_TryMatrixMethod(y, x, PYIDENT("__rmatmul__"))
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, y);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
} }
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__")); __Pyx_TryMatrixMethod(x, y, PYIDENT("__matmul__"))
if (func) { if (!right_is_subtype) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, x); __Pyx_TryMatrixMethod(y, x, PYIDENT("__rmatmul__"))
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
} }
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"unsupported operand type(s) for %.2s: '%.100s' and '%.100s'", "unsupported operand type(s) for %.2s: '%.100s' and '%.100s'",
...@@ -1433,18 +1432,9 @@ static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const ...@@ -1433,18 +1432,9 @@ static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const
} }
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) { static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func; __Pyx_TryMatrixMethod(x, y, PYIDENT("__imatmul__"))
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__imatmul__"));
if (func) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, y);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
return __Pyx__PyNumber_MatrixMultiply(x, y, "@="); return __Pyx__PyNumber_MatrixMultiply(x, y, "@=");
} }
#undef __Pyx_TryMatrixMethod
#endif #endif
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment