Commit f51ac43e authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

In the middle of a buffer refactor (nonworking; done indexing)

parent 8fd5165e
......@@ -17,8 +17,10 @@ class PureCFuncNode(Node):
self.type = type
self.c_code = c_code
self.visibility = visibility
self.entry = None
def analyse_types(self, env):
def analyse_expressions(self, env):
if not self.entry:
self.entry = env.declare_cfunction(
"<pure c function:%s>" % self.cname,
self.type, self.pos, cname=self.cname,
......@@ -52,17 +54,7 @@ tschecker_functype = PyrexTypes.CFuncType(
tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
class IntroduceBufferAuxiliaryVars(CythonTransform):
#
# Entry point
......@@ -70,7 +62,6 @@ class BufferTransform(CythonTransform):
def __call__(self, node):
assert isinstance(node, ModuleNode)
try:
cymod = self.context.modules[u'__cython__']
except KeyError:
......@@ -82,7 +73,7 @@ class BufferTransform(CythonTransform):
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
......@@ -101,6 +92,9 @@ class BufferTransform(CythonTransform):
in scope.entries.iteritems()
if entry.type.is_buffer]
if isinstance(node, ModuleNode) and len(bufvars) > 0:
# for now...note that pos is wrong
raise CompileError(node.pos, "Buffer vars not allowed in module scope")
for entry in bufvars:
name = entry.name
buftype = entry.type
......@@ -133,147 +127,6 @@ class BufferTransform(CythonTransform):
scope.buffer_entries = bufvars
self.scope = scope
acquire_buffer_fragment = TreeFragment(u"""
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def acquire_buffer_stats(self, entry, buffer_aux, pos):
# Just the stats for acquiring and unpacking the buffer auxiliaries
auxass = []
for idx, strideentry in enumerate(buffer_aux.stridevars):
strideentry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(pos, name=strideentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx)),
})
auxass += ass.stats
for idx, shapeentry in enumerate(buffer_aux.shapevars):
shapeentry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(pos, name=shapeentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx))
})
auxass += ass.stats
buffer_aux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"SUBJECT" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
}, pos=pos)
return acq.stats + auxass
def acquire_argument_buffer_stats(self, entry, pos):
# On function entry, not getting a buffer is an uncatchable
# exception, so we don't need to worry about what happens if
# we don't get a buffer.
stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
for s in stats:
s.analyse_declarations(self.scope)
s.analyse_expressions(self.scope)
return stats
# Notes: The cast to <char*> gets around Cython not supporting const types
reacquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
ACQUIRE
LHS = TMP
""")
def reacquire_buffer(self, node):
buffer_aux = node.lhs.entry.buffer_aux
acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
acq = self.reacquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
}, pos=node.pos)
# Preserve first assignment info on LHS
if node.first:
# TODO: Prettier code
acq.stats[4].first = True
del acq.stats[0]
del acq.stats[0]
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
pos = node.pos
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them with the buffer pointer
expr = AttributeNode(pos,
obj=NameNode(pos, name=bufaux.buffer_info_var.name),
attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result
buffer_cleanup_fragment = TreeFragment(u"""
if BUF is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
""")
def funcdef_buffer_cleanup(self, node, pos):
env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({
u"BUF" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
}, pos=pos)
for entry in node.local_scope.buffer_entries]
cleanup_stats = []
for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats)
cleanup.analyse_expressions(env)
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
node.body = StatListNode.create_analysed(pos, env, stats=[result])
return node
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
......@@ -282,40 +135,6 @@ class BufferTransform(CythonTransform):
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
node = self.funcdef_buffer_cleanup(node, node.pos)
stats = []
for arg in node.local_scope.arg_entries:
if arg.type.is_buffer:
stats += self.acquire_argument_buffer_stats(arg, node.pos)
node.body.stats = stats + node.body.stats
return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
else:
return node
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else:
return node
#
......@@ -325,7 +144,7 @@ class BufferTransform(CythonTransform):
def new_ts_func(self, name, code):
cname = "%s_%s" % (tsprefix, name)
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
funcnode.analyse_types(self.module_scope)
funcnode.analyse_expressions(self.module_scope)
self.ts_funcs.append(funcnode)
return funcnode
......@@ -464,7 +283,179 @@ class BufferTransform(CythonTransform):
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
#
# Entry point
#
def __call__(self, node):
assert isinstance(node, ModuleNode)
try:
cymod = self.context.modules[u'__cython__']
except KeyError:
# No buffer fun for this module
return node
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result
acquire_buffer_fragment = TreeFragment(u"""
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def acquire_buffer_stats(self, entry, buffer_aux, pos):
# Just the stats for acquiring and unpacking the buffer auxiliaries
auxass = []
for idx, strideentry in enumerate(buffer_aux.stridevars):
strideentry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(pos, name=strideentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx)),
})
auxass += ass.stats
for idx, shapeentry in enumerate(buffer_aux.shapevars):
shapeentry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(pos, name=shapeentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx))
})
auxass += ass.stats
buffer_aux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"SUBJECT" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
}, pos=pos)
return acq.stats + auxass
def acquire_argument_buffer_stats(self, entry, pos):
# On function entry, not getting a buffer is an uncatchable
# exception, so we don't need to worry about what happens if
# we don't get a buffer.
stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
for s in stats:
s.analyse_declarations(self.scope)
#s.analyse_expressions(self.scope)
return stats
# Notes: The cast to <char*> gets around Cython not supporting const types
reacquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
ACQUIRE
LHS = TMP
""")
def reacquire_buffer(self, node):
buffer_aux = node.lhs.entry.buffer_aux
acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
acq = self.reacquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
}, pos=node.pos)
# Preserve first assignment info on LHS
if node.first:
# TODO: Prettier code
acq.stats[4].first = True
del acq.stats[0]
del acq.stats[0]
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
buffer_cleanup_fragment = TreeFragment(u"""
if BUF is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
""")
def funcdef_buffer_cleanup(self, node, pos):
env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({
u"BUF" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
}, pos=pos)
for entry in node.local_scope.buffer_entries]
cleanup_stats = []
for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats)
cleanup.analyse_expressions(env)
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
node.body = StatListNode.create_analysed(pos, env, stats=[result])
return node
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
node = self.funcdef_buffer_cleanup(node, node.pos)
stats = []
for arg in node.local_scope.arg_entries:
if arg.type.is_buffer:
stats += self.acquire_argument_buffer_stats(arg, node.pos)
node.body.stats = stats + node.body.stats
return node
# TODO:
# - buf must be NULL before getting new buffer
......@@ -893,6 +893,10 @@ class NameNode(AtomicExprNode):
% self.name)
self.type = PyrexTypes.error_type
self.entry.used = 1
if self.entry.type.is_buffer:
# Need some temps
print self.dump()
def analyse_rvalue_entry(self, env):
#print "NameNode.analyse_rvalue_entry:", self.name ###
......@@ -1311,6 +1315,9 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
self.is_buffer_access = False
self.base.analyse_types(env)
......@@ -1318,6 +1325,7 @@ class IndexNode(ExprNode):
skip_child_analysis = False
buffer_access = False
if self.base.type.is_buffer:
assert isinstance(self.base, NameNode)
if isinstance(self.index, TupleNode):
indices = self.index.args
else:
......@@ -1329,21 +1337,19 @@ class IndexNode(ExprNode):
x.analyse_types(env)
if not x.type.is_int:
buffer_access = False
if buffer_access:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.dtype
self.is_temp = 1
self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
if not buffer_access:
self.index_temps = [Symtab.new_temp(i.type) for i in indices]
self.temps = self.index_temps
if getting:
# we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take
self.is_temp = True
else:
if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis:
......@@ -1388,7 +1394,7 @@ class IndexNode(ExprNode):
def calculate_result_code(self):
if self.is_buffer_access:
return "<not needed>"
return "<not used>"
else:
return "(%s[%s])" % (
self.base.result_code, self.index.result_code)
......@@ -1407,7 +1413,8 @@ class IndexNode(ExprNode):
if self.index is not None:
self.index.generate_evaluation_code(code)
else:
for i in self.indices: i.generate_evaluation_code(code)
for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code)
......@@ -1417,7 +1424,10 @@ class IndexNode(ExprNode):
for i in self.indices: i.generate_disposal_code(code)
def generate_result_code(self, code):
if self.type.is_pyobject:
if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode))
elif self.type.is_pyobject:
if self.index.type.is_int:
function = "__Pyx_GetItemInt"
index_code = self.index.result_code
......@@ -1453,7 +1463,10 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject:
if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code))
elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code)
else:
code.putln(
......@@ -1479,6 +1492,23 @@ class IndexNode(ExprNode):
code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code):
# 1. Assign indices to temps
for temp, index in zip(self.index_temps, self.indices):
code.putln("%s = %s;" % (temp.cname, index.result_code))
# 2. Output code to do bounds checking on these
# 3. Return a code fragment string which does buffer
# lookup, which can be used on lhs or rhs of an assignment
# in the caller depending on the scenario.
bufaux = self.base.entry.buffer_aux
offset = " + ".join(["%s * %s" % (idx.cname, stride.cname)
for idx, stride in
zip(self.index_temps, bufaux.stridevars)])
ptrcode = "(%s.buf + %s)" % (bufaux.buffer_info_var.cname, offset)
valuecode = "*%s" % self.base.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode
class SliceIndexNode(ExprNode):
# 2-element slice indexing
......
......@@ -361,23 +361,25 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from CodeGeneration import AnchorTemps
from Buffer import BufferTransform
from Buffer import BufferTransform, IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [
create_parse(context),
# printit,
NormalizeTree(context),
PostParse(context),
FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes,
AnalyseExpressionsTransform(context),
BufferTransform(context),
# BufferTransform(context),
SwitchTransform(),
OptimizeRefcounting(context),
# AnchorTemps(context),
AnchorTemps(context),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
......
......@@ -203,6 +203,7 @@ class BufferType(BaseType):
self.base = base
self.dtype = dtype
self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
def as_argument_type(self):
return self
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment