Commit a3c4210f authored by Stefan Behnel's avatar Stefan Behnel

implement PEP 465: dedicated infix operator for matrix multiplication

http://www.python.org/dev/peps/pep-0465/

New syntax: "c = a @ b" and "a @= b"

Includes a partial backport that works with the special methods at the Python level,
but not (properly) with extension types, which lack the necessary slot methods.  Also
currently lacks the subtype special casing that Python does for the numeric special
methods.
parent 7c20a895
......@@ -8841,6 +8841,16 @@ class TypeofNode(ExprNode):
#
#-------------------------------------------------------------------
try:
matmul_operator = operator.matmul
except AttributeError:
def matmul_operator(a, b):
try:
func = a.__matmul__
except AttributeError:
func = b.__rmatmul__
return func(a, b)
compile_time_binary_operators = {
'<': operator.lt,
'<=': operator.le,
......@@ -8862,6 +8872,7 @@ compile_time_binary_operators = {
'>>': operator.rshift,
'-': operator.sub,
'^': operator.xor,
'@': matmul_operator,
'in': lambda x, seq: x in seq,
'not_in': lambda x, seq: x not in seq,
}
......@@ -9177,10 +9188,11 @@ class NumBinopNode(BinopNode):
"+": "PyNumber_Add",
"-": "PyNumber_Subtract",
"*": "PyNumber_Multiply",
"@": "__Pyx_PyNumber_MatrixMultiply",
"/": "__Pyx_PyNumber_Divide",
"//": "PyNumber_FloorDivide",
"%": "PyNumber_Remainder",
"**": "PyNumber_Power"
"**": "PyNumber_Power",
}
overflow_op_names = {
......@@ -9279,6 +9291,17 @@ class MulNode(NumBinopNode):
return None
class MatMultNode(NumBinopNode):
# '@' operator.
def is_py_operation_types(self, type1, type2):
return True
def generate_evaluation_code(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("MatrixMultiply", "ObjectHandling.c"))
super(MatMultNode, self).generate_evaluation_code(code)
class DivNode(NumBinopNode):
# '/' or '//' operator.
......@@ -10436,10 +10459,11 @@ binop_node_classes = {
"+": AddNode,
"-": SubNode,
"*": MulNode,
"@": MatMultNode,
"/": DivNode,
"//": DivNode,
"%": ModNode,
"**": PowNode
"**": PowNode,
}
def binop_node(pos, operator, operand1, operand2, inplace=False):
......
......@@ -10,6 +10,7 @@ char_prefixes = "cC"
any_string_prefix = raw_prefixes + string_prefixes + char_prefixes
IDENT = 'IDENT'
def make_lexicon():
from Cython.Plex import \
Str, Any, AnyBut, AnyChar, Rep, Rep1, Opt, Bol, Eol, Eof, \
......@@ -50,13 +51,12 @@ def make_lexicon():
Str('u') + four_hex | Str('x') + two_hex |
Str('U') + four_hex + four_hex | AnyChar)
deco = Str("@")
bra = Any("([{")
ket = Any(")]}")
punct = Any(":,;+-*/|&<>=.%`~^?!")
punct = Any(":,;+-*/|&<>=.%`~^?!@")
diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//",
"+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=",
"<<=", ">>=", "**=", "//=", "->")
"<<=", ">>=", "**=", "//=", "->", "@=")
spaces = Rep1(Any(" \t\f"))
escaped_newline = Str("\\\n")
lineterm = Eol + Opt(Str("\n"))
......@@ -68,7 +68,6 @@ def make_lexicon():
(intliteral, 'INT'),
(fltconst, 'FLOAT'),
(imagconst, 'IMAG'),
(deco, 'DECORATOR'),
(punct | diphthong, TEXT),
(bra, Method('open_bracket_action')),
......
......@@ -267,10 +267,10 @@ def p_shift_expr(s):
def p_arith_expr(s):
return p_binop_expr(s, ('+', '-'), p_term)
#term: factor (('*'|'/'|'%') factor)*
#term: factor (('*'|'@'|'/'|'%'|'//') factor)*
def p_term(s):
return p_binop_expr(s, ('*', '/', '%', '//'), p_factor)
return p_binop_expr(s, ('*', '@', '/', '%', '//'), p_factor)
#factor: ('+'|'-'|'~'|'&'|typecast|sizeof) factor | power
......@@ -1129,7 +1129,7 @@ def p_expression_or_assignment(s):
expr = p_testlist_star_expr(s)
expr_list.append(expr)
if len(expr_list) == 1:
if re.match(r"([+*/\%^\&|-]|<<|>>|\*\*|//)=", s.sy):
if re.match(r"([+*/\%^\&|-]|<<|>>|\*\*|//|@)=", s.sy):
lhs = expr_list[0]
if isinstance(lhs, ExprNodes.SliceIndexNode):
# implementation requires IndexNode
......@@ -1837,7 +1837,7 @@ def p_statement(s, ctx, first_statement = 0):
return p_DEF_statement(s)
elif s.sy == 'IF':
return p_IF_statement(s, ctx)
elif s.sy == 'DECORATOR':
elif s.sy == '@':
if ctx.level not in ('module', 'class', 'c_class', 'function', 'property', 'module_pxd', 'c_class_pxd', 'other'):
s.error('decorator not allowed here')
s.level = ctx.level
......@@ -2884,7 +2884,7 @@ def p_ctypedef_statement(s, ctx):
def p_decorators(s):
decorators = []
while s.sy == 'DECORATOR':
while s.sy == '@':
pos = s.position()
s.next()
decstring = p_dotted_name(s, as_allowed=0)[2]
......
......@@ -702,7 +702,11 @@ PyNumberMethods = (
MethodSlot(ibinaryfunc, "nb_inplace_true_divide", "__itruediv__"),
# Added in release 2.5
MethodSlot(unaryfunc, "nb_index", "__index__", ifdef = "PY_VERSION_HEX >= 0x02050000")
MethodSlot(unaryfunc, "nb_index", "__index__"),
# Added in release 3.5
MethodSlot(binaryfunc, "nb_matrix_multiply", "__matmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
MethodSlot(ibinaryfunc, "nb_inplace_matrix_multiply", "__imatmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
)
PySequenceMethods = (
......
......@@ -1156,3 +1156,62 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg
return result;
}
#endif
/////////////// MatrixMultiply.proto ///////////////
#if PY_VERSION_HEX >= 0x03050000
#define __Pyx_PyNumber_MatrixMultiply(x,y) PyNumber_MatrixMultiply(x,y)
#define __Pyx_PyNumber_InPlaceMatrixMultiply(x,y) PyNumber_InPlaceMatrixMultiply(x,y)
#else
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y);
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y);
#endif
/////////////// MatrixMultiply ///////////////
//@requires: PyObjectGetAttrStr
#if PY_VERSION_HEX < 0x03050000
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func;
// FIXME: make subtype aware
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__matmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, y, NULL);
Py_DECREF(func);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, x, NULL);
Py_DECREF(func);
return result;
}
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func;
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__imatmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, y, NULL);
Py_DECREF(func);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
return __Pyx_PyNumber_MatrixMultiply(x, y);
}
#endif
import sys
if sys.version_info >= (3, 5):
__doc__ = """\
Note: support for providing Python special methods despite missing the C-level slot
is currently not supported.
>>> a, b = ExtMatMult(1), ExtMatMult(2)
>>> print(test_matmul(a, b))
ExtMatMult(1) @ ExtMatMult(2)
>>> print(test_matmul(a, 22))
ExtMatMult(1) @ 22
>>> print(test_matmul(11, b))
11 @ ExtMatMult(2)
>>> print(test_imatmul(a, b))
ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)')
>>> print(test_imatmul(a, b))
ExtMatMult("ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') @ ExtMatMult(2)")
"""
class MatMult(object):
def __init__(self, myself):
self.myself = myself
def __matmul__(self, other):
return '%r @ %r' % (self, other)
def __rmatmul__(self, other):
return '%r @ %r' % (other, self)
def __imatmul__(self, other):
self.myself = '%r @ %r' % (self, other)
return self
def __repr__(self):
return 'MatMult(%r)' % self.myself
cdef class ExtMatMult:
"""
Note: support for providing Python special methods despite missing the C-level slot
is currently not supported.
"""
cdef object myself
def __init__(self, myself):
self.myself = myself
def __matmul__(self, other):
return '%r @ %r' % (self, other)
def __rmatmul__(self, other):
return '%r @ %r' % (other, self)
def __imatmul__(self, other):
self.myself = '%r @ %r' % (self, other)
return self
def __repr__(self):
return 'ExtMatMult(%r)' % self.myself
def test_matmul(a, b):
"""
>>> print(test_matmul(MatMult(1), MatMult(2)))
MatMult(1) @ MatMult(2)
>>> print(test_matmul(MatMult(1), 22))
MatMult(1) @ 22
>>> print(test_matmul(11, MatMult(2)))
11 @ MatMult(2)
>>> print(test_matmul(MatMult('abc'), MatMult('def')))
MatMult('abc') @ MatMult('def')
"""
return a @ b
def test_imatmul(a, b):
"""
>>> print(test_imatmul(MatMult(1), MatMult(2)))
MatMult('MatMult(1) @ MatMult(2)')
>>> print(test_imatmul(MatMult('abc'), MatMult('def')))
MatMult("MatMult('abc') @ MatMult('def')")
"""
a @= b
return a
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment