Commit 267b8fb9 authored by scoder's avatar scoder

Merge pull request #288 from scoder/pep465

Implement PEP 465: dedicated infix operator for matrix multiplication
parents 1eb122af a3c4210f
...@@ -8843,6 +8843,16 @@ class TypeofNode(ExprNode): ...@@ -8843,6 +8843,16 @@ class TypeofNode(ExprNode):
# #
#------------------------------------------------------------------- #-------------------------------------------------------------------
try:
matmul_operator = operator.matmul
except AttributeError:
def matmul_operator(a, b):
try:
func = a.__matmul__
except AttributeError:
func = b.__rmatmul__
return func(a, b)
compile_time_binary_operators = { compile_time_binary_operators = {
'<': operator.lt, '<': operator.lt,
'<=': operator.le, '<=': operator.le,
...@@ -8864,6 +8874,7 @@ compile_time_binary_operators = { ...@@ -8864,6 +8874,7 @@ compile_time_binary_operators = {
'>>': operator.rshift, '>>': operator.rshift,
'-': operator.sub, '-': operator.sub,
'^': operator.xor, '^': operator.xor,
'@': matmul_operator,
'in': lambda x, seq: x in seq, 'in': lambda x, seq: x in seq,
'not_in': lambda x, seq: x not in seq, 'not_in': lambda x, seq: x not in seq,
} }
...@@ -9180,10 +9191,11 @@ class NumBinopNode(BinopNode): ...@@ -9180,10 +9191,11 @@ class NumBinopNode(BinopNode):
"+": "PyNumber_Add", "+": "PyNumber_Add",
"-": "PyNumber_Subtract", "-": "PyNumber_Subtract",
"*": "PyNumber_Multiply", "*": "PyNumber_Multiply",
"@": "__Pyx_PyNumber_MatrixMultiply",
"/": "__Pyx_PyNumber_Divide", "/": "__Pyx_PyNumber_Divide",
"//": "PyNumber_FloorDivide", "//": "PyNumber_FloorDivide",
"%": "PyNumber_Remainder", "%": "PyNumber_Remainder",
"**": "PyNumber_Power" "**": "PyNumber_Power",
} }
overflow_op_names = { overflow_op_names = {
...@@ -9282,6 +9294,17 @@ class MulNode(NumBinopNode): ...@@ -9282,6 +9294,17 @@ class MulNode(NumBinopNode):
return None return None
class MatMultNode(NumBinopNode):
# '@' operator.
def is_py_operation_types(self, type1, type2):
return True
def generate_evaluation_code(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("MatrixMultiply", "ObjectHandling.c"))
super(MatMultNode, self).generate_evaluation_code(code)
class DivNode(NumBinopNode): class DivNode(NumBinopNode):
# '/' or '//' operator. # '/' or '//' operator.
...@@ -10451,10 +10474,11 @@ binop_node_classes = { ...@@ -10451,10 +10474,11 @@ binop_node_classes = {
"+": AddNode, "+": AddNode,
"-": SubNode, "-": SubNode,
"*": MulNode, "*": MulNode,
"@": MatMultNode,
"/": DivNode, "/": DivNode,
"//": DivNode, "//": DivNode,
"%": ModNode, "%": ModNode,
"**": PowNode "**": PowNode,
} }
def binop_node(pos, operator, operand1, operand2, inplace=False): def binop_node(pos, operator, operand1, operand2, inplace=False):
......
...@@ -10,6 +10,7 @@ char_prefixes = "cC" ...@@ -10,6 +10,7 @@ char_prefixes = "cC"
any_string_prefix = raw_prefixes + string_prefixes + char_prefixes any_string_prefix = raw_prefixes + string_prefixes + char_prefixes
IDENT = 'IDENT' IDENT = 'IDENT'
def make_lexicon(): def make_lexicon():
from Cython.Plex import \ from Cython.Plex import \
Str, Any, AnyBut, AnyChar, Rep, Rep1, Opt, Bol, Eol, Eof, \ Str, Any, AnyBut, AnyChar, Rep, Rep1, Opt, Bol, Eol, Eof, \
...@@ -50,13 +51,12 @@ def make_lexicon(): ...@@ -50,13 +51,12 @@ def make_lexicon():
Str('u') + four_hex | Str('x') + two_hex | Str('u') + four_hex | Str('x') + two_hex |
Str('U') + four_hex + four_hex | AnyChar) Str('U') + four_hex + four_hex | AnyChar)
deco = Str("@")
bra = Any("([{") bra = Any("([{")
ket = Any(")]}") ket = Any(")]}")
punct = Any(":,;+-*/|&<>=.%`~^?!") punct = Any(":,;+-*/|&<>=.%`~^?!@")
diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//", diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//",
"+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=", "+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=",
"<<=", ">>=", "**=", "//=", "->") "<<=", ">>=", "**=", "//=", "->", "@=")
spaces = Rep1(Any(" \t\f")) spaces = Rep1(Any(" \t\f"))
escaped_newline = Str("\\\n") escaped_newline = Str("\\\n")
lineterm = Eol + Opt(Str("\n")) lineterm = Eol + Opt(Str("\n"))
...@@ -68,7 +68,6 @@ def make_lexicon(): ...@@ -68,7 +68,6 @@ def make_lexicon():
(intliteral, 'INT'), (intliteral, 'INT'),
(fltconst, 'FLOAT'), (fltconst, 'FLOAT'),
(imagconst, 'IMAG'), (imagconst, 'IMAG'),
(deco, 'DECORATOR'),
(punct | diphthong, TEXT), (punct | diphthong, TEXT),
(bra, Method('open_bracket_action')), (bra, Method('open_bracket_action')),
......
...@@ -267,10 +267,10 @@ def p_shift_expr(s): ...@@ -267,10 +267,10 @@ def p_shift_expr(s):
def p_arith_expr(s): def p_arith_expr(s):
return p_binop_expr(s, ('+', '-'), p_term) return p_binop_expr(s, ('+', '-'), p_term)
#term: factor (('*'|'/'|'%') factor)* #term: factor (('*'|'@'|'/'|'%'|'//') factor)*
def p_term(s): def p_term(s):
return p_binop_expr(s, ('*', '/', '%', '//'), p_factor) return p_binop_expr(s, ('*', '@', '/', '%', '//'), p_factor)
#factor: ('+'|'-'|'~'|'&'|typecast|sizeof) factor | power #factor: ('+'|'-'|'~'|'&'|typecast|sizeof) factor | power
...@@ -1129,7 +1129,7 @@ def p_expression_or_assignment(s): ...@@ -1129,7 +1129,7 @@ def p_expression_or_assignment(s):
expr = p_testlist_star_expr(s) expr = p_testlist_star_expr(s)
expr_list.append(expr) expr_list.append(expr)
if len(expr_list) == 1: if len(expr_list) == 1:
if re.match(r"([+*/\%^\&|-]|<<|>>|\*\*|//)=", s.sy): if re.match(r"([+*/\%^\&|-]|<<|>>|\*\*|//|@)=", s.sy):
lhs = expr_list[0] lhs = expr_list[0]
if isinstance(lhs, ExprNodes.SliceIndexNode): if isinstance(lhs, ExprNodes.SliceIndexNode):
# implementation requires IndexNode # implementation requires IndexNode
...@@ -1837,7 +1837,7 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1837,7 +1837,7 @@ def p_statement(s, ctx, first_statement = 0):
return p_DEF_statement(s) return p_DEF_statement(s)
elif s.sy == 'IF': elif s.sy == 'IF':
return p_IF_statement(s, ctx) return p_IF_statement(s, ctx)
elif s.sy == 'DECORATOR': elif s.sy == '@':
if ctx.level not in ('module', 'class', 'c_class', 'function', 'property', 'module_pxd', 'c_class_pxd', 'other'): if ctx.level not in ('module', 'class', 'c_class', 'function', 'property', 'module_pxd', 'c_class_pxd', 'other'):
s.error('decorator not allowed here') s.error('decorator not allowed here')
s.level = ctx.level s.level = ctx.level
...@@ -2884,7 +2884,7 @@ def p_ctypedef_statement(s, ctx): ...@@ -2884,7 +2884,7 @@ def p_ctypedef_statement(s, ctx):
def p_decorators(s): def p_decorators(s):
decorators = [] decorators = []
while s.sy == 'DECORATOR': while s.sy == '@':
pos = s.position() pos = s.position()
s.next() s.next()
decstring = p_dotted_name(s, as_allowed=0)[2] decstring = p_dotted_name(s, as_allowed=0)[2]
......
...@@ -703,7 +703,11 @@ PyNumberMethods = ( ...@@ -703,7 +703,11 @@ PyNumberMethods = (
MethodSlot(ibinaryfunc, "nb_inplace_true_divide", "__itruediv__"), MethodSlot(ibinaryfunc, "nb_inplace_true_divide", "__itruediv__"),
# Added in release 2.5 # Added in release 2.5
MethodSlot(unaryfunc, "nb_index", "__index__", ifdef = "PY_VERSION_HEX >= 0x02050000") MethodSlot(unaryfunc, "nb_index", "__index__"),
# Added in release 3.5
MethodSlot(binaryfunc, "nb_matrix_multiply", "__matmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
MethodSlot(ibinaryfunc, "nb_inplace_matrix_multiply", "__imatmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
) )
PySequenceMethods = ( PySequenceMethods = (
......
...@@ -1156,3 +1156,62 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg ...@@ -1156,3 +1156,62 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg
return result; return result;
} }
#endif #endif
/////////////// MatrixMultiply.proto ///////////////
#if PY_VERSION_HEX >= 0x03050000
#define __Pyx_PyNumber_MatrixMultiply(x,y) PyNumber_MatrixMultiply(x,y)
#define __Pyx_PyNumber_InPlaceMatrixMultiply(x,y) PyNumber_InPlaceMatrixMultiply(x,y)
#else
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y);
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y);
#endif
/////////////// MatrixMultiply ///////////////
//@requires: PyObjectGetAttrStr
#if PY_VERSION_HEX < 0x03050000
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func;
// FIXME: make subtype aware
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__matmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, y, NULL);
Py_DECREF(func);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, x, NULL);
Py_DECREF(func);
return result;
}
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func;
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__imatmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, y, NULL);
Py_DECREF(func);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
return __Pyx_PyNumber_MatrixMultiply(x, y);
}
#endif
import sys
if sys.version_info >= (3, 5):
__doc__ = """\
Note: support for providing Python special methods despite missing the C-level slot
is currently not supported.
>>> a, b = ExtMatMult(1), ExtMatMult(2)
>>> print(test_matmul(a, b))
ExtMatMult(1) @ ExtMatMult(2)
>>> print(test_matmul(a, 22))
ExtMatMult(1) @ 22
>>> print(test_matmul(11, b))
11 @ ExtMatMult(2)
>>> print(test_imatmul(a, b))
ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)')
>>> print(test_imatmul(a, b))
ExtMatMult("ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') @ ExtMatMult(2)")
"""
class MatMult(object):
def __init__(self, myself):
self.myself = myself
def __matmul__(self, other):
return '%r @ %r' % (self, other)
def __rmatmul__(self, other):
return '%r @ %r' % (other, self)
def __imatmul__(self, other):
self.myself = '%r @ %r' % (self, other)
return self
def __repr__(self):
return 'MatMult(%r)' % self.myself
cdef class ExtMatMult:
"""
Note: support for providing Python special methods despite missing the C-level slot
is currently not supported.
"""
cdef object myself
def __init__(self, myself):
self.myself = myself
def __matmul__(self, other):
return '%r @ %r' % (self, other)
def __rmatmul__(self, other):
return '%r @ %r' % (other, self)
def __imatmul__(self, other):
self.myself = '%r @ %r' % (self, other)
return self
def __repr__(self):
return 'ExtMatMult(%r)' % self.myself
def test_matmul(a, b):
"""
>>> print(test_matmul(MatMult(1), MatMult(2)))
MatMult(1) @ MatMult(2)
>>> print(test_matmul(MatMult(1), 22))
MatMult(1) @ 22
>>> print(test_matmul(11, MatMult(2)))
11 @ MatMult(2)
>>> print(test_matmul(MatMult('abc'), MatMult('def')))
MatMult('abc') @ MatMult('def')
"""
return a @ b
def test_imatmul(a, b):
"""
>>> print(test_imatmul(MatMult(1), MatMult(2)))
MatMult('MatMult(1) @ MatMult(2)')
>>> print(test_imatmul(MatMult('abc'), MatMult('def')))
MatMult("MatMult('abc') @ MatMult('def')")
"""
a @= b
return a
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment