Commit 5668a4fb authored by Robert Bradshaw's avatar Robert Bradshaw

Handle inplace arithmatic via parse tree transform.

Excludes buffers and C++, which have their own code.
This is in preparation for #591 (inline vs. cdivision) and
support for inline complex arithamtic.
parent 0f5d7221
......@@ -2746,6 +2746,7 @@ class SimpleCallNode(CallNode):
wrapper_call = False
has_optional_args = False
nogil = False
analysed = False
def compile_time_value(self, denv):
function = self.function.compile_time_value(denv)
......@@ -2799,6 +2800,9 @@ class SimpleCallNode(CallNode):
def analyse_types(self, env):
if self.analyse_as_type_constructor(env):
return
if self.analysed:
return
self.analysed = True
function = self.function
function.is_called = 1
self.function.analyse_types(env)
......
......@@ -98,6 +98,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators
from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform
......@@ -143,6 +144,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self),
_check_c_declarations,
AnalyseExpressionsTransform(self),
ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(),
SwitchTransform(),
......
......@@ -3520,15 +3520,15 @@ class InPlaceAssignmentNode(AssignmentNode):
# (it must be a NameNode, AttributeNode, or IndexNode).
child_attrs = ["lhs", "rhs"]
dup = None
def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env)
def analyse_types(self, env):
self.dup = self.create_dup_node(env) # re-assigns lhs to a shallow copy
self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env)
return
import ExprNodes
if self.lhs.type.is_pyobject:
self.rhs = self.rhs.coerce_to_pyobject(env)
......@@ -3539,6 +3539,28 @@ class InPlaceAssignmentNode(AssignmentNode):
self.result_value = self.result_value_temp.coerce_to(self.lhs.type, env)
def generate_execution_code(self, code):
import ExprNodes
self.rhs.generate_evaluation_code(code)
self.lhs.generate_subexpr_evaluation_code(code)
c_op = self.operator
if c_op == "//":
c_op = "/"
elif c_op == "**":
error(self.pos, "No C inplace power operator")
if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
if self.lhs.type.is_pyobject:
error(self.pos, "In-place operators not allowed on object buffers in this release.")
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
else:
# C++
# TODO: make sure overload is declared
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()))
self.lhs.generate_subexpr_disposal_code(code)
self.lhs.free_subexpr_temps(code)
self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
return
import ExprNodes
self.rhs.generate_evaluation_code(code)
self.dup.generate_subexpr_evaluation_code(code)
......@@ -3581,10 +3603,15 @@ class InPlaceAssignmentNode(AssignmentNode):
# have to do assignment directly to avoid side-effects
if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']:
error(self.pos, "Inplace non-c division not implemented for buffer types. (Use cdivision=False for now.)")
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
else:
self.dup.generate_result_code(code)
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) )
if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']:
error(self.pos, "bad")
else:
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) )
self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
if self.dup.is_temp:
......@@ -3645,7 +3672,6 @@ class InPlaceAssignmentNode(AssignmentNode):
def annotate(self, code):
self.lhs.annotate(code)
self.rhs.annotate(code)
self.dup.annotate(code)
def create_binop_node(self):
import ExprNodes
......
......@@ -1194,7 +1194,74 @@ class AnalyseExpressionsTransform(CythonTransform):
node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node)
return node
class ExpandInplaceOperators(CythonTransform):
def __call__(self, root):
self.env_stack = [root.scope]
return super(ExpandInplaceOperators, self).__call__(root)
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_InPlaceAssignmentNode(self, node):
lhs = node.lhs
rhs = node.rhs
if lhs.type.is_cpp_class:
# No getting around this exact operator here.
return node
if isinstance(lhs, IndexNode) and lhs.is_buffer_access:
# There is code to handle this case.
return node
def side_effect_free_reference(node, setting=False):
if node.type.is_pyobject and not setting:
node = LetRefNode(node)
return node, [node]
elif isinstance(node, IndexNode):
if node.is_buffer_access:
raise ValueError, "Buffer access"
base, temps = side_effect_free_reference(node.base)
index = LetRefNode(node.index)
return IndexNode(node.pos, base=base, index=index), temps + [index]
elif isinstance(node, AttributeNode):
obj, temps = side_effect_free_reference(node.obj)
return AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
elif isinstance(node, NameNode):
return node, []
else:
node = LetRefNode(node)
return node, [node]
try:
lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
except ValueError:
return node
dup = lhs.__class__(**lhs.__dict__)
binop = binop_node(node.pos,
operator = node.operator,
operand1 = dup,
operand2 = rhs)
node = SingleAssignmentNode(node.pos, lhs=lhs, rhs=binop) #, inplace=True)
# Use LetRefNode to avoid side effects.
let_ref_nodes.reverse()
for t in let_ref_nodes:
node = LetNode(t, node)
node.analyse_expressions(self.env_stack[-1])
return node
def visit_ExprNode(self, node):
# In-place assignments can't happen within an expression.
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
class AlignFunctionDefinitions(CythonTransform):
"""
This class takes the signatures from a .pxd file and applies them to
......
......@@ -8,6 +8,7 @@ import Nodes
import ExprNodes
from Nodes import Node
from ExprNodes import AtomicExprNode
from PyrexTypes import c_ptr_type
class TempHandle(object):
# THIS IS DEPRECATED, USE LetRefNode instead
......@@ -196,6 +197,8 @@ class LetNodeMixin:
def setup_temp_expr(self, code):
self.temp_expression.generate_evaluation_code(code)
self.temp_type = self.temp_expression.type
if self.temp_type.is_array:
self.temp_type = c_ptr_type(self.temp_type.base_type)
self._result_in_temp = self.temp_expression.result_in_temp()
if self._result_in_temp:
self.temp = self.temp_expression.result()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment