Commit a67aaf75 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Introduce TempsBlockNode utility, improve TreeFragment-generated temps

parent a5e1aea2
from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
from Cython.Compiler.Visitor import TreeVisitor
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
......@@ -37,6 +37,7 @@ class CodeWriter(TreeVisitor):
self.result = result
self.numindents = 0
self.tempnames = {}
self.tempblockindex = 0
def write(self, tree):
self.visit(tree)
......@@ -60,12 +61,6 @@ class CodeWriter(TreeVisitor):
self.startline(s)
self.endline()
def putname(self, name):
tmpdesc = get_temp_name_handle_desc(name)
if tmpdesc is not None:
name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
self.put(name)
def comma_seperated_list(self, items, output_rhs=False):
if len(items) > 0:
for item in items[:-1]:
......@@ -132,7 +127,7 @@ class CodeWriter(TreeVisitor):
self.endline()
def visit_NameNode(self, node):
self.putname(node.name)
self.put(node.name)
def visit_IntNode(self, node):
self.put(node.value)
......@@ -312,3 +307,18 @@ class CodeWriter(TreeVisitor):
self.visit(node.operand)
self.put(u")")
def visit_TempsBlockNode(self, node):
"""
Temporaries are output like $1_1', where the first number is
an index of the TempsBlockNode and the second number is an index
of the temporary which that block allocates.
"""
idx = 0
for handle in node.handles:
self.tempnames[handle] = "$%d_%d" % (self.tempblockindex, idx)
idx += 1
self.tempblockindex += 1
self.visit(node.body)
def visit_TempRefNode(self, node):
self.put(self.tempnames[node.handle])
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError
import Interpreter
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
......
......@@ -206,10 +206,10 @@ class ExprNode(Node):
return self.saved_subexpr_nodes
def result(self):
if self.is_temp:
return self.result_code
else:
return self.calculate_result_code()
if self.is_temp:
return self.result_code
else:
return self.calculate_result_code()
def result_as(self, type = None):
# Return the result code cast to the specified C type.
......
......@@ -4188,6 +4188,7 @@ class FromImportStatNode(StatNode):
self.module.generate_disposal_code(code)
#------------------------------------------------------------------------------------
#
# Runtime support code
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.UtilNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError
......@@ -409,7 +410,7 @@ class WithTransform(CythonTransform):
finally:
if EXC:
EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
""", temps=[u'MGR', u'EXC', u"EXIT"],
pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u"""
......@@ -428,32 +429,33 @@ class WithTransform(CythonTransform):
finally:
if EXC:
EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
pipeline=[NormalizeTree(None)])
def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO')
excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
excinfo_tempblock = TempsBlockNode(node.pos, [PyrexTypes.py_object_type], None)
if node.target is not None:
result = self.template_with_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
u'TARGET' : node.target,
u'EXCINFO' : excinfo_namenode
u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
}, pos=node.pos)
# Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
excinfo_tempblock.get_ref_node(0, node.pos))
else:
result = self.template_without_target.substitute({
u'EXPR' : node.manager,
u'BODY' : node.body,
u'EXCINFO' : excinfo_namenode
u'EXCINFO' : excinfo_tempblock.get_ref_node(0, node.pos)
}, pos=node.pos)
# Set except excinfo target to EXCINFO
result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
return result.stats
result.body.stats[4].body.stats[0].except_clauses[0].excinfo_target = (
excinfo_tempblock.get_ref_node(0, node.pos))
excinfo_tempblock.body = result
return excinfo_tempblock
class DecoratorTransform(CythonTransform):
......
......@@ -92,23 +92,23 @@ class TestWithTransform(TransformTest):
with x:
y = z ** 3
""")
self.assertCode(u"""
$MGR = x
$EXIT = $MGR.__exit__
$MGR.__enter__()
$EXC = True
$1_0 = x
$1_2 = $1_0.__exit__
$1_0.__enter__()
$1_1 = True
try:
try:
y = z ** 3
except:
$EXC = False
if (not $EXIT($EXCINFO)):
$1_1 = False
if (not $1_2($0_0)):
raise
finally:
if $EXC:
$EXIT(None, None, None)
if $1_1:
$1_2(None, None, None)
""", t)
......@@ -119,21 +119,21 @@ class TestWithTransform(TransformTest):
""")
self.assertCode(u"""
$MGR = x
$EXIT = $MGR.__exit__
$VALUE = $MGR.__enter__()
$EXC = True
$1_0 = x
$1_2 = $1_0.__exit__
$1_3 = $1_0.__enter__()
$1_1 = True
try:
try:
y = $VALUE
y = $1_3
y = z ** 3
except:
$EXC = False
if (not $EXIT($EXCINFO)):
$1_1 = False
if (not $1_2($0_0)):
raise
finally:
if $EXC:
$EXIT(None, None, None)
if $1_1:
$1_2(None, None, None)
""", t)
......
from Cython.TestUtils import CythonTest
from Cython.Compiler.TreeFragment import *
from Cython.Compiler.Nodes import *
from Cython.Compiler.UtilNodes import *
import Cython.Compiler.Naming as Naming
class TestTreeFragments(CythonTest):
......@@ -54,10 +55,10 @@ class TestTreeFragments(CythonTest):
x = TMP
""")
T = F.substitute(temps=[u"TMP"])
s = T.stats
self.assert_(s[0].expr.name == Naming.temp_prefix + u"1_TMP", s[0].expr.name)
self.assert_(s[1].rhs.name == Naming.temp_prefix + u"1_TMP")
self.assert_(s[0].expr.name != u"TMP")
s = T.body.stats
self.assert_(isinstance(s[0].expr, TempRefNode))
self.assert_(isinstance(s[1].rhs, TempRefNode))
self.assert_(s[0].expr.handle is s[1].rhs.handle)
if __name__ == "__main__":
import unittest
......
......@@ -8,11 +8,12 @@ from Scanning import PyrexScanner, StringSourceDescriptor
from Symtab import BuiltinScope, ModuleScope
import Symtab
import PyrexTypes
from Visitor import VisitorTransform, temp_name_handle
from Visitor import VisitorTransform
from Nodes import Node, StatListNode
from ExprNodes import NameNode
import Parsing
import Main
import UtilNodes
"""
Support for parsing strings into code trees.
......@@ -111,12 +112,18 @@ class TemplateTransform(VisitorTransform):
def __call__(self, node, substitutions, temps, pos):
self.substitutions = substitutions
tempdict = {}
for key in temps:
tempdict[key] = temp_name_handle(key) # pending result_code refactor: Symtab.new_temp(PyrexTypes.py_object_type, key)
self.temp_key_to_entries = tempdict
self.pos = pos
return super(TemplateTransform, self).__call__(node)
self.temps = temps
if len(temps) > 0:
self.tempblock = UtilNodes.TempsBlockNode(self.get_pos(node),
[PyrexTypes.py_object_type for x in temps],
body=None)
self.tempblock.body = super(TemplateTransform, self).__call__(node)
return self.tempblock
else:
return super(TemplateTransform, self).__call__(node)
def get_pos(self, node):
if self.pos:
......@@ -145,13 +152,13 @@ class TemplateTransform(VisitorTransform):
def visit_NameNode(self, node):
tempentry = self.temp_key_to_entries.get(node.name)
if tempentry is not None:
# Replace name with temporary
return NameNode(self.get_pos(node), name=tempentry)
# Pending result_code refactor: return NameNode(self.get_pos(node), entry=tempentry)
else:
try:
tmpidx = self.temps.index(node.name)
except:
return self.try_substitution(node, node.name)
else:
# Replace name with temporary
return self.tempblock.get_ref_node(tmpidx, self.get_pos(node))
def visit_ExprStatNode(self, node):
# If an expression-as-statement consists of only a replaceable
......
#
# Nodes used as utilities and support for transforms etc.
# These often make up sets including both Nodes and ExprNodes
# so it is convenient to have them in a seperate module.
#
import Nodes
import ExprNodes
from Nodes import Node
from ExprNodes import ExprNode
class TempHandle(object):
temp = None
def __init__(self, type):
self.type = type
class TempRefNode(ExprNode):
# handle TempHandle
subexprs = []
def analyse_types(self, env):
assert self.type == self.handle.type
def analyse_target_types(self, env):
assert self.type == self.handle.type
def analyse_target_declaration(self, env):
pass
def calculate_result_code(self):
result = self.handle.temp
if result is None: result = "<error>" # might be called and overwritten
return result
def generate_result_code(self, code):
pass
def generate_assignment_code(self, rhs, code):
if self.type.is_pyobject:
rhs.make_owned_reference(code)
code.put_xdecref(self.result(), self.ctype())
code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
rhs.generate_post_assignment_code(code)
class TempsBlockNode(Node):
"""
Creates a block which allocates temporary variables.
This is used by transforms to output constructs that need
to make use of a temporary variable. Simply pass the types
of the needed temporaries to the constructor.
The variables can be referred to using a TempRefNode
(which can be constructed by calling get_ref_node).
"""
child_attrs = ["body"]
def __init__(self, pos, types, body):
self.handles = [TempHandle(t) for t in types]
Node.__init__(self, pos, body=body)
def get_ref_node(self, index, pos):
handle = self.handles[index]
return TempRefNode(pos, handle=handle, type=handle.type)
def append_temp(self, type, pos):
"""
Appends a new temporary which this block manages, and returns
its index.
"""
self.handle.append(TempHandle(type))
return len(self.handle) - 1
def generate_execution_code(self, code):
for handle in self.handles:
handle.temp = code.funcstate.allocate_temp(handle.type)
self.body.generate_execution_code(code)
for handle in self.handles:
code.funcstate.release_temp(handle.temp)
def analyse_control_flow(self, env):
self.body.analyse_control_flow(env)
def analyse_declarations(self, env):
self.body.analyse_declarations(env)
def analyse_expressions(self, env):
self.body.analyse_expressions(env)
def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code)
def annotate(self, code):
self.body.annotate(code)
......@@ -199,23 +199,6 @@ def replace_node(ptr, value):
else:
getattr(parent, attrname)[listidx] = value
tmpnamectr = 0
def temp_name_handle(description=None):
global tmpnamectr
tmpnamectr += 1
if description is not None:
name = u"%d_%s" % (tmpnamectr, description)
else:
name = u"%d" % tmpnamectr
return EncodedString(Naming.temp_prefix + name)
def get_temp_name_handle_desc(handle):
if not handle.startswith(u"__cyt_"):
return None
else:
idx = handle.find(u"_", 6)
return handle[idx+1:]
class PrintTree(TreeVisitor):
"""Prints a representation of the tree to standard output.
Subclass and override repr_of to provide more information
......
......@@ -47,10 +47,16 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(expected), len(result),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
def assertCode(self, expected, result_tree):
def codeToLines(self, tree):
writer = CodeWriter()
writer.write(result_tree)
result_lines = writer.result.lines
writer.write(tree)
return writer.result.lines
def codeToString(self, tree):
return "\n".join(self.codeToLines(tree))
def assertCode(self, expected, result_tree):
result_lines = self.codeToLines(result_tree)
expected_lines = strip_common_indent(expected.split("\n"))
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment