Commit 44490b49 authored by Lisandro Dalcin's avatar Lisandro Dalcin

Update embedsignature directive

* emit function annotations
* implement ExpressionWriter visitor
parent 19d890ec
...@@ -519,3 +519,276 @@ class PxdWriter(DeclarationWriter): ...@@ -519,3 +519,276 @@ class PxdWriter(DeclarationWriter):
def visit_StatNode(self, node): def visit_StatNode(self, node):
pass pass
class ExpressionWriter(TreeVisitor):
def __init__(self, result=None):
super(ExpressionWriter, self).__init__()
if result is None:
result = u""
self.result = result
def write(self, tree):
self.visit(tree)
return self.result
def put(self, s):
self.result += s
def remove(self, s):
if self.result.endswith(s):
self.result = self.result[:-len(s)]
def comma_separated_list(self, items):
if len(items) > 0:
for item in items[:-1]:
self.visit(item)
self.put(u", ")
self.visit(items[-1])
def visit_Node(self, node):
raise AssertionError("Node not handled by serializer: %r" % node)
def visit_NameNode(self, node):
self.put(node.name)
def visit_NoneNode(self, node):
self.put(u"None")
def visit_BoolNode(self, node):
self.put(str(node.value))
def visit_ConstNode(self, node):
self.put(str(node.value))
def visit_ImagNode(self, node):
self.put(node.value)
self.put(u"j")
def visit_BytesNode(self, node):
repr_val = repr(node.value)
if repr_val[0] == 'b':
repr_val = repr_val[1:]
self.put(u"b%s" % repr_val)
def visit_StringNode(self, node):
repr_val = repr(node.value)
if repr_val[0] in 'ub':
repr_val = repr_val[1:]
self.put(u"%s" % repr_val)
def visit_UnicodeNode(self, node):
repr_val = repr(node.value)
if repr_val[0] == 'u':
repr_val = repr_val[1:]
self.put(u"u%s" % repr_val)
def emit_sequence(self, node, parens):
open_paren, close_paren = parens
items = node.subexpr_nodes()
self.put(open_paren)
self.comma_separated_list(items)
self.put(close_paren)
def visit_ListNode(self, node):
self.emit_sequence(node, u"[]")
def visit_TupleNode(self, node):
self.emit_sequence(node, u"()")
def visit_SetNode(self, node):
self.emit_sequence(node, u"{}")
def visit_DictNode(self, node):
self.emit_sequence(node, u"{}")
def visit_DictItemNode(self, node):
self.visit(node.key)
self.put(u": ")
self.visit(node.value)
unop_precedence = {
'not': 3, '!': 3,
'+': 11, '-': 11, '~': 11,
}
binop_precedence = {
'or': 1,
'and': 2,
# unary: 'not': 3, '!': 3,
'in': 4, 'not_in': 4, 'is': 4, 'is_not': 4, '<': 4, '<=': 4, '>': 4, '>=': 4, '!=': 4, '==': 4,
'|': 5,
'^': 6,
'&': 7,
'<<': 8, '>>': 8,
'+': 9, '-': 9,
'*': 10, '/': 10, '//': 10, '%': 10,
# unary: '+': 11, '-': 11, '~': 11
'**': 12,
}
def operator_enter(self, new_prec):
if not hasattr(self, 'precedence'):
self.precedence = [0]
old_prec = self.precedence[-1]
if old_prec > new_prec:
self.put(u"(")
self.precedence.append(new_prec)
def operator_exit(self):
old_prec, new_prec = self.precedence[-2:]
if old_prec > new_prec:
self.put(u")")
self.precedence.pop()
def visit_NotNode(self, node):
op = 'not'
prec = self.unop_precedence[op]
self.operator_enter(prec)
self.put(u"not ")
self.visit(node.operand)
self.operator_exit()
def visit_UnopNode(self, node):
op = node.operator
prec = self.unop_precedence[op]
self.operator_enter(prec)
self.put(u"%s" % node.operator)
self.visit(node.operand)
self.operator_exit()
def visit_BinopNode(self, node):
op = node.operator
prec = self.binop_precedence.get(op, 0)
self.operator_enter(prec)
self.visit(node.operand1)
self.put(u" %s " % op.replace('_', ' '))
self.visit(node.operand2)
self.operator_exit()
def visit_BoolBinopNode(self, node):
self.visit_BinopNode(node)
def visit_PrimaryCmpNode(self, node):
self.visit_BinopNode(node)
def visit_IndexNode(self, node):
self.visit(node.base)
self.put(u"[")
self.visit(node.index)
self.put(u"]")
def visit_SliceIndexNode(self, node):
self.visit(node.base)
self.put(u"[")
if node.start:
self.visit(node.start)
self.put(u":")
if node.stop:
self.visit(node.stop)
if node.slice:
self.put(u":")
self.visit(node.slice)
self.put(u"]")
def visit_SliceNode(self, node):
if not node.start.is_none:
self.visit(node.start)
self.put(u":")
if not node.stop.is_none:
self.visit(node.stop)
if not node.step.is_none:
self.put(u":")
self.visit(node.step)
def visit_CondExprNode(self, node):
self.visit(node.true_val)
self.put(u" if ")
self.visit(node.test)
self.put(u" else ")
self.visit(node.false_val)
def visit_AttributeNode(self, node):
self.visit(node.obj)
self.put(u".%s" % node.attribute)
def visit_SimpleCallNode(self, node):
self.visit(node.function)
self.put(u"(")
self.comma_separated_list(node.args)
self.put(")")
def emit_pos_args(self, node):
if node is None:
return
if isinstance(node, AddNode):
self.emit_pos_args(node.operand1)
self.emit_pos_args(node.operand2)
elif isinstance(node, TupleNode):
for expr in node.subexpr_nodes():
self.visit(expr)
self.put(u", ")
elif isinstance(node, AsTupleNode):
self.put("*")
self.visit(node.arg)
self.put(u", ")
else:
self.visit(node)
self.put(u", ")
def emit_kwd_args(self, node):
if node is None:
return
if isinstance(node, MergedDictNode):
for expr in node.subexpr_nodes():
self.emit_kwd_args(expr)
elif isinstance(node, DictNode):
for expr in node.subexpr_nodes():
self.put(u"%s=" % expr.key.value)
self.visit(expr.value)
self.put(u", ")
else:
self.put(u"**")
self.visit(node)
self.put(u", ")
def visit_GeneralCallNode(self, node):
self.visit(node.function)
self.put(u"(")
self.emit_pos_args(node.positional_args)
self.emit_kwd_args(node.keyword_args)
self.remove(u", ")
self.put(")")
def visit_ComprehensionNode(self, node):
tpmap = {'list': u"[]", 'dict': u"{}", 'set': u"{}"}
parens = tpmap[node.type.py_type_name()]
open_paren, close_paren = parens
body = node.loop.body
target = node.loop.target
sequence = node.loop.iterator.sequence
if isinstance(body, ComprehensionAppendNode):
condition = None
else:
condition = body.if_clauses[0].condition
body = body.if_clauses[0].body
self.put(open_paren)
self.visit(body)
self.put(u" for ")
self.visit(target)
self.put(u" in ")
self.visit(sequence)
if condition:
self.put(u" if ")
self.visit(condition)
self.put(close_paren)
def visit_ComprehensionAppendNode(self, node):
self.visit(node.expr)
def visit_DictComprehensionAppendNode(self, node):
self.visit(node.key_expr)
self.put(u": ")
self.visit(node.value_expr)
from __future__ import absolute_import from __future__ import absolute_import, print_function
from .Visitor import CythonTransform from .Visitor import CythonTransform
from .StringEncoding import EncodedString from .StringEncoding import EncodedString
from . import Options from . import Options
from . import PyrexTypes, ExprNodes from . import PyrexTypes, ExprNodes
from ..CodeWriter import ExpressionWriter
class AnnotationWriter(ExpressionWriter):
def visit_Node(self, node):
self.put(u"<???>")
def visit_LambdaNode(self, node):
# XXX Should we do better?
self.put("<lambda>")
class EmbedSignature(CythonTransform): class EmbedSignature(CythonTransform):
def __init__(self, context): def __init__(self, context):
super(EmbedSignature, self).__init__(context) super(EmbedSignature, self).__init__(context)
self.denv = None # XXX
self.class_name = None self.class_name = None
self.class_node = None self.class_node = None
unop_precedence = 11 def _fmt_expr(self, node):
binop_precedence = { writer = AnnotationWriter()
'or': 1, result = writer.write(node)
'and': 2, # print(type(node).__name__, '-->', result)
'not': 3,
'in': 4, 'not in': 4, 'is': 4, 'is not': 4, '<': 4, '<=': 4, '>': 4, '>=': 4, '!=': 4, '==': 4,
'|': 5,
'^': 6,
'&': 7,
'<<': 8, '>>': 8,
'+': 9, '-': 9,
'*': 10, '/': 10, '//': 10, '%': 10,
# unary: '+': 11, '-': 11, '~': 11
'**': 12}
def _fmt_expr_node(self, node, precedence=0):
if isinstance(node, ExprNodes.BinopNode) and not node.inplace:
new_prec = self.binop_precedence.get(node.operator, 0)
result = '%s %s %s' % (self._fmt_expr_node(node.operand1, new_prec),
node.operator,
self._fmt_expr_node(node.operand2, new_prec))
if precedence > new_prec:
result = '(%s)' % result
elif isinstance(node, ExprNodes.UnopNode):
result = '%s%s' % (node.operator,
self._fmt_expr_node(node.operand, self.unop_precedence))
if precedence > self.unop_precedence:
result = '(%s)' % result
elif isinstance(node, ExprNodes.AttributeNode):
result = '%s.%s' % (self._fmt_expr_node(node.obj), node.attribute)
else:
result = node.name
return result return result
def _fmt_arg_defv(self, arg):
default_val = arg.default
if not default_val:
return None
if isinstance(default_val, ExprNodes.NullNode):
return 'NULL'
try:
denv = self.denv # XXX
ctval = default_val.compile_time_value(self.denv)
repr_val = repr(ctval)
if isinstance(default_val, ExprNodes.UnicodeNode):
if repr_val[:1] != 'u':
return u'u%s' % repr_val
elif isinstance(default_val, ExprNodes.BytesNode):
if repr_val[:1] != 'b':
return u'b%s' % repr_val
elif isinstance(default_val, ExprNodes.StringNode):
if repr_val[:1] in 'ub':
return repr_val[1:]
return repr_val
except Exception:
try:
return self._fmt_expr_node(default_val)
except AttributeError:
return '<???>'
def _fmt_arg(self, arg): def _fmt_arg(self, arg):
if arg.type is PyrexTypes.py_object_type or arg.is_self_arg: if arg.type is PyrexTypes.py_object_type or arg.is_self_arg:
doc = arg.name doc = arg.name
else: else:
doc = arg.type.declaration_code(arg.name, for_display=1) doc = arg.type.declaration_code(arg.name, for_display=1)
if arg.annotation:
annotation = self._fmt_expr(arg.annotation)
doc = doc + (': %s' % annotation)
if arg.default: if arg.default:
arg_defv = self._fmt_arg_defv(arg) default = self._fmt_expr(arg.default)
if arg_defv: doc = doc + (' = %s' % default)
doc = doc + ('=%s' % arg_defv) elif arg.default:
default = self._fmt_expr(arg.default)
doc = doc + ('=%s' % default)
return doc return doc
def _fmt_star_arg(self, arg):
arg_doc = arg.name
if arg.annotation:
annotation = self._fmt_expr(arg.annotation)
arg_doc = arg_doc + (': %s' % annotation)
return arg_doc
def _fmt_arglist(self, args, def _fmt_arglist(self, args,
npargs=0, pargs=None, npargs=0, pargs=None,
nkargs=0, kargs=None, nkargs=0, kargs=None,
...@@ -94,11 +64,13 @@ class EmbedSignature(CythonTransform): ...@@ -94,11 +64,13 @@ class EmbedSignature(CythonTransform):
arg_doc = self._fmt_arg(arg) arg_doc = self._fmt_arg(arg)
arglist.append(arg_doc) arglist.append(arg_doc)
if pargs: if pargs:
arglist.insert(npargs, '*%s' % pargs.name) arg_doc = self._fmt_star_arg(pargs)
arglist.insert(npargs, '*%s' % arg_doc)
elif nkargs: elif nkargs:
arglist.insert(npargs, '*') arglist.insert(npargs, '*')
if kargs: if kargs:
arglist.append('**%s' % kargs.name) arg_doc = self._fmt_star_arg(kargs)
arglist.append('**%s' % arg_doc)
return arglist return arglist
def _fmt_ret_type(self, ret): def _fmt_ret_type(self, ret):
...@@ -110,6 +82,7 @@ class EmbedSignature(CythonTransform): ...@@ -110,6 +82,7 @@ class EmbedSignature(CythonTransform):
def _fmt_signature(self, cls_name, func_name, args, def _fmt_signature(self, cls_name, func_name, args,
npargs=0, pargs=None, npargs=0, pargs=None,
nkargs=0, kargs=None, nkargs=0, kargs=None,
return_expr=None,
return_type=None, hide_self=False): return_type=None, hide_self=False):
arglist = self._fmt_arglist(args, arglist = self._fmt_arglist(args,
npargs, pargs, npargs, pargs,
...@@ -119,7 +92,10 @@ class EmbedSignature(CythonTransform): ...@@ -119,7 +92,10 @@ class EmbedSignature(CythonTransform):
func_doc = '%s(%s)' % (func_name, arglist_doc) func_doc = '%s(%s)' % (func_name, arglist_doc)
if cls_name: if cls_name:
func_doc = '%s.%s' % (cls_name, func_doc) func_doc = '%s.%s' % (cls_name, func_doc)
if return_type: ret_doc = None
if return_expr:
ret_doc = self._fmt_expr(return_expr)
elif return_type:
ret_doc = self._fmt_ret_type(return_type) ret_doc = self._fmt_ret_type(return_type)
if ret_doc: if ret_doc:
func_doc = '%s -> %s' % (func_doc, ret_doc) func_doc = '%s -> %s' % (func_doc, ret_doc)
...@@ -177,6 +153,7 @@ class EmbedSignature(CythonTransform): ...@@ -177,6 +153,7 @@ class EmbedSignature(CythonTransform):
class_name, func_name, node.args, class_name, func_name, node.args,
npargs, node.star_arg, npargs, node.star_arg,
nkargs, node.starstar_arg, nkargs, node.starstar_arg,
return_expr=node.return_type_annotation,
return_type=None, hide_self=hide_self) return_type=None, hide_self=hide_self)
if signature: if signature:
if is_constructor: if is_constructor:
......
...@@ -80,6 +80,9 @@ __doc__ = ur""" ...@@ -80,6 +80,9 @@ __doc__ = ur"""
>>> print (Ext.m.__doc__) >>> print (Ext.m.__doc__)
Ext.m(self, a=u'spam') Ext.m(self, a=u'spam')
>>> print (Ext.n.__doc__)
Ext.n(self, a: int, b: float = 1.0, *args: tuple, **kwargs: dict) -> (None, True)
>>> print (Ext.get_int.__doc__) >>> print (Ext.get_int.__doc__)
Ext.get_int(self) -> int Ext.get_int(self) -> int
...@@ -185,7 +188,7 @@ __doc__ = ur""" ...@@ -185,7 +188,7 @@ __doc__ = ur"""
f_defexpr4(int x=(Ext.CONST1 + FLAG1) * Ext.CONST2) f_defexpr4(int x=(Ext.CONST1 + FLAG1) * Ext.CONST2)
>>> print(funcdoc(f_defexpr5)) >>> print(funcdoc(f_defexpr5))
f_defexpr5(int x=4) f_defexpr5(int x=2 + 2)
>>> print(funcdoc(f_charptr_null)) >>> print(funcdoc(f_charptr_null))
f_charptr_null(char *s=NULL) -> char * f_charptr_null(char *s=NULL) -> char *
...@@ -259,6 +262,9 @@ cdef class Ext: ...@@ -259,6 +262,9 @@ cdef class Ext:
def m(self, a=u'spam'): def m(self, a=u'spam'):
pass pass
def n(self, a: int, b: float = 1.0, *args: tuple, **kwargs: dict) -> (None, True):
pass
cpdef int get_int(self): cpdef int get_int(self):
return 0 return 0
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment