Commit 3ba21c1d authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Works with some assignment expressions

parent 72d54fb4
......@@ -1339,8 +1339,11 @@ class IndexNode(ExprNode):
return 1
def calculate_result_code(self):
return "(%s[%s])" % (
self.base.result_code, self.index.result_code)
if self.is_buffer_access:
return "<not needed>"
else:
return "(%s[%s])" % (
self.base.result_code, self.index.result_code)
def index_unsigned_parameter(self):
if self.index.type.is_int:
......@@ -3842,6 +3845,10 @@ class CoerceToPyTypeNode(CoercionNode):
gil_message = "Converting to Python object"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code):
function = self.arg.type.to_py_function
code.putln('%s = %s(%s); %s' % (
......@@ -3866,6 +3873,10 @@ class CoerceFromPyTypeNode(CoercionNode):
error(arg.pos,
"Obtaining char * from temporary Python value")
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code):
function = self.type.from_py_function
operand = self.arg.py_result()
......
......@@ -155,6 +155,7 @@ class BufferTransform(CythonTransform):
def __call__(self, node):
cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type
self.lhs = False
return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope):
......@@ -229,75 +230,105 @@ class BufferTransform(CythonTransform):
# attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump()
def visit_SingleAssignmentNode(self, node):
self.visitchildren(node)
# On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
# self.lhs = True
self.visit(node.rhs)
self.visit(node.lhs)
# self.lhs = False
# self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux
if bufaux is not None:
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
auxass = []
for idx, entry in enumerate(bufaux.stridevars):
entry.used = True
ass = self.fetch_strides.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
entry.used = True
ass = self.fetch_shape.substitute({
u"TARGET": NameNode(node.pos, name=entry.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"IDX": IntNode(node.pos, value=EncodedString(idx))
})
auxass.append(ass)
bufaux.buffer_info_var.used = True
acq = self.acquire_buffer_fragment.substitute({
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"LHS" : node.lhs,
u"RHS": node.rhs,
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(node.pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
}, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
# stats += [node] # Do assignment after successful buffer acquisition
# print acq.dump()
return stats
else:
return node
return tmp.stats[0].expr
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index,
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
print to_sum
indices = node.indices
# reduce * on indices
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
})
tmp.analyse_expressions(self.scope)
return tmp.stats[0].expr
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else:
return node
......@@ -309,6 +340,56 @@ class BufferTransform(CythonTransform):
# print node.dump()
class PhaseEnvelopeNode(Node):
"""
This node is used if you need to protect a node from reevaluation
of a phase. For instance, if you extract...
Use with care!
"""
# Phases
PARSED, ANALYSED = range(2)
def __init__(self, phase, wrapped):
self.phase = phase
self.wrapped = wrapped
def get_pos(self): return self.wrapped.pos
def set_pos(self, value): self.wrapped.pos = value
pos = property(get_pos, set_pos)
def get_subexprs(self): return self.wrapped.subexprs
subexprs = property(get_subexprs)
def analyse_types(self, env):
if self.phase < self.ANALYSED:
self.wrapped.analyse_types(env)
def __getattribute__(self, attrname):
wrapped = object.__getattribute__(self, "wrapped")
phase = object.__getattribute__(self, "phase")
if attrname == "wrapped": return wrapped
if attrname == "phase": return phase
attr = getattr(wrapped, attrname)
overridden = ("analyse_types",)
print attrname, attr
if not isinstance(attr, Node):
return attr
else:
if attr is None: return None
else:
return PhaseEnvelopeNode(phase, attr)
class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment