Commit 55feb21e authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Temp allocation possible in CCodeWriter

parent b6f46a39
...@@ -160,7 +160,7 @@ def get_release_buffer_code(entry): ...@@ -160,7 +160,7 @@ def get_release_buffer_code(entry):
entry.cname, entry.cname,
entry.buffer_aux.buffer_info_var.cname) entry.buffer_aux.buffer_info_var.cname)
def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer_type, def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
is_initialized, pos, code): is_initialized, pos, code):
""" """
Generate code for reassigning a buffer variables. This only deals with getting Generate code for reassigning a buffer variables. This only deals with getting
...@@ -193,27 +193,31 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer ...@@ -193,27 +193,31 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer
lhs_cname, bufstruct)) lhs_cname, bufstruct))
code.end_block() code.end_block()
# Acquire # Acquire
retcode_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname)) code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
# If acquisition failed, attempt to reacquire the old buffer # If acquisition failed, attempt to reacquire the old buffer
# before raising the exception. A failure of reacquisition # before raising the exception. A failure of reacquisition
# will cause the reacquisition exception to be reported, one # will cause the reacquisition exception to be reported, one
# can consider working around this later. # can consider working around this later.
code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
code.begin_block() code.begin_block()
# In anticipation of a better temp system, create non-consistent C code for now type, value, tb = [code.func.allocate_temp(PyrexTypes.py_object_type)
code.putln('PyObject *__pyx_type, *__pyx_value, *__pyx_tb;') for i in range(3)]
code.putln('PyErr_Fetch(&__pyx_type, &__pyx_value, &__pyx_tb);') code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname))) code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
code.begin_block() code.begin_block()
code.putln('Py_XDECREF(__pyx_type); Py_XDECREF(__pyx_value); Py_XDECREF(__pyx_tb);') code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
code.putln('__Pyx_RaiseBufferFallbackError();') code.putln('__Pyx_RaiseBufferFallbackError();')
code.putln('} else {') code.putln('} else {')
code.putln('PyErr_Restore(__pyx_type, __pyx_value, __pyx_tb);') code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
for t in (type, value, tb):
code.func.release_temp(t)
code.end_block() code.end_block()
# Unpack indices # Unpack indices
code.end_block() code.end_block()
put_unpack_buffer_aux_into_scope(buffer_aux, code) put_unpack_buffer_aux_into_scope(buffer_aux, code)
code.putln(code.error_goto_if_neg(retcode_cname, pos)) code.putln(code.error_goto_if_neg(retcode_cname, pos))
code.func.release_temp(retcode_cname)
else: else:
# Our entry had no previous value, so set to None when acquisition fails. # Our entry had no previous value, so set to None when acquisition fails.
# In this case, auxiliary vars should be set up right in initialization to a zero-buffer, # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
...@@ -227,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer ...@@ -227,7 +231,7 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, retcode_cname, buffer_aux, buffer
code.putln('}') code.putln('}')
def put_access(entry, index_types, index_cnames, tmp_cname, pos, code): def put_access(entry, index_signeds, index_cnames, pos, code):
"""Returns a c string which can be used to access the buffer """Returns a c string which can be used to access the buffer
for reading or writing. for reading or writing.
...@@ -241,11 +245,12 @@ def put_access(entry, index_types, index_cnames, tmp_cname, pos, code): ...@@ -241,11 +245,12 @@ def put_access(entry, index_types, index_cnames, tmp_cname, pos, code):
# Check bounds and fix negative indices # Check bounds and fix negative indices
boundscheck = True boundscheck = True
nonegs = True nonegs = True
tmp_cname = code.func.allocate_temp(PyrexTypes.c_int_type)
if boundscheck: if boundscheck:
code.putln("%s = -1;" % tmp_cname) code.putln("%s = -1;" % tmp_cname)
for idx, (type, cname, shape) in enumerate(zip(index_types, index_cnames, for idx, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
bufaux.shapevars)): bufaux.shapevars)):
if type.signed != 0: if signed != 0:
nonegs = False nonegs = False
# not unsigned, deal with negative index # not unsigned, deal with negative index
code.putln("if (%s < 0) {" % cname) code.putln("if (%s < 0) {" % cname)
...@@ -269,6 +274,7 @@ def put_access(entry, index_types, index_cnames, tmp_cname, pos, code): ...@@ -269,6 +274,7 @@ def put_access(entry, index_types, index_cnames, tmp_cname, pos, code):
code.putln('__Pyx_BufferIndexError(%s);' % tmp_cname) code.putln('__Pyx_BufferIndexError(%s);' % tmp_cname)
code.putln(code.error_goto(pos)) code.putln(code.error_goto(pos))
code.end_block() code.end_block()
code.func.release_temp(tmp_cname)
# Create buffer lookup and return it # Create buffer lookup and return it
......
...@@ -10,10 +10,13 @@ from PyrexTypes import py_object_type, typecast ...@@ -10,10 +10,13 @@ from PyrexTypes import py_object_type, typecast
from TypeSlots import method_coexist from TypeSlots import method_coexist
from Scanning import SourceDescriptor from Scanning import SourceDescriptor
from Cython.StringIOTree import StringIOTree from Cython.StringIOTree import StringIOTree
from sets import Set as set
class FunctionContext(object): class FunctionContext(object):
# Not used for now, perhaps later # Not used for now, perhaps later
def __init__(self): def __init__(self, names_taken=set()):
self.names_taken = names_taken
self.error_label = None self.error_label = None
self.label_counter = 0 self.label_counter = 0
self.labels_used = {} self.labels_used = {}
...@@ -22,8 +25,10 @@ class FunctionContext(object): ...@@ -22,8 +25,10 @@ class FunctionContext(object):
self.continue_label = None self.continue_label = None
self.break_label = None self.break_label = None
self.temps_allocated = [] self.temps_allocated = [] # of (name, type)
self.temps_free = {} # type -> list of free vars self.temps_free = {} # type -> list of free vars
self.temps_used_type = {} # name -> type
self.temp_counter = 0
def new_label(self): def new_label(self):
n = self.label_counter n = self.label_counter
...@@ -82,13 +87,36 @@ class FunctionContext(object): ...@@ -82,13 +87,36 @@ class FunctionContext(object):
return lbl in self.labels_used return lbl in self.labels_used
def allocate_temp(self, type): def allocate_temp(self, type):
"""
Allocates a temporary (which may create a new one or get a previously
allocated and released one of the same type). Type is simply registered
and handed back, but will usually be a PyrexType.
A C string referring to the variable is returned.
"""
freelist = self.temps_free.get(type) freelist = self.temps_free.get(type)
if freelist is not None and len(freelist) > 0: if freelist is not None and len(freelist) > 0:
return freelist.pop() result = freelist.pop()
else: else:
pass while True:
self.temp_counter += 1
result = "%s%d" % (Naming.codewriter_temp_prefix, self.temp_counter)
if not result in self.names_taken: break
self.temps_allocated.append((result, type))
self.temps_used_type[result] = type
return result
def release_temp(self, name):
"""
Releases a temporary so that it can be reused by other code needing
a temp of the same type.
"""
type = self.temps_used_type[name]
freelist = self.temps_free.get(type)
if freelist is None:
freelist = []
self.temps_free[type] = freelist
freelist.append(name)
def funccontext_property(name): def funccontext_property(name):
def get(self): def get(self):
...@@ -333,6 +361,14 @@ class CCodeWriter(object): ...@@ -333,6 +361,14 @@ class CCodeWriter(object):
self.put(" = %s" % entry.type.literal_code(entry.init)) self.put(" = %s" % entry.type.literal_code(entry.init))
self.putln(";") self.putln(";")
def put_temp_declarations(self, func_context):
for name, type in func_context.temps_allocated:
decl = type.declaration_code(name)
if type.is_pyobject:
self.putln("%s = NULL;" % decl)
else:
self.putln("%s;" % decl)
def entry_as_pyobject(self, entry): def entry_as_pyobject(self, entry):
type = entry.type type = entry.type
if (not entry.is_self_arg and not entry.type.is_complete()) \ if (not entry.is_self_arg and not entry.type.is_complete()) \
......
from Visitor import CythonTransform
from sets import Set as set
class AnchorTemps(CythonTransform):
def init_scope(self, scope):
scope.free_temp_entries = []
def handle_node(self, node):
if node.temps:
for temp in node.temps:
temp.cname = self.scope.allocate_temp(temp.type)
self.temps_beneath_try.add(temp.cname)
self.visitchildren(node)
for temp in node.temps:
self.scope.release_temp(temp.cname)
else:
self.visitchildren(node)
def visit_Node(self, node):
self.handle_node(node)
return node
def visit_ModuleNode(self, node):
self.scope = node.scope
self.temps_beneath_try = set()
self.init_scope(self.scope)
self.handle_node(node)
return node
def visit_FuncDefNode(self, node):
pscope = self.scope
pscope_temps = self.temps_beneath_try
self.scope = node.local_scope
self.init_scope(node.local_scope)
self.handle_node(node)
self.scope = pscope
self.temps_beneath_try = pscope_temps
return node
def visit_TryExceptNode(self, node):
old_tbt = self.temps_beneath_try
self.temps_beneath_try = set()
self.handle_node(node)
entries = [ scope.cname_to_entry[cname] for
cname in self.temps_beneath_try]
node.cleanup_list.extend(entries)
return node
...@@ -889,9 +889,6 @@ class NameNode(AtomicExprNode): ...@@ -889,9 +889,6 @@ class NameNode(AtomicExprNode):
# think of had a single symbol result_code but better # think of had a single symbol result_code but better
# safe than sorry. Feel free to change this. # safe than sorry. Feel free to change this.
import Buffer import Buffer
self.new_buffer_temp = Symtab.new_temp(self.entry.type)
self.retcode_temp = Symtab.new_temp(PyrexTypes.c_int_type)
self.temps = [self.new_buffer_temp, self.retcode_temp]
Buffer.used_buffer_aux_vars(self.entry) Buffer.used_buffer_aux_vars(self.entry)
def analyse_rvalue_entry(self, env): def analyse_rvalue_entry(self, env):
...@@ -1068,13 +1065,13 @@ class NameNode(AtomicExprNode): ...@@ -1068,13 +1065,13 @@ class NameNode(AtomicExprNode):
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
def generate_acquire_buffer(self, rhs, code): def generate_acquire_buffer(self, rhs, code):
rhstmp = self.new_buffer_temp.cname rhstmp = code.func.allocate_temp(self.entry.type)
buffer_aux = self.entry.buffer_aux buffer_aux = self.entry.buffer_aux
bufstruct = buffer_aux.buffer_info_var.cname bufstruct = buffer_aux.buffer_info_var.cname
code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype()))) code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype())))
import Buffer import Buffer
Buffer.put_assign_to_buffer(self.result_code, rhstmp, self.retcode_temp.cname, buffer_aux, self.entry.type, Buffer.put_assign_to_buffer(self.result_code, rhstmp, buffer_aux, self.entry.type,
is_initialized=not self.skip_assignment_decref, is_initialized=not self.skip_assignment_decref,
pos=self.pos, code=code) pos=self.pos, code=code)
code.putln("%s = 0;" % rhstmp) code.putln("%s = 0;" % rhstmp)
...@@ -1366,10 +1363,7 @@ class IndexNode(ExprNode): ...@@ -1366,10 +1363,7 @@ class IndexNode(ExprNode):
self.index = None self.index = None
self.type = self.base.type.dtype self.type = self.base.type.dtype
self.is_buffer_access = True self.is_buffer_access = True
self.index_temps = [Symtab.new_temp(i.type) for i in indices]
self.tmpint = Symtab.new_temp(PyrexTypes.c_int_type)
self.temps = self.index_temps + [self.tmpint]
if getting: if getting:
# we only need a temp because result_code isn't refactored to # we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take # generation time, but this seems an ok shortcut to take
...@@ -1525,14 +1519,15 @@ class IndexNode(ExprNode): ...@@ -1525,14 +1519,15 @@ class IndexNode(ExprNode):
def buffer_access_code(self, code): def buffer_access_code(self, code):
# Assign indices to temps # Assign indices to temps
for temp, index in zip(self.index_temps, self.indices): index_temps = [code.func.allocate_temp(i.type) for i in self.indices]
code.putln("%s = %s;" % (temp.cname, index.result_code)) for temp, index in zip(index_temps, self.indices):
code.putln("%s = %s;" % (temp, index.result_code))
# Generate buffer access code using these temps # Generate buffer access code using these temps
import Buffer import Buffer
valuecode = Buffer.put_access(entry=self.base.entry, valuecode = Buffer.put_access(entry=self.base.entry,
index_types=[i.type for i in self.index_temps], index_signeds=[i.type.signed for i in self.indices],
index_cnames=[i.cname for i in self.index_temps], index_cnames=index_temps,
pos=self.pos, tmp_cname=self.tmpint.cname, code=code) pos=self.pos, code=code)
return valuecode return valuecode
......
...@@ -370,7 +370,6 @@ def create_default_pipeline(context, options, result): ...@@ -370,7 +370,6 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from CodeGeneration import AnchorTemps
from Buffer import IntroduceBufferAuxiliaryVars from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
def printit(x): print x.dump() def printit(x): print x.dump()
...@@ -389,7 +388,6 @@ def create_default_pipeline(context, options, result): ...@@ -389,7 +388,6 @@ def create_default_pipeline(context, options, result):
# BufferTransform(context), # BufferTransform(context),
SwitchTransform(), SwitchTransform(),
OptimizeRefcounting(context), OptimizeRefcounting(context),
AnchorTemps(context),
# CreateClosureClasses(context), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
......
...@@ -8,6 +8,9 @@ ...@@ -8,6 +8,9 @@
pyrex_prefix = "__pyx_" pyrex_prefix = "__pyx_"
codewriter_temp_prefix = "_tmp"
temp_prefix = u"__cyt_" temp_prefix = u"__cyt_"
builtin_prefix = pyrex_prefix + "builtin_" builtin_prefix = pyrex_prefix + "builtin_"
......
...@@ -866,7 +866,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -866,7 +866,7 @@ class FuncDefNode(StatNode, BlockNode):
(self.return_type.declaration_code( (self.return_type.declaration_code(
Naming.retval_cname), Naming.retval_cname),
init)) init))
code.put_var_declarations(lenv.temp_entries) tempvardecl_code = code.insertion_point()
self.generate_keyword_list(code) self.generate_keyword_list(code)
# ----- Extern library function declarations # ----- Extern library function declarations
lenv.generate_library_function_declarations(code) lenv.generate_library_function_declarations(code)
...@@ -966,6 +966,9 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -966,6 +966,9 @@ class FuncDefNode(StatNode, BlockNode):
if not self.return_type.is_void: if not self.return_type.is_void:
code.putln("return %s;" % Naming.retval_cname) code.putln("return %s;" % Naming.retval_cname)
code.putln("}") code.putln("}")
# ----- Go back and insert temp variable declarations
tempvardecl_code.put_var_declarations(lenv.temp_entries)
tempvardecl_code.put_temp_declarations(code.func)
# ----- Python version # ----- Python version
code.exit_cfunc_scope() code.exit_cfunc_scope()
if self.py_func: if self.py_func:
......
...@@ -145,15 +145,6 @@ class Entry: ...@@ -145,15 +145,6 @@ class Entry:
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
error(self.pos, "Previous declaration is here") error(self.pos, "Previous declaration is here")
def new_temp(type, description=""):
# Returns a temporary entry which is "floating" and not finally resolved
# before the AnchorTemps transform is run. cname will not be available on
# the temp before this transform is run. See the mentioned transform for
# more docs.
e = Entry(name="$" + description, type=type, cname="<temperror>")
e.is_variable = True
return e
class Scope: class Scope:
# name string Unqualified name # name string Unqualified name
# outer_scope Scope or None Enclosing scope # outer_scope Scope or None Enclosing scope
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment