Commit e440f779 authored by Stefan Behnel's avatar Stefan Behnel

implement type inference across closures

parent c1690bf3
......@@ -186,6 +186,7 @@ class Entry(object):
from_cython_utility_code = None
error_on_uninitialized = False
cf_used = True
outer_entry = None
def __init__(self, name, cname, type, pos = None, init = None):
self.name = name
......@@ -196,6 +197,7 @@ class Entry(object):
self.overloaded_alternatives = []
self.cf_assignments = []
self.cf_references = []
self.inner_entries = []
def __repr__(self):
return "Entry(name=%s, type=%s)" % (self.name, self.type)
......@@ -207,6 +209,22 @@ class Entry(object):
def all_alternatives(self):
return [self] + self.overloaded_alternatives
def all_entries(self):
"""
Returns all entries for this entry, including the equivalent ones
in other closures.
"""
if self.from_closure:
return self.outer_entry.all_entries()
entries = []
def collect_inner_entries(entry):
entries.append(entry)
for e in entry.inner_entries:
collect_inner_entries(e)
collect_inner_entries(self)
return entries
class Scope(object):
# name string Unqualified name
......@@ -1524,6 +1542,7 @@ class LocalScope(Scope):
inner_entry.from_closure = True
inner_entry.is_declared_generic = entry.is_declared_generic
self.entries[name] = inner_entry
entry.inner_entries.append(inner_entry)
return inner_entry
return entry
......
......@@ -335,6 +335,9 @@ class PyObjectTypeInferer(object):
class SimpleAssignmentTypeInferer(object):
"""
Very basic type inference.
Note: in order to support cross-closure type inference, this must be
applies to nested scopes in top-down order.
"""
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
......@@ -357,13 +360,11 @@ class SimpleAssignmentTypeInferer(object):
ready_to_infer = []
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
if entry.in_closure or entry.from_closure:
# cross-closure type inference is not currently supported
entry.type = py_object_type
continue
entries = entry.all_entries()
all = set()
for assmt in entry.cf_assignments:
all.update(assmt.type_dependencies(scope))
for e in entries:
for assmt in e.cf_assignments:
all.update(assmt.type_dependencies(e.scope))
if all:
dependancies_by_entry[entry] = all
for dep in all:
......@@ -387,14 +388,20 @@ class SimpleAssignmentTypeInferer(object):
while True:
while ready_to_infer:
entry = ready_to_infer.pop()
types = [assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments]
types = [
assmt.rhs.infer_type(scope)
for e in entry.all_entries()
for assmt in e.cf_assignments
]
if types and Utils.all(types):
entry.type = spanning_type(types, entry.might_overflow, entry.pos)
entry_type = spanning_type(types, entry.might_overflow, entry.pos)
else:
# FIXME: raise a warning?
# print "No assignments", entry.pos, entry
entry.type = py_object_type
entry_type = py_object_type
# propagate entry type to all nested scopes
for e in entry.all_entries():
e.type = entry_type
if verbose:
message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type))
resolve_dependancy(entry)
......
# mode: run
# tag: typeinference
cimport cython
def test_outer_inner_double():
"""
>>> print(test_outer_inner_double())
double
"""
x = 1.0
def inner():
nonlocal x
x = 2.0
inner()
assert x == 2.0
return cython.typeof(x)
def test_outer_inner_double_int():
"""
>>> print(test_outer_inner_double_int())
('double', 'double')
"""
x = 1.0
y = 2
def inner():
nonlocal x, y
x = 1
y = 2.0
inner()
return cython.typeof(x), cython.typeof(y)
def test_outer_inner_pyarg():
"""
>>> print(test_outer_inner_pyarg())
2
long
"""
x = 1
def inner(y):
return x + y
print inner(1)
return cython.typeof(x)
def test_outer_inner_carg():
"""
>>> print(test_outer_inner_carg())
2.0
long
"""
x = 1
def inner(double y):
return x + y
print inner(1)
return cython.typeof(x)
def test_outer_inner_incompatible():
"""
>>> print(test_outer_inner_incompatible())
Python object
"""
x = 1.0
def inner():
nonlocal x
x = 'test'
inner()
return cython.typeof(x)
def test_outer_inner2_double():
"""
>>> print(test_outer_inner2_double())
double
"""
x = 1.0
def inner1():
nonlocal x
x = 2
def inner2():
nonlocal x
x = 3.0
inner1()
inner2()
return cython.typeof(x)
......@@ -31,12 +31,12 @@ def test_unicode_loop():
print 2, cython.typeof(c)
yield c
def test_nonlocal_disables_inference():
def test_with_nonlocal():
"""
>>> chars = list(test_nonlocal_disables_inference())
1 Python object
2 Python object
2 Python object
>>> chars = list(test_with_nonlocal())
1 Py_UCS4
2 Py_UCS4
2 Py_UCS4
>>> len(chars)
2
>>> ''.join(chars) == 'ab'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment