Commit ccd17d53 authored by Tao He's avatar Tao He Committed by GitHub

Improve syntax feature support of Cython CodeWriter (GH-3514)

parent 75f4aff7
""" """
Serializes a Cython code tree to Cython code. This is primarily useful for Serializes a Cython code tree to Cython code. This is primarily useful for
debugging and testing purposes. debugging and testing purposes.
The output is in a strict format, no whitespace or comments from the input The output is in a strict format, no whitespace or comments from the input
is preserved (and it could not be as it is not present in the code tree). is preserved (and it could not be as it is not present in the code tree).
""" """
...@@ -10,6 +9,7 @@ from __future__ import absolute_import, print_function ...@@ -10,6 +9,7 @@ from __future__ import absolute_import, print_function
from .Compiler.Visitor import TreeVisitor from .Compiler.Visitor import TreeVisitor
from .Compiler.ExprNodes import * from .Compiler.ExprNodes import *
from .Compiler.Nodes import CNameDeclaratorNode, CSimpleBaseTypeNode
class LinesResult(object): class LinesResult(object):
...@@ -80,6 +80,14 @@ class DeclarationWriter(TreeVisitor): ...@@ -80,6 +80,14 @@ class DeclarationWriter(TreeVisitor):
self.visit(item.default) self.visit(item.default)
self.put(u", ") self.put(u", ")
self.visit(items[-1]) self.visit(items[-1])
if output_rhs and items[-1].default is not None:
self.put(u" = ")
self.visit(items[-1].default)
def _visit_indented(self, node):
self.indent()
self.visit(node)
self.dedent()
def visit_Node(self, node): def visit_Node(self, node):
raise AssertionError("Node not handled by serializer: %r" % node) raise AssertionError("Node not handled by serializer: %r" % node)
...@@ -96,9 +104,7 @@ class DeclarationWriter(TreeVisitor): ...@@ -96,9 +104,7 @@ class DeclarationWriter(TreeVisitor):
else: else:
file = u'"%s"' % node.include_file file = u'"%s"' % node.include_file
self.putline(u"cdef extern from %s:" % file) self.putline(u"cdef extern from %s:" % file)
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent()
def visit_CPtrDeclaratorNode(self, node): def visit_CPtrDeclaratorNode(self, node):
self.put('*') self.put('*')
...@@ -133,13 +139,12 @@ class DeclarationWriter(TreeVisitor): ...@@ -133,13 +139,12 @@ class DeclarationWriter(TreeVisitor):
self.put("short " * -node.longness) self.put("short " * -node.longness)
elif node.longness > 0: elif node.longness > 0:
self.put("long " * node.longness) self.put("long " * node.longness)
self.put(node.name) if node.name is not None:
self.put(node.name)
def visit_CComplexBaseTypeNode(self, node): def visit_CComplexBaseTypeNode(self, node):
self.put(u'(')
self.visit(node.base_type) self.visit(node.base_type)
self.visit(node.declarator) self.visit(node.declarator)
self.put(u')')
def visit_CNestedBaseTypeNode(self, node): def visit_CNestedBaseTypeNode(self, node):
self.visit(node.base_type) self.visit(node.base_type)
...@@ -159,7 +164,7 @@ class DeclarationWriter(TreeVisitor): ...@@ -159,7 +164,7 @@ class DeclarationWriter(TreeVisitor):
self.comma_separated_list(node.declarators, output_rhs=True) self.comma_separated_list(node.declarators, output_rhs=True)
self.endline() self.endline()
def visit_container_node(self, node, decl, extras, attributes): def _visit_container_node(self, node, decl, extras, attributes):
# TODO: visibility # TODO: visibility
self.startline(decl) self.startline(decl)
if node.name: if node.name:
...@@ -188,7 +193,7 @@ class DeclarationWriter(TreeVisitor): ...@@ -188,7 +193,7 @@ class DeclarationWriter(TreeVisitor):
if node.packed: if node.packed:
decl += u'packed ' decl += u'packed '
decl += node.kind decl += node.kind
self.visit_container_node(node, decl, None, node.attributes) self._visit_container_node(node, decl, None, node.attributes)
def visit_CppClassNode(self, node): def visit_CppClassNode(self, node):
extras = "" extras = ""
...@@ -196,10 +201,10 @@ class DeclarationWriter(TreeVisitor): ...@@ -196,10 +201,10 @@ class DeclarationWriter(TreeVisitor):
extras = u"[%s]" % ", ".join(node.templates) extras = u"[%s]" % ", ".join(node.templates)
if node.base_classes: if node.base_classes:
extras += "(%s)" % ", ".join(node.base_classes) extras += "(%s)" % ", ".join(node.base_classes)
self.visit_container_node(node, u"cdef cppclass", extras, node.attributes) self._visit_container_node(node, u"cdef cppclass", extras, node.attributes)
def visit_CEnumDefNode(self, node): def visit_CEnumDefNode(self, node):
self.visit_container_node(node, u"cdef enum", None, node.items) self._visit_container_node(node, u"cdef enum", None, node.items)
def visit_CEnumDefItemNode(self, node): def visit_CEnumDefItemNode(self, node):
self.startline(node.name) self.startline(node.name)
...@@ -225,9 +230,7 @@ class DeclarationWriter(TreeVisitor): ...@@ -225,9 +230,7 @@ class DeclarationWriter(TreeVisitor):
self.put(node.base_class_name) self.put(node.base_class_name)
self.put(u")") self.put(u")")
self.endline(u":") self.endline(u":")
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent()
def visit_CTypeDefNode(self, node): def visit_CTypeDefNode(self, node):
self.startline(u"ctypedef ") self.startline(u"ctypedef ")
...@@ -241,14 +244,45 @@ class DeclarationWriter(TreeVisitor): ...@@ -241,14 +244,45 @@ class DeclarationWriter(TreeVisitor):
self.startline(u"def %s(" % node.name) self.startline(u"def %s(" % node.name)
self.comma_separated_list(node.args) self.comma_separated_list(node.args)
self.endline(u"):") self.endline(u"):")
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent() def visit_CFuncDefNode(self, node):
self.startline(u'cpdef ' if node.overridable else u'cdef ')
if node.modifiers:
self.put(' '.join(node.modifiers))
self.put(' ')
if node.visibility != 'private':
self.put(node.visibility)
self.put(u' ')
if node.api:
self.put(u'api ')
if node.base_type:
self.visit(node.base_type)
if node.base_type.name is not None:
self.put(u' ')
# visit the CFuncDeclaratorNode, but put a `:` at the end of line
self.visit(node.declarator.base)
self.put(u'(')
self.comma_separated_list(node.declarator.args)
self.endline(u'):')
self._visit_indented(node.body)
def visit_CArgDeclNode(self, node): def visit_CArgDeclNode(self, node):
if node.base_type.name is not None: # For "CSimpleBaseTypeNode", the variable type may have been parsed as type.
# For other node types, the "name" is always None.
if not isinstance(node.base_type, CSimpleBaseTypeNode) or \
node.base_type.name is not None:
self.visit(node.base_type) self.visit(node.base_type)
self.put(u" ")
# If we printed something for "node.base_type", we may need to print an extra ' '.
#
# Special case: if "node.declarator" is a "CNameDeclaratorNode",
# its "name" might be an empty string, for example, for "cdef f(x)".
if node.declarator.declared_name():
self.put(u" ")
self.visit(node.declarator) self.visit(node.declarator)
if node.default is not None: if node.default is not None:
self.put(u" = ") self.put(u" = ")
...@@ -328,14 +362,10 @@ class StatementWriter(DeclarationWriter): ...@@ -328,14 +362,10 @@ class StatementWriter(DeclarationWriter):
self.put(u" in ") self.put(u" in ")
self.visit(node.iterator.sequence) self.visit(node.iterator.sequence)
self.endline(u":") self.endline(u":")
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent()
if node.else_clause is not None: if node.else_clause is not None:
self.line(u"else:") self.line(u"else:")
self.indent() self._visit_indented(node.else_clause)
self.visit(node.else_clause)
self.dedent()
def visit_IfStatNode(self, node): def visit_IfStatNode(self, node):
# The IfClauseNode is handled directly without a separate match # The IfClauseNode is handled directly without a separate match
...@@ -343,21 +373,30 @@ class StatementWriter(DeclarationWriter): ...@@ -343,21 +373,30 @@ class StatementWriter(DeclarationWriter):
self.startline(u"if ") self.startline(u"if ")
self.visit(node.if_clauses[0].condition) self.visit(node.if_clauses[0].condition)
self.endline(":") self.endline(":")
self.indent() self._visit_indented(node.if_clauses[0].body)
self.visit(node.if_clauses[0].body)
self.dedent()
for clause in node.if_clauses[1:]: for clause in node.if_clauses[1:]:
self.startline("elif ") self.startline("elif ")
self.visit(clause.condition) self.visit(clause.condition)
self.endline(":") self.endline(":")
self.indent() self._visit_indented(clause.body)
self.visit(clause.body)
self.dedent()
if node.else_clause is not None: if node.else_clause is not None:
self.line("else:") self.line("else:")
self.indent() self._visit_indented(node.else_clause)
self.visit(node.else_clause)
self.dedent() def visit_WhileStatNode(self, node):
self.startline(u"while ")
self.visit(node.condition)
self.endline(u":")
self._visit_indented(node.body)
if node.else_clause is not None:
self.line("else:")
self._visit_indented(node.else_clause)
def visit_ContinueStatNode(self, node):
self.line(u"continue")
def visit_BreakStatNode(self, node):
self.line(u"break")
def visit_SequenceNode(self, node): def visit_SequenceNode(self, node):
self.comma_separated_list(node.args) # Might need to discover whether we need () around tuples...hmm... self.comma_separated_list(node.args) # Might need to discover whether we need () around tuples...hmm...
...@@ -382,25 +421,17 @@ class StatementWriter(DeclarationWriter): ...@@ -382,25 +421,17 @@ class StatementWriter(DeclarationWriter):
self.put(u" as ") self.put(u" as ")
self.visit(node.target) self.visit(node.target)
self.endline(u":") self.endline(u":")
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent()
def visit_TryFinallyStatNode(self, node): def visit_TryFinallyStatNode(self, node):
self.line(u"try:") self.line(u"try:")
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent()
self.line(u"finally:") self.line(u"finally:")
self.indent() self._visit_indented(node.finally_clause)
self.visit(node.finally_clause)
self.dedent()
def visit_TryExceptStatNode(self, node): def visit_TryExceptStatNode(self, node):
self.line(u"try:") self.line(u"try:")
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent()
for x in node.except_clauses: for x in node.except_clauses:
self.visit(x) self.visit(x)
if node.else_clause is not None: if node.else_clause is not None:
...@@ -415,9 +446,7 @@ class StatementWriter(DeclarationWriter): ...@@ -415,9 +446,7 @@ class StatementWriter(DeclarationWriter):
self.put(u", ") self.put(u", ")
self.visit(node.target) self.visit(node.target)
self.endline(":") self.endline(":")
self.indent() self._visit_indented(node.body)
self.visit(node.body)
self.dedent()
def visit_ReturnStatNode(self, node): def visit_ReturnStatNode(self, node):
self.startline("return ") self.startline("return ")
...@@ -480,12 +509,18 @@ class ExpressionWriter(TreeVisitor): ...@@ -480,12 +509,18 @@ class ExpressionWriter(TreeVisitor):
def visit_Node(self, node): def visit_Node(self, node):
raise AssertionError("Node not handled by serializer: %r" % node) raise AssertionError("Node not handled by serializer: %r" % node)
def visit_NameNode(self, node): def visit_IntNode(self, node):
self.put(node.name) self.put(node.value)
def visit_FloatNode(self, node):
self.put(node.value)
def visit_NoneNode(self, node): def visit_NoneNode(self, node):
self.put(u"None") self.put(u"None")
def visit_NameNode(self, node):
self.put(node.name)
def visit_EllipsisNode(self, node): def visit_EllipsisNode(self, node):
self.put(u"...") self.put(u"...")
...@@ -756,12 +791,13 @@ class PxdWriter(DeclarationWriter, ExpressionWriter): ...@@ -756,12 +791,13 @@ class PxdWriter(DeclarationWriter, ExpressionWriter):
return node return node
def visit_CFuncDefNode(self, node): def visit_CFuncDefNode(self, node):
if 'inline' in node.modifiers:
return
if node.overridable: if node.overridable:
self.startline(u'cpdef ') self.startline(u'cpdef ')
else: else:
self.startline(u'cdef ') self.startline(u'cdef ')
if node.modifiers:
self.put(' '.join(node.modifiers))
self.put(' ')
if node.visibility != 'private': if node.visibility != 'private':
self.put(node.visibility) self.put(node.visibility)
self.put(u' ') self.put(u' ')
......
...@@ -21,6 +21,7 @@ class TestCodeWriter(CythonTest): ...@@ -21,6 +21,7 @@ class TestCodeWriter(CythonTest):
self.t(u""" self.t(u"""
print(x + y ** 2) print(x + y ** 2)
print(x, y, z) print(x, y, z)
print(x + y, x + y * z, x * (y + z))
""") """)
def test_if(self): def test_if(self):
...@@ -46,6 +47,20 @@ class TestCodeWriter(CythonTest): ...@@ -46,6 +47,20 @@ class TestCodeWriter(CythonTest):
pass pass
""") """)
def test_cdef(self):
self.t(u"""
cdef f(x, y, z):
pass
cdef public void (x = 34, y = 54, z):
pass
cdef f(int *x, void *y, Value *z):
pass
cdef f(int **x, void **y, Value **z):
pass
cdef inline f(int &x, Value &z):
pass
""")
def test_longness_and_signedness(self): def test_longness_and_signedness(self):
self.t(u"def f(unsigned long long long long long int y):\n pass") self.t(u"def f(unsigned long long long long long int y):\n pass")
...@@ -75,6 +90,14 @@ class TestCodeWriter(CythonTest): ...@@ -75,6 +90,14 @@ class TestCodeWriter(CythonTest):
print(43) print(43)
""") """)
def test_while_loop(self):
self.t(u"""
while True:
while True:
while True:
continue
""")
def test_inplace_assignment(self): def test_inplace_assignment(self):
self.t(u"x += 43") self.t(u"x += 43")
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment