Commit 024a3d31 authored by Robert Bradshaw's avatar Robert Bradshaw

Actual type inference.

parent d3ef6b9c
...@@ -309,12 +309,12 @@ class ExprNode(Node): ...@@ -309,12 +309,12 @@ class ExprNode(Node):
# --------------- Type Inference ----------------- # --------------- Type Inference -----------------
def type_dependencies(self): def type_dependencies(self, env):
# Returns the list of entries whose types must be determined # Returns the list of entries whose types must be determined
# before the type of self can be infered. # before the type of self can be infered.
if hasattr(self, 'type') and self.type is not None: if hasattr(self, 'type') and self.type is not None:
return () return ()
return sum([node.type_dependencies() for node in self.subexpr_nodes()], ()) return sum([node.type_dependencies(env) for node in self.subexpr_nodes()], ())
def infer_type(self, env): def infer_type(self, env):
# Attempt to deduce the type of self. # Attempt to deduce the type of self.
...@@ -832,8 +832,9 @@ class StringNode(ConstNode): ...@@ -832,8 +832,9 @@ class StringNode(ConstNode):
def calculate_result_code(self): def calculate_result_code(self):
return self.result_code return self.result_code
class UnicodeNode(PyConstNode): class UnicodeNode(PyConstNode):
# entry Symtab.Entry
type = unicode_type type = unicode_type
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
...@@ -976,8 +977,21 @@ class NameNode(AtomicExprNode): ...@@ -976,8 +977,21 @@ class NameNode(AtomicExprNode):
create_analysed_rvalue = staticmethod(create_analysed_rvalue) create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def type_dependencies(self): def type_dependencies(self, env):
return self.entry if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is not None and self.entry.type.is_unspecified:
return (self.entry,)
else:
return ()
def infer_type(self, env):
if self.entry is None:
self.entry = env.lookup(self.name)
if self.entry is None:
return py_object_type
else:
return self.entry.type
def compile_time_value(self, denv): def compile_time_value(self, denv):
try: try:
...@@ -1628,8 +1642,8 @@ class IndexNode(ExprNode): ...@@ -1628,8 +1642,8 @@ class IndexNode(ExprNode):
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env))) return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
return None return None
def type_dependencies(self): def type_dependencies(self, env):
return self.base.type_dependencies() return self.base.type_dependencies(env)
def infer_type(self, env): def infer_type(self, env):
if isinstance(self.base, StringNode): if isinstance(self.base, StringNode):
...@@ -2251,10 +2265,10 @@ class SimpleCallNode(CallNode): ...@@ -2251,10 +2265,10 @@ class SimpleCallNode(CallNode):
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def type_dependencies(self): def type_dependencies(self, env):
# TODO: Update when Danilo's C++ code merged in to handle the # TODO: Update when Danilo's C++ code merged in to handle the
# the case of function overloading. # the case of function overloading.
return self.function.type_dependencies() return self.function.type_dependencies(env)
def infer_type(self, env): def infer_type(self, env):
func_type = self.function.infer_type(env) func_type = self.function.infer_type(env)
...@@ -2705,13 +2719,16 @@ class AttributeNode(ExprNode): ...@@ -2705,13 +2719,16 @@ class AttributeNode(ExprNode):
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def type_dependencies(self, env):
return self.obj.type_dependencies(env)
def infer_type(self, env): def infer_type(self, env):
if self.analyse_as_cimported_attribute(env, 0): if self.analyse_as_cimported_attribute(env, 0):
return self.entry.type return self.entry.type
elif self.analyse_as_unbound_cmethod(env): elif self.analyse_as_unbound_cmethod(env):
return self.entry.type return self.entry.type
else: else:
self.analyse_attribute(env) self.analyse_attribute(env, obj_type = self.obj.infer_type(env))
return self.type return self.type
def analyse_target_declaration(self, env): def analyse_target_declaration(self, env):
...@@ -2816,13 +2833,17 @@ class AttributeNode(ExprNode): ...@@ -2816,13 +2833,17 @@ class AttributeNode(ExprNode):
self.is_temp = 1 self.is_temp = 1
self.result_ctype = py_object_type self.result_ctype = py_object_type
def analyse_attribute(self, env): def analyse_attribute(self, env, obj_type = None):
# Look up attribute and set self.type and self.member. # Look up attribute and set self.type and self.member.
self.is_py_attr = 0 self.is_py_attr = 0
self.member = self.attribute self.member = self.attribute
if self.obj.type.is_string: if obj_type is None:
self.obj = self.obj.coerce_to_pyobject(env) if self.obj.type.is_string:
obj_type = self.obj.type self.obj = self.obj.coerce_to_pyobject(env)
obj_type = self.obj.type
else:
if obj_type.is_string:
obj_type = py_object_type
if obj_type.is_ptr or obj_type.is_array: if obj_type.is_ptr or obj_type.is_array:
obj_type = obj_type.base_type obj_type = obj_type.base_type
self.op = "->" self.op = "->"
...@@ -2861,10 +2882,11 @@ class AttributeNode(ExprNode): ...@@ -2861,10 +2882,11 @@ class AttributeNode(ExprNode):
# type, or it is an extension type and the attribute is either not # type, or it is an extension type and the attribute is either not
# declared or is declared as a Python method. Treat it as a Python # declared or is declared as a Python method. Treat it as a Python
# attribute reference. # attribute reference.
self.analyse_as_python_attribute(env) self.analyse_as_python_attribute(env, obj_type)
def analyse_as_python_attribute(self, env): def analyse_as_python_attribute(self, env, obj_type = None):
obj_type = self.obj.type if obj_type is None:
obj_type = self.obj.type
self.member = self.attribute self.member = self.attribute
if obj_type.is_pyobject: if obj_type.is_pyobject:
self.type = py_object_type self.type = py_object_type
...@@ -3017,6 +3039,7 @@ class StarredTargetNode(ExprNode): ...@@ -3017,6 +3039,7 @@ class StarredTargetNode(ExprNode):
subexprs = ['target'] subexprs = ['target']
is_starred = 1 is_starred = 1
type = py_object_type type = py_object_type
is_temp = 1
def __init__(self, pos, target): def __init__(self, pos, target):
self.pos = pos self.pos = pos
...@@ -3347,7 +3370,7 @@ class ListNode(SequenceNode): ...@@ -3347,7 +3370,7 @@ class ListNode(SequenceNode):
gil_message = "Constructing Python list" gil_message = "Constructing Python list"
def type_dependencies(self): def type_dependencies(self, env):
return () return ()
def infer_type(self, env): def infer_type(self, env):
...@@ -3608,7 +3631,7 @@ class DictNode(ExprNode): ...@@ -3608,7 +3631,7 @@ class DictNode(ExprNode):
except Exception, e: except Exception, e:
self.compile_time_value_error(e) self.compile_time_value_error(e)
def type_dependencies(self): def type_dependencies(self, env):
return () return ()
def infer_type(self, env): def infer_type(self, env):
...@@ -4064,10 +4087,10 @@ class TypecastNode(ExprNode): ...@@ -4064,10 +4087,10 @@ class TypecastNode(ExprNode):
subexprs = ['operand'] subexprs = ['operand']
base_type = declarator = type = None base_type = declarator = type = None
def type_dependencies(self): def type_dependencies(self, env):
return () return ()
def infer_types(self, env): def infer_type(self, env):
if self.type is None: if self.type is None:
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env) _, self.type = self.declarator.analyse(base_type, env)
...@@ -4297,8 +4320,8 @@ class BinopNode(ExprNode): ...@@ -4297,8 +4320,8 @@ class BinopNode(ExprNode):
def infer_type(self, env): def infer_type(self, env):
return self.result_type(self.operand1.infer_type(env), return self.result_type(self.operand1.infer_type(env),
self.operand1.infer_type(env)) self.operand2.infer_type(env))
def analyse_types(self, env): def analyse_types(self, env):
self.operand1.analyse_types(env) self.operand1.analyse_types(env)
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
...@@ -4821,9 +4844,12 @@ class CondExprNode(ExprNode): ...@@ -4821,9 +4844,12 @@ class CondExprNode(ExprNode):
subexprs = ['test', 'true_val', 'false_val'] subexprs = ['test', 'true_val', 'false_val']
def type_dependencies(self): def type_dependencies(self, env):
return self.true_val.type_dependencies() + self.false_val.type_dependencies() return self.true_val.type_dependencies(env) + self.false_val.type_dependencies(env)
def infer_type(self, env):
return self.compute_result_type(self.true_val.infer_type(env),
self.false_val.infer_type(env))
def infer_types(self, env): def infer_types(self, env):
return self.compute_result_type(self.true_val.infer_types(env), return self.compute_result_type(self.true_val.infer_types(env),
self.false_val.infer_types(env)) self.false_val.infer_types(env))
...@@ -5078,6 +5104,13 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5078,6 +5104,13 @@ class PrimaryCmpNode(ExprNode, CmpNode):
cascade = None cascade = None
def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def type_dependencies(self, env):
return ()
def calculate_constant_result(self): def calculate_constant_result(self):
self.constant_result = self.calculate_cascaded_constant_result( self.constant_result = self.calculate_cascaded_constant_result(
self.operand1.constant_result) self.operand1.constant_result)
...@@ -5212,6 +5245,13 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -5212,6 +5245,13 @@ class CascadedCmpNode(Node, CmpNode):
cascade = None cascade = None
constant_result = constant_value_not_set # FIXME: where to calculate this? constant_result = constant_value_not_set # FIXME: where to calculate this?
def infer_type(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def type_dependencies(self, env):
return ()
def analyse_types(self, env, operand1): def analyse_types(self, env, operand1):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
...@@ -5435,11 +5475,10 @@ class CoerceToPyTypeNode(CoercionNode): ...@@ -5435,11 +5475,10 @@ class CoerceToPyTypeNode(CoercionNode):
# to a Python object. # to a Python object.
type = py_object_type type = py_object_type
is_temp = 1
def __init__(self, arg, env): def __init__(self, arg, env):
CoercionNode.__init__(self, arg) CoercionNode.__init__(self, arg)
self.type = py_object_type
self.is_temp = 1
if not arg.type.create_to_py_utility_code(env): if not arg.type.create_to_py_utility_code(env):
error(arg.pos, error(arg.pos,
"Cannot convert '%s' to Python object" % arg.type) "Cannot convert '%s' to Python object" % arg.type)
...@@ -5611,16 +5650,16 @@ class CloneNode(CoercionNode): ...@@ -5611,16 +5650,16 @@ class CloneNode(CoercionNode):
self.result_ctype = arg.result_ctype self.result_ctype = arg.result_ctype
if hasattr(arg, 'entry'): if hasattr(arg, 'entry'):
self.entry = arg.entry self.entry = arg.entry
def result(self): def result(self):
return self.arg.result() return self.arg.result()
def type_dependencies(self): def type_dependencies(self, env):
return self.arg.type_dependencies() return self.arg.type_dependencies(env)
def infer_type(self, env): def infer_type(self, env):
return self.arg.infer_type(env) return self.arg.infer_type(env)
def analyse_types(self, env): def analyse_types(self, env):
self.type = self.arg.type self.type = self.arg.type
self.result_ctype = self.arg.result_ctype self.result_ctype = self.arg.result_ctype
......
...@@ -77,6 +77,7 @@ class PyrexType(BaseType): ...@@ -77,6 +77,7 @@ class PyrexType(BaseType):
# #
is_pyobject = 0 is_pyobject = 0
is_unspecified = 0
is_extension_type = 0 is_extension_type = 0
is_builtin_type = 0 is_builtin_type = 0
is_numeric = 0 is_numeric = 0
...@@ -1591,6 +1592,8 @@ class CUCharPtrType(CStringType, CPtrType): ...@@ -1591,6 +1592,8 @@ class CUCharPtrType(CStringType, CPtrType):
class UnspecifiedType(PyrexType): class UnspecifiedType(PyrexType):
# Used as a placeholder until the type can be determined. # Used as a placeholder until the type can be determined.
is_unspecified = 1
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
...@@ -1788,6 +1791,20 @@ def widest_numeric_type(type1, type2): ...@@ -1788,6 +1791,20 @@ def widest_numeric_type(type1, type2):
return sign_and_rank_to_type[min(type1.signed, type2.signed), max(type1.rank, type2.rank)] return sign_and_rank_to_type[min(type1.signed, type2.signed), max(type1.rank, type2.rank)]
return widest_type return widest_type
def spanning_type(type1, type2):
# Return a type assignable from both type1 and type2.
if type1 == type2:
return type1
elif type1.is_numeric and type2.is_numeric:
return widest_numeric_type(type1, type2)
elif type1.assignable_from(type2):
return type1
elif type2.assignable_from(type1):
return type2
else:
return py_object_type
def simple_c_type(signed, longness, name): def simple_c_type(signed, longness, name):
# Find type descriptor for simple type given name and modifiers. # Find type descriptor for simple type given name and modifiers.
# Returns None if arguments don't make sense. # Returns None if arguments don't make sense.
......
...@@ -173,6 +173,9 @@ class Entry(object): ...@@ -173,6 +173,9 @@ class Entry(object):
self.pos = pos self.pos = pos
self.init = init self.init = init
self.assignments = [] self.assignments = []
def __repr__(self):
return "Entry(name=%s, type=%s)" % (self.name, self.type)
def redeclared(self, pos): def redeclared(self, pos):
error(pos, "'%s' does not match previous declaration" % self.name) error(pos, "'%s' does not match previous declaration" % self.name)
...@@ -546,10 +549,8 @@ class Scope(object): ...@@ -546,10 +549,8 @@ class Scope(object):
return 0 return 0
def infer_types(self): def infer_types(self):
for name, entry in self.entries.items(): from TypeInference import get_type_inferer
if entry.type is unspecified_type: get_type_inferer().infer_types(self)
entry.type = py_object_type
entry.init_to_none = Options.init_local_none # TODO: is there a better place for this?
class PreImportScope(Scope): class PreImportScope(Scope):
...@@ -1053,6 +1054,10 @@ class ModuleScope(Scope): ...@@ -1053,6 +1054,10 @@ class ModuleScope(Scope):
var_entry.is_cglobal = 1 var_entry.is_cglobal = 1
var_entry.is_readonly = 1 var_entry.is_readonly = 1
entry.as_variable = var_entry entry.as_variable = var_entry
def infer_types(self):
from TypeInference import PyObjectTypeInferer
PyObjectTypeInferer().infer_types(self)
class LocalScope(Scope): class LocalScope(Scope):
...@@ -1084,7 +1089,7 @@ class LocalScope(Scope): ...@@ -1084,7 +1089,7 @@ class LocalScope(Scope):
cname, visibility, is_cdef) cname, visibility, is_cdef)
if type.is_pyobject and not Options.init_local_none: if type.is_pyobject and not Options.init_local_none:
entry.init = "0" entry.init = "0"
entry.init_to_none = type.is_pyobject and Options.init_local_none entry.init_to_none = (type.is_pyobject or type.is_unspecified) and Options.init_local_none
entry.is_local = 1 entry.is_local = 1
self.var_entries.append(entry) self.var_entries.append(entry)
return entry return entry
......
import ExprNodes import ExprNodes
import PyrexTypes from PyrexTypes import py_object_type, unspecified_type, spanning_type
from Visitor import CythonTransform from Visitor import CythonTransform
try:
set
except NameError:
# Python 2.3
from sets import Set as set
class TypedExprNode(ExprNodes.ExprNode): class TypedExprNode(ExprNodes.ExprNode):
# Used for declaring assignments of a specified type whithout a known entry. # Used for declaring assignments of a specified type whithout a known entry.
def __init__(self, type): def __init__(self, type):
self.type = type self.type = type
object_expr = TypedExprNode(PyrexTypes.py_object_type) object_expr = TypedExprNode(py_object_type)
class MarkAssignments(CythonTransform): class MarkAssignments(CythonTransform):
...@@ -42,15 +49,35 @@ class MarkAssignments(CythonTransform): ...@@ -42,15 +49,35 @@ class MarkAssignments(CythonTransform):
return node return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
# TODO: Figure out how this interacts with the range optimization... # TODO: Remove redundancy with range optimization...
self.mark_assignment(node.target, object_expr) sequence = node.iterator.sequence
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and \
isinstance(function, ExprNodes.NameNode) and \
function.name in ('range', 'xrange'):
self.mark_assignment(node.target, sequence.args[0])
if len(sequence.args) > 1:
self.mark_assignment(node.target, sequence.args[1])
if len(sequence.args) > 2:
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos,
'+',
sequence.args[0],
sequence.args[2]))
else:
self.mark_assignment(node.target, object_expr)
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_ForFromStatNode(self, node): def visit_ForFromStatNode(self, node):
self.mark_assignment(node.target, node.bound1) self.mark_assignment(node.target, node.bound1)
if node.step is not None: if node.step is not None:
self.mark_assignment(node.target, ExprNodes.binop_node(node.pos, '+', node.bound1, node.step)) self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos,
'+',
node.bound1,
node.step))
self.visitchildren(node) self.visitchildren(node)
return node return node
...@@ -69,3 +96,79 @@ class MarkAssignments(CythonTransform): ...@@ -69,3 +96,79 @@ class MarkAssignments(CythonTransform):
self.mark_assignment(target, object_expr) self.mark_assignment(target, object_expr)
self.visitchildren(node) self.visitchildren(node)
return node return node
class PyObjectTypeInferer:
"""
If it's not declared, it's a PyObject.
"""
def infer_types(self, scope):
"""
Given a dict of entries, map all unspecified types to a specified type.
"""
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
entry.type = py_object_type
class SimpleAssignmentTypeInferer:
"""
Very basic type inference.
"""
# TODO: Implement a real type inference algorithm.
# (Something more powerful than just extending this one...)
def infer_types(self, scope):
dependancies_by_entry = {} # entry -> dependancies
entries_by_dependancy = {} # dependancy -> entries
ready_to_infer = []
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
all = set()
for expr in entry.assignments:
all.update(expr.type_dependencies(scope))
if all:
dependancies_by_entry[entry] = all
for dep in all:
if dep not in entries_by_dependancy:
entries_by_dependancy[dep] = set([entry])
else:
entries_by_dependancy[dep].add(entry)
else:
ready_to_infer.append(entry)
def resolve_dependancy(dep):
if dep in entries_by_dependancy:
for entry in entries_by_dependancy[dep]:
entry_deps = dependancies_by_entry[entry]
entry_deps.remove(dep)
if not entry_deps and entry != dep:
del dependancies_by_entry[entry]
ready_to_infer.append(entry)
# Try to infer things in order...
while ready_to_infer:
while ready_to_infer:
entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments]
if types:
entry.type = reduce(spanning_type, types)
else:
print "No assignments", entry.pos, entry
entry.type = py_object_type
resolve_dependancy(entry)
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
if types:
entry.type = reduce(spanning_type, types)
types = [expr.infer_type(scope) for expr in entry.assignments]
entry.type = reduce(spanning_type, types) # might be wider...
resolve_dependancy(entry)
del dependancies_by_entry[entry]
if ready_to_infer:
break
# We can't figure out the rest with this algorithm, let them be objects.
for entry in dependancies_by_entry:
entry.type = py_object_type
def get_type_inferer():
return SimpleAssignmentTypeInferer()
print "starting" print "starting"
def primes(int kmax): def primes(int kmax):
cdef int n, k, i # cdef int n, k, i
cdef int p[1000] cdef int p[1000]
result = [] result = []
if kmax > 1000: if kmax > 1000:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment