Commit 3ba21c1d authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Works with some assignment expressions

parent 72d54fb4
...@@ -1339,8 +1339,11 @@ class IndexNode(ExprNode): ...@@ -1339,8 +1339,11 @@ class IndexNode(ExprNode):
return 1 return 1
def calculate_result_code(self): def calculate_result_code(self):
return "(%s[%s])" % ( if self.is_buffer_access:
self.base.result_code, self.index.result_code) return "<not needed>"
else:
return "(%s[%s])" % (
self.base.result_code, self.index.result_code)
def index_unsigned_parameter(self): def index_unsigned_parameter(self):
if self.index.type.is_int: if self.index.type.is_int:
...@@ -3842,6 +3845,10 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -3842,6 +3845,10 @@ class CoerceToPyTypeNode(CoercionNode):
gil_message = "Converting to Python object" gil_message = "Converting to Python object"
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.arg.type.to_py_function function = self.arg.type.to_py_function
code.putln('%s = %s(%s); %s' % ( code.putln('%s = %s(%s); %s' % (
...@@ -3866,6 +3873,10 @@ class CoerceFromPyTypeNode(CoercionNode): ...@@ -3866,6 +3873,10 @@ class CoerceFromPyTypeNode(CoercionNode):
error(arg.pos, error(arg.pos,
"Obtaining char * from temporary Python value") "Obtaining char * from temporary Python value")
def analyse_types(self, env):
# The arg is always already analysed
pass
def generate_result_code(self, code): def generate_result_code(self, code):
function = self.type.from_py_function function = self.type.from_py_function
operand = self.arg.py_result() operand = self.arg.py_result()
......
...@@ -155,6 +155,7 @@ class BufferTransform(CythonTransform): ...@@ -155,6 +155,7 @@ class BufferTransform(CythonTransform):
def __call__(self, node): def __call__(self, node):
cymod = self.context.modules[u'__cython__'] cymod = self.context.modules[u'__cython__']
self.buffer_type = cymod.entries[u'Py_buffer'].type self.buffer_type = cymod.entries[u'Py_buffer'].type
self.lhs = False
return super(BufferTransform, self).__call__(node) return super(BufferTransform, self).__call__(node)
def handle_scope(self, node, scope): def handle_scope(self, node, scope):
...@@ -229,75 +230,105 @@ class BufferTransform(CythonTransform): ...@@ -229,75 +230,105 @@ class BufferTransform(CythonTransform):
# attribute=EncodedString("strides")), # attribute=EncodedString("strides")),
# index=IntNode(node.pos, value=EncodedString(idx)))) # index=IntNode(node.pos, value=EncodedString(idx))))
# print ass.dump() # print ass.dump()
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
self.visitchildren(node) # On assignments, two buffer-related things can happen:
# a) A buffer variable is assigned to (reacquisition)
# b) Buffer access assignment: arr[...] = ...
# Since we don't allow nested buffers, these don't overlap.
# self.lhs = True
self.visit(node.rhs)
self.visit(node.lhs)
# self.lhs = False
# self.visitchildren(node)
# Only acquire buffers on vars (not attributes) for now.
if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
# Is buffer variable
return self.reacquire_buffer(node)
elif (isinstance(node.lhs, IndexNode) and
isinstance(node.lhs.base, NameNode) and
node.lhs.base.entry.buffer_aux is not None):
return self.assign_into_buffer(node)
def reacquire_buffer(self, node):
bufaux = node.lhs.entry.buffer_aux bufaux = node.lhs.entry.buffer_aux
if bufaux is not None: auxass = []
auxass = [] for idx, entry in enumerate(bufaux.stridevars):
for idx, entry in enumerate(bufaux.stridevars): entry.used = True
entry.used = True ass = self.fetch_strides.substitute({
ass = self.fetch_strides.substitute({ u"TARGET": NameNode(node.pos, name=entry.name),
u"TARGET": NameNode(node.pos, name=entry.name), u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), u"IDX": IntNode(node.pos, value=EncodedString(idx))
u"IDX": IntNode(node.pos, value=EncodedString(idx)) })
}) auxass.append(ass)
auxass.append(ass)
for idx, entry in enumerate(bufaux.shapevars):
for idx, entry in enumerate(bufaux.shapevars): entry.used = True
entry.used = True ass = self.fetch_shape.substitute({
ass = self.fetch_shape.substitute({ u"TARGET": NameNode(node.pos, name=entry.name),
u"TARGET": NameNode(node.pos, name=entry.name), u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name), u"IDX": IntNode(node.pos, value=EncodedString(idx))
u"IDX": IntNode(node.pos, value=EncodedString(idx)) })
}) auxass.append(ass)
auxass.append(ass)
bufaux.buffer_info_var.used = True
bufaux.buffer_info_var.used = True acq = self.acquire_buffer_fragment.substitute({
acq = self.acquire_buffer_fragment.substitute({ u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name), u"LHS" : node.lhs,
u"LHS" : node.lhs, u"RHS": node.rhs,
u"RHS": node.rhs, u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass), u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name) }, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
return stats
def assign_into_buffer(self, node):
result = SingleAssignmentNode(node.pos,
rhs=self.visit(node.rhs),
lhs=self.buffer_index(node.lhs))
result.analyse_expressions(self.scope)
return result
def buffer_index(self, node):
bufaux = node.base.entry.buffer_aux
assert bufaux is not None
# indices * strides...
to_sum = [ IntBinopNode(node.pos, operator='*',
operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
# then sum them
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
}, pos=node.pos) }, pos=node.pos)
# Note: The below should probably be refactored into something
# like fragment.substitute(..., context=self.context), with return tmp.stats[0].expr
# TreeFragment getting context.pipeline_until_now() and
# applying it on the fragment.
acq.analyse_declarations(self.scope)
acq.analyse_expressions(self.scope)
stats = acq.stats
# stats += [node] # Do assignment after successful buffer acquisition
# print acq.dump()
return stats
else:
return node
buffer_access = TreeFragment(u""" buffer_access = TreeFragment(u"""
(<unsigned char*>(BUF.buf + OFFSET))[0] (<unsigned char*>(BUF.buf + OFFSET))[0]
""") """)
def visit_IndexNode(self, node): def visit_IndexNode(self, node):
# Only occurs when the IndexNode is an rvalue
if node.is_buffer_access: if node.is_buffer_access:
assert node.index is None assert node.index is None
assert node.indices is not None assert node.indices is not None
bufaux = node.base.entry.buffer_aux result = self.buffer_index(node)
assert bufaux is not None result.analyse_expressions(self.scope)
to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index, return result
operand2=NameNode(node.pos, name=stride.name))
for index, stride in zip(node.indices, bufaux.stridevars)]
print to_sum
indices = node.indices
# reduce * on indices
expr = to_sum[0]
for next in to_sum[1:]:
expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
tmp= self.buffer_access.substitute({
'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
'OFFSET': expr
})
tmp.analyse_expressions(self.scope)
return tmp.stats[0].expr
else: else:
return node return node
...@@ -309,6 +340,56 @@ class BufferTransform(CythonTransform): ...@@ -309,6 +340,56 @@ class BufferTransform(CythonTransform):
# print node.dump() # print node.dump()
class PhaseEnvelopeNode(Node):
"""
This node is used if you need to protect a node from reevaluation
of a phase. For instance, if you extract...
Use with care!
"""
# Phases
PARSED, ANALYSED = range(2)
def __init__(self, phase, wrapped):
self.phase = phase
self.wrapped = wrapped
def get_pos(self): return self.wrapped.pos
def set_pos(self, value): self.wrapped.pos = value
pos = property(get_pos, set_pos)
def get_subexprs(self): return self.wrapped.subexprs
subexprs = property(get_subexprs)
def analyse_types(self, env):
if self.phase < self.ANALYSED:
self.wrapped.analyse_types(env)
def __getattribute__(self, attrname):
wrapped = object.__getattribute__(self, "wrapped")
phase = object.__getattribute__(self, "phase")
if attrname == "wrapped": return wrapped
if attrname == "phase": return phase
attr = getattr(wrapped, attrname)
overridden = ("analyse_types",)
print attrname, attr
if not isinstance(attr, Node):
return attr
else:
if attr is None: return None
else:
return PhaseEnvelopeNode(phase, attr)
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment