Commit d3ef6b9c authored by Robert Bradshaw's avatar Robert Bradshaw

Type inference methods for expression nodes.

parent cf6abbcf
......@@ -307,6 +307,27 @@ class ExprNode(Node):
temp_bool = bool.coerce_to_temp(env)
return temp_bool
# --------------- Type Inference -----------------
def type_dependencies(self):
# Returns the list of entries whose types must be determined
# before the type of self can be infered.
if hasattr(self, 'type') and self.type is not None:
return ()
return sum([node.type_dependencies() for node in self.subexpr_nodes()], ())
def infer_type(self, env):
# Attempt to deduce the type of self.
# Differs from analyse_types as it avoids unnecessary
# analysis of subexpressions, but can assume everything
# in self.type_dependencies() has been resolved.
if hasattr(self, 'type') and self.type is not None:
return self.type
elif hasattr(self, 'entry') and self.entry is not None:
return self.entry.type
else:
self.not_implemented("infer_type")
# --------------- Type Analysis ------------------
def analyse_as_module(self, env):
......@@ -858,6 +879,8 @@ class LongNode(AtomicExprNode):
#
# value string
type = py_object_type
def calculate_constant_result(self):
self.constant_result = long(self.value)
......@@ -865,7 +888,6 @@ class LongNode(AtomicExprNode):
return long(self.value)
def analyse_types(self, env):
self.type = py_object_type
self.is_temp = 1
gil_message = "Constructing Python long int"
......@@ -954,6 +976,9 @@ class NameNode(AtomicExprNode):
create_analysed_rvalue = staticmethod(create_analysed_rvalue)
def type_dependencies(self):
return self.entry
def compile_time_value(self, denv):
try:
return denv.lookup(self.name)
......@@ -1298,12 +1323,13 @@ class BackquoteNode(ExprNode):
#
# arg ExprNode
type = py_object_type
subexprs = ['arg']
def analyse_types(self, env):
self.arg.analyse_types(env)
self.arg = self.arg.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1
gil_message = "Backquote expression"
......@@ -1329,6 +1355,8 @@ class ImportNode(ExprNode):
# module_name IdentifierStringNode dotted name of module
# name_list ListNode or None list of names to be imported
type = py_object_type
subexprs = ['module_name', 'name_list']
def analyse_types(self, env):
......@@ -1337,7 +1365,6 @@ class ImportNode(ExprNode):
if self.name_list:
self.name_list.analyse_types(env)
self.name_list.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1
env.use_utility_code(import_utility_code)
......@@ -1367,12 +1394,13 @@ class IteratorNode(ExprNode):
#
# sequence ExprNode
type = py_object_type
subexprs = ['sequence']
def analyse_types(self, env):
self.sequence.analyse_types(env)
self.sequence = self.sequence.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1
gil_message = "Iterating over Python object"
......@@ -1424,10 +1452,11 @@ class NextNode(AtomicExprNode):
#
# iterator ExprNode
type = py_object_type
def __init__(self, iterator, env):
self.pos = iterator.pos
self.iterator = iterator
self.type = py_object_type
self.is_temp = 1
def generate_result_code(self, code):
......@@ -1480,9 +1509,10 @@ class ExcValueNode(AtomicExprNode):
# of an ExceptClauseNode to fetch the current
# exception value.
type = py_object_type
def __init__(self, pos, env):
ExprNode.__init__(self, pos)
self.type = py_object_type
def set_var(self, var):
self.var = var
......@@ -1598,6 +1628,19 @@ class IndexNode(ExprNode):
return PyrexTypes.CArrayType(base_type, int(self.index.compile_time_value(env)))
return None
def type_dependencies(self):
return self.base.type_dependencies()
def infer_type(self, env):
if isinstance(self.base, StringNode):
return py_object_type
base_type = self.base.infer_type(env)
if base_type.is_ptr or base_type.is_array:
return base_type.base_type
else:
# TODO: Handle buffers (hopefully without too much redundancy).
return py_object_type
def analyse_types(self, env):
self.analyse_base_and_index_types(env, getting = 1)
......@@ -2107,6 +2150,9 @@ class SliceNode(ExprNode):
# stop ExprNode
# step ExprNode
type = py_object_type
is_temp = 1
def calculate_constant_result(self):
self.constant_result = self.base.constant_result[
self.start.constant_result : \
......@@ -2137,8 +2183,6 @@ class SliceNode(ExprNode):
self.start = self.start.coerce_to_pyobject(env)
self.stop = self.stop.coerce_to_pyobject(env)
self.step = self.step.coerce_to_pyobject(env)
self.type = py_object_type
self.is_temp = 1
gil_message = "Constructing Python slice object"
......@@ -2154,6 +2198,7 @@ class SliceNode(ExprNode):
class CallNode(ExprNode):
def analyse_as_type_constructor(self, env):
type = self.function.analyse_as_type(env)
if type and type.is_struct_or_union:
......@@ -2206,6 +2251,20 @@ class SimpleCallNode(CallNode):
except Exception, e:
self.compile_time_value_error(e)
def type_dependencies(self):
# TODO: Update when Danilo's C++ code merged in to handle the
# the case of function overloading.
return self.function.type_dependencies()
def infer_type(self, env):
func_type = self.function.infer_type(env)
if func_type.is_ptr:
func_type = func_type.base_type
if func_type.is_cfunction:
return func_type.return_type
else:
return py_object_type
def analyse_as_type(self, env):
attr = self.function.as_cython_attribute()
if attr == 'pointer':
......@@ -2466,6 +2525,8 @@ class GeneralCallNode(CallNode):
# keyword_args ExprNode or None Dict of keyword arguments
# starstar_arg ExprNode or None Dict of extra keyword args
type = py_object_type
subexprs = ['function', 'positional_args', 'keyword_args', 'starstar_arg']
nogil_check = Node.gil_error
......@@ -2644,6 +2705,15 @@ class AttributeNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
def infer_type(self, env):
if self.analyse_as_cimported_attribute(env, 0):
return self.entry.type
elif self.analyse_as_unbound_cmethod(env):
return self.entry.type
else:
self.analyse_attribute(env)
return self.type
def analyse_target_declaration(self, env):
pass
......@@ -3208,6 +3278,8 @@ class SequenceNode(ExprNode):
class TupleNode(SequenceNode):
# Tuple constructor.
type = tuple_type
gil_message = "Constructing Python tuple"
def analyse_types(self, env, skip_children=False):
......@@ -3216,7 +3288,6 @@ class TupleNode(SequenceNode):
self.is_literal = 1
else:
SequenceNode.analyse_types(self, env, skip_children)
self.type = tuple_type
def calculate_result_code(self):
if len(self.args) > 0:
......@@ -3276,6 +3347,13 @@ class ListNode(SequenceNode):
gil_message = "Constructing Python list"
def type_dependencies(self):
return ()
def infer_type(self, env):
# TOOD: Infer non-object list arrays.
return list_type
def analyse_expressions(self, env):
SequenceNode.analyse_expressions(self, env)
self.coerce_to_pyobject(env)
......@@ -3382,6 +3460,9 @@ class ComprehensionNode(ExprNode):
subexprs = ["target"]
child_attrs = ["loop", "append"]
def infer_type(self, env):
return self.target.infer_type(env)
def analyse_types(self, env):
self.target.analyse_expressions(env)
self.type = self.target.type
......@@ -3458,6 +3539,8 @@ class DictComprehensionAppendNode(ComprehensionAppendNode):
class SetNode(ExprNode):
# Set constructor.
type = set_type
subexprs = ['args']
gil_message = "Constructing Python set"
......@@ -3525,6 +3608,13 @@ class DictNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
def type_dependencies(self):
return ()
def infer_type(self, env):
# TOOD: Infer struct constructors.
return dict_type
def analyse_types(self, env):
hold_errors()
for item in self.key_value_pairs:
......@@ -3688,12 +3778,13 @@ class UnboundMethodNode(ExprNode):
#
# function ExprNode Function object
type = py_object_type
is_temp = 1
subexprs = ['function']
def analyse_types(self, env):
self.function.analyse_types(env)
self.type = py_object_type
self.is_temp = 1
gil_message = "Constructing an unbound method"
......@@ -3714,9 +3805,11 @@ class PyCFunctionNode(AtomicExprNode):
#
# pymethdef_cname string PyMethodDef structure
type = py_object_type
is_temp = 1
def analyse_types(self, env):
self.type = py_object_type
self.is_temp = 1
pass
gil_message = "Constructing Python function"
......@@ -3772,6 +3865,9 @@ class UnopNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
def infer_type(self, env):
return self.operand.infer_type(env)
def analyse_types(self, env):
self.operand.analyse_types(env)
if self.is_py_operation():
......@@ -3820,6 +3916,10 @@ class NotNode(ExprNode):
#
# operand ExprNode
type = PyrexTypes.c_bint_type
subexprs = ['operand']
def calculate_constant_result(self):
self.constant_result = not self.operand.constant_result
......@@ -3830,12 +3930,12 @@ class NotNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
subexprs = ['operand']
def infer_type(self, env):
return PyrexTypes.c_bint_type
def analyse_types(self, env):
self.operand.analyse_types(env)
self.operand = self.operand.coerce_to_boolean(env)
self.type = PyrexTypes.c_bint_type
def calculate_result_code(self):
return "(!%s)" % self.operand.result()
......@@ -3904,6 +4004,9 @@ class AmpersandNode(ExprNode):
subexprs = ['operand']
def infer_type(self, env):
return PyrexTypes.c_ptr_type(self.operand.infer_type(env))
def analyse_types(self, env):
self.operand.analyse_types(env)
argtype = self.operand.type
......@@ -3961,6 +4064,15 @@ class TypecastNode(ExprNode):
subexprs = ['operand']
base_type = declarator = type = None
def type_dependencies(self):
return ()
def infer_types(self, env):
if self.type is None:
base_type = self.base_type.analyse(env)
_, self.type = self.declarator.analyse(base_type, env)
return self.type
def analyse_types(self, env):
if self.type is None:
base_type = self.base_type.analyse(env)
......@@ -4183,6 +4295,10 @@ class BinopNode(ExprNode):
except Exception, e:
self.compile_time_value_error(e)
def infer_type(self, env):
return self.result_type(self.operand1.infer_type(env),
self.operand1.infer_type(env))
def analyse_types(self, env):
self.operand1.analyse_types(env)
self.operand2.analyse_types(env)
......@@ -4196,8 +4312,16 @@ class BinopNode(ExprNode):
self.analyse_c_operation(env)
def is_py_operation(self):
return (self.operand1.type.is_pyobject
or self.operand2.type.is_pyobject)
return self.is_py_operation_types(self.operand1.type, self.operand2.type)
def is_py_operation_types(self, type1, type2):
return type1.is_pyobject or type2.is_pyobject
def result_type(self, type1, type2):
if self.is_py_operation_types(type1, type2):
return py_object_type
else:
return self.compute_c_result_type(type1, type2)
def nogil_check(self, env):
if self.is_py_operation():
......@@ -4321,12 +4445,11 @@ class IntBinopNode(NumBinopNode):
class AddNode(NumBinopNode):
# '+' operator.
def is_py_operation(self):
if self.operand1.type.is_string \
and self.operand2.type.is_string:
def is_py_operation_types(self, type1, type2):
if type1.is_string and type2.is_string:
return 1
else:
return NumBinopNode.is_py_operation(self)
return NumBinopNode.is_py_operation_types(self, type1, type2)
def compute_c_result_type(self, type1, type2):
#print "AddNode.compute_c_result_type:", type1, self.operator, type2 ###
......@@ -4355,14 +4478,12 @@ class SubNode(NumBinopNode):
class MulNode(NumBinopNode):
# '*' operator.
def is_py_operation(self):
type1 = self.operand1.type
type2 = self.operand2.type
def is_py_operation_types(self, type1, type2):
if (type1.is_string and type2.is_int) \
or (type2.is_string and type1.is_int):
return 1
else:
return NumBinopNode.is_py_operation(self)
return NumBinopNode.is_py_operation_types(self, type1, type2)
class DivNode(NumBinopNode):
......@@ -4499,10 +4620,10 @@ class DivNode(NumBinopNode):
class ModNode(DivNode):
# '%' operator.
def is_py_operation(self):
return (self.operand1.type.is_string
or self.operand2.type.is_string
or NumBinopNode.is_py_operation(self))
def is_py_operation_types(self, type1, type2):
return (type1.is_string
or type2.is_string
or NumBinopNode.is_py_operation_types(self, type1, type2))
def zero_division_message(self):
if self.type.is_int:
......@@ -4579,6 +4700,13 @@ class BoolBinopNode(ExprNode):
subexprs = ['operand1', 'operand2']
def infer_type(self, env):
if (self.operand1.infer_type(env).is_pyobject or
self.operand2.infer_type(env).is_pyobject):
return py_object_type
else:
return PyrexTypes.c_bint_type
def calculate_constant_result(self):
if self.operator == 'and':
self.constant_result = \
......@@ -4693,6 +4821,13 @@ class CondExprNode(ExprNode):
subexprs = ['test', 'true_val', 'false_val']
def type_dependencies(self):
return self.true_val.type_dependencies() + self.false_val.type_dependencies()
def infer_types(self, env):
return self.compute_result_type(self.true_val.infer_types(env),
self.false_val.infer_types(env))
def calculate_constant_result(self):
if self.test.constant_result:
self.constant_result = self.true_val.constant_result
......@@ -4777,6 +4912,10 @@ class CmpNode(object):
# Mixin class containing code common to PrimaryCmpNodes
# and CascadedCmpNodes.
def infer_types(self, env):
# TODO: Actually implement this (after merging with -unstable).
return py_object_type
def calculate_cascaded_constant_result(self, operand1_result):
func = compile_time_binary_operators[self.operator]
operand2_result = self.operand2.constant_result
......@@ -5295,6 +5434,8 @@ class CoerceToPyTypeNode(CoercionNode):
# This node is used to convert a C data type
# to a Python object.
type = py_object_type
def __init__(self, arg, env):
CoercionNode.__init__(self, arg)
self.type = py_object_type
......@@ -5366,9 +5507,10 @@ class CoerceToBooleanNode(CoercionNode):
# This node is used when a result needs to be used
# in a boolean context.
type = PyrexTypes.c_bint_type
def __init__(self, arg, env):
CoercionNode.__init__(self, arg)
self.type = PyrexTypes.c_bint_type
if arg.type.is_pyobject:
self.is_temp = 1
......@@ -5473,6 +5615,12 @@ class CloneNode(CoercionNode):
def result(self):
return self.arg.result()
def type_dependencies(self):
return self.arg.type_dependencies()
def infer_type(self, env):
return self.arg.infer_type(env)
def analyse_types(self, env):
self.type = self.arg.type
self.result_ctype = self.arg.result_ctype
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment