Commit 33e7929b authored by Robert Bradshaw's avatar Robert Bradshaw

Merge branch 'overflow'

parents 8d2a7185 c93a2314
...@@ -69,10 +69,18 @@ class UtilityCodeBase(object): ...@@ -69,10 +69,18 @@ class UtilityCodeBase(object):
##### MyUtility.proto ##### ##### MyUtility.proto #####
[proto declarations]
##### MyUtility.init #####
[code run at module initialization]
##### MyUtility ##### ##### MyUtility #####
#@requires: MyOtherUtility #@requires: MyOtherUtility
#@substitute: naming #@substitute: naming
[definitions]
for prototypes and implementation respectively. For non-python or for prototypes and implementation respectively. For non-python or
-cython files backslashes should be used instead. 5 to 30 comment -cython files backslashes should be used instead. 5 to 30 comment
characters may be used on either side. characters may be used on either side.
...@@ -374,10 +382,13 @@ class UtilityCode(UtilityCodeBase): ...@@ -374,10 +382,13 @@ class UtilityCode(UtilityCodeBase):
output['utility_code_def'].put(self.format_code(self.impl)) output['utility_code_def'].put(self.format_code(self.impl))
if self.init: if self.init:
writer = output['init_globals'] writer = output['init_globals']
writer.putln("/* %s.init */" % self.name)
if isinstance(self.init, basestring): if isinstance(self.init, basestring):
writer.put(self.format_code(self.init)) writer.put(self.format_code(self.init))
else: else:
self.init(writer, output.module_pos) self.init(writer, output.module_pos)
writer.putln(writer.error_goto_if_PyErr(output.module_pos))
writer.putln()
if self.cleanup and Options.generate_cleanup_code: if self.cleanup and Options.generate_cleanup_code:
writer = output['cleanup_globals'] writer = output['cleanup_globals']
if isinstance(self.cleanup, basestring): if isinstance(self.cleanup, basestring):
...@@ -400,13 +411,14 @@ def sub_tempita(s, context, file=None, name=None): ...@@ -400,13 +411,14 @@ def sub_tempita(s, context, file=None, name=None):
return sub(s, **context) return sub(s, **context)
class TempitaUtilityCode(UtilityCode): class TempitaUtilityCode(UtilityCode):
def __init__(self, name=None, proto=None, impl=None, file=None, context=None, **kwargs): def __init__(self, name=None, proto=None, impl=None, init=None, file=None, context=None, **kwargs):
if context is None: if context is None:
context = {} context = {}
proto = sub_tempita(proto, context, file, name) proto = sub_tempita(proto, context, file, name)
impl = sub_tempita(impl, context, file, name) impl = sub_tempita(impl, context, file, name)
init = sub_tempita(init, context, file, name)
super(TempitaUtilityCode, self).__init__( super(TempitaUtilityCode, self).__init__(
proto, impl, name=name, file=file, **kwargs) proto, impl, init=init, name=name, file=file, **kwargs)
def none_or_sub(self, s, context): def none_or_sub(self, s, context):
""" """
......
...@@ -7341,6 +7341,9 @@ class TypecastNode(ExprNode): ...@@ -7341,6 +7341,9 @@ class TypecastNode(ExprNode):
if self.type is None: if self.type is None:
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env) _, self.type = self.declarator.analyse(base_type, env)
if self.operand.has_constant_result():
# Must be done after self.type is resolved.
self.calculate_constant_result()
if self.type.is_cfunction: if self.type.is_cfunction:
error(self.pos, error(self.pos,
"Cannot cast to a function type") "Cannot cast to a function type")
...@@ -7400,11 +7403,11 @@ class TypecastNode(ExprNode): ...@@ -7400,11 +7403,11 @@ class TypecastNode(ExprNode):
return self.operand.check_const() return self.operand.check_const()
def calculate_constant_result(self): def calculate_constant_result(self):
# we usually do not know the result of a type cast at code self.constant_result = self.calculate_result_code(self.operand.constant_result)
# generation time
pass
def calculate_result_code(self): def calculate_result_code(self, operand_result = None):
if operand_result is None:
operand_result = self.operand.result()
if self.type.is_complex: if self.type.is_complex:
operand_result = self.operand.result() operand_result = self.operand.result()
if self.operand.type.is_complex: if self.operand.type.is_complex:
...@@ -7418,7 +7421,7 @@ class TypecastNode(ExprNode): ...@@ -7418,7 +7421,7 @@ class TypecastNode(ExprNode):
real_part, real_part,
imag_part) imag_part)
else: else:
return self.type.cast_code(self.operand.result()) return self.type.cast_code(operand_result)
def get_constant_c_result_code(self): def get_constant_c_result_code(self):
operand_result = self.operand.get_constant_c_result_code() operand_result = self.operand.get_constant_c_result_code()
...@@ -7997,6 +8000,7 @@ class NumBinopNode(BinopNode): ...@@ -7997,6 +8000,7 @@ class NumBinopNode(BinopNode):
# Binary operation taking numeric arguments. # Binary operation taking numeric arguments.
infix = True infix = True
overflow_check = False
def analyse_c_operation(self, env): def analyse_c_operation(self, env):
type1 = self.operand1.type type1 = self.operand1.type
...@@ -8007,6 +8011,13 @@ class NumBinopNode(BinopNode): ...@@ -8007,6 +8011,13 @@ class NumBinopNode(BinopNode):
return return
if self.type.is_complex: if self.type.is_complex:
self.infix = False self.infix = False
if self.type.is_int and env.directives['overflowcheck'] and self.operator in self.overflow_op_names:
self.overflow_check = True
self.func = self.type.overflow_check_binop(
self.overflow_op_names[self.operator],
env,
const_rhs = self.operand2.has_constant_result())
self.is_temp = True
if not self.infix or (type1.is_numeric and type2.is_numeric): if not self.infix or (type1.is_numeric and type2.is_numeric):
self.operand1 = self.operand1.coerce_to(self.type, env) self.operand1 = self.operand1.coerce_to(self.type, env)
self.operand2 = self.operand2.coerce_to(self.type, env) self.operand2 = self.operand2.coerce_to(self.type, env)
...@@ -8048,8 +8059,26 @@ class NumBinopNode(BinopNode): ...@@ -8048,8 +8059,26 @@ class NumBinopNode(BinopNode):
return (type1.is_numeric or type1.is_enum) \ return (type1.is_numeric or type1.is_enum) \
and (type2.is_numeric or type2.is_enum) and (type2.is_numeric or type2.is_enum)
def generate_result_code(self, code):
super(NumBinopNode, self).generate_result_code(code)
if self.overflow_check:
self.overflow_bit = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
code.putln("%s = 0;" % self.overflow_bit);
code.putln("%s = %s;" % (self.result(), self.calculate_result_code()))
code.putln("if (unlikely(%s)) {" % self.overflow_bit)
code.putln('PyErr_Format(PyExc_OverflowError, "value too large");')
code.putln(code.error_goto(self.pos))
code.putln("}")
code.funcstate.release_temp(self.overflow_bit)
def calculate_result_code(self): def calculate_result_code(self):
if self.infix: if self.overflow_check:
return "%s(%s, %s, &%s)" % (
self.func,
self.operand1.result(),
self.operand2.result(),
self.overflow_bit)
elif self.infix:
return "(%s %s %s)" % ( return "(%s %s %s)" % (
self.operand1.result(), self.operand1.result(),
self.operator, self.operator,
...@@ -8089,6 +8118,13 @@ class NumBinopNode(BinopNode): ...@@ -8089,6 +8118,13 @@ class NumBinopNode(BinopNode):
"**": "PyNumber_Power" "**": "PyNumber_Power"
} }
overflow_op_names = {
"+": "add",
"-": "sub",
"*": "mul",
"<<": "lshift",
}
class IntBinopNode(NumBinopNode): class IntBinopNode(NumBinopNode):
# Binary operation taking integer arguments. # Binary operation taking integer arguments.
......
...@@ -81,6 +81,7 @@ directive_defaults = { ...@@ -81,6 +81,7 @@ directive_defaults = {
'auto_cpdef': False, 'auto_cpdef': False,
'cdivision': False, # was True before 0.12 'cdivision': False, # was True before 0.12
'cdivision_warnings': False, 'cdivision_warnings': False,
'overflowcheck': False,
'always_allow_keywords': False, 'always_allow_keywords': False,
'allow_none_for_extension_args': True, 'allow_none_for_extension_args': True,
'wraparound' : True, 'wraparound' : True,
......
...@@ -25,7 +25,7 @@ class BaseType(object): ...@@ -25,7 +25,7 @@ class BaseType(object):
# This is not entirely robust. # This is not entirely robust.
safe = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_0123456789' safe = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_0123456789'
all = [] all = []
for c in self.declaration_code("").replace(" ", "__"): for c in self.declaration_code("").replace("unsigned ", "unsigned_").replace("long long", "long_long").replace(" ", "__"):
if c in safe: if c in safe:
all.append(c) all.append(c)
else: else:
...@@ -402,6 +402,26 @@ class CTypedefType(BaseType): ...@@ -402,6 +402,26 @@ class CTypedefType(BaseType):
# delegation # delegation
return self.typedef_base_type.create_from_py_utility_code(env) return self.typedef_base_type.create_from_py_utility_code(env)
def overflow_check_binop(self, binop, env, const_rhs=False):
env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
type = self.declaration_code("")
name = self.specialization_name()
if binop == "lshift":
env.use_utility_code(TempitaUtilityCode.load(
"LeftShift", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'SIGNED': self.signed}))
else:
if const_rhs:
binop += "_const"
_load_overflow_base(env)
env.use_utility_code(TempitaUtilityCode.load(
"SizeCheck", "Overflow.c",
context={'TYPE': type, 'NAME': name}))
env.use_utility_code(TempitaUtilityCode.load(
"Binop", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'BINOP': binop}))
return "__Pyx_%s_%s_checking_overflow" % (binop, name)
def error_condition(self, result_code): def error_condition(self, result_code):
if self.typedef_is_external: if self.typedef_is_external:
if self.exception_value: if self.exception_value:
...@@ -1546,7 +1566,51 @@ class CIntType(CNumericType): ...@@ -1546,7 +1566,51 @@ class CIntType(CNumericType):
# We do not really know the size of the type, so return # We do not really know the size of the type, so return
# a 32-bit literal and rely on casting to final type. It will # a 32-bit literal and rely on casting to final type. It will
# be negative for signed ints, which is good. # be negative for signed ints, which is good.
return "0xbad0bad0"; return "0xbad0bad0"
def overflow_check_binop(self, binop, env, const_rhs=False):
env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
type = self.declaration_code("")
name = self.specialization_name()
if binop == "lshift":
env.use_utility_code(TempitaUtilityCode.load(
"LeftShift", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'SIGNED': not self.signed}))
else:
if const_rhs:
binop += "_const"
if type in ('int', 'long', 'long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseSigned", "Overflow.c",
context={'INT': type, 'NAME': name}))
elif type in ('unsigned int', 'unsigned long', 'unsigned long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseUnsigned", "Overflow.c",
context={'UINT': type, 'NAME': name}))
elif self.rank <= 1:
# sizeof(short) < sizeof(int)
return "__Pyx_%s_%s_no_overflow" % (binop, name)
else:
_load_overflow_base(env)
env.use_utility_code(TempitaUtilityCode.load(
"SizeCheck", "Overflow.c",
context={'TYPE': type, 'NAME': name}))
env.use_utility_code(TempitaUtilityCode.load(
"Binop", "Overflow.c",
context={'TYPE': type, 'NAME': name, 'BINOP': binop}))
return "__Pyx_%s_%s_checking_overflow" % (binop, name)
def _load_overflow_base(env):
env.use_utility_code(UtilityCode.load("Common", "Overflow.c"))
for type in ('int', 'long', 'long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseSigned", "Overflow.c",
context={'INT': type, 'NAME': type.replace(' ', '_')}))
for type in ('unsigned int', 'unsigned long', 'unsigned long long'):
env.use_utility_code(TempitaUtilityCode.load(
"BaseCaseUnsigned", "Overflow.c",
context={'UINT': type, 'NAME': type.replace(' ', '_')}))
class CAnonEnumType(CIntType): class CAnonEnumType(CIntType):
......
This diff is collapsed.
...@@ -149,6 +149,11 @@ Cython code. Here is the list of currently supported directives: ...@@ -149,6 +149,11 @@ Cython code. Here is the list of currently supported directives:
appropriate exception is raised. This is off by default for appropriate exception is raised. This is off by default for
performance reasons. Default is False. performance reasons. Default is False.
``overflowcheck`` (True / False)
If set to True, raise errors on overflowing C integer arithmetic
operations. Incurs a slight runtime penalty, but much faster than
using Python ints. Default is False.
``embedsignature`` (True / False) ``embedsignature`` (True / False)
If set to True, Cython will embed a textual copy of the call If set to True, Cython will embed a textual copy of the call
signature in the docstring of all Python visible functions and signature in the docstring of all Python visible functions and
......
cimport cython
cdef object two = 2
cdef int size_in_bits = sizeof(INT) * 8
cdef bint is_signed_ = (<INT>-1 < 0)
cdef INT max_value_ = <INT>(two ** (size_in_bits - is_signed_) - 1)
cdef INT min_value_ = ~max_value_
cdef INT half_ = max_value_ // 2
# Python visible.
is_signed = is_signed_
max_value = max_value_
min_value = min_value_
half = half_
import operator
from libc.math cimport sqrt
cpdef check(func, op, a, b):
cdef INT res, op_res
cdef bint func_overflow = False
cdef bint assign_overflow = False
try:
res = func(a, b)
except OverflowError:
func_overflow = True
try:
op_res = op(a, b)
except OverflowError:
assign_overflow = True
assert func_overflow == assign_overflow, "Inconsistant overflow: %s(%s, %s)" % (func, a, b)
if not func_overflow:
assert res == op_res, "Inconsistant values: %s(%s, %s) == %s != %s" % (func, a, b, res, op_res)
medium_values = (max_value_ / 2, max_value_ / 3, min_value_ / 2, <INT>sqrt(max_value_) - 1, <INT>sqrt(max_value_) + 1)
def run_test(func, op):
cdef INT offset, b
check(func, op, 300, 200)
check(func, op, max_value_, max_value_)
check(func, op, max_value_, min_value_)
if not is_signed_ or not func is test_sub:
check(func, op, min_value_, min_value_)
for offset in range(5):
check(func, op, max_value_ - 1, offset)
check(func, op, min_value_ + 1, offset)
if is_signed_:
check(func, op, max_value_ - 1, 2 - offset)
check(func, op, min_value_ + 1, 2 - offset)
for offset in range(9):
check(func, op, max_value_ / 2, offset)
check(func, op, min_value_ / 3, offset)
check(func, op, max_value_ / 4, offset)
check(func, op, min_value_ / 5, offset)
if is_signed_:
check(func, op, max_value_ / 2, 4 - offset)
check(func, op, min_value_ / 3, 4 - offset)
check(func, op, max_value_ / -4, 3 - offset)
check(func, op, min_value_ / -5, 3 - offset)
for offset in range(-3, 4):
for a in medium_values:
for b in medium_values:
check(func, op, a, b + offset)
@cython.overflowcheck(True)
def test_add(INT a, INT b):
"""
>>> test_add(1, 2)
3
>>> test_add(max_value, max_value) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
OverflowError: value too large
>>> run_test(test_add, operator.add)
"""
return int(a + b)
@cython.overflowcheck(True)
def test_sub(INT a, INT b):
"""
>>> test_sub(10, 1)
9
>>> test_sub(min_value, 1) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
OverflowError: value too large
>>> run_test(test_sub, operator.sub)
"""
return int(a - b)
@cython.overflowcheck(True)
def test_mul(INT a, INT b):
"""
>>> test_mul(11, 13)
143
>>> test_mul(max_value / 2, max_value / 2) #doctest: +ELLIPSIS
Traceback (most recent call last):
...
OverflowError: value too large
>>> run_test(test_mul, operator.mul)
"""
return int(a * b)
@cython.overflowcheck(True)
def test_nested(INT a, INT b, INT c):
"""
>>> test_nested(1, 2, 3)
6
>>> expect_overflow(test_nested, half + 1, half + 1, half + 1)
>>> expect_overflow(test_nested, half - 1, half - 1, half - 1)
"""
return int(a + b + c)
def expect_overflow(func, *args):
try:
res = func(*args)
except OverflowError:
return
assert False, "Expected OverflowError, got %s" % res
cpdef format(INT value):
"""
>>> format(1)
'1'
>>> format(half - 1)
'half - 1'
>>> format(half)
'half'
>>> format(half + 2)
'half + 2'
>>> format(half + half - 3)
'half + half - 3'
>>> format(max_value)
'max_value'
"""
if value == max_value_:
return "max_value"
elif value == half_:
return "half"
elif max_value_ - value <= max_value_ // 4:
return "half + half - %s" % (half_ + half_ - value)
elif max_value_ - value <= half_:
return "half + %s" % (value - half_)
elif max_value_ - value <= half_ + max_value_ // 4:
return "half - %s" % (half_ - value)
else:
return "%s" % value
cdef INT called(INT value):
print("called(%s)" % format(value))
return value
@cython.overflowcheck(True)
def test_nested_func(INT a, INT b, INT c):
"""
>>> test_nested_func(1, 2, 3)
called(5)
6
>>> expect_overflow(test_nested_func, half + 1, half + 1, half + 1)
>>> expect_overflow(test_nested_func, half - 1, half - 1, half - 1)
called(half + half - 2)
>>> print(format(test_nested_func(1, half - 1, half - 1)))
called(half + half - 2)
half + half - 1
>>>
"""
return int(a + called(b + c))
@cython.overflowcheck(True)
def test_add_const(INT a):
"""
>>> test_add_const(1)
101
>>> expect_overflow(test_add_const, max_value)
>>> expect_overflow(test_add_const , max_value - 99)
>>> test_add_const(max_value - 100) == max_value
True
"""
return int(a + <INT>100)
@cython.overflowcheck(True)
def test_sub_const(INT a):
"""
>>> test_sub_const(101)
1
>>> expect_overflow(test_sub_const, min_value)
>>> expect_overflow(test_sub_const, min_value + 99)
>>> test_sub_const(min_value + 100) == min_value
True
"""
return int(a - <INT>100)
@cython.overflowcheck(True)
def test_mul_const(INT a):
"""
>>> test_mul_const(2)
200
>>> expect_overflow(test_mul_const, max_value)
>>> expect_overflow(test_mul_const, max_value // 99)
>>> test_mul_const(max_value // 100) == max_value - max_value % 100
True
"""
return int(a * <INT>100)
@cython.overflowcheck(True)
def test_lshift(INT a, int b):
"""
>>> test_lshift(1, 10)
1024
>>> expect_overflow(test_lshift, 1, 100)
>>> expect_overflow(test_lshift, max_value, 1)
>>> test_lshift(max_value, 0) == max_value
True
>>> check(test_lshift, operator.lshift, 10, 15)
>>> check(test_lshift, operator.lshift, 10, 30)
>>> check(test_lshift, operator.lshift, 100, 60)
"""
return int(a << b)
ctypedef int INT
include "overflow_check.pxi"
ctypedef long long INT
include "overflow_check.pxi"
ctypedef unsigned int INT
include "overflow_check.pxi"
ctypedef unsigned long long INT
include "overflow_check.pxi"
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment