Commit 024a3d31 authored by Robert Bradshaw's avatar Robert Bradshaw

Actual type inference.

parent d3ef6b9c
......@@ -309,12 +309,12 @@ class ExprNode(Node):
# --------------- Type Inference -----------------
def type_dependencies(self):
def type_dependencies(self, env):
# Returns the list of entries whose types must be determined
# before the type of self can be infered.
if hasattr(self, 'type') and self.type is not None:
return ()
return sum([node.type_dependencies() for node in self.subexpr_nodes()], ())
return sum([node.type_dependencies(env) for node in self.subexpr_nodes()], ())
def infer_type(self, env):
# Attempt to deduce the type of self.
......@@ -832,8 +832,9 @@ class StringNode(ConstNode):
def calculate_result_code(self):
return self.result_code
class UnicodeNode(PyConstNode):
# entry Symtab.Entry
type = unicode_type
def coerce_to(self, dst_type, env):
......@@ -976,8 +977,21 @@ class NameNode(AtomicExprNode):
create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def type_dependencies(self):
return self.entry
def type_dependencies(self, env):
if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is not None and self.entry.type.is_unspecified:
return (self.entry,)
else:
return ()
def infer_type(self, env):
if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is None:
return py_object_type
else:
return self.entry.type
def compile_time_value(self, denv):
try:
......@@ -1628,8 +1642,8 @@ class IndexNode(ExprNode):
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
return None
def type_dependencies(self):
return self.base.type_dependencies()
def type_dependencies(self, env):
return self.base.type_dependencies(env)
def infer_type(self, env):
if isinstance(self.base, StringNode):
......@@ -2251,10 +2265,10 @@ class SimpleCallNode(CallNode):
except Exception, e:
self.compile_time_value_error(e)
def type_dependencies(self):
def type_dependencies(self, env):
# TODO: Update when Danilo's C++ code merged in to handle the
# the case of function overloading.
return self.function.type_dependencies()
return self.function.type_dependencies(env)
def infer_type(self, env):
func_type = self.function.infer_type(env)
......@@ -2705,13 +2719,16 @@ class AttributeNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
def type_dependencies(self, env):
return self.obj.type_dependencies(env)
def infer_type(self, env):
if self.analyse_as_cimported_attribute(env, 0):
return self.entry.type
elif self.analyse_as_unbound_cmethod(env):
return self.entry.type
else:
self.analyse_attribute(env)
self.analyse_attribute(env, obj_type = self.obj.infer_type(env))
return self.type
def analyse_target_declaration(self, env):
......@@ -2816,13 +2833,17 @@ class AttributeNode(ExprNode):
self.is_temp = 1
self.result_ctype = py_object_type
def analyse_attribute(self, env):
def analyse_attribute(self, env, obj_type = None):
# Look up attribute and set self.type and self.member.
self.is_py_attr = 0
self.member = self.attribute
if obj_type is None:
if self.obj.type.is_string:
self.obj = self.obj.coerce_to_pyobject(env)
obj_type = self.obj.type
else:
if obj_type.is_string:
obj_type = py_object_type
if obj_type.is_ptr or obj_type.is_array:
obj_type = obj_type.base_type
self.op = "->"
......@@ -2861,9 +2882,10 @@ class AttributeNode(ExprNode):
# type, or it is an extension type and the attribute is either not
# declared or is declared as a Python method. Treat it as a Python
# attribute reference.
self.analyse_as_python_attribute(env)
self.analyse_as_python_attribute(env, obj_type)
def analyse_as_python_attribute(self, env):
def analyse_as_python_attribute(self, env, obj_type = None):
if obj_type is None:
obj_type = self.obj.type
self.member = self.attribute
if obj_type.is_pyobject:
......@@ -3017,6 +3039,7 @@ class StarredTargetNode(ExprNode):
subexprs = ['target']
is_starred = 1
type = py_object_type
is_temp = 1
def __init__(self, pos, target):
self.pos = pos
......@@ -3347,7 +3370,7 @@ class ListNode(SequenceNode):
gil_message = "Constructing Python list"
def type_dependencies(self):
def type_dependencies(self, env):
return ()
def infer_type(self, env):
......@@ -3608,7 +3631,7 @@ class DictNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
def type_dependencies(self):
def type_dependencies(self, env):
return ()
def infer_type(self, env):
......@@ -4064,10 +4087,10 @@ class TypecastNode(ExprNode):
subexprs = ['operand']
base_type = declarator = type = None
def type_dependencies(self):
def type_dependencies(self, env):
return ()
def infer_types(self, env):
def infer_type(self, env):
if self.type is None:
base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env)
......@@ -4297,7 +4320,7 @@ class BinopNode(ExprNode):
def infer_type(self, env):
return self.result_type(self.operand1.infer_type(env),
self.operand1.infer_type(env))
self.operand2.infer_type(env))
def analyse_types(self, env):
self.operand1.analyse_types(env)
......@@ -4821,9 +4844,12 @@ class CondExprNode(ExprNode):
subexprs = ['test', 'true_val', 'false_val']
def type_dependencies(self):
return self.true_val.type_dependencies() + self.false_val.type_dependencies()
def type_dependencies(self, env):
return self.true_val.type_dependencies(env) + self.false_val.type_dependencies(env)
def infer_type(self, env):
return self.compute_result_type(self.true_val.infer_type(env),
self.false_val.infer_type(env))
def infer_types(self, env):
return self.compute_result_type(self.true_val.infer_types(env),
self.false_val.infer_types(env))
......@@ -5078,6 +5104,13 @@ class PrimaryCmpNode(ExprNode, CmpNode):
cascade = None
def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def type_dependencies(self, env):
return ()
def calculate_constant_result(self):
self.constant_result = self.calculate_cascaded_constant_result(
self.operand1.constant_result)
......@@ -5212,6 +5245,13 @@ class CascadedCmpNode(Node, CmpNode):
cascade = None
constant_result = constant_value_not_set # FIXME: where to calculate this?
def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def type_dependencies(self, env):
return ()
def analyse_types(self, env, operand1):
self.operand2.analyse_types(env)
if self.cascade:
......@@ -5435,11 +5475,10 @@ class CoerceToPyTypeNode(CoercionNode):
# to a Python object.
type = py_object_type
is_temp = 1
def __init__(self, arg, env):
CoercionNode.__init__(self, arg)
self.type = py_object_type
self.is_temp = 1
if not arg.type.create_to_py_utility_code(env):
error(arg.pos,
"Cannot convert '%s' to Python object" % arg.type)
......@@ -5615,8 +5654,8 @@ class CloneNode(CoercionNode):
def result(self):
return self.arg.result()
def type_dependencies(self):
return self.arg.type_dependencies()
def type_dependencies(self, env):
return self.arg.type_dependencies(env)
def infer_type(self, env):
return self.arg.infer_type(env)
......
......@@ -77,6 +77,7 @@ class PyrexType(BaseType):
#
is_pyobject = 0
is_unspecified = 0
is_extension_type = 0
is_builtin_type = 0
is_numeric = 0
......@@ -1592,6 +1593,8 @@ class CUCharPtrType(CStringType, CPtrType):
class UnspecifiedType(PyrexType):
# Used as a placeholder until the type can be determined.
is_unspecified = 1
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
return "<unspecified>"
......@@ -1788,6 +1791,20 @@ def widest_numeric_type(type1, type2):
return sign_and_rank_to_type[min(type1.signed, type2.signed), max(type1.rank, type2.rank)]
return widest_type
def spanning_type(type1, type2):
# Return a type assignable from both type1 and type2.
if type1 == type2:
return type1
elif type1.is_numeric and type2.is_numeric:
return widest_numeric_type(type1, type2)
elif type1.assignable_from(type2):
return type1
elif type2.assignable_from(type1):
return type2
else:
return py_object_type
def simple_c_type(signed, longness, name):
# Find type descriptor for simple type given name and modifiers.
# Returns None if arguments don't make sense.
......
......@@ -174,6 +174,9 @@ class Entry(object):
self.init = init
self.assignments = []
def __repr__(self):
return "Entry(name=%s, type=%s)" % (self.name, self.type)
def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name)
error(self.pos, "Previous declaration is here")
......@@ -546,10 +549,8 @@ class Scope(object):
return 0
def infer_types(self):
for name, entry in self.entries.items():
if entry.type is unspecified_type:
entry.type = py_object_type
entry.init_to_none = Options.init_local_none # TODO: is there a better place for this?
from TypeInference import get_type_inferer
get_type_inferer().infer_types(self)
class PreImportScope(Scope):
......@@ -1054,6 +1055,10 @@ class ModuleScope(Scope):
var_entry.is_readonly = 1
entry.as_variable = var_entry
def infer_types(self):
from TypeInference import PyObjectTypeInferer
PyObjectTypeInferer().infer_types(self)
class LocalScope(Scope):
def __init__(self, name, outer_scope):
......@@ -1084,7 +1089,7 @@ class LocalScope(Scope):
cname, visibility, is_cdef)
if type.is_pyobject and not Options.init_local_none:
entry.init = "0"
entry.init_to_none = type.is_pyobject and Options.init_local_none
entry.init_to_none = (type.is_pyobject or type.is_unspecified) and Options.init_local_none
entry.is_local = 1
self.var_entries.append(entry)
return entry
......
import ExprNodes
import PyrexTypes
from PyrexTypes import py_object_type, unspecified_type, spanning_type
from Visitor import CythonTransform
try:
set
except NameError:
# Python 2.3
from sets import Set as set
class TypedExprNode(ExprNodes.ExprNode):
# Used for declaring assignments of a specified type whithout a known entry.
def __init__(self, type):
self.type = type
object_expr = TypedExprNode(PyrexTypes.py_object_type)
object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform):
......@@ -42,7 +49,23 @@ class MarkAssignments(CythonTransform):
return node
def visit_ForInStatNode(self, node):
# TODO: Figure out how this interacts with the range optimization...
# TODO: Remove redundancy with range optimization...
sequence = node.iterator.sequence
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and \
isinstance(function, ExprNodes.NameNode) and \
function.name in ('range', 'xrange'):
self.mark_assignment(node.target, sequence.args[0])
if len(sequence.args) > 1:
self.mark_assignment(node.target, sequence.args[1])
if len(sequence.args) > 2:
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos,
'+',
sequence.args[0],
sequence.args[2]))
else:
self.mark_assignment(node.target, object_expr)
self.visitchildren(node)
return node
......@@ -50,7 +73,11 @@ class MarkAssignments(CythonTransform):
def visit_ForFromStatNode(self, node):
self.mark_assignment(node.target, node.bound1)
if node.step is not None:
self.mark_assignment(node.target, ExprNodes.binop_node(node.pos, '+', node.bound1, node.step))
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos,
'+',
node.bound1,
node.step))
self.visitchildren(node)
return node
......@@ -69,3 +96,79 @@ class MarkAssignments(CythonTransform):
self.mark_assignment(target, object_expr)
self.visitchildren(node)
return node
class PyObjectTypeInferer:
"""
If it's not declared, it's a PyObject.
"""
def infer_types(self, scope):
"""
Given a dict of entries, map all unspecified types to a specified type.
"""
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
entry.type = py_object_type
class SimpleAssignmentTypeInferer:
"""
Very basic type inference.
"""
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
def infer_types(self, scope):
dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = []
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
all = set()
for expr in entry.assignments:
all.update(expr.type_dependencies(scope))
if all:
dependancies_by_entry[entry] = all
for dep in all:
if dep not in entries_by_dependancy:
entries_by_dependancy[dep] = set([entry])
else:
entries_by_dependancy[dep].add(entry)
else:
ready_to_infer.append(entry)
def resolve_dependancy(dep):
if dep in entries_by_dependancy:
for entry in entries_by_dependancy[dep]:
entry_deps = dependancies_by_entry[entry]
entry_deps.remove(dep)
if not entry_deps and entry != dep:
del dependancies_by_entry[entry]
ready_to_infer.append(entry)
# Try to infer things in order...
while ready_to_infer:
while ready_to_infer:
entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments]
if types:
entry.type = reduce(spanning_type, types)
else:
print "No assignments", entry.pos, entry
entry.type = py_object_type
resolve_dependancy(entry)
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types:
entry.type = reduce(spanning_type, types)
types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = reduce(spanning_type, types) # might be wider...
resolve_dependancy(entry)
del dependancies_by_entry[entry]
if ready_to_infer:
break
# We can't figure out the rest with this algorithm, let them be objects.
for entry in dependancies_by_entry:
entry.type = py_object_type
def get_type_inferer():
return SimpleAssignmentTypeInferer()
print "starting"
def primes(int kmax):
cdef int n, k, i
# cdef int n, k, i
cdef int p[1000]
result = []
if kmax > 1000:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment