Commit 6a589fd0 authored by scoder's avatar scoder

Merge pull request #233 from vitek/_type_inference_new

Assignmment based type inference
parents dd495c21 4fc031e7
......@@ -1476,6 +1476,7 @@ class NameNode(AtomicExprNode):
cf_is_null = False
allow_null = False
nogil = False
inferred_type = None
def as_cython_attribute(self):
return self.cython_attribute
......@@ -1484,7 +1485,7 @@ class NameNode(AtomicExprNode):
if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is not None and self.entry.type.is_unspecified:
return (self.entry,)
return (self,)
else:
return ()
......@@ -1492,6 +1493,8 @@ class NameNode(AtomicExprNode):
if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is None or self.entry.type is unspecified_type:
if self.inferred_type is not None:
return self.inferred_type
return py_object_type
elif (self.entry.type.is_extension_type or self.entry.type.is_builtin_type) and \
self.name == self.entry.type.name:
......@@ -1506,6 +1509,12 @@ class NameNode(AtomicExprNode):
# special case: referring to a C function must return its pointer
return PyrexTypes.CPtrType(self.entry.type)
else:
# If entry is inferred as pyobject it's safe to use local
# NameNode's inferred_type.
if self.entry.type.is_pyobject and self.inferred_type:
# Overflow may happen if integer
if not (self.inferred_type.is_int and self.entry.might_overflow):
return self.inferred_type
return self.entry.type
def compile_time_value(self, denv):
......
......@@ -35,6 +35,7 @@ cdef class NameAssignment:
cdef public object pos
cdef public set refs
cdef public object bit
cdef public object inferred_type
cdef class AssignmentList:
cdef public object bit
......
......@@ -319,15 +319,23 @@ class NameAssignment(object):
self.refs = set()
self.is_arg = False
self.is_deletion = False
self.inferred_type = None
def __repr__(self):
return '%s(entry=%r)' % (self.__class__.__name__, self.entry)
def infer_type(self, scope):
return self.rhs.infer_type(scope)
def infer_type(self):
self.inferred_type = self.rhs.infer_type(self.entry.scope)
return self.inferred_type
def type_dependencies(self, scope):
return self.rhs.type_dependencies(scope)
def type_dependencies(self):
return self.rhs.type_dependencies(self.entry.scope)
@property
def type(self):
if not self.entry.type.is_unspecified:
return self.entry.type
return self.inferred_type
class StaticAssignment(NameAssignment):
......@@ -341,11 +349,11 @@ class StaticAssignment(NameAssignment):
entry.type, may_be_none=may_be_none, pos=entry.pos)
super(StaticAssignment, self).__init__(lhs, lhs, entry)
def infer_type(self, scope):
def infer_type(self):
return self.entry.type
def type_dependencies(self, scope):
return []
def type_dependencies(self):
return ()
class Argument(NameAssignment):
......@@ -359,11 +367,12 @@ class NameDeletion(NameAssignment):
NameAssignment.__init__(self, lhs, lhs, entry)
self.is_deletion = True
def infer_type(self, scope):
inferred_type = self.rhs.infer_type(scope)
def infer_type(self):
inferred_type = self.rhs.infer_type(self.entry.scope)
if (not inferred_type.is_pyobject and
inferred_type.can_coerce_to_pyobject(scope)):
inferred_type.can_coerce_to_pyobject(self.entry.scope)):
return py_object_type
self.inferred_type = inferred_type
return inferred_type
......@@ -410,7 +419,9 @@ class ControlFlowState(list):
else:
if len(state) == 1:
self.is_single = True
super(ControlFlowState, self).__init__(state)
# XXX: Remove fake_rhs_expr
super(ControlFlowState, self).__init__(
[i for i in state if i.rhs is not fake_rhs_expr])
def one(self):
return self[0]
......
......@@ -339,8 +339,11 @@ class SimpleAssignmentTypeInferer(object):
Note: in order to support cross-closure type inference, this must be
applies to nested scopes in top-down order.
"""
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
def set_entry_type(self, entry, entry_type):
entry.type = entry_type
for e in entry.all_entries():
e.type = entry_type
def infer_types(self, scope):
enabled = scope.directives['infer_types']
verbose = scope.directives['infer_types.verbose']
......@@ -352,85 +355,126 @@ class SimpleAssignmentTypeInferer(object):
else:
for entry in scope.entries.values():
if entry.type is unspecified_type:
entry.type = py_object_type
self.set_entry_type(entry, py_object_type)
return
dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = []
# Set of assignemnts
assignments = set([])
assmts_resolved = set([])
dependencies = {}
assmt_to_names = {}
for name, entry in scope.entries.items():
for assmt in entry.cf_assignments:
names = assmt.type_dependencies()
assmt_to_names[assmt] = names
assmts = set()
for node in names:
assmts.update(node.cf_state)
dependencies[assmt] = assmts
if entry.type is unspecified_type:
all = set()
for assmt in entry.cf_assignments:
all.update(assmt.type_dependencies(entry.scope))
if all:
dependancies_by_entry[entry] = all
for dep in all:
if dep not in entries_by_dependancy:
entries_by_dependancy[dep] = set([entry])
else:
entries_by_dependancy[dep].add(entry)
else:
ready_to_infer.append(entry)
def resolve_dependancy(dep):
if dep in entries_by_dependancy:
for entry in entries_by_dependancy[dep]:
entry_deps = dependancies_by_entry[entry]
entry_deps.remove(dep)
if not entry_deps and entry != dep:
del dependancies_by_entry[entry]
ready_to_infer.append(entry)
# Try to infer things in order...
assignments.update(entry.cf_assignments)
else:
assmts_resolved.update(entry.cf_assignments)
def infer_name_node_type(node):
types = [assmt.inferred_type for assmt in node.cf_state]
if not types:
node_type = py_object_type
else:
node_type = spanning_type(
types, entry.might_overflow, entry.pos)
node.inferred_type = node_type
def infer_name_node_type_partial(node):
types = [assmt.inferred_type for assmt in node.cf_state
if assmt.inferred_type is not None]
if not types:
return
return spanning_type(types, entry.might_overflow, entry.pos)
def resolve_assignments(assignments):
resolved = set()
for assmt in assignments:
deps = dependencies[assmt]
# All assignments are resolved
if assmts_resolved.issuperset(deps):
for node in assmt_to_names[assmt]:
infer_name_node_type(node)
# Resolve assmt
inferred_type = assmt.infer_type()
done = False
assmts_resolved.add(assmt)
resolved.add(assmt)
assignments -= resolved
return resolved
def partial_infer(assmt):
partial_types = []
for node in assmt_to_names[assmt]:
partial_type = infer_name_node_type_partial(node)
if partial_type is None:
return False
partial_types.append((node, partial_type))
for node, partial_type in partial_types:
node.inferred_type = partial_type
assmt.infer_type()
return True
partial_assmts = set()
def resolve_partial(assignments):
# try to handle circular references
partials = set()
for assmt in assignments:
partial_types = []
if assmt in partial_assmts:
continue
for node in assmt_to_names[assmt]:
if partial_infer(assmt):
partials.add(assmt)
assmts_resolved.add(assmt)
partial_assmts.update(partials)
return partials
# Infer assignments
while True:
while ready_to_infer:
entry = ready_to_infer.pop()
types = [
assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments
]
if not resolve_assignments(assignments):
if not resolve_partial(assignments):
break
inferred = set()
# First pass
for entry in scope.entries.values():
if entry.type is not unspecified_type:
continue
entry_type = py_object_type
if assmts_resolved.issuperset(entry.cf_assignments):
types = [assmt.inferred_type for assmt in entry.cf_assignments]
if types and Utils.all(types):
entry_type = spanning_type(types, entry.might_overflow, entry.pos)
else:
# FIXME: raise a warning?
# print "No assignments", entry.pos, entry
entry_type = py_object_type
# propagate entry type to all nested scopes
for e in entry.all_entries():
if e.type is unspecified_type:
e.type = entry_type
else:
# FIXME: can this actually happen?
assert e.type == entry_type, (
'unexpected type mismatch between closures for inferred type %s: %s vs. %s' %
entry_type, e, entry)
if verbose:
message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type))
resolve_dependancy(entry)
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [assmt.infer_type(scope)
for assmt in entry.cf_assignments
if assmt.type_dependencies(scope) == ()]
if types:
entry.type = spanning_type(types, entry.might_overflow, entry.pos)
types = [assmt.infer_type(scope)
for assmt in entry.cf_assignments]
entry.type = spanning_type(types, entry.might_overflow, entry.pos) # might be wider...
resolve_dependancy(entry)
del dependancies_by_entry[entry]
if ready_to_infer:
break
if not ready_to_infer:
break
# We can't figure out the rest with this algorithm, let them be objects.
for entry in dependancies_by_entry:
entry.type = py_object_type
if verbose:
message(entry.pos, "inferred '%s' to be of type '%s' (default)" % (entry.name, entry.type))
entry_type = spanning_type(
types, entry.might_overflow, entry.pos)
inferred.add(entry)
self.set_entry_type(entry, entry_type)
def reinfer():
dirty = False
for entry in inferred:
types = [assmt.infer_type()
for assmt in entry.cf_assignments]
new_type = spanning_type(types, entry.might_overflow, entry.pos)
if new_type != entry.type:
self.set_entry_type(entry, new_type)
dirty = True
return dirty
# types propagation
while reinfer():
pass
if verbose:
for entry in inferred:
message(entry.pos, "inferred '%s' to be of type '%s'" % (
entry.name, entry.type))
def find_spanning_type(type1, type2):
if type1 is type2:
......
cimport cython
from cython cimport typeof, infer_types
def test_swap():
"""
>>> test_swap()
"""
a = 0
b = 1
tmp = a
a = b
b = tmp
assert typeof(a) == "long", typeof(a)
assert typeof(b) == "long", typeof(b)
assert typeof(tmp) == "long", typeof(tmp)
def test_object_assmt():
"""
>>> test_object_assmt()
"""
a = 1
b = a
a = "str"
assert typeof(a) == "Python object", typeof(a)
assert typeof(b) == "long", typeof(b)
def test_long_vs_double(cond):
"""
>>> test_long_vs_double(0)
"""
assert typeof(a) == "double", typeof(a)
assert typeof(b) == "double", typeof(b)
assert typeof(c) == "double", typeof(c)
assert typeof(d) == "double", typeof(d)
if cond:
a = 1
b = 2
c = (a + b) / 2
else:
a = 1.0
b = 2.0
d = (a + b) / 2
def test_double_vs_pyobject():
"""
>>> test_double_vs_pyobject()
"""
assert typeof(a) == "Python object", typeof(a)
assert typeof(b) == "Python object", typeof(b)
assert typeof(d) == "double", typeof(d)
a = []
b = []
a = 1.0
b = 2.0
d = (a + b) / 2
def test_python_objects(cond):
"""
>>> test_python_objects(0)
"""
if cond == 1:
a = [1, 2, 3]
o_list = a
elif cond == 2:
a = set([1, 2, 3])
o_set = a
else:
a = {1:1, 2:2, 3:3}
o_dict = a
assert typeof(a) == "Python object", typeof(a)
assert typeof(o_list) == "list object", typeof(o_list)
assert typeof(o_dict) == "dict object", typeof(o_dict)
assert typeof(o_set) == "set object", typeof(o_set)
# CF loops
def test_cf_loop():
"""
>>> test_cf_loop()
"""
cdef int i
a = 0.0
for i in range(3):
a += 1
assert typeof(a) == "double", typeof(a)
def test_cf_loop_intermediate():
"""
>>> test_cf_loop()
"""
cdef int i
a = 0
for i in range(3):
b = a
a = b + 1
assert typeof(a) == "long", typeof(a)
assert typeof(b) == "long", typeof(b)
# Integer overflow
def test_integer_overflow():
"""
>>> test_integer_overflow()
"""
a = 1
b = 2
c = a + b
assert typeof(a) == "Python object", typeof(a)
assert typeof(b) == "Python object", typeof(b)
assert typeof(c) == "Python object", typeof(c)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment