Commit f51ac43e authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

In the middle of a buffer refactor (nonworking; done indexing)

parent 8fd5165e
...@@ -17,12 +17,14 @@ class PureCFuncNode(Node): ...@@ -17,12 +17,14 @@ class PureCFuncNode(Node):
self.type = type self.type = type
self.c_code = c_code self.c_code = c_code
self.visibility = visibility self.visibility = visibility
self.entry = None
def analyse_types(self, env): def analyse_expressions(self, env):
self.entry = env.declare_cfunction( if not self.entry:
"<pure c function:%s>" % self.cname, self.entry = env.declare_cfunction(
self.type, self.pos, cname=self.cname, "<pure c function:%s>" % self.cname,
defining=True, visibility=self.visibility) self.type, self.pos, cname=self.cname,
defining=True, visibility=self.visibility)
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code, transforms):
assert self.type.optional_arg_count == 0 assert self.type.optional_arg_count == 0
...@@ -52,17 +54,7 @@ tschecker_functype = PyrexTypes.CFuncType( ...@@ -52,17 +54,7 @@ tschecker_functype = PyrexTypes.CFuncType(
tsprefix = "__Pyx_tsc" tsprefix = "__Pyx_tsc"
class BufferTransform(CythonTransform): class IntroduceBufferAuxiliaryVars(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
# #
# Entry point # Entry point
...@@ -70,7 +62,6 @@ class BufferTransform(CythonTransform): ...@@ -70,7 +62,6 @@ class BufferTransform(CythonTransform):
def __call__(self, node): def __call__(self, node):
assert isinstance(node, ModuleNode) assert isinstance(node, ModuleNode)
try: try:
cymod = self.context.modules[u'__cython__'] cymod = self.context.modules[u'__cython__']
except KeyError: except KeyError:
...@@ -82,7 +73,7 @@ class BufferTransform(CythonTransform): ...@@ -82,7 +73,7 @@ class BufferTransform(CythonTransform):
self.ts_item_checkers = {} self.ts_item_checkers = {}
self.module_scope = node.scope self.module_scope = node.scope
self.module_pos = node.pos self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node) result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
# Register ts stuff # Register ts stuff
if "endian.h" not in node.scope.include_files: if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h") node.scope.include_files.append("endian.h")
...@@ -101,6 +92,9 @@ class BufferTransform(CythonTransform): ...@@ -101,6 +92,9 @@ class BufferTransform(CythonTransform):
in scope.entries.iteritems() in scope.entries.iteritems()
if entry.type.is_buffer] if entry.type.is_buffer]
if isinstance(node, ModuleNode) and len(bufvars) > 0:
# for now...note that pos is wrong
raise CompileError(node.pos, "Buffer vars not allowed in module scope")
for entry in bufvars: for entry in bufvars:
name = entry.name name = entry.name
buftype = entry.type buftype = entry.type
...@@ -133,147 +127,6 @@ class BufferTransform(CythonTransform): ...@@ -133,147 +127,6 @@ class BufferTransform(CythonTransform):
scope.buffer_entries = bufvars scope.buffer_entries = bufvars
self.scope = scope self.scope = scope
acquire_buffer_fragment = TreeFragment(u"""
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def acquire_buffer_stats(self, entry, buffer_aux, pos):
# Just the stats for acquiring and unpacking the buffer auxiliaries
auxass = []
for idx, strideentry in enumerate(buffer_aux.stridevars):
strideentry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(pos, name=strideentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx)),
})
auxass += ass.stats
for idx, shapeentry in enumerate(buffer_aux.shapevars):
shapeentry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(pos, name=shapeentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx))
})
auxass += ass.stats
buffer_aux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"SUBJECT" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
}, pos=pos)
return acq.stats + auxass
def acquire_argument_buffer_stats(self, entry, pos):
# On function entry, not getting a buffer is an uncatchable
# exception, so we don't need to worry about what happens if
# we don't get a buffer.
stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
for s in stats:
s.analyse_declarations(self.scope)
s.analyse_expressions(self.scope)
return stats
# Notes: The cast to <char*> gets around Cython not supporting const types
reacquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
ACQUIRE
LHS = TMP
""")
def reacquire_buffer(self, node):
buffer_aux = node.lhs.entry.buffer_aux
acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
acq = self.reacquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
}, pos=node.pos)
# Preserve first assignment info on LHS
if node.first:
# TODO: Prettier code
acq.stats[4].first = True
del acq.stats[0]
del acq.stats[0]
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
pos = node.pos
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them with the buffer pointer
expr = AttributeNode(pos,
obj=NameNode(pos, name=bufaux.buffer_info_var.name),
attribute=EncodedString("buf"))
for next in to_sum:
expr = AddNode(pos, operator='+', operand1=expr, operand2=next)
casted = TypecastNode(pos, operand=expr,
type=PyrexTypes.c_ptr_type(node.base.entry.type.dtype))
result = IndexNode(pos, base=casted, index=IntNode(pos, value='0'))
return result
buffer_cleanup_fragment = TreeFragment(u"""
if BUF is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
""")
def funcdef_buffer_cleanup(self, node, pos):
env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({
u"BUF" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
}, pos=pos)
for entry in node.local_scope.buffer_entries]
cleanup_stats = []
for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats)
cleanup.analyse_expressions(env)
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
node.body = StatListNode.create_analysed(pos, env, stats=[result])
return node
#
# Transforms
#
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope) self.handle_scope(node, node.scope)
self.visitchildren(node) self.visitchildren(node)
...@@ -282,42 +135,8 @@ class BufferTransform(CythonTransform): ...@@ -282,42 +135,8 @@ class BufferTransform(CythonTransform):
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope) self.handle_scope(node, node.local_scope)
self.visitchildren(node) self.visitchildren(node)
node = self.funcdef_buffer_cleanup(node, node.pos)
stats = []
for arg in node.local_scope.arg_entries:
if arg.type.is_buffer:
stats += self.acquire_argument_buffer_stats(arg, node.pos)
node.body.stats = stats + node.body.stats
return node return node
def visit_SingleAssignmentNode(self, node):
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
else:
return node
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else:
return node
# #
# Utils for creating type string checkers # Utils for creating type string checkers
# #
...@@ -325,7 +144,7 @@ class BufferTransform(CythonTransform): ...@@ -325,7 +144,7 @@ class BufferTransform(CythonTransform):
def new_ts_func(self, name, code): def new_ts_func(self, name, code):
cname = "%s_%s" % (tsprefix, name) cname = "%s_%s" % (tsprefix, name)
funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code) funcnode = PureCFuncNode(self.module_pos, cname, tschecker_functype, code)
funcnode.analyse_types(self.module_scope) funcnode.analyse_expressions(self.module_scope)
self.ts_funcs.append(funcnode) self.ts_funcs.append(funcnode)
return funcnode return funcnode
...@@ -462,9 +281,181 @@ class BufferTransform(CythonTransform): ...@@ -462,9 +281,181 @@ class BufferTransform(CythonTransform):
self.tscheckers[dtype] = funcnode self.tscheckers[dtype] = funcnode
return funcnode.entry return funcnode.entry
class BufferTransform(CythonTransform):
"""
Run after type analysis. Takes care of the buffer functionality.
Expects to be run on the full module. If you need to process a fragment
one should look into refactoring this transform.
"""
# Abbreviations:
# "ts" means typestring and/or typestring checking stuff
scope = None
#
# Entry point
#
def __call__(self, node):
assert isinstance(node, ModuleNode)
try:
cymod = self.context.modules[u'__cython__']
except KeyError:
# No buffer fun for this module
return node
self.bufstruct_type = cymod.entries[u'Py_buffer'].type
self.tscheckers = {}
self.ts_funcs = []
self.ts_item_checkers = {}
self.module_scope = node.scope
self.module_pos = node.pos
result = super(BufferTransform, self).__call__(node)
# Register ts stuff
if "endian.h" not in node.scope.include_files:
node.scope.include_files.append("endian.h")
result.body.stats += self.ts_funcs
return result
acquire_buffer_fragment = TreeFragment(u"""
__cython__.PyObject_GetBuffer(<__cython__.PyObject*>SUBJECT, &BUFINFO, 0)
TSCHECKER(<char*>BUFINFO.format)
""")
fetch_strides = TreeFragment(u"""
TARGET = BUFINFO.strides[IDX]
""")
fetch_shape = TreeFragment(u"""
TARGET = BUFINFO.shape[IDX]
""")
def acquire_buffer_stats(self, entry, buffer_aux, pos):
# Just the stats for acquiring and unpacking the buffer auxiliaries
auxass = []
for idx, strideentry in enumerate(buffer_aux.stridevars):
strideentry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(pos, name=strideentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx)),
})
auxass += ass.stats
for idx, shapeentry in enumerate(buffer_aux.shapevars):
shapeentry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(pos, name=shapeentry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"IDX": IntNode(pos, value=EncodedString(idx))
})
auxass += ass.stats
buffer_aux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"SUBJECT" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=buffer_aux.buffer_info_var.name),
u"TSCHECKER": NameNode(pos, name=buffer_aux.tschecker.name)
}, pos=pos)
return acq.stats + auxass
def acquire_argument_buffer_stats(self, entry, pos):
# On function entry, not getting a buffer is an uncatchable
# exception, so we don't need to worry about what happens if
# we don't get a buffer.
stats = self.acquire_buffer_stats(entry, entry.buffer_aux, pos)
for s in stats:
s.analyse_declarations(self.scope)
#s.analyse_expressions(self.scope)
return stats
# Notes: The cast to <char*> gets around Cython not supporting const types
reacquire_buffer_fragment = TreeFragment(u"""
TMP = LHS
if TMP is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
TMP = RHS
if TMP is not None:
ACQUIRE
LHS = TMP
""")
def reacquire_buffer(self, node):
buffer_aux = node.lhs.entry.buffer_aux
acquire_stats = self.acquire_buffer_stats(buffer_aux.temp_var, buffer_aux, node.pos)
acq = self.reacquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=buffer_aux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ACQUIRE": StatListNode(node.pos, stats=acquire_stats),
u"BUFINFO": NameNode(pos=node.pos, name=buffer_aux.buffer_info_var.name)
}, pos=node.pos)
# Preserve first assignment info on LHS
if node.first:
# TODO: Prettier code
acq.stats[4].first = True
del acq.stats[0]
del acq.stats[0]
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
buffer_cleanup_fragment = TreeFragment(u"""
if BUF is not None:
__cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>BUF, &BUFINFO)
""")
def funcdef_buffer_cleanup(self, node, pos):
env = node.local_scope
cleanups = [self.buffer_cleanup_fragment.substitute({
u"BUF" : NameNode(pos, name=entry.name),
u"BUFINFO": NameNode(pos, name=entry.buffer_aux.buffer_info_var.name)
}, pos=pos)
for entry in node.local_scope.buffer_entries]
cleanup_stats = []
for c in cleanups: cleanup_stats += c.stats
cleanup = StatListNode(pos, stats=cleanup_stats)
cleanup.analyse_expressions(env)
result = TryFinallyStatNode.create_analysed(pos, env, body=node.body, finally_clause=cleanup)
node.body = StatListNode.create_analysed(pos, env, stats=[result])
return node
#
# Transforms
#
def visit_ModuleNode(self, node):
self.handle_scope(node, node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.handle_scope(node, node.local_scope)
self.visitchildren(node)
node = self.funcdef_buffer_cleanup(node, node.pos)
stats = []
for arg in node.local_scope.arg_entries:
if arg.type.is_buffer:
stats += self.acquire_argument_buffer_stats(arg, node.pos)
node.body.stats = stats + node.body.stats
return node
# TODO: # TODO:
# - buf must be NULL before getting new buffer # - buf must be NULL before getting new buffer
...@@ -893,6 +893,10 @@ class NameNode(AtomicExprNode): ...@@ -893,6 +893,10 @@ class NameNode(AtomicExprNode):
% self.name) % self.name)
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.entry.used = 1 self.entry.used = 1
if self.entry.type.is_buffer:
# Need some temps
print self.dump()
def analyse_rvalue_entry(self, env): def analyse_rvalue_entry(self, env):
#print "NameNode.analyse_rvalue_entry:", self.name ### #print "NameNode.analyse_rvalue_entry:", self.name ###
...@@ -1311,6 +1315,9 @@ class IndexNode(ExprNode): ...@@ -1311,6 +1315,9 @@ class IndexNode(ExprNode):
self.analyse_base_and_index_types(env, setting = 1) self.analyse_base_and_index_types(env, setting = 1)
def analyse_base_and_index_types(self, env, getting = 0, setting = 0): def analyse_base_and_index_types(self, env, getting = 0, setting = 0):
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
self.is_buffer_access = False self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
...@@ -1318,6 +1325,7 @@ class IndexNode(ExprNode): ...@@ -1318,6 +1325,7 @@ class IndexNode(ExprNode):
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
if self.base.type.is_buffer: if self.base.type.is_buffer:
assert isinstance(self.base, NameNode)
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
indices = self.index.args indices = self.index.args
else: else:
...@@ -1329,21 +1337,19 @@ class IndexNode(ExprNode): ...@@ -1329,21 +1337,19 @@ class IndexNode(ExprNode):
x.analyse_types(env) x.analyse_types(env)
if not x.type.is_int: if not x.type.is_int:
buffer_access = False buffer_access = False
if buffer_access:
# self.indices = [
# x.coerce_to(PyrexTypes.c_py_ssize_t_type, env).coerce_to_simple(env)
# for x in indices]
self.indices = indices
self.index = None
self.type = self.base.type.dtype
self.is_temp = 1
self.is_buffer_access = True
# Note: This might be cleaned up by having IndexNode
# parsed in a saner way and only construct the tuple if
# needed.
if not buffer_access: if buffer_access:
self.indices = indices
self.index = None
self.type = self.base.type.dtype
self.is_buffer_access = True
self.index_temps = [Symtab.new_temp(i.type) for i in indices]
self.temps = self.index_temps
if getting:
# we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take
self.is_temp = True
else:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
self.index.analyse_types(env, skip_children=skip_child_analysis) self.index.analyse_types(env, skip_children=skip_child_analysis)
elif not skip_child_analysis: elif not skip_child_analysis:
...@@ -1388,7 +1394,7 @@ class IndexNode(ExprNode): ...@@ -1388,7 +1394,7 @@ class IndexNode(ExprNode):
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access: if self.is_buffer_access:
return "<not needed>" return "<not used>"
else: else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result_code, self.index.result_code) self.base.result_code, self.index.result_code)
...@@ -1407,7 +1413,8 @@ class IndexNode(ExprNode): ...@@ -1407,7 +1413,8 @@ class IndexNode(ExprNode):
if self.index is not None: if self.index is not None:
self.index.generate_evaluation_code(code) self.index.generate_evaluation_code(code)
else: else:
for i in self.indices: i.generate_evaluation_code(code) for i in self.indices:
i.generate_evaluation_code(code)
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
...@@ -1417,7 +1424,10 @@ class IndexNode(ExprNode): ...@@ -1417,7 +1424,10 @@ class IndexNode(ExprNode):
for i in self.indices: i.generate_disposal_code(code) for i in self.indices: i.generate_disposal_code(code)
def generate_result_code(self, code): def generate_result_code(self, code):
if self.type.is_pyobject: if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode))
elif self.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
function = "__Pyx_GetItemInt" function = "__Pyx_GetItemInt"
index_code = self.index.result_code index_code = self.index.result_code
...@@ -1453,7 +1463,10 @@ class IndexNode(ExprNode): ...@@ -1453,7 +1463,10 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.type.is_pyobject: if self.is_buffer_access:
valuecode = self.buffer_access_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code))
elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
else: else:
code.putln( code.putln(
...@@ -1479,6 +1492,23 @@ class IndexNode(ExprNode): ...@@ -1479,6 +1492,23 @@ class IndexNode(ExprNode):
code.error_goto(self.pos))) code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code):
# 1. Assign indices to temps
for temp, index in zip(self.index_temps, self.indices):
code.putln("%s = %s;" % (temp.cname, index.result_code))
# 2. Output code to do bounds checking on these
# 3. Return a code fragment string which does buffer
# lookup, which can be used on lhs or rhs of an assignment
# in the caller depending on the scenario.
bufaux = self.base.entry.buffer_aux
offset = " + ".join(["%s * %s" % (idx.cname, stride.cname)
for idx, stride in
zip(self.index_temps, bufaux.stridevars)])
ptrcode = "(%s.buf + %s)" % (bufaux.buffer_info_var.cname, offset)
valuecode = "*%s" % self.base.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode
class SliceIndexNode(ExprNode): class SliceIndexNode(ExprNode):
# 2-element slice indexing # 2-element slice indexing
......
...@@ -361,23 +361,25 @@ def create_default_pipeline(context, options, result): ...@@ -361,23 +361,25 @@ def create_default_pipeline(context, options, result):
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from CodeGeneration import AnchorTemps from CodeGeneration import AnchorTemps
from Buffer import BufferTransform from Buffer import BufferTransform, IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [ return [
create_parse(context), create_parse(context),
# printit,
NormalizeTree(context), NormalizeTree(context),
PostParse(context), PostParse(context),
FlattenInListTransform(), FlattenInListTransform(),
WithTransform(context), WithTransform(context),
DecoratorTransform(context), DecoratorTransform(context),
AnalyseDeclarationsTransform(context), AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(context), AnalyseExpressionsTransform(context),
BufferTransform(context), # BufferTransform(context),
SwitchTransform(), SwitchTransform(),
OptimizeRefcounting(context), OptimizeRefcounting(context),
# AnchorTemps(context), AnchorTemps(context),
# CreateClosureClasses(context), # CreateClosureClasses(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
......
...@@ -203,6 +203,7 @@ class BufferType(BaseType): ...@@ -203,6 +203,7 @@ class BufferType(BaseType):
self.base = base self.base = base
self.dtype = dtype self.dtype = dtype
self.ndim = ndim self.ndim = ndim
self.buffer_ptr_type = CPtrType(dtype)
def as_argument_type(self): def as_argument_type(self):
return self return self
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment