Commit c5ba9581 authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Moved buffer transform to Buffer.py

parent 3ba21c1d
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
from sets import Set as set
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
"""
scope = None
def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
for name, entry in bufvars:
# Variable has buffer opts, declare auxiliary vars
bufopts = entry.type.buffer_options
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.buffer_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(node.pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
}, pos=node.pos)
return tmp.stats[0].expr
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else:
return node
def visit_CallNode(self, node):
### print node.dump()
return node
# def visit_FuncDefNode(self, node):
# print node.dump()
...@@ -354,9 +354,10 @@ def create_generate_code(context, options, result): ...@@ -354,9 +354,10 @@ def create_generate_code(context, options, result):
return generate_code return generate_code
def create_default_pipeline(context, options, result): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, BufferTransform from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
from Buffer import BufferTransform
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
return [ return [
......
...@@ -146,250 +146,6 @@ class PostParse(CythonTransform): ...@@ -146,250 +146,6 @@ class PostParse(CythonTransform):
node.keyword_args = None node.keyword_args = None
return node return node
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
"""
scope = None
def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
self.lhs = False
return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope):
# For all buffers, insert extra variables in the scope.
# The variables are also accessible from the buffer_info
# on the buffer entry
bufvars = [(name, entry) for name, entry
in scope.entries.iteritems()
if entry.type.buffer_options is not None]
for name, entry in bufvars:
# Variable has buffer opts, declare auxiliary vars
bufopts = entry.type.buffer_options
bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
self.buffer_type, node.pos)
temp_var = scope.declare_var(temp_name_handle(u"%s_tmp" % name),
entry.type, node.pos)
stridevars = []
shapevars = []
for idx in range(bufopts.ndim):
# stride
varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
stridevars.append(var)
# shape
varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
shapevars.append(var)
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars,
shapevars)
entry.buffer_aux.temp_var = temp_var
self.scope = scope
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
return node
acquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
ASSIGN_AUX
LHS = TMP
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
# ass = SingleAssignmentNode(pos=node.pos,
# lhs=NameNode(node.pos, name=entry.name),
# rhs=IndexNode(node.pos,
# base=AttributeNode(node.pos,
# obj=NameNode(node.pos, name=bufaux.buffer_info_var.name),
# attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump()
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
# self.lhs = True
self.visit(node.rhs)
self.visit(node.lhs)
# self.lhs = False
# self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(node.pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
}, pos=node.pos)
return tmp.stats[0].expr
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else:
return node
def visit_CallNode(self, node):
### print node.dump()
return node
# def visit_FuncDefNode(self, node):
# print node.dump()
class PhaseEnvelopeNode(Node):
"""
This node is used if you need to protect a node from reevaluation
of a phase. For instance, if you extract...
Use with care!
"""
# Phases
PARSED, ANALYSED = range(2)
def __init__(self, phase, wrapped):
self.phase = phase
self.wrapped = wrapped
def get_pos(self): return self.wrapped.pos
def set_pos(self, value): self.wrapped.pos = value
pos = property(get_pos, set_pos)
def get_subexprs(self): return self.wrapped.subexprs
subexprs = property(get_subexprs)
def analyse_types(self, env):
if self.phase < self.ANALYSED:
self.wrapped.analyse_types(env)
def __getattribute__(self, attrname):
wrapped = object.__getattribute__(self, "wrapped")
phase = object.__getattribute__(self, "phase")
if attrname == "wrapped": return wrapped
if attrname == "phase": return phase
attr = getattr(wrapped, attrname)
overridden = ("analyse_types",)
print attrname, attr
if not isinstance(attr, Node):
return attr
else:
if attr is None: return None
else:
return PhaseEnvelopeNode(phase, attr)
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment