Commit 3ba21c1d authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Works with some assignment expressions

parent 72d54fb4
...@@ -1339,6 +1339,9 @@ class IndexNode(ExprNode): ...@@ -1339,6 +1339,9 @@ class IndexNode(ExprNode):
return 1 return 1
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access:
return "<not needed>"
else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result_code, self.index.result_code) self.base.result_code, self.index.result_code)
...@@ -3842,6 +3845,10 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -3842,6 +3845,10 @@ class CoerceToPyTypeNode(CoercionNode):
gil_message = "Converting to Python object" gil_message = "Converting to Python object"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.arg.type.to_py_function function = self.arg.type.to_py_function
code.putln('%s = %s(%s); %s' % ( code.putln('%s = %s(%s); %s' % (
...@@ -3866,6 +3873,10 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -3866,6 +3873,10 @@ class CoerceFromPyTypeNode(CoercionNode):
error(arg.pos, error(arg.pos,
"Obtaining char * from temporary Python value") "Obtaining char * from temporary Python value")
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.type.from_py_function function = self.type.from_py_function
operand = self.arg.py_result() operand = self.arg.py_result()
......
...@@ -155,6 +155,7 @@ class BufferTransform(CythonTransform): ...@@ -155,6 +155,7 @@ class BufferTransform(CythonTransform):
def __call__(self, node): def __call__(self, node):
cymod = self.context.modules[u'__cython__'] cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type self.buffer_type = cymod.entries[u'Py_buffer'].type
self.lhs = False
return super(BufferTransform, self).__call__(node) return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope): def handle_scope(self, node, scope):
...@@ -229,10 +230,29 @@ class BufferTransform(CythonTransform): ...@@ -229,10 +230,29 @@ class BufferTransform(CythonTransform):
# attribute=EncodedString("strides")), # attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx)))) # index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump() # print ass.dump()
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
self.visitchildren(node) # On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
# self.lhs = True
self.visit(node.rhs)
self.visit(node.lhs)
# self.lhs = False
# self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux bufaux = node.lhs.entry.buffer_aux
if bufaux is not None:
auxass = [] auxass = []
for idx, entry in enumerate(bufaux.stridevars): for idx, entry in enumerate(bufaux.stridevars):
entry.used = True entry.used = True
...@@ -267,37 +287,48 @@ class BufferTransform(CythonTransform): ...@@ -267,37 +287,48 @@ class BufferTransform(CythonTransform):
acq.analyse_declarations(self.scope) acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope) acq.analyse_expressions(self.scope)
stats = acq.stats stats = acq.stats
# stats += [node] # Do assignment after successful buffer acquisition
# print acq.dump()
return stats return stats
else:
return node
buffer_access = TreeFragment(u""" def assign_into_buffer(self, node):
(<unsigned char*>(BUF.buf + OFFSET))[0] result = SingleAssignmentNode(node.pos,
""") rhs=self.visit(node.rhs),
def visit_IndexNode(self, node): lhs=self.buffer_index(node.lhs))
if node.is_buffer_access: result.analyse_expressions(self.scope)
assert node.index is None return result
assert node.indices is not None
def buffer_index(self, node):
bufaux = node.base.entry.buffer_aux bufaux = node.base.entry.buffer_aux
assert bufaux is not None assert bufaux is not None
to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index, # indices * strides...
to_sum = [ IntBinopNode(node.pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name)) operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)] for index, stride in zip(node.indices, bufaux.stridevars)]
print to_sum
indices = node.indices # then sum them
# reduce * on indices
expr = to_sum[0] expr = to_sum[0]
for next in to_sum[1:]: for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next) expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({ tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name), 'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr 'OFFSET': expr
}) }, pos=node.pos)
tmp.analyse_expressions(self.scope)
return tmp.stats[0].expr return tmp.stats[0].expr
buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0]
""")
def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access:
assert node.index is None
assert node.indices is not None
result = self.buffer_index(node)
result.analyse_expressions(self.scope)
return result
else: else:
return node return node
...@@ -309,6 +340,56 @@ class BufferTransform(CythonTransform): ...@@ -309,6 +340,56 @@ class BufferTransform(CythonTransform):
# print node.dump() # print node.dump()
class PhaseEnvelopeNode(Node):
"""
This node is used if you need to protect a node from reevaluation
of a phase. For instance, if you extract...
Use with care!
"""
# Phases
PARSED, ANALYSED = range(2)
def __init__(self, phase, wrapped):
self.phase = phase
self.wrapped = wrapped
def get_pos(self): return self.wrapped.pos
def set_pos(self, value): self.wrapped.pos = value
pos = property(get_pos, set_pos)
def get_subexprs(self): return self.wrapped.subexprs
subexprs = property(get_subexprs)
def analyse_types(self, env):
if self.phase < self.ANALYSED:
self.wrapped.analyse_types(env)
def __getattribute__(self, attrname):
wrapped = object.__getattribute__(self, "wrapped")
phase = object.__getattribute__(self, "phase")
if attrname == "wrapped": return wrapped
if attrname == "phase": return phase
attr = getattr(wrapped, attrname)
overridden = ("analyse_types",)
print attrname, attr
if not isinstance(attr, Node):
return attr
else:
if attr is None: return None
else:
return PhaseEnvelopeNode(phase, attr)
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment