FusedNode.py 39 KB
Newer Older
1 2
from __future__ import absolute_import

3 4
import copy

5 6 7 8
from . import (ExprNodes, PyrexTypes, MemoryView,
               ParseTreeTransforms, StringEncoding, Errors)
from .ExprNodes import CloneNode, ProxyNode, TupleNode
from .Nodes import FuncDefNode, CFuncDefNode, StatListNode, DefNode
9
from ..Utils import OrderedSet
10

11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

class FusedCFuncDefNode(StatListNode):
    """
    This node replaces a function with fused arguments. It deep-copies the
    function for every permutation of fused types, and allocates a new local
    scope for it. It keeps track of the original function in self.node, and
    the entry of the original function in the symbol table is given the
    'fused_cfunction' attribute which points back to us.
    Then when a function lookup occurs (to e.g. call it), the call can be
    dispatched to the right function.

    node    FuncDefNode    the original function
    nodes   [FuncDefNode]  list of copies of node with different specific types
    py_func DefNode        the fused python function subscriptable from
                           Python space
    __signatures__         A DictNode mapping signature specialization strings
                           to PyCFunction nodes
    resulting_fused_function  PyCFunction for the fused DefNode that delegates
                              to specializations
    fused_func_assignment   Assignment of the fused function to the function name
    defaults_tuple          TupleNode of defaults (letting PyCFunctionNode build
                            defaults would result in many different tuples)
    specialized_pycfuncs    List of synthesized pycfunction nodes for the
                            specializations
    code_object             CodeObjectNode shared by all specializations and the
                            fused function
37 38

    fused_compound_types    All fused (compound) types (e.g. floating[:])
39 40 41 42 43 44
    """

    __signatures__ = None
    resulting_fused_function = None
    fused_func_assignment = None
    defaults_tuple = None
45
    decorators = None
46

47 48 49
    child_attrs = StatListNode.child_attrs + [
        '__signatures__', 'resulting_fused_function', 'fused_func_assignment']

50 51 52 53 54 55 56 57
    def __init__(self, node, env):
        super(FusedCFuncDefNode, self).__init__(node.pos)

        self.nodes = []
        self.node = node

        is_def = isinstance(self.node, DefNode)
        if is_def:
58
            # self.node.decorators = []
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
            self.copy_def(env)
        else:
            self.copy_cdef(env)

        # Perform some sanity checks. If anything fails, it's a bug
        for n in self.nodes:
            assert not n.entry.type.is_fused
            assert not n.local_scope.return_type.is_fused
            if node.return_type.is_fused:
                assert not n.return_type.is_fused

            if not is_def and n.cfunc_declarator.optional_arg_count:
                assert n.type.op_arg_struct

        node.entry.fused_cfunction = self
        # Copy the nodes as AnalyseDeclarationsTransform will prepend
        # self.py_func to self.stats, as we only want specialized
        # CFuncDefNodes in self.nodes
        self.stats = self.nodes[:]

    def copy_def(self, env):
        """
        Create a copy of the original def or lambda function for specialized
        versions.
        """
        fused_compound_types = PyrexTypes.unique(
            [arg.type for arg in self.node.args if arg.type.is_fused])
86 87
        fused_types = self._get_fused_base_types(fused_compound_types)
        permutations = PyrexTypes.get_all_specialized_permutations(fused_types)
88

89 90
        self.fused_compound_types = fused_compound_types

91 92 93 94 95
        if self.node.entry in env.pyfunc_entries:
            env.pyfunc_entries.remove(self.node.entry)

        for cname, fused_to_specific in permutations:
            copied_node = copy.deepcopy(self.node)
96 97
            # keep signature object identity for special casing in DefNode.analyse_declarations()
            copied_node.entry.signature = self.node.entry.signature
98 99 100 101 102 103

            self._specialize_function_args(copied_node.args, fused_to_specific)
            copied_node.return_type = self.node.return_type.specialize(
                                                    fused_to_specific)

            copied_node.analyse_declarations(env)
104 105
            # copied_node.is_staticmethod = self.node.is_staticmethod
            # copied_node.is_classmethod = self.node.is_classmethod
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
            self.create_new_local_scope(copied_node, env, fused_to_specific)
            self.specialize_copied_def(copied_node, cname, self.node.entry,
                                       fused_to_specific, fused_compound_types)

            PyrexTypes.specialize_entry(copied_node.entry, cname)
            copied_node.entry.used = True
            env.entries[copied_node.entry.name] = copied_node.entry

            if not self.replace_fused_typechecks(copied_node):
                break

        self.orig_py_func = self.node
        self.py_func = self.make_fused_cpdef(self.node, env, is_def=True)

    def copy_cdef(self, env):
        """
        Create a copy of the original c(p)def function for all specialized
        versions.
        """
        permutations = self.node.type.get_all_specialized_permutations()
        # print 'Node %s has %d specializations:' % (self.node.entry.name,
        #                                            len(permutations))
        # import pprint; pprint.pprint([d for cname, d in permutations])

        # Prevent copying of the python function
        self.orig_py_func = orig_py_func = self.node.py_func
        self.node.py_func = None
        if orig_py_func:
            env.pyfunc_entries.remove(orig_py_func.entry)

        fused_types = self.node.type.get_fused_types()
137
        self.fused_compound_types = fused_types
138

139
        new_cfunc_entries = []
140 141 142
        for cname, fused_to_specific in permutations:
            copied_node = copy.deepcopy(self.node)

143
            # Make the types in our CFuncType specific.
144 145
            type = copied_node.type.specialize(fused_to_specific)
            entry = copied_node.entry
146 147 148 149 150 151 152 153 154 155 156 157 158
            type.specialize_entry(entry, cname)

            # Reuse existing Entries (e.g. from .pxd files).
            for i, orig_entry in enumerate(env.cfunc_entries):
                if entry.cname == orig_entry.cname and type.same_as_resolved_type(orig_entry.type):
                    copied_node.entry = env.cfunc_entries[i]
                    if not copied_node.entry.func_cname:
                        copied_node.entry.func_cname = entry.func_cname
                    entry = copied_node.entry
                    type = entry.type
                    break
            else:
                new_cfunc_entries.append(entry)
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191

            copied_node.type = type
            entry.type, type.entry = type, entry

            entry.used = (entry.used or
                          self.node.entry.defined_in_pxd or
                          env.is_c_class_scope or
                          entry.is_cmethod)

            if self.node.cfunc_declarator.optional_arg_count:
                self.node.cfunc_declarator.declare_optional_arg_struct(
                                           type, env, fused_cname=cname)

            copied_node.return_type = type.return_type
            self.create_new_local_scope(copied_node, env, fused_to_specific)

            # Make the argument types in the CFuncDeclarator specific
            self._specialize_function_args(copied_node.cfunc_declarator.args,
                                           fused_to_specific)

            # If a cpdef, declare all specialized cpdefs (this
            # also calls analyse_declarations)
            copied_node.declare_cpdef_wrapper(env)
            if copied_node.py_func:
                env.pyfunc_entries.remove(copied_node.py_func.entry)

                self.specialize_copied_def(
                        copied_node.py_func, cname, self.node.entry.as_variable,
                        fused_to_specific, fused_types)

            if not self.replace_fused_typechecks(copied_node):
                break

192 193 194 195 196 197 198 199
        # replace old entry with new entries
        try:
            cindex = env.cfunc_entries.index(self.node.entry)
        except ValueError:
            env.cfunc_entries.extend(new_cfunc_entries)
        else:
            env.cfunc_entries[cindex:cindex+1] = new_cfunc_entries

200 201 202 203 204 205
        if orig_py_func:
            self.py_func = self.make_fused_cpdef(orig_py_func, env,
                                                 is_def=False)
        else:
            self.py_func = orig_py_func

206 207 208 209 210 211 212 213 214 215 216
    def _get_fused_base_types(self, fused_compound_types):
        """
        Get a list of unique basic fused types, from a list of
        (possibly) compound fused types.
        """
        base_types = []
        seen = set()
        for fused_type in fused_compound_types:
            fused_type.get_fused_types(result=base_types, seen=seen)
        return base_types

217 218 219 220 221
    def _specialize_function_args(self, args, fused_to_specific):
        for arg in args:
            if arg.type.is_fused:
                arg.type = arg.type.specialize(fused_to_specific)
                if arg.type.is_memoryviewslice:
222
                    arg.type.validate_memslice_dtype(arg.pos)
223 224 225 226
                if arg.annotation:
                    # TODO might be nice if annotations were specialized instead?
                    # (Or might be hard to do reliably)
                    arg.annotation.untyped = True
227 228 229 230 231

    def create_new_local_scope(self, node, env, f2s):
        """
        Create a new local scope for the copied node and append it to
        self.nodes. A new local scope is needed because the arguments with the
Unknown's avatar
Unknown committed
232
        fused types are already in the local scope, and we need the specialized
233 234 235 236 237 238 239 240 241 242 243 244
        entries created after analyse_declarations on each specialized version
        of the (CFunc)DefNode.
        f2s is a dict mapping each fused type to its specialized version
        """
        node.create_local_scope(env)
        node.local_scope.fused_to_specific = f2s

        # This is copied from the original function, set it to false to
        # stop recursion
        node.has_fused_arguments = False
        self.nodes.append(node)

245
    def specialize_copied_def(self, node, cname, py_entry, f2s, fused_compound_types):
246 247
        """Specialize the copy of a DefNode given the copied node,
        the specialization cname and the original DefNode entry"""
248
        fused_types = self._get_fused_base_types(fused_compound_types)
249 250 251 252 253
        type_strings = [
            PyrexTypes.specialization_signature_string(fused_type, f2s)
                for fused_type in fused_types
        ]

254
        node.specialized_signature_string = '|'.join(type_strings)
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282

        node.entry.pymethdef_cname = PyrexTypes.get_fused_cname(
                                        cname, node.entry.pymethdef_cname)
        node.entry.doc = py_entry.doc
        node.entry.doc_cname = py_entry.doc_cname

    def replace_fused_typechecks(self, copied_node):
        """
        Branch-prune fused type checks like

            if fused_t is int:
                ...

        Returns whether an error was issued and whether we should stop in
        in order to prevent a flood of errors.
        """
        num_errors = Errors.num_errors
        transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
                                       copied_node.local_scope)
        transform(copied_node)

        if Errors.num_errors > num_errors:
            return False

        return True

    def _fused_instance_checks(self, normal_types, pyx_code, env):
        """
luz.paz's avatar
luz.paz committed
283
        Generate Cython code for instance checks, matching an object to
284 285 286 287
        specialized types.
        """
        for specialized_type in normal_types:
            # all_numeric = all_numeric and specialized_type.is_numeric
288 289 290 291
            pyx_code.context.update(
                py_type_name=specialized_type.py_type_name(),
                specialized_type_name=specialized_type.specialization_string,
            )
292 293
            pyx_code.put_chunk(
                u"""
294 295
                    if isinstance(arg, {{py_type_name}}):
                        dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'; break
296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
                """)

    def _dtype_name(self, dtype):
        if dtype.is_typedef:
            return '___pyx_%s' % dtype
        return str(dtype).replace(' ', '_')

    def _dtype_type(self, dtype):
        if dtype.is_typedef:
            return self._dtype_name(dtype)
        return str(dtype)

    def _sizeof_dtype(self, dtype):
        if dtype.is_pyobject:
            return 'sizeof(void *)'
        else:
            return "sizeof(%s)" % self._dtype_type(dtype)

    def _buffer_check_numpy_dtype_setup_cases(self, pyx_code):
        "Setup some common cases to match dtypes against specializations"
316
        if pyx_code.indenter("if kind in b'iu':"):
317 318
            pyx_code.putln("pass")
            pyx_code.named_insertion_point("dtype_int")
319
            pyx_code.dedent()
320

321
        if pyx_code.indenter("elif kind == b'f':"):
322 323
            pyx_code.putln("pass")
            pyx_code.named_insertion_point("dtype_float")
324
            pyx_code.dedent()
325

326
        if pyx_code.indenter("elif kind == b'c':"):
327 328
            pyx_code.putln("pass")
            pyx_code.named_insertion_point("dtype_complex")
329
            pyx_code.dedent()
330

331
        if pyx_code.indenter("elif kind == b'O':"):
332 333
            pyx_code.putln("pass")
            pyx_code.named_insertion_point("dtype_object")
334
            pyx_code.dedent()
335 336 337

    match = "dest_sig[{{dest_sig_idx}}] = '{{specialized_type_name}}'"
    no_match = "dest_sig[{{dest_sig_idx}}] = None"
338
    def _buffer_check_numpy_dtype(self, pyx_code, specialized_buffer_types, pythran_types):
339 340 341 342 343
        """
        Match a numpy dtype object to the individual specializations.
        """
        self._buffer_check_numpy_dtype_setup_cases(pyx_code)

344 345 346 347
        for specialized_type in pythran_types+specialized_buffer_types:
            final_type = specialized_type
            if specialized_type.is_pythran_expr:
                specialized_type = specialized_type.org_buffer
348 349 350 351 352
            dtype = specialized_type.dtype
            pyx_code.context.update(
                itemsize_match=self._sizeof_dtype(dtype) + " == itemsize",
                signed_match="not (%s_is_signed ^ dtype_signed)" % self._dtype_name(dtype),
                dtype=dtype,
353
                specialized_type_name=final_type.specialization_string)
354 355 356 357 358 359 360 361 362

            dtypes = [
                (dtype.is_int, pyx_code.dtype_int),
                (dtype.is_float, pyx_code.dtype_float),
                (dtype.is_complex, pyx_code.dtype_complex)
            ]

            for dtype_category, codewriter in dtypes:
                if dtype_category:
363
                    cond = '{{itemsize_match}} and (<Py_ssize_t>arg.ndim) == %d' % (
364
                                                    specialized_type.ndim,)
365 366 367
                    if dtype.is_int:
                        cond += ' and {{signed_match}}'

368 369 370
                    if final_type.is_pythran_expr:
                        cond += ' and arg_is_pythran_compatible'

371
                    if codewriter.indenter("if %s:" % cond):
372
                        #codewriter.putln("print 'buffer match found based on numpy dtype'")
373 374
                        codewriter.putln(self.match)
                        codewriter.putln("break")
375
                        codewriter.dedent()
376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396

    def _buffer_parse_format_string_check(self, pyx_code, decl_code,
                                          specialized_type, env):
        """
        For each specialized type, try to coerce the object to a memoryview
        slice of that type. This means obtaining a buffer and parsing the
        format string.
        TODO: separate buffer acquisition from format parsing
        """
        dtype = specialized_type.dtype
        if specialized_type.is_buffer:
            axes = [('direct', 'strided')] * specialized_type.ndim
        else:
            axes = specialized_type.axes

        memslice_type = PyrexTypes.MemoryViewSliceType(dtype, axes)
        memslice_type.create_from_py_utility_code(env)
        pyx_code.context.update(
            coerce_from_py_func=memslice_type.from_py_function,
            dtype=dtype)
        decl_code.putln(
397
            "{{memviewslice_cname}} {{coerce_from_py_func}}(object, int)")
398 399 400 401 402 403 404 405 406

        pyx_code.context.update(
            specialized_type_name=specialized_type.specialization_string,
            sizeof_dtype=self._sizeof_dtype(dtype))

        pyx_code.put_chunk(
            u"""
                # try {{dtype}}
                if itemsize == -1 or itemsize == {{sizeof_dtype}}:
407
                    memslice = {{coerce_from_py_func}}(arg, 0)
408 409 410 411 412 413
                    if memslice.memview:
                        __PYX_XDEC_MEMVIEW(&memslice, 1)
                        # print 'found a match for the buffer through format parsing'
                        %s
                        break
                    else:
414
                        __pyx_PyErr_Clear()
415 416
            """ % self.match)

417
    def _buffer_checks(self, buffer_types, pythran_types, pyx_code, decl_code, env):
418 419 420 421 422 423 424
        """
        Generate Cython code to match objects to buffer specializations.
        First try to get a numpy dtype object and match it against the individual
        specializations. If that fails, try naively to coerce the object
        to each specialization, which obtains the buffer each time and tries
        to match the format string.
        """
425
        # The first thing to find a match in this loop breaks out of the loop
426 427
        pyx_code.put_chunk(
            u"""
428
                """ + (u"arg_is_pythran_compatible = False" if pythran_types else u"") + u"""
429 430 431
                if ndarray is not None:
                    if isinstance(arg, ndarray):
                        dtype = arg.dtype
432
                        """ + (u"arg_is_pythran_compatible = True" if pythran_types else u"") + u"""
433 434 435 436
                    elif __pyx_memoryview_check(arg):
                        arg_base = arg.base
                        if isinstance(arg_base, ndarray):
                            dtype = arg_base.dtype
437 438
                        else:
                            dtype = None
439 440
                    else:
                        dtype = None
441

442 443 444 445
                    itemsize = -1
                    if dtype is not None:
                        itemsize = dtype.itemsize
                        kind = ord(dtype.kind)
446 447 448 449 450 451
                        dtype_signed = kind == 'i'
            """)
        pyx_code.indent(2)
        if pythran_types:
            pyx_code.put_chunk(
                u"""
452
                        # Pythran only supports the endianness of the current compiler
453 454 455
                        byteorder = dtype.byteorder
                        if byteorder == "<" and not __Pyx_Is_Little_Endian():
                            arg_is_pythran_compatible = False
456
                        elif byteorder == ">" and __Pyx_Is_Little_Endian():
457 458 459
                            arg_is_pythran_compatible = False
                        if arg_is_pythran_compatible:
                            cur_stride = itemsize
460 461
                            shape = arg.shape
                            strides = arg.strides
462
                            for i in range(arg.ndim-1, -1, -1):
463
                                if (<Py_ssize_t>strides[i]) != cur_stride:
464 465
                                    arg_is_pythran_compatible = False
                                    break
466
                                cur_stride *= <Py_ssize_t> shape[i]
467
                            else:
468
                                arg_is_pythran_compatible = not (arg.flags.f_contiguous and (<Py_ssize_t>arg.ndim) > 1)
469
                """)
470
        pyx_code.named_insertion_point("numpy_dtype_checks")
471
        self._buffer_check_numpy_dtype(pyx_code, buffer_types, pythran_types)
472
        pyx_code.dedent(2)
473

474 475 476
        for specialized_type in buffer_types:
            self._buffer_parse_format_string_check(
                    pyx_code, decl_code, specialized_type, env)
477

478
    def _buffer_declarations(self, pyx_code, decl_code, all_buffer_types, pythran_types):
479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499
        """
        If we have any buffer specializations, write out some variable
        declarations and imports.
        """
        decl_code.put_chunk(
            u"""
                ctypedef struct {{memviewslice_cname}}:
                    void *memview

                void __PYX_XDEC_MEMVIEW({{memviewslice_cname}} *, int have_gil)
                bint __pyx_memoryview_check(object)
            """)

        pyx_code.local_variable_declarations.put_chunk(
            u"""
                cdef {{memviewslice_cname}} memslice
                cdef Py_ssize_t itemsize
                cdef bint dtype_signed
                cdef char kind

                itemsize = -1
500 501 502 503 504
            """)

        if pythran_types:
            pyx_code.local_variable_declarations.put_chunk(u"""
                cdef bint arg_is_pythran_compatible
505
                cdef Py_ssize_t cur_stride
506 507 508 509
            """)

        pyx_code.imports.put_chunk(
            u"""
510
                cdef type ndarray
511
                ndarray = __Pyx_ImportNumPyArrayTypeIfAvailable()
512 513
            """)

514
        seen_typedefs = set()
515 516 517
        seen_int_dtypes = set()
        for buffer_type in all_buffer_types:
            dtype = buffer_type.dtype
518
            dtype_name = self._dtype_name(dtype)
519
            if dtype.is_typedef:
520 521 522 523 524
                if dtype_name not in seen_typedefs:
                    seen_typedefs.add(dtype_name)
                    decl_code.putln(
                        'ctypedef %s %s "%s"' % (dtype.resolve(), dtype_name,
                                                 dtype.empty_declaration_code()))
525 526 527 528

            if buffer_type.dtype.is_int:
                if str(dtype) not in seen_int_dtypes:
                    seen_int_dtypes.add(str(dtype))
529
                    pyx_code.context.update(dtype_name=dtype_name,
530 531 532 533
                                            dtype_type=self._dtype_type(dtype))
                    pyx_code.local_variable_declarations.put_chunk(
                        u"""
                            cdef bint {{dtype_name}}_is_signed
534
                            {{dtype_name}}_is_signed = not (<{{dtype_type}}> -1 > 0)
535 536 537 538 539 540 541
                        """)

    def _split_fused_types(self, arg):
        """
        Specialize fused types and split into normal types and buffer types.
        """
        specialized_types = PyrexTypes.get_specialized_types(arg.type)
542 543 544 545

        # Prefer long over int, etc by sorting (see type classes in PyrexTypes.py)
        specialized_types.sort()

546
        seen_py_type_names = set()
547
        normal_types, buffer_types, pythran_types = [], [], []
548
        has_object_fallback = False
549 550 551 552 553 554
        for specialized_type in specialized_types:
            py_type_name = specialized_type.py_type_name()
            if py_type_name:
                if py_type_name in seen_py_type_names:
                    continue
                seen_py_type_names.add(py_type_name)
555 556 557 558
                if py_type_name == 'object':
                    has_object_fallback = True
                else:
                    normal_types.append(specialized_type)
559 560
            elif specialized_type.is_pythran_expr:
                pythran_types.append(specialized_type)
561 562 563
            elif specialized_type.is_buffer or specialized_type.is_memoryviewslice:
                buffer_types.append(specialized_type)

564
        return normal_types, buffer_types, pythran_types, has_object_fallback
565 566 567 568 569

    def _unpack_argument(self, pyx_code):
        pyx_code.put_chunk(
            u"""
                # PROCESSING ARGUMENT {{arg_tuple_idx}}
570 571
                if {{arg_tuple_idx}} < len(<tuple>args):
                    arg = (<tuple>args)[{{arg_tuple_idx}}]
572
                elif kwargs is not None and '{{arg.name}}' in <dict>kwargs:
573
                    arg = (<dict>kwargs)['{{arg.name}}']
574
                else:
575
                {{if arg.default}}
576
                    arg = (<tuple>defaults)[{{default_idx}}]
577
                {{else}}
578 579 580 581 582 583
                    {{if arg_tuple_idx < min_positional_args}}
                        raise TypeError("Expected at least %d argument%s, got %d" % (
                            {{min_positional_args}}, {{'"s"' if min_positional_args != 1 else '""'}}, len(<tuple>args)))
                    {{else}}
                        raise TypeError("Missing keyword-only argument: '%s'" % "{{arg.default}}")
                    {{endif}}
584 585 586
                {{endif}}
            """)

587 588 589 590 591 592 593 594 595 596
    def _fused_signature_index(self, pyx_code):
        """
        Generate Cython code for constructing a persistent nested dictionary index of
        fused type specialization signatures.
        """
        pyx_code.put_chunk(
            u"""
                if not _fused_sigindex:
                    for sig in <dict>signatures:
                        sigindex_node = _fused_sigindex
597 598
                        *sig_series, last_type = sig.strip('()').split('|')
                        for sig_type in sig_series:
599 600 601 602
                            if sig_type not in sigindex_node:
                                sigindex_node[sig_type] = sigindex_node = {}
                            else:
                                sigindex_node = sigindex_node[sig_type]
603
                        sigindex_node[last_type] = sig
604 605 606
            """
        )

607 608 609 610 611 612 613
    def make_fused_cpdef(self, orig_py_func, env, is_def):
        """
        This creates the function that is indexable from Python and does
        runtime dispatch based on the argument types. The function gets the
        arg tuple and kwargs dict (or None) and the defaults tuple
        as arguments from the Binding Fused Function's tp_call.
        """
614
        from . import TreeFragment, Code, UtilityCode
615

616 617
        fused_types = self._get_fused_base_types([
            arg.type for arg in self.node.args if arg.type.is_fused])
618 619 620 621

        context = {
            'memviewslice_cname': MemoryView.memviewslice_cname,
            'func_args': self.node.args,
622
            'n_fused': len(fused_types),
623 624 625 626
            'min_positional_args':
                self.node.num_required_args - self.node.num_required_kw_args
                if is_def else
                sum(1 for arg in self.node.args if arg.default is None),
627 628 629 630 631 632 633 634
            'name': orig_py_func.entry.name,
        }

        pyx_code = Code.PyxCodeWriter(context=context)
        decl_code = Code.PyxCodeWriter(context=context)
        decl_code.put_chunk(
            u"""
                cdef extern from *:
635
                    void __pyx_PyErr_Clear "PyErr_Clear" ()
636
                    type __Pyx_ImportNumPyArrayTypeIfAvailable()
637
                    int __Pyx_Is_Little_Endian()
638 639 640 641 642
            """)
        decl_code.indent()

        pyx_code.put_chunk(
            u"""
643
                def __pyx_fused_cpdef(signatures, args, kwargs, defaults, _fused_sigindex={}):
644 645
                    # FIXME: use a typed signature - currently fails badly because
                    #        default arguments inherit the types we specify here!
646

647 648 649 650
                    cdef list search_list

                    cdef dict sn, sigindex_node

651
                    dest_sig = [None] * {{n_fused}}
652

653 654
                    if kwargs is not None and not kwargs:
                        kwargs = None
655 656 657 658 659

                    cdef Py_ssize_t i

                    # instance check body
            """)
660

661
        pyx_code.indent()  # indent following code to function body
662
        pyx_code.named_insertion_point("imports")
663
        pyx_code.named_insertion_point("func_defs")
664 665 666 667
        pyx_code.named_insertion_point("local_variable_declarations")

        fused_index = 0
        default_idx = 0
668
        all_buffer_types = OrderedSet()
669
        seen_fused_types = set()
670
        for i, arg in enumerate(self.node.args):
671 672 673 674 675 676 677 678 679
            if arg.type.is_fused:
                arg_fused_types = arg.type.get_fused_types()
                if len(arg_fused_types) > 1:
                    raise NotImplementedError("Determination of more than one fused base "
                                              "type per argument is not implemented.")
                fused_type = arg_fused_types[0]

            if arg.type.is_fused and fused_type not in seen_fused_types:
                seen_fused_types.add(fused_type)
680 681 682 683 684 685 686 687

                context.update(
                    arg_tuple_idx=i,
                    arg=arg,
                    dest_sig_idx=fused_index,
                    default_idx=default_idx,
                )

688
                normal_types, buffer_types, pythran_types, has_object_fallback = self._split_fused_types(arg)
689 690
                self._unpack_argument(pyx_code)

691 692 693 694
                # 'unrolled' loop, first match breaks out of it
                if pyx_code.indenter("while 1:"):
                    if normal_types:
                        self._fused_instance_checks(normal_types, pyx_code, env)
695
                    if buffer_types or pythran_types:
696
                        env.use_utility_code(Code.UtilityCode.load_cached("IsLittleEndian", "ModuleSetupCode.c"))
697
                        self._buffer_checks(buffer_types, pythran_types, pyx_code, decl_code, env)
698 699 700
                    if has_object_fallback:
                        pyx_code.context.update(specialized_type_name='object')
                        pyx_code.putln(self.match)
701 702
                    else:
                        pyx_code.putln(self.no_match)
703
                    pyx_code.putln("break")
704 705 706
                    pyx_code.dedent()

                fused_index += 1
707
                all_buffer_types.update(buffer_types)
708
                all_buffer_types.update(ty.org_buffer for ty in pythran_types)
709 710 711 712 713

            if arg.default:
                default_idx += 1

        if all_buffer_types:
714
            self._buffer_declarations(pyx_code, decl_code, all_buffer_types, pythran_types)
715
            env.use_utility_code(Code.UtilityCode.load_cached("Import", "ImportExport.c"))
716
            env.use_utility_code(Code.UtilityCode.load_cached("ImportNumPyArray", "ImportExport.c"))
717

718 719
        self._fused_signature_index(pyx_code)

720 721
        pyx_code.put_chunk(
            u"""
722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745
                sigindex_matches = []
                sigindex_candidates = [_fused_sigindex]

                for dst_type in dest_sig:
                    found_matches = []
                    found_candidates = []
                    # Make two seperate lists: One for signature sub-trees
                    #        with at least one definite match, and another for
                    #        signature sub-trees with only ambiguous matches
                    #        (where `dest_sig[i] is None`).
                    if dst_type is None:
                        for sn in sigindex_matches:
                            found_matches.extend(sn.values())
                        for sn in sigindex_candidates:
                            found_candidates.extend(sn.values())
                    else:
                        for search_list in (sigindex_matches, sigindex_candidates):
                            for sn in search_list:
                                if dst_type in sn:
                                    found_matches.append(sn[dst_type])
                    sigindex_matches = found_matches
                    sigindex_candidates = found_candidates
                    if not (found_matches or found_candidates):
                        break
746

747
                candidates = sigindex_matches
748 749 750 751 752 753

                if not candidates:
                    raise TypeError("No matching signature found")
                elif len(candidates) > 1:
                    raise TypeError("Function call with ambiguous argument types")
                else:
754
                    return (<dict>signatures)[candidates[0]]
755 756 757 758 759
            """)

        fragment_code = pyx_code.getvalue()
        # print decl_code.getvalue()
        # print fragment_code
760 761 762
        from .Optimize import ConstantFolding
        fragment = TreeFragment.TreeFragment(
            fragment_code, level='module', pipeline=[ConstantFolding()])
763
        ast = TreeFragment.SetPosTransform(self.node.pos)(fragment.root)
Stefan Behnel's avatar
Stefan Behnel committed
764 765
        UtilityCode.declare_declarations_in_scope(
            decl_code.getvalue(), env.global_scope())
766
        ast.scope = env
767
        # FIXME: for static methods of cdef classes, we build the wrong signature here: first arg becomes 'self'
768
        ast.analyse_declarations(env)
Stefan Behnel's avatar
Stefan Behnel committed
769
        py_func = ast.stats[-1]  # the DefNode
770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818
        self.fragment_scope = ast.scope

        if isinstance(self.node, DefNode):
            py_func.specialized_cpdefs = self.nodes[:]
        else:
            py_func.specialized_cpdefs = [n.py_func for n in self.nodes]

        return py_func

    def update_fused_defnode_entry(self, env):
        copy_attributes = (
            'name', 'pos', 'cname', 'func_cname', 'pyfunc_cname',
            'pymethdef_cname', 'doc', 'doc_cname', 'is_member',
            'scope'
        )

        entry = self.py_func.entry

        for attr in copy_attributes:
            setattr(entry, attr,
                    getattr(self.orig_py_func.entry, attr))

        self.py_func.name = self.orig_py_func.name
        self.py_func.doc = self.orig_py_func.doc

        env.entries.pop('__pyx_fused_cpdef', None)
        if isinstance(self.node, DefNode):
            env.entries[entry.name] = entry
        else:
            env.entries[entry.name].as_variable = entry

        env.pyfunc_entries.append(entry)

        self.py_func.entry.fused_cfunction = self
        for node in self.nodes:
            if isinstance(self.node, DefNode):
                node.fused_py_func = self.py_func
            else:
                node.py_func.fused_py_func = self.py_func
                node.entry.as_variable = entry

        self.synthesize_defnodes()
        self.stats.append(self.__signatures__)

    def analyse_expressions(self, env):
        """
        Analyse the expressions. Take care to only evaluate default arguments
        once and clone the result for all specializations
        """
819 820 821 822 823 824
        for fused_compound_type in self.fused_compound_types:
            for fused_type in fused_compound_type.get_fused_types():
                for specialization_type in fused_type.types:
                    if specialization_type.is_complex:
                        specialization_type.create_declaration_utility_code(env)

825
        if self.py_func:
826 827 828 829
            self.__signatures__ = self.__signatures__.analyse_expressions(env)
            self.py_func = self.py_func.analyse_expressions(env)
            self.resulting_fused_function = self.resulting_fused_function.analyse_expressions(env)
            self.fused_func_assignment = self.fused_func_assignment.analyse_expressions(env)
830 831 832 833 834

        self.defaults = defaults = []

        for arg in self.node.args:
            if arg.default:
835
                arg.default = arg.default.analyse_expressions(env)
836 837 838 839
                defaults.append(ProxyNode(arg.default))
            else:
                defaults.append(None)

840 841
        for i, stat in enumerate(self.stats):
            stat = self.stats[i] = stat.analyse_expressions(env)
842 843
            if isinstance(stat, FuncDefNode) and stat is not self.py_func:
                # the dispatcher specifically doesn't want its defaults overriding
Mark Florisson's avatar
Mark Florisson committed
844
                for arg, default in zip(stat.args, defaults):
845 846 847 848 849
                    if default is not None:
                        arg.default = CloneNode(default).coerce_to(arg.type, env)

        if self.py_func:
            args = [CloneNode(default) for default in defaults if default]
Mark Florisson's avatar
Mark Florisson committed
850
            self.defaults_tuple = TupleNode(self.pos, args=args)
Robert Bradshaw's avatar
Robert Bradshaw committed
851
            self.defaults_tuple = self.defaults_tuple.analyse_types(env, skip_children=True).coerce_to_pyobject(env)
Mark Florisson's avatar
Mark Florisson committed
852
            self.defaults_tuple = ProxyNode(self.defaults_tuple)
853 854 855 856 857 858
            self.code_object = ProxyNode(self.specialized_pycfuncs[0].code_object)

            fused_func = self.resulting_fused_function.arg
            fused_func.defaults_tuple = CloneNode(self.defaults_tuple)
            fused_func.code_object = CloneNode(self.code_object)

859
            for i, pycfunc in enumerate(self.specialized_pycfuncs):
860
                pycfunc.code_object = CloneNode(self.code_object)
861
                pycfunc = self.specialized_pycfuncs[i] = pycfunc.analyse_types(env)
Mark Florisson's avatar
Mark Florisson committed
862
                pycfunc.defaults_tuple = CloneNode(self.defaults_tuple)
863
        return self
864 865 866 867 868 869 870 871 872 873

    def synthesize_defnodes(self):
        """
        Create the __signatures__ dict of PyCFunctionNode specializations.
        """
        if isinstance(self.nodes[0], CFuncDefNode):
            nodes = [node.py_func for node in self.nodes]
        else:
            nodes = self.nodes

874 875 876 877
        # For the moment, fused functions do not support METH_FASTCALL
        for node in nodes:
            node.entry.signature.use_fastcall = False

Stefan Behnel's avatar
Stefan Behnel committed
878 879
        signatures = [StringEncoding.EncodedString(node.specialized_signature_string)
                      for node in nodes]
880
        keys = [ExprNodes.StringNode(node.pos, value=sig)
Stefan Behnel's avatar
Stefan Behnel committed
881 882 883 884 885
                for node, sig in zip(nodes, signatures)]
        values = [ExprNodes.PyCFunctionNode.from_defnode(node, binding=True)
                  for node in nodes]

        self.__signatures__ = ExprNodes.DictNode.from_pairs(self.pos, zip(keys, values))
886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901

        self.specialized_pycfuncs = values
        for pycfuncnode in values:
            pycfuncnode.is_specialization = True

    def generate_function_definitions(self, env, code):
        if self.py_func:
            self.py_func.pymethdef_required = True
            self.fused_func_assignment.generate_function_definitions(env, code)

        for stat in self.stats:
            if isinstance(stat, FuncDefNode) and stat.entry.used:
                code.mark_pos(stat.pos)
                stat.generate_function_definitions(env, code)

    def generate_execution_code(self, code):
902 903
        # Note: all def function specialization are wrapped in PyCFunction
        # nodes in the self.__signatures__ dictnode.
904 905 906 907 908 909 910 911 912 913 914 915
        for default in self.defaults:
            if default is not None:
                default.generate_evaluation_code(code)

        if self.py_func:
            self.defaults_tuple.generate_evaluation_code(code)
            self.code_object.generate_evaluation_code(code)

        for stat in self.stats:
            code.mark_pos(stat.pos)
            if isinstance(stat, ExprNodes.ExprNode):
                stat.generate_evaluation_code(code)
916
            else:
917 918 919 920 921 922 923 924 925
                stat.generate_execution_code(code)

        if self.__signatures__:
            self.resulting_fused_function.generate_evaluation_code(code)

            code.putln(
                "((__pyx_FusedFunctionObject *) %s)->__signatures__ = %s;" %
                                    (self.resulting_fused_function.result(),
                                     self.__signatures__.result()))
926
            self.__signatures__.generate_giveref(code)
927 928
            self.__signatures__.generate_post_assignment_code(code)
            self.__signatures__.free_temps(code)
929 930 931 932 933

            self.fused_func_assignment.generate_execution_code(code)

            # Dispose of results
            self.resulting_fused_function.generate_disposal_code(code)
934
            self.resulting_fused_function.free_temps(code)
935
            self.defaults_tuple.generate_disposal_code(code)
936
            self.defaults_tuple.free_temps(code)
937
            self.code_object.generate_disposal_code(code)
938
            self.code_object.free_temps(code)
939 940 941 942

        for default in self.defaults:
            if default is not None:
                default.generate_disposal_code(code)
943
                default.free_temps(code)
944 945 946

    def annotate(self, code):
        for stat in self.stats:
947
            stat.annotate(code)