Commit e440f779 authored by Stefan Behnel's avatar Stefan Behnel

implement type inference across closures

parent c1690bf3
...@@ -186,6 +186,7 @@ class Entry(object): ...@@ -186,6 +186,7 @@ class Entry(object):
from_cython_utility_code = None from_cython_utility_code = None
error_on_uninitialized = False error_on_uninitialized = False
cf_used = True cf_used = True
outer_entry = None
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
self.name = name self.name = name
...@@ -196,6 +197,7 @@ class Entry(object): ...@@ -196,6 +197,7 @@ class Entry(object):
self.overloaded_alternatives = [] self.overloaded_alternatives = []
self.cf_assignments = [] self.cf_assignments = []
self.cf_references = [] self.cf_references = []
self.inner_entries = []
def __repr__(self): def __repr__(self):
return "Entry(name=%s, type=%s)" % (self.name, self.type) return "Entry(name=%s, type=%s)" % (self.name, self.type)
...@@ -207,6 +209,22 @@ class Entry(object): ...@@ -207,6 +209,22 @@ class Entry(object):
def all_alternatives(self): def all_alternatives(self):
return [self] + self.overloaded_alternatives return [self] + self.overloaded_alternatives
def all_entries(self):
"""
Returns all entries for this entry, including the equivalent ones
in other closures.
"""
if self.from_closure:
return self.outer_entry.all_entries()
entries = []
def collect_inner_entries(entry):
entries.append(entry)
for e in entry.inner_entries:
collect_inner_entries(e)
collect_inner_entries(self)
return entries
class Scope(object): class Scope(object):
# name string Unqualified name # name string Unqualified name
...@@ -1524,6 +1542,7 @@ class LocalScope(Scope): ...@@ -1524,6 +1542,7 @@ class LocalScope(Scope):
inner_entry.from_closure = True inner_entry.from_closure = True
inner_entry.is_declared_generic = entry.is_declared_generic inner_entry.is_declared_generic = entry.is_declared_generic
self.entries[name] = inner_entry self.entries[name] = inner_entry
entry.inner_entries.append(inner_entry)
return inner_entry return inner_entry
return entry return entry
......
...@@ -335,6 +335,9 @@ class PyObjectTypeInferer(object): ...@@ -335,6 +335,9 @@ class PyObjectTypeInferer(object):
class SimpleAssignmentTypeInferer(object): class SimpleAssignmentTypeInferer(object):
""" """
Very basic type inference. Very basic type inference.
Note: in order to support cross-closure type inference, this must be
applies to nested scopes in top-down order.
""" """
# TODO: Implement a real type inference algorithm. # TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...) # (Something more powerful than just extending this one...)
...@@ -357,13 +360,11 @@ class SimpleAssignmentTypeInferer(object): ...@@ -357,13 +360,11 @@ class SimpleAssignmentTypeInferer(object):
ready_to_infer = [] ready_to_infer = []
for name, entry in scope.entries.items(): for name, entry in scope.entries.items():
if entry.type is unspecified_type: if entry.type is unspecified_type:
if entry.in_closure or entry.from_closure: entries = entry.all_entries()
# cross-closure type inference is not currently supported
entry.type = py_object_type
continue
all = set() all = set()
for assmt in entry.cf_assignments: for e in entries:
all.update(assmt.type_dependencies(scope)) for assmt in e.cf_assignments:
all.update(assmt.type_dependencies(e.scope))
if all: if all:
dependancies_by_entry[entry] = all dependancies_by_entry[entry] = all
for dep in all: for dep in all:
...@@ -387,14 +388,20 @@ class SimpleAssignmentTypeInferer(object): ...@@ -387,14 +388,20 @@ class SimpleAssignmentTypeInferer(object):
while True: while True:
while ready_to_infer: while ready_to_infer:
entry = ready_to_infer.pop() entry = ready_to_infer.pop()
types = [assmt.rhs.infer_type(scope) types = [
for assmt in entry.cf_assignments] assmt.rhs.infer_type(scope)
for e in entry.all_entries()
for assmt in e.cf_assignments
]
if types and Utils.all(types): if types and Utils.all(types):
entry.type = spanning_type(types, entry.might_overflow, entry.pos) entry_type = spanning_type(types, entry.might_overflow, entry.pos)
else: else:
# FIXME: raise a warning? # FIXME: raise a warning?
# print "No assignments", entry.pos, entry # print "No assignments", entry.pos, entry
entry.type = py_object_type entry_type = py_object_type
# propagate entry type to all nested scopes
for e in entry.all_entries():
e.type = entry_type
if verbose: if verbose:
message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type)) message(entry.pos, "inferred '%s' to be of type '%s'" % (entry.name, entry.type))
resolve_dependancy(entry) resolve_dependancy(entry)
......
# mode: run
# tag: typeinference
cimport cython
def test_outer_inner_double():
"""
>>> print(test_outer_inner_double())
double
"""
x = 1.0
def inner():
nonlocal x
x = 2.0
inner()
assert x == 2.0
return cython.typeof(x)
def test_outer_inner_double_int():
"""
>>> print(test_outer_inner_double_int())
('double', 'double')
"""
x = 1.0
y = 2
def inner():
nonlocal x, y
x = 1
y = 2.0
inner()
return cython.typeof(x), cython.typeof(y)
def test_outer_inner_pyarg():
"""
>>> print(test_outer_inner_pyarg())
2
long
"""
x = 1
def inner(y):
return x + y
print inner(1)
return cython.typeof(x)
def test_outer_inner_carg():
"""
>>> print(test_outer_inner_carg())
2.0
long
"""
x = 1
def inner(double y):
return x + y
print inner(1)
return cython.typeof(x)
def test_outer_inner_incompatible():
"""
>>> print(test_outer_inner_incompatible())
Python object
"""
x = 1.0
def inner():
nonlocal x
x = 'test'
inner()
return cython.typeof(x)
def test_outer_inner2_double():
"""
>>> print(test_outer_inner2_double())
double
"""
x = 1.0
def inner1():
nonlocal x
x = 2
def inner2():
nonlocal x
x = 3.0
inner1()
inner2()
return cython.typeof(x)
...@@ -31,12 +31,12 @@ def test_unicode_loop(): ...@@ -31,12 +31,12 @@ def test_unicode_loop():
print 2, cython.typeof(c) print 2, cython.typeof(c)
yield c yield c
def test_nonlocal_disables_inference(): def test_with_nonlocal():
""" """
>>> chars = list(test_nonlocal_disables_inference()) >>> chars = list(test_with_nonlocal())
1 Python object 1 Py_UCS4
2 Python object 2 Py_UCS4
2 Python object 2 Py_UCS4
>>> len(chars) >>> len(chars)
2 2
>>> ''.join(chars) == 'ab' >>> ''.join(chars) == 'ab'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment