TypeInference.py 21 KB
Newer Older
1 2 3 4 5 6 7 8 9 10
from __future__ import absolute_import

from .Errors import error, message
from . import ExprNodes
from . import Nodes
from . import Builtin
from . import PyrexTypes
from .. import Utils
from .PyrexTypes import py_object_type, unspecified_type
from .Visitor import CythonTransform, EnvTransform
11

12 13 14 15 16
try:
    reduce
except NameError:
    from functools import reduce

Robert Bradshaw's avatar
Robert Bradshaw committed
17

18
class TypedExprNode(ExprNodes.ExprNode):
Stefan Behnel's avatar
Stefan Behnel committed
19
    # Used for declaring assignments of a specified type without a known entry.
20 21 22 23
    subexprs = []

    def __init__(self, type, pos=None):
        super(TypedExprNode, self).__init__(pos, type=type)
24

Robert Bradshaw's avatar
Robert Bradshaw committed
25
object_expr = TypedExprNode(py_object_type)
26

27 28 29 30 31

class MarkParallelAssignments(EnvTransform):
    # Collects assignments inside parallel blocks prange, with parallel.
    # Perhaps it's better to move it to ControlFlowAnalysis.

32 33 34
    # tells us whether we're in a normal loop
    in_loop = False

35 36
    parallel_errors = False

Mark Florisson's avatar
Mark Florisson committed
37 38 39
    def __init__(self, context):
        # Track the parallel block scopes (with parallel, for i in prange())
        self.parallel_block_stack = []
40
        super(MarkParallelAssignments, self).__init__(context)
Mark Florisson's avatar
Mark Florisson committed
41 42

    def mark_assignment(self, lhs, rhs, inplace_op=None):
43
        if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
44 45 46
            if lhs.entry is None:
                # TODO: This shouldn't happen...
                return
Mark Florisson's avatar
Mark Florisson committed
47 48 49

            if self.parallel_block_stack:
                parallel_node = self.parallel_block_stack[-1]
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
                previous_assignment = parallel_node.assignments.get(lhs.entry)

                # If there was a previous assignment to the variable, keep the
                # previous assignment position
                if previous_assignment:
                    pos, previous_inplace_op = previous_assignment

                    if (inplace_op and previous_inplace_op and
                            inplace_op != previous_inplace_op):
                        # x += y; x *= y
                        t = (inplace_op, previous_inplace_op)
                        error(lhs.pos,
                              "Reduction operator '%s' is inconsistent "
                              "with previous reduction operator '%s'" % t)
                else:
                    pos = lhs.pos

                parallel_node.assignments[lhs.entry] = (pos, inplace_op)
Mark Florisson's avatar
Mark Florisson committed
68
                parallel_node.assigned_nodes.append(lhs)
Mark Florisson's avatar
Mark Florisson committed
69

70
        elif isinstance(lhs, ExprNodes.SequenceNode):
71
            for i, arg in enumerate(lhs.args):
72 73
                if not rhs or arg.is_starred:
                    item_node = None
74
                else:
75
                    item_node = rhs.inferable_item_node(i)
76
                self.mark_assignment(arg, item_node)
77 78 79
        else:
            # Could use this info to infer cdef class attributes...
            pass
80

81
    def visit_WithTargetAssignmentStatNode(self, node):
82
        self.mark_assignment(node.lhs, node.with_node.enter_call)
83 84 85
        self.visitchildren(node)
        return node

86 87 88 89 90 91 92 93 94 95
    def visit_SingleAssignmentNode(self, node):
        self.mark_assignment(node.lhs, node.rhs)
        self.visitchildren(node)
        return node

    def visit_CascadedAssignmentNode(self, node):
        for lhs in node.lhs_list:
            self.mark_assignment(lhs, node.rhs)
        self.visitchildren(node)
        return node
96

97
    def visit_InPlaceAssignmentNode(self, node):
Mark Florisson's avatar
Mark Florisson committed
98
        self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
99 100 101 102
        self.visitchildren(node)
        return node

    def visit_ForInStatNode(self, node):
Robert Bradshaw's avatar
Robert Bradshaw committed
103
        # TODO: Remove redundancy with range optimization...
Stefan Behnel's avatar
Stefan Behnel committed
104
        is_special = False
Robert Bradshaw's avatar
Robert Bradshaw committed
105
        sequence = node.iterator.sequence
106
        target = node.target
107 108
        if isinstance(sequence, ExprNodes.SimpleCallNode):
            function = sequence.function
109
            if sequence.self is None and function.is_name:
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
                entry = self.current_env().lookup(function.name)
                if not entry or entry.is_builtin:
                    if function.name == 'reversed' and len(sequence.args) == 1:
                        sequence = sequence.args[0]
                    elif function.name == 'enumerate' and len(sequence.args) == 1:
                        if target.is_sequence_constructor and len(target.args) == 2:
                            iterator = sequence.args[0]
                            if iterator.is_name:
                                iterator_type = iterator.infer_type(self.current_env())
                                if iterator_type.is_builtin_type:
                                    # assume that builtin types have a length within Py_ssize_t
                                    self.mark_assignment(
                                        target.args[0],
                                        ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
                                                          type=PyrexTypes.c_py_ssize_t_type))
                                    target = target.args[1]
                                    sequence = sequence.args[0]
Robert Bradshaw's avatar
Robert Bradshaw committed
127 128
        if isinstance(sequence, ExprNodes.SimpleCallNode):
            function = sequence.function
Stefan Behnel's avatar
Stefan Behnel committed
129
            if sequence.self is None and function.is_name:
130 131 132 133 134 135 136 137 138 139 140 141 142
                entry = self.current_env().lookup(function.name)
                if not entry or entry.is_builtin:
                    if function.name in ('range', 'xrange'):
                        is_special = True
                        for arg in sequence.args[:2]:
                            self.mark_assignment(target, arg)
                        if len(sequence.args) > 2:
                            self.mark_assignment(
                                target,
                                ExprNodes.binop_node(node.pos,
                                                     '+',
                                                     sequence.args[0],
                                                     sequence.args[2]))
143

Stefan Behnel's avatar
Stefan Behnel committed
144
        if not is_special:
145 146 147 148 149
            # A for-loop basically translates to subsequent calls to
            # __getitem__(), so using an IndexNode here allows us to
            # naturally infer the base type of pointers, C arrays,
            # Python strings, etc., while correctly falling back to an
            # object type when the base type cannot be handled.
150
            self.mark_assignment(target, ExprNodes.IndexNode(
151
                node.pos,
152 153 154
                base=sequence,
                index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
                                        type=PyrexTypes.c_py_ssize_t_type)))
155

156 157 158 159 160 161
        self.visitchildren(node)
        return node

    def visit_ForFromStatNode(self, node):
        self.mark_assignment(node.target, node.bound1)
        if node.step is not None:
Robert Bradshaw's avatar
Robert Bradshaw committed
162
            self.mark_assignment(node.target,
163 164 165
                    ExprNodes.binop_node(node.pos,
                                         '+',
                                         node.bound1,
Robert Bradshaw's avatar
Robert Bradshaw committed
166
                                         node.step))
167 168 169
        self.visitchildren(node)
        return node

170
    def visit_WhileStatNode(self, node):
171
        self.visitchildren(node)
172 173 174 175 176 177 178
        return node

    def visit_ExceptClauseNode(self, node):
        if node.target is not None:
            self.mark_assignment(node.target, object_expr)
        self.visitchildren(node)
        return node
179

180
    def visit_FromCImportStatNode(self, node):
181
        pass # Can't be assigned to...
182 183 184 185 186 187 188

    def visit_FromImportStatNode(self, node):
        for name, target in node.items:
            if name != "*":
                self.mark_assignment(target, object_expr)
        self.visitchildren(node)
        return node
Robert Bradshaw's avatar
Robert Bradshaw committed
189

190 191 192 193
    def visit_DefNode(self, node):
        # use fake expressions with the right result type
        if node.star_arg:
            self.mark_assignment(
194
                node.star_arg, TypedExprNode(Builtin.tuple_type, node.pos))
195 196
        if node.starstar_arg:
            self.mark_assignment(
197
                node.starstar_arg, TypedExprNode(Builtin.dict_type, node.pos))
198
        EnvTransform.visit_FuncDefNode(self, node)
199 200
        return node

201 202 203 204 205 206
    def visit_DelStatNode(self, node):
        for arg in node.args:
            self.mark_assignment(arg, arg)
        self.visitchildren(node)
        return node

Mark Florisson's avatar
Mark Florisson committed
207 208 209 210 211 212
    def visit_ParallelStatNode(self, node):
        if self.parallel_block_stack:
            node.parent = self.parallel_block_stack[-1]
        else:
            node.parent = None

213
        nested = False
Mark Florisson's avatar
Mark Florisson committed
214 215 216 217 218 219
        if node.is_prange:
            if not node.parent:
                node.is_parallel = True
            else:
                node.is_parallel = (node.parent.is_prange or not
                                    node.parent.is_parallel)
220
                nested = node.parent.is_prange
Mark Florisson's avatar
Mark Florisson committed
221 222
        else:
            node.is_parallel = True
223 224 225 226
            # Note: nested with parallel() blocks are handled by
            # ParallelRangeTransform!
            # nested = node.parent
            nested = node.parent and node.parent.is_prange
Mark Florisson's avatar
Mark Florisson committed
227 228

        self.parallel_block_stack.append(node)
229

230
        nested = nested or len(self.parallel_block_stack) > 2
231 232
        if not self.parallel_errors and nested and not node.is_prange:
            error(node.pos, "Only prange() may be nested")
233 234
            self.parallel_errors = True

235 236 237 238 239 240 241 242 243 244 245 246 247
        if node.is_prange:
            child_attrs = node.child_attrs
            node.child_attrs = ['body', 'target', 'args']
            self.visitchildren(node)
            node.child_attrs = child_attrs

            self.parallel_block_stack.pop()
            if node.else_clause:
                node.else_clause = self.visit(node.else_clause)
        else:
            self.visitchildren(node)
            self.parallel_block_stack.pop()

248
        self.parallel_errors = False
249 250
        return node

251
    def visit_YieldExprNode(self, node):
252
        if self.parallel_block_stack:
253
            error(node.pos, "'%s' not allowed in parallel sections" % node.expr_keyword)
Mark Florisson's avatar
Mark Florisson committed
254 255
        return node

256 257
    def visit_ReturnStatNode(self, node):
        node.in_parallel = bool(self.parallel_block_stack)
Mark Florisson's avatar
Mark Florisson committed
258 259 260
        return node


Craig Citro's avatar
Craig Citro committed
261
class MarkOverflowingArithmetic(CythonTransform):
262 263 264 265 266 267 268 269 270

    # It may be possible to integrate this with the above for
    # performance improvements (though likely not worth it).

    might_overflow = False

    def __call__(self, root):
        self.env_stack = []
        self.env = root.scope
271
        return super(MarkOverflowingArithmetic, self).__call__(root)
272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287

    def visit_safe_node(self, node):
        self.might_overflow, saved = False, self.might_overflow
        self.visitchildren(node)
        self.might_overflow = saved
        return node

    def visit_neutral_node(self, node):
        self.visitchildren(node)
        return node

    def visit_dangerous_node(self, node):
        self.might_overflow, saved = True, self.might_overflow
        self.visitchildren(node)
        self.might_overflow = saved
        return node
288

289 290 291 292 293 294 295 296 297 298 299 300 301
    def visit_FuncDefNode(self, node):
        self.env_stack.append(self.env)
        self.env = node.local_scope
        self.visit_safe_node(node)
        self.env = self.env_stack.pop()
        return node

    def visit_NameNode(self, node):
        if self.might_overflow:
            entry = node.entry or self.env.lookup(node.name)
            if entry:
                entry.might_overflow = True
        return node
302

303 304 305 306 307
    def visit_BinopNode(self, node):
        if node.operator in '&|^':
            return self.visit_neutral_node(node)
        else:
            return self.visit_dangerous_node(node)
308

309
    def visit_SimpleCallNode(self, node):
Stefan Behnel's avatar
Stefan Behnel committed
310
        if node.function.is_name and node.function.name == 'abs':
311 312 313 314 315
          # Overflows for minimum value of fixed size ints.
          return self.visit_dangerous_node(node)
        else:
          return self.visit_neutral_node(node)

316
    visit_UnopNode = visit_neutral_node
317

318
    visit_UnaryMinusNode = visit_dangerous_node
319

320
    visit_InPlaceAssignmentNode = visit_dangerous_node
321

322
    visit_Node = visit_safe_node
323

324
    def visit_assignment(self, lhs, rhs):
325
        if (isinstance(rhs, ExprNodes.IntNode)
326 327 328 329 330
                and isinstance(lhs, ExprNodes.NameNode)
                and Utils.long_literal(rhs.value)):
            entry = lhs.entry or self.env.lookup(lhs.name)
            if entry:
                entry.might_overflow = True
331

332 333 334 335
    def visit_SingleAssignmentNode(self, node):
        self.visit_assignment(node.lhs, node.rhs)
        self.visitchildren(node)
        return node
336

337 338 339 340 341
    def visit_CascadedAssignmentNode(self, node):
        for lhs in node.lhs_list:
            self.visit_assignment(lhs, node.rhs)
        self.visitchildren(node)
        return node
Robert Bradshaw's avatar
Robert Bradshaw committed
342

Stefan Behnel's avatar
Stefan Behnel committed
343
class PyObjectTypeInferer(object):
Robert Bradshaw's avatar
Robert Bradshaw committed
344 345 346 347 348 349 350 351 352 353 354
    """
    If it's not declared, it's a PyObject.
    """
    def infer_types(self, scope):
        """
        Given a dict of entries, map all unspecified types to a specified type.
        """
        for name, entry in scope.entries.items():
            if entry.type is unspecified_type:
                entry.type = py_object_type

Stefan Behnel's avatar
Stefan Behnel committed
355
class SimpleAssignmentTypeInferer(object):
Robert Bradshaw's avatar
Robert Bradshaw committed
356 357
    """
    Very basic type inference.
358 359 360

    Note: in order to support cross-closure type inference, this must be
    applies to nested scopes in top-down order.
Robert Bradshaw's avatar
Robert Bradshaw committed
361
    """
362 363 364 365 366
    def set_entry_type(self, entry, entry_type):
        entry.type = entry_type
        for e in entry.all_entries():
            e.type = entry_type

Robert Bradshaw's avatar
Robert Bradshaw committed
367
    def infer_types(self, scope):
368
        enabled = scope.directives['infer_types']
369
        verbose = scope.directives['infer_types.verbose']
370

371 372 373 374 375 376 377
        if enabled == True:
            spanning_type = aggressive_spanning_type
        elif enabled is None: # safe mode
            spanning_type = safe_spanning_type
        else:
            for entry in scope.entries.values():
                if entry.type is unspecified_type:
378
                    self.set_entry_type(entry, py_object_type)
379 380
            return

381
        # Set of assignemnts
Stefan Behnel's avatar
Stefan Behnel committed
382 383
        assignments = set()
        assmts_resolved = set()
384 385 386
        dependencies = {}
        assmt_to_names = {}

Robert Bradshaw's avatar
Robert Bradshaw committed
387
        for name, entry in scope.entries.items():
388 389 390 391 392 393 394
            for assmt in entry.cf_assignments:
                names = assmt.type_dependencies()
                assmt_to_names[assmt] = names
                assmts = set()
                for node in names:
                    assmts.update(node.cf_state)
                dependencies[assmt] = assmts
Robert Bradshaw's avatar
Robert Bradshaw committed
395
            if entry.type is unspecified_type:
396 397 398 399 400 401 402 403 404
                assignments.update(entry.cf_assignments)
            else:
                assmts_resolved.update(entry.cf_assignments)

        def infer_name_node_type(node):
            types = [assmt.inferred_type for assmt in node.cf_state]
            if not types:
                node_type = py_object_type
            else:
405
                entry = node.entry
406
                node_type = spanning_type(
407
                    types, entry.might_overflow, entry.pos, scope)
408 409 410 411 412 413 414
            node.inferred_type = node_type

        def infer_name_node_type_partial(node):
            types = [assmt.inferred_type for assmt in node.cf_state
                     if assmt.inferred_type is not None]
            if not types:
                return
415
            entry = node.entry
416
            return spanning_type(types, entry.might_overflow, entry.pos, scope)
417 418 419 420 421 422 423 424 425 426 427 428 429

        def resolve_assignments(assignments):
            resolved = set()
            for assmt in assignments:
                deps = dependencies[assmt]
                # All assignments are resolved
                if assmts_resolved.issuperset(deps):
                    for node in assmt_to_names[assmt]:
                        infer_name_node_type(node)
                    # Resolve assmt
                    inferred_type = assmt.infer_type()
                    assmts_resolved.add(assmt)
                    resolved.add(assmt)
430
            assignments.difference_update(resolved)
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451
            return resolved

        def partial_infer(assmt):
            partial_types = []
            for node in assmt_to_names[assmt]:
                partial_type = infer_name_node_type_partial(node)
                if partial_type is None:
                    return False
                partial_types.append((node, partial_type))
            for node, partial_type in partial_types:
                node.inferred_type = partial_type
            assmt.infer_type()
            return True

        partial_assmts = set()
        def resolve_partial(assignments):
            # try to handle circular references
            partials = set()
            for assmt in assignments:
                if assmt in partial_assmts:
                    continue
452 453 454
                if partial_infer(assmt):
                    partials.add(assmt)
                    assmts_resolved.add(assmt)
455 456 457 458
            partial_assmts.update(partials)
            return partials

        # Infer assignments
Robert Bradshaw's avatar
Robert Bradshaw committed
459
        while True:
460 461 462 463 464 465 466 467 468 469 470
            if not resolve_assignments(assignments):
                if not resolve_partial(assignments):
                    break
        inferred = set()
        # First pass
        for entry in scope.entries.values():
            if entry.type is not unspecified_type:
                continue
            entry_type = py_object_type
            if assmts_resolved.issuperset(entry.cf_assignments):
                types = [assmt.inferred_type for assmt in entry.cf_assignments]
471
                if types and all(types):
472
                    entry_type = spanning_type(
473
                        types, entry.might_overflow, entry.pos, scope)
474 475 476 477 478 479 480 481
                    inferred.add(entry)
            self.set_entry_type(entry, entry_type)

        def reinfer():
            dirty = False
            for entry in inferred:
                types = [assmt.infer_type()
                         for assmt in entry.cf_assignments]
482
                new_type = spanning_type(types, entry.might_overflow, entry.pos, scope)
483 484 485 486 487 488 489 490 491 492 493 494 495 496
                if new_type != entry.type:
                    self.set_entry_type(entry, new_type)
                    dirty = True
            return dirty

        # types propagation
        while reinfer():
            pass

        if verbose:
            for entry in inferred:
                message(entry.pos, "inferred '%s' to be of type '%s'" % (
                    entry.name, entry.type))

Robert Bradshaw's avatar
Robert Bradshaw committed
497

498 499
def find_spanning_type(type1, type2):
    if type1 is type2:
500
        result_type = type1
501 502 503
    elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
        # type inference can break the coercion back to a Python bool
        # if it returns an arbitrary int type here
504
        return py_object_type
505 506
    else:
        result_type = PyrexTypes.spanning_type(type1, type2)
Craig Citro's avatar
Craig Citro committed
507 508
    if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
                       Builtin.float_type):
509 510 511
        # Python's float type is just a C double, so it's safe to
        # use the C type instead
        return PyrexTypes.c_double_type
512 513
    return result_type

Robert Bradshaw's avatar
Robert Bradshaw committed
514
def simply_type(result_type, pos):
Robert Bradshaw's avatar
Robert Bradshaw committed
515 516
    if result_type.is_reference:
        result_type = result_type.ref_base_type
Robert Bradshaw's avatar
Robert Bradshaw committed
517 518
    if result_type.is_const:
        result_type = result_type.const_base_type
519 520
    if result_type.is_cpp_class:
        result_type.check_nullary_constructor(pos)
521 522
    if result_type.is_array:
        result_type = PyrexTypes.c_ptr_type(result_type.base_type)
523
    return result_type
524

525
def aggressive_spanning_type(types, might_overflow, pos, scope):
Robert Bradshaw's avatar
Robert Bradshaw committed
526
    return simply_type(reduce(find_spanning_type, types), pos)
527

528
def safe_spanning_type(types, might_overflow, pos, scope):
Robert Bradshaw's avatar
Robert Bradshaw committed
529
    result_type = simply_type(reduce(find_spanning_type, types), pos)
530
    if result_type.is_pyobject:
531 532 533 534 535 536 537 538
        # In theory, any specific Python type is always safe to
        # infer. However, inferring str can cause some existing code
        # to break, since we are also now much more strict about
        # coercion from str to char *. See trac #553.
        if result_type.name == 'str':
            return py_object_type
        else:
            return result_type
539 540 541 542 543 544 545
    elif result_type is PyrexTypes.c_double_type:
        # Python's float type is just a C double, so it's safe to use
        # the C type instead
        return result_type
    elif result_type is PyrexTypes.c_bint_type:
        # find_spanning_type() only returns 'bint' for clean boolean
        # operations without other int types, so this is safe, too
546
        return result_type
547
    elif result_type.is_ptr:
548
        # Any pointer except (signed|unsigned|) char* can't implicitly
549
        # become a PyObject, and inferring char* is now accepted, too.
Robert Bradshaw's avatar
Robert Bradshaw committed
550 551 552 553 554 555 556 557 558
        return result_type
    elif result_type.is_cpp_class:
        # These can't implicitly become Python objects either.
        return result_type
    elif result_type.is_struct:
        # Though we have struct -> object for some structs, this is uncommonly
        # used, won't arise in pure Python, and there shouldn't be side
        # effects, so I'm declaring this safe.
        return result_type
559
    # TODO: double complex should be OK as well, but we need
Robert Bradshaw's avatar
Robert Bradshaw committed
560
    # to make sure everything is supported.
Stefan Behnel's avatar
Stefan Behnel committed
561
    elif (result_type.is_int or result_type.is_enum) and not might_overflow:
562
        return result_type
563 564
    elif (not result_type.can_coerce_to_pyobject(scope)
            and not result_type.is_error):
565
        return result_type
566 567
    return py_object_type

568

Robert Bradshaw's avatar
Robert Bradshaw committed
569 570
def get_type_inferer():
    return SimpleAssignmentTypeInferer()