Commit 698f7535 authored by scoder's avatar scoder

Merge pull request #118 from vitek/_markassignments

Use assignments collected by CF for type inference
parents 7c852f79 e99f6ac7
......@@ -26,6 +26,7 @@ cdef class ExitBlock(ControlBlock):
cdef class NameAssignment:
cdef public bint is_arg
cdef public bint is_deletion
cdef public object lhs
cdef public object rhs
cdef public object entry
......
......@@ -8,7 +8,8 @@ cython.declare(PyrexTypes=object, Naming=object, ExprNodes=object, Nodes=object,
import Builtin
import ExprNodes
import Nodes
from PyrexTypes import py_object_type
from PyrexTypes import py_object_type, unspecified_type
import PyrexTypes
from Visitor import TreeVisitor, CythonTransform
from Errors import error, warning, InternalError
......@@ -24,6 +25,9 @@ class TypedExprNode(ExprNodes.ExprNode):
object_expr = TypedExprNode(py_object_type, may_be_none=True)
object_expr_not_none = TypedExprNode(py_object_type, may_be_none=False)
# Fake rhs to silence "unused variable" warning
fake_rhs_expr = TypedExprNode(unspecified_type)
class ControlBlock(object):
"""Control flow graph node. Sequence of assignments and name references.
......@@ -174,7 +178,7 @@ class ControlFlow(object):
def mark_deletion(self, node, entry):
if self.block and self.is_tracked(entry):
assignment = NameAssignment(node, None, entry)
assignment = NameDeletion(node, entry)
self.block.stats.append(assignment)
self.block.gen[entry] = Uninitialized
self.entries.add(entry)
......@@ -293,6 +297,7 @@ class ExceptionDescr(object):
self.finally_enter = finally_enter
self.finally_exit = finally_exit
class NameAssignment(object):
def __init__(self, lhs, rhs, entry):
if lhs.cf_state is None:
......@@ -303,15 +308,24 @@ class NameAssignment(object):
self.pos = lhs.pos
self.refs = set()
self.is_arg = False
self.is_deletion = False
def __repr__(self):
return '%s(entry=%r)' % (self.__class__.__name__, self.entry)
class Argument(NameAssignment):
def __init__(self, lhs, rhs, entry):
NameAssignment.__init__(self, lhs, rhs, entry)
self.is_arg = True
class NameDeletion(NameAssignment):
def __init__(self, lhs, entry):
NameAssignment.__init__(self, lhs, lhs, entry)
self.is_deletion = True
class Uninitialized(object):
pass
......@@ -462,12 +476,13 @@ def check_definitions(flow, compiler_directives):
stat.lhs.cf_state.update(state)
assmt_nodes.add(stat.lhs)
i_state = i_state & ~i_assmts.mask
if stat.rhs:
i_state |= stat.bit
else:
if stat.is_deletion:
i_state |= i_assmts.bit
else:
i_state |= stat.bit
assignments.add(stat)
stat.entry.cf_assignments.append(stat)
if stat.rhs is not fake_rhs_expr:
stat.entry.cf_assignments.append(stat)
elif isinstance(stat, NameReference):
references[stat.node] = stat.entry
stat.entry.cf_references.append(stat)
......@@ -754,7 +769,8 @@ class ControlFlowAnalysis(CythonTransform):
entry = self.env.lookup(node.name)
if entry:
may_be_none = not node.not_none
self.flow.mark_argument(node, TypedExprNode(entry.type, may_be_none), entry)
self.flow.mark_argument(
node, TypedExprNode(entry.type, may_be_none), entry)
return node
def visit_NameNode(self, node):
......@@ -838,6 +854,59 @@ class ControlFlowAnalysis(CythonTransform):
self.flow.block = None
return node
def mark_forloop_target(self, node):
# TODO: Remove redundancy with range optimization...
is_special = False
sequence = node.iterator.sequence
target = node.target
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.env.lookup(function.name)
if not entry or entry.is_builtin:
if function.name == 'reversed' and len(sequence.args) == 1:
sequence = sequence.args[0]
elif function.name == 'enumerate' and len(sequence.args) == 1:
if target.is_sequence_constructor and len(target.args) == 2:
iterator = sequence.args[0]
if iterator.is_name:
iterator_type = iterator.infer_type(self.env)
if iterator_type.is_builtin_type:
# assume that builtin types have a length within Py_ssize_t
self.mark_assignment(
target.args[0],
ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
type=PyrexTypes.c_py_ssize_t_type))
target = target.args[1]
sequence = sequence.args[0]
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.env.lookup(function.name)
if not entry or entry.is_builtin:
if function.name in ('range', 'xrange'):
is_special = True
for arg in sequence.args[:2]:
self.mark_assignment(target, arg)
if len(sequence.args) > 2:
self.mark_assignment(
target,
ExprNodes.binop_node(node.pos,
'+',
sequence.args[0],
sequence.args[2]))
if not is_special:
# A for-loop basically translates to subsequent calls to
# __getitem__(), so using an IndexNode here allows us to
# naturally infer the base type of pointers, C arrays,
# Python strings, etc., while correctly falling back to an
# object type when the base type cannot be handled.
self.mark_assignment(target, ExprNodes.IndexNode(
node.pos,
base = sequence,
index = ExprNodes.IntNode(node.pos, value = '0')))
def visit_ForInStatNode(self, node):
condition_block = self.flow.nextblock()
next_block = self.flow.newblock()
......@@ -846,7 +915,11 @@ class ControlFlowAnalysis(CythonTransform):
self.visit(node.iterator)
# Target assignment
self.flow.nextblock()
self.mark_assignment(node.target)
if isinstance(node, Nodes.ForInStatNode):
self.mark_forloop_target(node)
else: # Parallel
self.mark_assignment(node.target)
# Body block
if isinstance(node, Nodes.ParallelRangeNode):
......@@ -916,12 +989,15 @@ class ControlFlowAnalysis(CythonTransform):
self.flow.loops.append(LoopDescr(next_block, condition_block))
self.visit(node.bound1)
self.visit(node.bound2)
if node.step:
if node.step is not None:
self.visit(node.step)
# Target assignment
self.flow.nextblock()
self.mark_assignment(node.target)
self.mark_assignment(node.target, node.bound1)
if node.step is not None:
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos, '+',
node.bound1, node.step))
# Body block
self.flow.nextblock()
self.visit(node.body)
......@@ -1143,6 +1219,6 @@ class ControlFlowAnalysis(CythonTransform):
def visit_AmpersandNode(self, node):
if node.operand.is_name:
# Fake assignment to silence warning
self.mark_assignment(node.operand)
self.mark_assignment(node.operand, fake_rhs_expr)
self.visitchildren(node)
return node
......@@ -129,7 +129,7 @@ def create_pipeline(context, mode, exclude_classes=()):
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions
from ParseTreeTransforms import RemoveUnreachableCode, GilCheck
from FlowControl import ControlFlowAnalysis
......@@ -179,10 +179,10 @@ def create_pipeline(context, mode, exclude_classes=()):
EmbedSignature(context),
EarlyReplaceBuiltinCalls(context), ## Necessary?
TransformBuiltinMethods(context), ## Necessary?
MarkAssignments(context),
MarkParallelAssignments(context),
ControlFlowAnalysis(context),
RemoveUnreachableCode(context),
# MarkAssignments(context),
# MarkParallelAssignments(context),
MarkOverflowingArithmetic(context),
IntroduceBufferAuxiliaryVars(context),
_check_c_declarations,
......
......@@ -112,7 +112,6 @@ class Entry(object):
# buffer_aux BufferAux or None Extra information needed for buffer variables
# inline_func_in_pxd boolean Hacky special case for inline function in pxd file.
# Ideally this should not be necesarry.
# assignments [ExprNode] List of expressions that get assigned to this entry.
# might_overflow boolean In an arithmetic expression that could cause
# overflow (used for type inference).
# utility_code_definition For some Cython builtins, the utility code
......@@ -193,7 +192,6 @@ class Entry(object):
self.pos = pos
self.init = init
self.overloaded_alternatives = []
self.assignments = []
self.cf_assignments = []
self.cf_references = []
......
......@@ -15,7 +15,11 @@ class TypedExprNode(ExprNodes.ExprNode):
object_expr = TypedExprNode(py_object_type)
class MarkAssignments(EnvTransform):
class MarkParallelAssignments(EnvTransform):
# Collects assignments inside parallel blocks prange, with parallel.
# Perhaps it's better to move it to ControlFlowAnalysis.
# tells us whether we're in a normal loop
in_loop = False
......@@ -24,14 +28,13 @@ class MarkAssignments(EnvTransform):
def __init__(self, context):
# Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = []
return super(MarkAssignments, self).__init__(context)
return super(MarkParallelAssignments, self).__init__(context)
def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
if lhs.entry is None:
# TODO: This shouldn't happen...
return
lhs.entry.assignments.append(rhs)
if self.parallel_block_stack:
parallel_node = self.parallel_block_stack[-1]
......@@ -359,8 +362,8 @@ class SimpleAssignmentTypeInferer(object):
entry.type = py_object_type
continue
all = set()
for expr in entry.assignments:
all.update(expr.type_dependencies(scope))
for assmt in entry.cf_assignments:
all.update(assmt.rhs.type_dependencies(scope))
if all:
dependancies_by_entry[entry] = all
for dep in all:
......@@ -384,7 +387,8 @@ class SimpleAssignmentTypeInferer(object):
while True:
while ready_to_infer:
entry = ready_to_infer.pop()
types = [expr.infer_type(scope) for expr in entry.assignments]
types = [assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments]
if types and Utils.all(types):
entry.type = spanning_type(types, entry.might_overflow)
else:
......@@ -397,10 +401,13 @@ class SimpleAssignmentTypeInferer(object):
# Deal with simple circular dependancies...
for entry, deps in dependancies_by_entry.items():
if len(deps) == 1 and deps == set([entry]):
types = [expr.infer_type(scope) for expr in entry.assignments if expr.type_dependencies(scope) == ()]
types = [assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments
if assmt.rhs.type_dependencies(scope) == ()]
if types:
entry.type = spanning_type(types, entry.might_overflow)
types = [expr.infer_type(scope) for expr in entry.assignments]
types = [assmt.rhs.infer_type(scope)
for assmt in entry.cf_assignments]
entry.type = spanning_type(types, entry.might_overflow) # might be wider...
resolve_dependancy(entry)
del dependancies_by_entry[entry]
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment