Commit ab433f3d authored by Stefan Behnel's avatar Stefan Behnel

'safe' mode for type inference: only infer types that are very unlikely to break code

parent 52679b28
...@@ -62,7 +62,7 @@ directive_defaults = { ...@@ -62,7 +62,7 @@ directive_defaults = {
'ccomplex' : False, # use C99/C++ for complex types and arith 'ccomplex' : False, # use C99/C++ for complex types and arith
'callspec' : "", 'callspec' : "",
'profile': False, 'profile': False,
'infer_types': False, 'infer_types': 'none', # 'none', 'safe', 'all'
'autotestdict': True, 'autotestdict': True,
# test support # test support
...@@ -87,7 +87,7 @@ directive_scopes = { # defaults to available everywhere ...@@ -87,7 +87,7 @@ directive_scopes = { # defaults to available everywhere
def parse_directive_value(name, value): def parse_directive_value(name, value):
""" """
Parses value as an option value for the given name and returns Parses value as an option value for the given name and returns
the interpreted value. None is returned if the option does not exist. the interpreted value. None is returned if the option does not exist.
>>> print parse_directive_value('nonexisting', 'asdf asdfd') >>> print parse_directive_value('nonexisting', 'asdf asdfd')
None None
...@@ -110,6 +110,8 @@ def parse_directive_value(name, value): ...@@ -110,6 +110,8 @@ def parse_directive_value(name, value):
return int(value) return int(value)
except ValueError: except ValueError:
raise ValueError("%s directive must be set to an integer" % name) raise ValueError("%s directive must be set to an integer" % name)
elif type is str:
return str(value)
else: else:
assert False assert False
......
import ExprNodes import ExprNodes
import PyrexTypes
from PyrexTypes import py_object_type, unspecified_type, spanning_type from PyrexTypes import py_object_type, unspecified_type, spanning_type
from Visitor import CythonTransform from Visitor import CythonTransform
...@@ -119,6 +120,7 @@ class SimpleAssignmentTypeInferer: ...@@ -119,6 +120,7 @@ class SimpleAssignmentTypeInferer:
# TODO: Implement a real type inference algorithm. # TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...) # (Something more powerful than just extending this one...)
def infer_types(self, scope): def infer_types(self, scope):
which_types_to_infer = scope.directives['infer_types']
dependancies_by_entry = {} # entry -> dependancies dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = [] ready_to_infer = []
...@@ -150,11 +152,12 @@ class SimpleAssignmentTypeInferer: ...@@ -150,11 +152,12 @@ class SimpleAssignmentTypeInferer:
entry = ready_to_infer.pop() entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
if types: if types:
entry.type = reduce(spanning_type, types) result_type = reduce(spanning_type, types)
else: else:
# List comprehension? # List comprehension?
# print "No assignments", entry.pos, entry # print "No assignments", entry.pos, entry
entry.type = py_object_type result_type = py_object_type
entry.type = find_safe_type(result_type, which_types_to_infer)
resolve_dependancy(entry) resolve_dependancy(entry)
# Deal with simple circular dependancies... # Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items(): for entry, deps in dependancies_by_entry.items():
...@@ -164,6 +167,7 @@ class SimpleAssignmentTypeInferer: ...@@ -164,6 +167,7 @@ class SimpleAssignmentTypeInferer:
entry.type = reduce(spanning_type, types) entry.type = reduce(spanning_type, types)
types = [expr.infer_type(scope) for expr in entry.assignments] types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = reduce(spanning_type, types) # might be wider... entry.type = reduce(spanning_type, types) # might be wider...
entry.type = find_safe_type(entry.type, which_types_to_infer)
resolve_dependancy(entry) resolve_dependancy(entry)
del dependancies_by_entry[entry] del dependancies_by_entry[entry]
if ready_to_infer: if ready_to_infer:
...@@ -175,5 +179,18 @@ class SimpleAssignmentTypeInferer: ...@@ -175,5 +179,18 @@ class SimpleAssignmentTypeInferer:
for entry in dependancies_by_entry: for entry in dependancies_by_entry:
entry.type = py_object_type entry.type = py_object_type
def find_safe_type(result_type, which_types_to_infer):
if which_types_to_infer == 'all':
return result_type
elif which_types_to_infer == 'safe':
if result_type.is_pyobject:
# any specific Python type is always safe to infer
return result_type
elif result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type):
# Python's float type is just a C double, so it's safe to
# use the C type instead
return PyrexTypes.c_double_type
return py_object_type
def get_type_inferer(): def get_type_inferer():
return SimpleAssignmentTypeInferer() return SimpleAssignmentTypeInferer()
# cython: infer_types = True # cython: infer_types = all
from cython cimport typeof from cython cimport typeof, infer_types
cdef class MyType:
pass
def simple(): def simple():
""" """
...@@ -26,6 +29,23 @@ def simple(): ...@@ -26,6 +29,23 @@ def simple():
t = (4,5,6) t = (4,5,6)
assert typeof(t) == "tuple object", typeof(t) assert typeof(t) == "tuple object", typeof(t)
def builtin_types():
"""
>>> builtin_types()
"""
b = bytes()
assert typeof(b) == "bytes object", typeof(b)
u = unicode()
assert typeof(u) == "unicode object", typeof(u)
L = list()
assert typeof(L) == "list object", typeof(L)
t = tuple()
assert typeof(t) == "tuple object", typeof(t)
d = dict()
assert typeof(d) == "dict object", typeof(d)
B = bool()
assert typeof(B) == "bool object", typeof(B)
def multiple_assignments(): def multiple_assignments():
""" """
>>> multiple_assignments() >>> multiple_assignments()
...@@ -43,9 +63,9 @@ def multiple_assignments(): ...@@ -43,9 +63,9 @@ def multiple_assignments():
c = [1,2,3] c = [1,2,3]
assert typeof(c) == "Python object" assert typeof(c) == "Python object"
def arithmatic(): def arithmetic():
""" """
>>> arithmatic() >>> arithmetic()
""" """
a = 1 + 2 a = 1 + 2
assert typeof(a) == "long" assert typeof(a) == "long"
...@@ -105,3 +125,15 @@ def loop(): ...@@ -105,3 +125,15 @@ def loop():
for d in range(0, 10L, 2): for d in range(0, 10L, 2):
pass pass
assert typeof(a) == "long" assert typeof(a) == "long"
@infer_types('safe')
def safe_only():
"""
>>> safe_only()
"""
a = 1.0
assert typeof(a) == "double", typeof(c)
b = 1
assert typeof(b) == "Python object", typeof(c)
c = MyType()
assert typeof(c) == "MyType", typeof(c)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment