Commit 7e1b07cc authored by Robert Bradshaw's avatar Robert Bradshaw

Cheaper overflow checks for nested expressions.

parent 4cf1df5c
...@@ -7980,6 +7980,8 @@ class BinopNode(ExprNode): ...@@ -7980,6 +7980,8 @@ class BinopNode(ExprNode):
extra_args, extra_args,
code.error_goto_if_null(self.result(), self.pos))) code.error_goto_if_null(self.result(), self.pos)))
code.put_gotref(self.py_result()) code.put_gotref(self.py_result())
elif self.is_temp:
code.putln("%s = %s;" % (self.result(), self.calculate_result_code()))
def type_error(self): def type_error(self):
if not (self.operand1.type.is_error if not (self.operand1.type.is_error
...@@ -8027,6 +8029,7 @@ class NumBinopNode(BinopNode): ...@@ -8027,6 +8029,7 @@ class NumBinopNode(BinopNode):
infix = True infix = True
overflow_check = False overflow_check = False
overflow_bit_node = None
def analyse_c_operation(self, env): def analyse_c_operation(self, env):
type1 = self.operand1.type type1 = self.operand1.type
...@@ -8091,12 +8094,13 @@ class NumBinopNode(BinopNode): ...@@ -8091,12 +8094,13 @@ class NumBinopNode(BinopNode):
return (type1.is_numeric or type1.is_enum) \ return (type1.is_numeric or type1.is_enum) \
and (type2.is_numeric or type2.is_enum) and (type2.is_numeric or type2.is_enum)
def generate_result_code(self, code): def generate_evaluation_code(self, code):
super(NumBinopNode, self).generate_result_code(code)
if self.overflow_check: if self.overflow_check:
self.overflow_bit_node = self
self.overflow_bit = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False) self.overflow_bit = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
code.putln("%s = 0;" % self.overflow_bit) code.putln("%s = 0;" % self.overflow_bit)
code.putln("%s = %s;" % (self.result(), self.calculate_result_code())) super(NumBinopNode, self).generate_evaluation_code(code)
if self.overflow_check:
code.putln("if (unlikely(%s)) {" % self.overflow_bit) code.putln("if (unlikely(%s)) {" % self.overflow_bit)
code.putln('PyErr_Format(PyExc_OverflowError, "value too large");') code.putln('PyErr_Format(PyExc_OverflowError, "value too large");')
code.putln(code.error_goto(self.pos)) code.putln(code.error_goto(self.pos))
...@@ -8104,12 +8108,12 @@ class NumBinopNode(BinopNode): ...@@ -8104,12 +8108,12 @@ class NumBinopNode(BinopNode):
code.funcstate.release_temp(self.overflow_bit) code.funcstate.release_temp(self.overflow_bit)
def calculate_result_code(self): def calculate_result_code(self):
if self.overflow_check: if self.overflow_bit_node is not None:
return "%s(%s, %s, &%s)" % ( return "%s(%s, %s, &%s)" % (
self.func, self.func,
self.operand1.result(), self.operand1.result(),
self.operand2.result(), self.operand2.result(),
self.overflow_bit) self.overflow_bit_node.overflow_bit)
elif self.infix: elif self.infix:
return "(%s %s %s)" % ( return "(%s %s %s)" % (
self.operand1.result(), self.operand1.result(),
......
...@@ -3244,3 +3244,37 @@ class FinalOptimizePhase(Visitor.CythonTransform): ...@@ -3244,3 +3244,37 @@ class FinalOptimizePhase(Visitor.CythonTransform):
if not node.arg.may_be_none(): if not node.arg.may_be_none():
return node.arg return node.arg
return node return node
class ConsolidateOverflowCheck(Visitor.CythonTransform):
"""
This class facilitates the sharing of overflow checking among all nodes
of a nested arithmetic expression. For example, given the expression
a*b + c, where a, b, and x are all possibly overflowing ints, the entire
sequence will be evaluated and the overflow bit checked only at the end.
"""
overflow_bit_node = None
def visit_Node(self, node):
if self.overflow_bit_node is not None:
saved = self.overflow_bit_node
self.overflow_bit_node = None
self.visitchildren(node)
self.overflow_bit_node = saved
else:
self.visitchildren(node)
return node
def visit_NumBinopNode(self, node):
if node.overflow_check:
top_level_overflow = self.overflow_bit_node is None
if top_level_overflow:
self.overflow_bit_node = node
else:
node.overflow_bit_node = self.overflow_bit_node
node.overflow_check = False
self.visitchildren(node)
if top_level_overflow:
self.overflow_bit_node = None
else:
self.visitchildren(node)
return node
...@@ -144,6 +144,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -144,6 +144,7 @@ def create_pipeline(context, mode, exclude_classes=()):
from Optimize import InlineDefNodeCalls from Optimize import InlineDefNodeCalls
from Optimize import ConstantFolding, FinalOptimizePhase from Optimize import ConstantFolding, FinalOptimizePhase
from Optimize import DropRefcountingTransform from Optimize import DropRefcountingTransform
from Optimize import ConsolidateOverflowCheck
from Buffer import IntroduceBufferAuxiliaryVars from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_declarations, check_c_declarations_pxd from ModuleNode import check_c_declarations, check_c_declarations_pxd
...@@ -196,6 +197,7 @@ def create_pipeline(context, mode, exclude_classes=()): ...@@ -196,6 +197,7 @@ def create_pipeline(context, mode, exclude_classes=()):
CreateClosureClasses(context), ## After all lookups and type inference CreateClosureClasses(context), ## After all lookups and type inference
ExpandInplaceOperators(context), ExpandInplaceOperators(context),
OptimizeBuiltinCalls(context), ## Necessary? OptimizeBuiltinCalls(context), ## Necessary?
ConsolidateOverflowCheck(context),
IterationTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
DropRefcountingTransform(), DropRefcountingTransform(),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment