Commit 82d16f78 authored by Robert Bradshaw's avatar Robert Bradshaw

Left shift overflow guards.

parent a734b360
...@@ -8011,10 +8011,10 @@ class NumBinopNode(BinopNode): ...@@ -8011,10 +8011,10 @@ class NumBinopNode(BinopNode):
return return
if self.type.is_complex: if self.type.is_complex:
self.infix = False self.infix = False
if self.type.is_int and env.directives['overflowcheck'] and self.operator in ('+', '-', '*'): if self.type.is_int and env.directives['overflowcheck'] and self.operator in self.overflow_op_names:
self.overflow_check = True self.overflow_check = True
self.func = self.type.overflow_check_binop( self.func = self.type.overflow_check_binop(
self.op_names[self.operator], self.overflow_op_names[self.operator],
env, env,
const_rhs = self.operand2.has_constant_result()) const_rhs = self.operand2.has_constant_result())
self.is_temp = True self.is_temp = True
...@@ -8118,10 +8118,11 @@ class NumBinopNode(BinopNode): ...@@ -8118,10 +8118,11 @@ class NumBinopNode(BinopNode):
"**": "PyNumber_Power" "**": "PyNumber_Power"
} }
op_names = { overflow_op_names = {
"+": "add", "+": "add",
"-": "sub", "-": "sub",
"*": "mul", "*": "mul",
"<<": "lshift",
} }
class IntBinopNode(NumBinopNode): class IntBinopNode(NumBinopNode):
......
...@@ -403,11 +403,14 @@ class CTypedefType(BaseType): ...@@ -403,11 +403,14 @@ class CTypedefType(BaseType):
return self.typedef_base_type.create_from_py_utility_code(env) return self.typedef_base_type.create_from_py_utility_code(env)
def overflow_check_binop(self, binop, env, const_rhs=False): def overflow_check_binop(self, binop, env, const_rhs=False):
if const_rhs:
binop += "_const"
env.use_utility_code(UtilityCode.load("Common", "Overflow.c")) env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
type = self.declaration_code("") type = self.declaration_code("")
name = self.specialization_name() name = self.specialization_name()
if binop == "lshift":
env.use_utility_code(TempitaUtilityCode.load("LeftShift", "Overflow.c", context={'TYPE': type, 'NAME': name}))
else:
if const_rhs:
binop += "_const"
_load_overflow_base(env) _load_overflow_base(env)
env.use_utility_code(TempitaUtilityCode.load("SizeCheck", "Overflow.c", context={'TYPE': type, 'NAME': name})) env.use_utility_code(TempitaUtilityCode.load("SizeCheck", "Overflow.c", context={'TYPE': type, 'NAME': name}))
env.use_utility_code(TempitaUtilityCode.load("Binop", "Overflow.c", context={'TYPE': type, 'NAME': name, 'BINOP': binop})) env.use_utility_code(TempitaUtilityCode.load("Binop", "Overflow.c", context={'TYPE': type, 'NAME': name, 'BINOP': binop}))
...@@ -1563,6 +1566,9 @@ class CIntType(CNumericType): ...@@ -1563,6 +1566,9 @@ class CIntType(CNumericType):
env.use_utility_code(UtilityCode.load("Common", "Overflow.c")) env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
type = self.declaration_code("") type = self.declaration_code("")
name = self.specialization_name() name = self.specialization_name()
if binop == "lshift":
env.use_utility_code(TempitaUtilityCode.load("LeftShift", "Overflow.c", context={'TYPE': type, 'NAME': name}))
else:
if const_rhs: if const_rhs:
binop += "_const" binop += "_const"
if type in ('int', 'long', 'long long'): if type in ('int', 'long', 'long long'):
......
...@@ -268,3 +268,12 @@ static CYTHON_INLINE {{TYPE}} __Pyx_{{BINOP}}_{{NAME}}_checking_overflow({{TYPE} ...@@ -268,3 +268,12 @@ static CYTHON_INLINE {{TYPE}} __Pyx_{{BINOP}}_{{NAME}}_checking_overflow({{TYPE}
} }
} }
} }
/////////////// LeftShift.proto ///////////////
static CYTHON_INLINE {{TYPE}} __Pyx_lshift_{{NAME}}_checking_overflow({{TYPE}} a, {{TYPE}} b, int *overflow) {
*overflow |= (b < 0) | (b > (8 * sizeof({{TYPE}}))) | (a > (__PYX_MAX({{TYPE}}) >> b));
return a << b;
}
#define __Pyx_lshift_const_{{NAME}}_checking_overflow __Pyx_lshift_{{NAME}}_checking_overflow
...@@ -206,3 +206,19 @@ def test_mul_const(INT a): ...@@ -206,3 +206,19 @@ def test_mul_const(INT a):
True True
""" """
return int(a * <INT>100) return int(a * <INT>100)
@cython.overflowcheck(True)
def test_lshift(INT a, int b):
"""
>>> test_lshift(1, 10)
1024
>>> expect_overflow(test_lshift, 1, 100)
>>> expect_overflow(test_lshift, max_value, 1)
>>> test_lshift(max_value, 0) == max_value
True
>>> check(test_lshift, operator.lshift, 10, 15)
>>> check(test_lshift, operator.lshift, 10, 30)
>>> check(test_lshift, operator.lshift, 100, 60)
"""
return int(a << b)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment