Commit 5668a4fb authored by Robert Bradshaw's avatar Robert Bradshaw

Handle inplace arithmatic via parse tree transform.

Excludes buffers and C++, which have their own code.
This is in preparation for #591 (inline vs. cdivision) and
support for inline complex arithamtic.
parent 0f5d7221
...@@ -2746,6 +2746,7 @@ class SimpleCallNode(CallNode): ...@@ -2746,6 +2746,7 @@ class SimpleCallNode(CallNode):
wrapper_call = False wrapper_call = False
has_optional_args = False has_optional_args = False
nogil = False nogil = False
analysed = False
def compile_time_value(self, denv): def compile_time_value(self, denv):
function = self.function.compile_time_value(denv) function = self.function.compile_time_value(denv)
...@@ -2799,6 +2800,9 @@ class SimpleCallNode(CallNode): ...@@ -2799,6 +2800,9 @@ class SimpleCallNode(CallNode):
def analyse_types(self, env): def analyse_types(self, env):
if self.analyse_as_type_constructor(env): if self.analyse_as_type_constructor(env):
return return
if self.analysed:
return
self.analysed = True
function = self.function function = self.function
function.is_called = 1 function.is_called = 1
self.function.analyse_types(env) self.function.analyse_types(env)
......
...@@ -98,6 +98,7 @@ class Context(object): ...@@ -98,6 +98,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
...@@ -143,6 +144,7 @@ class Context(object): ...@@ -143,6 +144,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary? OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
......
...@@ -3520,15 +3520,15 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -3520,15 +3520,15 @@ class InPlaceAssignmentNode(AssignmentNode):
# (it must be a NameNode, AttributeNode, or IndexNode). # (it must be a NameNode, AttributeNode, or IndexNode).
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
dup = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
self.lhs.analyse_target_declaration(env) self.lhs.analyse_target_declaration(env)
def analyse_types(self, env): def analyse_types(self, env):
self.dup = self.create_dup_node(env) # re-assigns lhs to a shallow copy
self.rhs.analyse_types(env) self.rhs.analyse_types(env)
self.lhs.analyse_target_types(env) self.lhs.analyse_target_types(env)
return
import ExprNodes import ExprNodes
if self.lhs.type.is_pyobject: if self.lhs.type.is_pyobject:
self.rhs = self.rhs.coerce_to_pyobject(env) self.rhs = self.rhs.coerce_to_pyobject(env)
...@@ -3539,6 +3539,28 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -3539,6 +3539,28 @@ class InPlaceAssignmentNode(AssignmentNode):
self.result_value = self.result_value_temp.coerce_to(self.lhs.type, env) self.result_value = self.result_value_temp.coerce_to(self.lhs.type, env)
def generate_execution_code(self, code): def generate_execution_code(self, code):
import ExprNodes
self.rhs.generate_evaluation_code(code)
self.lhs.generate_subexpr_evaluation_code(code)
c_op = self.operator
if c_op == "//":
c_op = "/"
elif c_op == "**":
error(self.pos, "No C inplace power operator")
if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
if self.lhs.type.is_pyobject:
error(self.pos, "In-place operators not allowed on object buffers in this release.")
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
else:
# C++
# TODO: make sure overload is declared
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()))
self.lhs.generate_subexpr_disposal_code(code)
self.lhs.free_subexpr_temps(code)
self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code)
return
import ExprNodes import ExprNodes
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
self.dup.generate_subexpr_evaluation_code(code) self.dup.generate_subexpr_evaluation_code(code)
...@@ -3581,9 +3603,14 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -3581,9 +3603,14 @@ class InPlaceAssignmentNode(AssignmentNode):
# have to do assignment directly to avoid side-effects # have to do assignment directly to avoid side-effects
if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access: if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']:
error(self.pos, "Inplace non-c division not implemented for buffer types. (Use cdivision=False for now.)")
self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op) self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
else: else:
self.dup.generate_result_code(code) self.dup.generate_result_code(code)
if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']:
error(self.pos, "bad")
else:
code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) ) code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) )
self.rhs.generate_disposal_code(code) self.rhs.generate_disposal_code(code)
self.rhs.free_temps(code) self.rhs.free_temps(code)
...@@ -3645,7 +3672,6 @@ class InPlaceAssignmentNode(AssignmentNode): ...@@ -3645,7 +3672,6 @@ class InPlaceAssignmentNode(AssignmentNode):
def annotate(self, code): def annotate(self, code):
self.lhs.annotate(code) self.lhs.annotate(code)
self.rhs.annotate(code) self.rhs.annotate(code)
self.dup.annotate(code)
def create_binop_node(self): def create_binop_node(self):
import ExprNodes import ExprNodes
......
...@@ -1195,6 +1195,73 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1195,6 +1195,73 @@ class AnalyseExpressionsTransform(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
class ExpandInplaceOperators(CythonTransform):
def __call__(self, root):
self.env_stack = [root.scope]
return super(ExpandInplaceOperators, self).__call__(root)
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_InPlaceAssignmentNode(self, node):
lhs = node.lhs
rhs = node.rhs
if lhs.type.is_cpp_class:
# No getting around this exact operator here.
return node
if isinstance(lhs, IndexNode) and lhs.is_buffer_access:
# There is code to handle this case.
return node
def side_effect_free_reference(node, setting=False):
if node.type.is_pyobject and not setting:
node = LetRefNode(node)
return node, [node]
elif isinstance(node, IndexNode):
if node.is_buffer_access:
raise ValueError, "Buffer access"
base, temps = side_effect_free_reference(node.base)
index = LetRefNode(node.index)
return IndexNode(node.pos, base=base, index=index), temps + [index]
elif isinstance(node, AttributeNode):
obj, temps = side_effect_free_reference(node.obj)
return AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
elif isinstance(node, NameNode):
return node, []
else:
node = LetRefNode(node)
return node, [node]
try:
lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
except ValueError:
return node
dup = lhs.__class__(**lhs.__dict__)
binop = binop_node(node.pos,
operator = node.operator,
operand1 = dup,
operand2 = rhs)
node = SingleAssignmentNode(node.pos, lhs=lhs, rhs=binop) #, inplace=True)
# Use LetRefNode to avoid side effects.
let_ref_nodes.reverse()
for t in let_ref_nodes:
node = LetNode(t, node)
node.analyse_expressions(self.env_stack[-1])
return node
def visit_ExprNode(self, node):
# In-place assignments can't happen within an expression.
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
class AlignFunctionDefinitions(CythonTransform): class AlignFunctionDefinitions(CythonTransform):
""" """
This class takes the signatures from a .pxd file and applies them to This class takes the signatures from a .pxd file and applies them to
......
...@@ -8,6 +8,7 @@ import Nodes ...@@ -8,6 +8,7 @@ import Nodes
import ExprNodes import ExprNodes
from Nodes import Node from Nodes import Node
from ExprNodes import AtomicExprNode from ExprNodes import AtomicExprNode
from PyrexTypes import c_ptr_type
class TempHandle(object): class TempHandle(object):
# THIS IS DEPRECATED, USE LetRefNode instead # THIS IS DEPRECATED, USE LetRefNode instead
...@@ -196,6 +197,8 @@ class LetNodeMixin: ...@@ -196,6 +197,8 @@ class LetNodeMixin:
def setup_temp_expr(self, code): def setup_temp_expr(self, code):
self.temp_expression.generate_evaluation_code(code) self.temp_expression.generate_evaluation_code(code)
self.temp_type = self.temp_expression.type self.temp_type = self.temp_expression.type
if self.temp_type.is_array:
self.temp_type = c_ptr_type(self.temp_type.base_type)
self._result_in_temp = self.temp_expression.result_in_temp() self._result_in_temp = self.temp_expression.result_in_temp()
if self._result_in_temp: if self._result_in_temp:
self.temp = self.temp_expression.result() self.temp = self.temp_expression.result()
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment