Commit be95915c authored by Stefan Behnel's avatar Stefan Behnel

fix error handling in backported matrix multiplication and compare to actual behaviour in Py3.5

parent 1fed1015
......@@ -1353,7 +1353,8 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) {
#define __Pyx_PyNumber_MatrixMultiply(x,y) PyNumber_MatrixMultiply(x,y)
#define __Pyx_PyNumber_InPlaceMatrixMultiply(x,y) PyNumber_InPlaceMatrixMultiply(x,y)
#else
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y);
#define __Pyx_PyNumber_MatrixMultiply(x,y) __Pyx__PyNumber_MatrixMultiply(x, y, "@")
static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name);
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y);
#endif
......@@ -1392,7 +1393,7 @@ bad:
return result;
}
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name) {
PyObject *func;
// FIXME: make subtype aware
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
......@@ -1410,10 +1411,20 @@ static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__"));
if (func) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, x);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
PyErr_Format(PyExc_TypeError,
"unsupported operand type(s) for %.2s: '%.100s' and '%.100s'",
op_name,
Py_TYPE(x)->tp_name,
Py_TYPE(y)->tp_name);
return NULL;
}
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
......@@ -1429,6 +1440,6 @@ static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y)
return NULL;
PyErr_Clear();
}
return __Pyx_PyNumber_MatrixMultiply(x, y);
return __Pyx__PyNumber_MatrixMultiply(x, y, "@=");
}
#endif
......@@ -17,6 +17,28 @@ ExtMatMult(1) @ 22
ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)')
>>> print(test_imatmul(a, b))
ExtMatMult("ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') @ ExtMatMult(2)")
>>> x = y = 1
>>> x @ y
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @: 'int' and 'int'
>>> x @= y
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @=: 'int' and 'int'
>>> y = MatMult(22)
>>> x @= y
>>> print(x)
1 @ MatMult(22)
>>> x = MatMult(22)
>>> print(x @ 1)
MatMult(22) @ 1
>>> print(1 @ x)
1 @ MatMult(22)
>>> x @= 1
>>> print(x)
MatMult('MatMult(22) @ 1')
"""
......@@ -71,6 +93,10 @@ def test_matmul(a, b):
11 @ MatMult(2)
>>> print(test_matmul(MatMult('abc'), MatMult('def')))
MatMult('abc') @ MatMult('def')
>>> test_matmul(1, 2)
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @: 'int' and 'int'
"""
return a @ b
......@@ -81,6 +107,14 @@ def test_imatmul(a, b):
MatMult('MatMult(1) @ MatMult(2)')
>>> print(test_imatmul(MatMult('abc'), MatMult('def')))
MatMult("MatMult('abc') @ MatMult('def')")
>>> print(test_imatmul(11, MatMult('def')))
11 @ MatMult('def')
>>> print(test_imatmul(MatMult('abc'), 11))
MatMult("MatMult('abc') @ 11")
>>> test_imatmul(1, 2)
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @=: 'int' and 'int'
"""
a @= b
return a
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment