Buffer.py 29.7 KB
Newer Older
1 2 3 4 5 6
from Visitor import VisitorTransform, CythonTransform
from ModuleNode import ModuleNode
from Nodes import *
from ExprNodes import *
from StringEncoding import EncodedString
from Errors import CompileError
7
from UtilityCode import CythonUtilityCode
8
from Code import UtilityCode, ContentHashingUtilityCode
9
import Cython.Compiler.Options
10
import Interpreter
11
import PyrexTypes
Stefan Behnel's avatar
Stefan Behnel committed
12 13
import Naming
import Symtab
14 15 16 17 18 19 20 21 22

import textwrap

def dedent(text, reindent=0):
    text = textwrap.dedent(text)
    if reindent > 0:
        indent = " " * reindent
        text = '\n'.join([indent + x for x in text.split('\n')])
    return text
23 24 25 26 27 28 29 30

class IntroduceBufferAuxiliaryVars(CythonTransform):

    #
    # Entry point
    #

    buffers_exists = False
31
    using_memoryview = False
32 33 34

    def __call__(self, node):
        assert isinstance(node, ModuleNode)
35
        self.max_ndim = 0
36
        result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
37
        if self.buffers_exists:
38
            use_bufstruct_declare_code(node.scope)
39
            use_py2_buffer_functions(node.scope)
40 41
            node.scope.use_utility_code(empty_bufstruct_utility)

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
        return result


    #
    # Basic operations for transforms
    #
    def handle_scope(self, node, scope):
        # For all buffers, insert extra variables in the scope.
        # The variables are also accessible from the buffer_info
        # on the buffer entry
        bufvars = [entry for name, entry
                   in scope.entries.iteritems()
                   if entry.type.is_buffer]
        if len(bufvars) > 0:
            self.buffers_exists = True

58
        memviewslicevars = [entry for name, entry
59
                in scope.entries.iteritems()
60 61
                if entry.type.is_memoryviewslice]
        if len(memviewslicevars) > 0:
62 63 64 65 66 67 68 69
            self.buffers_exists = True


        for (name, entry) in scope.entries.iteritems():
            if name == 'memoryview' and isinstance(entry.utility_code_definition, CythonUtilityCode):
                self.using_memoryview = True
                break

70 71

        if isinstance(node, ModuleNode) and len(bufvars) > 0:
72
            # for now...note that pos is wrong
73 74
            raise CompileError(node.pos, "Buffer vars not allowed in module scope")
        for entry in bufvars:
75 76
            if entry.type.dtype.is_ptr:
                raise CompileError(node.pos, "Buffers with pointer types not yet supported.")
77

78 79
            name = entry.name
            buftype = entry.type
80 81 82
            if buftype.ndim > Options.buffer_max_dims:
                raise CompileError(node.pos,
                        "Buffer ndims exceeds Options.buffer_max_dims = %d" % Options.buffer_max_dims)
83 84
            if buftype.ndim > self.max_ndim:
                self.max_ndim = buftype.ndim
85 86

            # Declare auxiliary vars
87 88
            def decvar(type, prefix):
                cname = scope.mangle(prefix, name)
89
                aux_var = scope.declare_var(name=None, cname=cname,
90
                                            type=type, pos=node.pos)
91
                if entry.is_arg:
92
                    aux_var.used = True # otherwise, NameNode will mark whether it is used
93

94
                return aux_var
95

96
            auxvars = ((PyrexTypes.c_pyx_buffer_nd_type, Naming.pybuffernd_prefix),
97
                       (PyrexTypes.c_pyx_buffer_type, Naming.pybufferstruct_prefix))
98
            pybuffernd, rcbuffer = [decvar(type, prefix) for (type, prefix) in auxvars]
99

100
            entry.buffer_aux = Symtab.BufferAux(pybuffernd, rcbuffer)
101

102 103 104 105 106 107 108 109 110 111 112 113 114
        scope.buffer_entries = bufvars
        self.scope = scope

    def visit_ModuleNode(self, node):
        self.handle_scope(node, node.scope)
        self.visitchildren(node)
        return node

    def visit_FuncDefNode(self, node):
        self.handle_scope(node, node.local_scope)
        self.visitchildren(node)
        return node

115 116 117
#
# Analysis
#
118 119
buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
120
buffer_positional_options_count = 1 # anything beyond this needs keyword argument
121 122 123 124 125

ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
126
ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
127
ERR_BUF_NDIM = 'ndim must be a non-negative integer'
128
ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
129
ERR_BUF_BOOL = '"%s" must be a boolean'
130 131 132 133 134 135 136 137 138 139 140 141 142 143

def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
    """
    Must be called during type analysis, as analyse is called
    on the dtype argument.

    posargs and dictargs should consist of a list and a dict
    of tuples (value, pos). Defaults should be a dict of values.

    Returns a dict containing all the options a buffer can have and
    its value (with the positions stripped).
    """
    if defaults is None:
        defaults = buffer_defaults
144

145
    posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env, type_args = (0,'dtype'))
146

147
    if len(posargs) > buffer_positional_options_count:
148 149 150
        raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)

    options = {}
Stefan Behnel's avatar
Stefan Behnel committed
151
    for name, (value, pos) in dictargs.iteritems():
152 153
        if not name in buffer_options:
            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
Stefan Behnel's avatar
Stefan Behnel committed
154 155
        options[name] = value

156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
    for name, (value, pos) in zip(buffer_options, posargs):
        if not name in buffer_options:
            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
        if name in options:
            raise CompileError(pos, ERR_BUF_DUP % name)
        options[name] = value

    # Check that they are all there and copy defaults
    for name in buffer_options:
        if not name in options:
            try:
                options[name] = defaults[name]
            except KeyError:
                if need_complete:
                    raise CompileError(globalpos, ERR_BUF_MISSING % name)

172 173 174 175 176 177
    dtype = options.get("dtype")
    if dtype and dtype.is_extension_type:
        raise CompileError(globalpos, ERR_BUF_DTYPE)

    ndim = options.get("ndim")
    if ndim and (not isinstance(ndim, int) or ndim < 0):
178 179
        raise CompileError(globalpos, ERR_BUF_NDIM)

180
    mode = options.get("mode")
181
    if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
182 183
        raise CompileError(globalpos, ERR_BUF_MODE)

184 185 186 187 188 189 190
    def assert_bool(name):
        x = options.get(name)
        if not isinstance(x, bool):
            raise CompileError(globalpos, ERR_BUF_BOOL % name)

    assert_bool('negative_indices')
    assert_bool('cast')
191

192
    return options
193

194 195 196 197

#
# Code generation
#
198

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
class BufferEntry(object):
    def __init__(self, entry):
        self.entry = entry
        self.type = entry.type
        self.cname = entry.buffer_aux.buflocal_nd_var.cname
        self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
        self.buf_ptr_type = self.entry.type.buffer_ptr_type

    def get_buf_suboffsetvars(self):
        return self._for_all_ndim("%s.diminfo[%d].suboffsets")

    def get_buf_stridevars(self):
        return self._for_all_ndim("%s.diminfo[%d].strides")

    def get_buf_shapevars(self):
        return self._for_all_ndim("%s.diminfo[%d].shape")

    def _for_all_ndim(self, s):
        return [s % (self.cname, i) for i in range(self.type.ndim)]

219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
    def generate_buffer_lookup_code(self, code, index_cnames):
        # Create buffer lookup and return it
        # This is done via utility macros/inline functions, which vary
        # according to the access mode used.
        params = []
        nd = self.type.ndim
        mode = self.type.mode
        if mode == 'full':
            for i, s, o in zip(index_cnames,
                               self.get_buf_stridevars(),
                               self.get_buf_suboffsetvars()):
                params.append(i)
                params.append(s)
                params.append(o)
            funcname = "__Pyx_BufPtrFull%dd" % nd
            funcgen = buf_lookup_full_code
        else:
            if mode == 'strided':
                funcname = "__Pyx_BufPtrStrided%dd" % nd
                funcgen = buf_lookup_strided_code
            elif mode == 'c':
                funcname = "__Pyx_BufPtrCContig%dd" % nd
                funcgen = buf_lookup_c_code
            elif mode == 'fortran':
                funcname = "__Pyx_BufPtrFortranContig%dd" % nd
                funcgen = buf_lookup_fortran_code
            else:
                assert False
            for i, s in zip(index_cnames, self.get_buf_stridevars()):
                params.append(i)
                params.append(s)

        # Make sure the utility code is available
        if funcname not in code.globalstate.utility_codes:
            code.globalstate.utility_codes.add(funcname)
            protocode = code.globalstate['utility_code_proto']
            defcode = code.globalstate['utility_code_def']
            funcgen(protocode, defcode, name=funcname, nd=nd)

        buf_ptr_type_code = self.buf_ptr_type.declaration_code("")
        ptrcode = "%s(%s, %s, %s)" % (funcname, buf_ptr_type_code, self.buf_ptr,
                                      ", ".join(params))
        return ptrcode

263

264
def get_flags(buffer_aux, buffer_type):
265
    flags = 'PyBUF_FORMAT'
266 267
    mode = buffer_type.mode
    if mode == 'full':
268
        flags += '| PyBUF_INDIRECT'
269
    elif mode == 'strided':
270
        flags += '| PyBUF_STRIDES'
271 272 273 274
    elif mode == 'c':
        flags += '| PyBUF_C_CONTIGUOUS'
    elif mode == 'fortran':
        flags += '| PyBUF_F_CONTIGUOUS'
275 276
    else:
        assert False
277 278
    if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
    return flags
279

280 281
def used_buffer_aux_vars(entry):
    buffer_aux = entry.buffer_aux
282 283
    buffer_aux.buflocal_nd_var.used = True
    buffer_aux.rcbuf_var.used = True
284

285
def put_unpack_buffer_aux_into_scope(buf_entry, code):
286 287
    # Generate code to copy the needed struct info into local
    # variables.
288 289
    buffer_aux, mode = buf_entry.buffer_aux, buf_entry.type.mode
    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
290

291
    fldnames = ['strides', 'shape']
292
    if mode == 'full':
293 294 295 296 297 298 299 300 301
        fldnames.append('suboffsets')

    ln = []
    for i in range(buf_entry.type.ndim):
        for fldname in fldnames:
            ln.append("%s.diminfo[%d].%s = %s.rcbuffer->pybuffer.%s[%d];" % \
                    (pybuffernd_struct, i, fldname,
                     pybuffernd_struct, fldname, i))
    code.putln(' '.join(ln))
302

303 304 305 306 307 308 309 310 311 312 313 314
def put_init_vars(entry, code):
    bufaux = entry.buffer_aux
    pybuffernd_struct = bufaux.buflocal_nd_var.cname
    pybuffer_struct = bufaux.rcbuf_var.cname
    # init pybuffer_struct
    code.putln("%s.pybuffer.buf = NULL;" % pybuffer_struct)
    code.putln("%s.refcount = 0;" % pybuffer_struct)
    # init the buffer object
    # code.put_init_var_to_py_none(entry)
    # init the pybuffernd_struct
    code.putln("%s.data = NULL;" % pybuffernd_struct)
    code.putln("%s.rcbuffer = &%s;" % (pybuffernd_struct, pybuffer_struct))
315 316

def put_acquire_arg_buffer(entry, code, pos):
317
    code.globalstate.use_utility_code(acquire_utility_code)
318
    buffer_aux = entry.buffer_aux
319
    getbuffer = get_getbuffer_call(code, entry.cname, buffer_aux, entry.type)
320

321
    # Acquire any new buffer
322
    code.putln("{")
323
    code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % entry.type.dtype.struct_nesting_depth())
324 325
    code.putln(code.error_goto_if("%s == -1" % getbuffer, pos))
    code.putln("}")
326
    # An exception raised in arg parsing cannot be catched, so no
327
    # need to care about the buffer then.
328
    put_unpack_buffer_aux_into_scope(entry, code)
329

330 331
def put_release_buffer_code(code, entry):
    code.globalstate.use_utility_code(acquire_utility_code)
332
    code.putln("__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);" % entry.buffer_aux.buflocal_nd_var.cname)
333

334 335 336 337
def get_getbuffer_call(code, obj_cname, buffer_aux, buffer_type):
    ndim = buffer_type.ndim
    cast = int(buffer_type.cast)
    flags = get_flags(buffer_aux, buffer_type)
338
    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
339 340

    dtype_typeinfo = get_type_information_cname(code, buffer_type.dtype)
341

342
    return ("__Pyx_GetBufferAndValidate(&%(pybuffernd_struct)s.rcbuffer->pybuffer, "
343
            "(PyObject*)%(obj_cname)s, &%(dtype_typeinfo)s, %(flags)s, %(ndim)d, "
344
            "%(cast)d, __pyx_stack)" % locals())
345

346
def put_assign_to_buffer(lhs_cname, rhs_cname, buf_entry,
347
                         is_initialized, pos, code):
348 349 350 351 352 353 354
    """
    Generate code for reassigning a buffer variables. This only deals with getting
    the buffer auxiliary structure and variables set up correctly, the assignment
    itself and refcounting is the responsibility of the caller.

    However, the assignment operation may throw an exception so that the reassignment
    never happens.
355

356 357 358 359 360
    Depending on the circumstances there are two possible outcomes:
    - Old buffer released, new acquired, rhs assigned to lhs
    - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
      (which may or may not succeed).
    """
361

362
    buffer_aux, buffer_type = buf_entry.buffer_aux, buf_entry.type
363
    code.globalstate.use_utility_code(acquire_utility_code)
364
    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
365
    flags = get_flags(buffer_aux, buffer_type)
366

367
    code.putln("{")  # Set up necesarry stack for getbuffer
368
    code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % buffer_type.dtype.struct_nesting_depth())
369

370
    getbuffer = get_getbuffer_call(code, "%s", buffer_aux, buffer_type) # fill in object below
371

372 373
    if is_initialized:
        # Release any existing buffer
374
        code.putln('__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);' % pybuffernd_struct)
375
        # Acquire
376
        retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
377
        code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
Stefan Behnel's avatar
Stefan Behnel committed
378
        code.putln('if (%s) {' % (code.unlikely("%s < 0" % retcode_cname)))
379 380 381 382
        # If acquisition failed, attempt to reacquire the old buffer
        # before raising the exception. A failure of reacquisition
        # will cause the reacquisition exception to be reported, one
        # can consider working around this later.
383
        type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type, manage_ref=False)
384 385
                           for i in range(3)]
        code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
Stefan Behnel's avatar
Stefan Behnel committed
386
        code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
387
        code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb)) # Do not refnanny these!
388
        code.globalstate.use_utility_code(raise_buffer_fallback_code)
389
        code.putln('__Pyx_RaiseBufferFallbackError();')
390
        code.putln('} else {')
391 392
        code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
        for t in (type, value, tb):
393
            code.funcstate.release_temp(t)
Stefan Behnel's avatar
Stefan Behnel committed
394 395
        code.putln('}')
        code.putln('}')
396
        # Unpack indices
397
        put_unpack_buffer_aux_into_scope(buf_entry, code)
398
        code.putln(code.error_goto_if_neg(retcode_cname, pos))
399
        code.funcstate.release_temp(retcode_cname)
400
    else:
401 402 403 404
        # Our entry had no previous value, so set to None when acquisition fails.
        # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
        # so it suffices to set the buf field to NULL.
        code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
405
        code.putln('%s = %s; __Pyx_INCREF(Py_None); %s.rcbuffer->pybuffer.buf = NULL;' %
406 407
                   (lhs_cname,
                    PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
408
                    pybuffernd_struct))
409 410 411
        code.putln(code.error_goto(pos))
        code.put('} else {')
        # Unpack indices
412
        put_unpack_buffer_aux_into_scope(buf_entry, code)
413
        code.putln('}')
414

415
    code.putln("}") # Release stack
416

417 418
def put_buffer_lookup_code(entry, index_signeds, index_cnames, directives,
                           pos, code, negative_indices):
419 420 421 422 423
    """
    Generates code to process indices and calculate an offset into
    a buffer. Returns a C string which gives a pointer which can be
    read from or written to at will (it is an expression so caller should
    store it in a temporary if it is used more than once).
424 425 426 427 428

    As the bounds checking can have any number of combinations of unsigned
    arguments, smart optimizations etc. we insert it directly in the function
    body. The lookup however is delegated to a inline function that is instantiated
    once per ndim (lookup with suboffsets tend to get quite complicated).
429

430
    entry is a BufferEntry
431
    """
432
    negative_indices = directives['wraparound'] and negative_indices
433

434
    if directives['boundscheck']:
435 436 437 438
        # Check bounds and fix negative indices.
        # We allocate a temporary which is initialized to -1, meaning OK (!).
        # If an error occurs, the temp is set to the dimension index the
        # error is occuring at.
439
        tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
440
        code.putln("%s = -1;" % tmp_cname)
441
        for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
442
                                                         entry.get_buf_shapevars())):
443 444 445
            if signed != 0:
                # not unsigned, deal with negative index
                code.putln("if (%s < 0) {" % cname)
446
                if negative_indices:
447
                    code.putln("%s += %s;" % (cname, shape))
448 449 450 451
                    code.putln("if (%s) %s = %d;" % (
                        code.unlikely("%s < 0" % cname), tmp_cname, dim))
                else:
                    code.putln("%s = %d;" % (tmp_cname, dim))
452
                code.put("} else ")
453
            # check bounds in positive direction
454
            if signed != 0:
455 456 457
                cast = ""
            else:
                cast = "(size_t)"
458
            code.putln("if (%s) %s = %d;" % (
459
                code.unlikely("%s >= %s%s" % (cname, cast, shape)),
460
                              tmp_cname, dim))
461
        code.globalstate.use_utility_code(raise_indexerror_code)
Stefan Behnel's avatar
Stefan Behnel committed
462
        code.putln("if (%s) {" % code.unlikely("%s != -1" % tmp_cname))
463
        code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
464
        code.putln(code.error_goto(pos))
Stefan Behnel's avatar
Stefan Behnel committed
465
        code.putln('}')
466
        code.funcstate.release_temp(tmp_cname)
467
    elif negative_indices:
468 469
        # Only fix negative indices.
        for signed, cname, shape in zip(index_signeds, index_cnames,
470
                                        entry.get_buf_shapevars()):
471
            if signed != 0:
472
                code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape))
473

474
    return entry.generate_buffer_lookup_code(code, index_cnames)
475

476

477 478 479
def use_bufstruct_declare_code(env):
    env.use_utility_code(buffer_struct_declare_code)

480 481

def get_empty_bufstruct_code(max_ndim):
482 483 484 485
    code = dedent("""
        Py_ssize_t __Pyx_zeros[] = {%s};
        Py_ssize_t __Pyx_minusones[] = {%s};
    """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
486
    return UtilityCode(proto=code)
487

488
empty_bufstruct_utility = get_empty_bufstruct_code(Options.buffer_max_dims)
489

490
def buf_lookup_full_code(proto, defin, name, nd):
491
    """
492
    Generates a buffer lookup function for the right number
493 494
    of dimensions. The function gives back a void* at the right location.
    """
495
    # _i_ndex, _s_tride, sub_o_ffset
496 497 498 499
    macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
    proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))

    funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
500
    proto.putln("static CYTHON_INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
501
    defin.putln(dedent("""
502
        static CYTHON_INLINE void* %s_imp(void* buf, %s) {
503
          char* ptr = (char*)buf;
504
        """) % (name, funcargs) + "".join([dedent("""\
505
          ptr += s%d * i%d;
506
          if (o%d >= 0) ptr = *((char**)ptr) + o%d;
507
        """) % (i, i, i, i) for i in range(nd)]
508
        ) + "\nreturn ptr;\n}")
509

510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542
def buf_lookup_strided_code(proto, defin, name, nd):
    """
    Generates a buffer lookup function for the right number
    of dimensions. The function gives back a void* at the right location.
    """
    # _i_ndex, _s_tride
    args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
    offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
    proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))

def buf_lookup_c_code(proto, defin, name, nd):
    """
    Similar to strided lookup, but can assume that the last dimension
    doesn't need a multiplication as long as.
    Still we keep the same signature for now.
    """
    if nd == 1:
        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
    else:
        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))

def buf_lookup_fortran_code(proto, defin, name, nd):
    """
    Like C lookup, but the first index is optimized instead.
    """
    if nd == 1:
        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
    else:
        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
543

544 545

def use_py2_buffer_functions(env):
546 547 548
    env.use_utility_code(GetAndReleaseBufferUtilityCode())

class GetAndReleaseBufferUtilityCode(object):
549 550 551
    # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
    # For >= 2.6 we do double mode -- use the new buffer interface on objects
    # which has the right tp_flags set, but emulation otherwise.
552

553
    requires = None
554
    is_cython_utility = False
555

556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579
    def __init__(self):
        pass

    def __eq__(self, other):
        return isinstance(other, GetAndReleaseBufferUtilityCode)

    def __hash__(self):
        return 24342342

    def get_tree(self): pass

    def put_code(self, output):
        code = output['utility_code_def']
        proto = output['utility_code_proto']
        env = output.module_node.scope
        cython_scope = env.context.cython_scope

        proto.put(dedent("""\
            #if PY_MAJOR_VERSION < 3
            static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
            static void __Pyx_ReleaseBuffer(Py_buffer *view);
            #else
            #define __Pyx_GetBuffer PyObject_GetBuffer
            #define __Pyx_ReleaseBuffer PyBuffer_Release
580
            #endif
581 582 583 584 585 586 587 588 589 590 591 592
        """))
        
        # Search all types for __getbuffer__ overloads
        types = []
        visited_scopes = set()
        def find_buffer_types(scope):
            if scope in visited_scopes:
                return
            visited_scopes.add(scope)
            for m in scope.cimported_modules:
                find_buffer_types(m)
            for e in scope.type_entries:
593 594
                if isinstance(e.utility_code_definition, CythonUtilityCode):
                    continue
595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619
                t = e.type
                if t.is_extension_type:
                    if scope is cython_scope and not e.used:
                        continue
                    release = get = None
                    for x in t.scope.pyfunc_entries:
                        if x.name == u"__getbuffer__": get = x.func_cname
                        elif x.name == u"__releasebuffer__": release = x.func_cname
                    if get:
                        types.append((t.typeptr_cname, get, release))

        find_buffer_types(env)

        code.put(dedent("""
            #if PY_MAJOR_VERSION < 3
            static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
              #if PY_VERSION_HEX >= 0x02060000
              if (PyObject_CheckBuffer(obj)) return PyObject_GetBuffer(obj, view, flags);
              #endif
            """))
        
        if len(types) > 0:
            clause = "if"
            for t, get, release in types:
                code.putln("  %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);" % (clause, t, get))
620
                clause = "else if"
621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653
            code.putln("  else {")
        code.put(dedent("""\
            PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
            return -1;
            """, 2))
        if len(types) > 0:
            code.putln("  }")
        code.put(dedent("""\
             }

            static void __Pyx_ReleaseBuffer(Py_buffer *view) {
              PyObject* obj = view->obj;
              if (obj) {
                #if PY_VERSION_HEX >= 0x02060000
                if (PyObject_CheckBuffer(obj)) {PyBuffer_Release(view); return;}
                #endif
        """))
                 
        if len(types) > 0:
            clause = "if"
            for t, get, release in types:
                if release:
                    code.putln("%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release))
                    clause = "else if"
        code.put(dedent("""
                Py_DECREF(obj);
                view->obj = NULL;
              }
            }

            #endif
        """))

654 655


656 657 658 659 660 661 662 663 664 665 666 667 668 669 670
def mangle_dtype_name(dtype):
    # Use prefixes to seperate user defined types from builtins
    # (consider "typedef float unsigned_int")
    if dtype.is_pyobject:
        return "object"
    elif dtype.is_ptr:
        return "ptr"
    else:
        if dtype.is_typedef or dtype.is_struct_or_union:
            prefix = "nn_"
        else:
            prefix = ""
        return prefix + dtype.declaration_code("").replace(" ", "_")

def get_type_information_cname(code, dtype, maxdepth=None):
671
    # Output the run-time type information (__Pyx_TypeInfo) for given dtype,
672 673
    # and return the name of the type info struct.
    #
674 675
    # Structs with two floats of the same size are encoded as complex numbers.
    # One can seperate between complex numbers declared as struct or with native
676 677
    # encoding by inspecting to see if the fields field of the type is
    # filled in.
678 679 680
    namesuffix = mangle_dtype_name(dtype)
    name = "__Pyx_TypeInfo_%s" % namesuffix
    structinfo_name = "__Pyx_StructFields_%s" % namesuffix
681

682
    if dtype.is_error: return "<error>"
683

684 685 686
    # It's critical that walking the type info doesn't use more stack
    # depth than dtype.struct_nesting_depth() returns, so use an assertion for this
    if maxdepth is None: maxdepth = dtype.struct_nesting_depth()
687 688 689
    if maxdepth <= 0:
        assert False

690 691 692
    if name not in code.globalstate.utility_codes:
        code.globalstate.utility_codes.add(name)
        typecode = code.globalstate['typeinfo']
693

694
        complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
695

696 697 698 699 700 701 702 703
        declcode = dtype.declaration_code("")
        if dtype.is_simple_buffer_dtype():
            structinfo_name = "NULL"
        elif dtype.is_struct:
            fields = dtype.scope.var_entries
            # Must pre-call all used types in order not to recurse utility code
            # writing.
            assert len(fields) > 0
704
            types = [get_type_information_cname(code, f.type, maxdepth - 1)
705 706 707 708 709 710 711 712 713
                     for f in fields]
            typecode.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True)
            for f, typeinfo in zip(fields, types):
                typecode.putln('  {&%s, "%s", offsetof(%s, %s)},' %
                           (typeinfo, f.name, dtype.declaration_code(""), f.cname), safe=True)
            typecode.putln('  {NULL, NULL, 0}', safe=True)
            typecode.putln("};", safe=True)
        else:
            assert False
714

715
        rep = str(dtype)
716 717 718

        flags = "0"

719 720 721 722 723
        if dtype.is_int:
            if dtype.signed == 0:
                typegroup = 'U'
            else:
                typegroup = 'I'
724
        elif complex_possible or dtype.is_complex:
725 726 727 728 729
            typegroup = 'C'
        elif dtype.is_float:
            typegroup = 'R'
        elif dtype.is_struct:
            typegroup = 'S'
730 731
            if dtype.packed:
                flags = "__PYX_BUF_FLAGS_PACKED_STRUCT"
732 733
        elif dtype.is_pyobject:
            typegroup = 'O'
734
        else:
735 736
            print dtype
            assert False
737

738 739 740 741 742
        if dtype.is_int:
            is_unsigned = "IS_UNSIGNED(%s)" % declcode
        else:
            is_unsigned = "0"

743
        typecode.putln(('static __Pyx_TypeInfo %s = { "%s", %s, sizeof(%s), \'%s\', %s, %s };'
744 745 746 747 748
                        ) % (name,
                             rep,
                             structinfo_name,
                             declcode,
                             typegroup,
749
                             is_unsigned,
750
                             flags,
751 752
                        ), safe=True)
    return name
753

754 755
def load_buffer_utility(util_code_name, **kwargs):
    return UtilityCode.load(util_code_name, "Buffer.c", **kwargs)
756

757 758 759
context = dict(max_dims=Options.buffer_max_dims)
buffer_struct_declare_code = load_buffer_utility("BufferStructDeclare",
                                                 context=context)
760

761

762 763
# Utility function to set the right exception
# The caller should immediately goto_error
764
raise_indexerror_code = load_buffer_utility("BufferIndexError")
765

766 767 768 769 770
parse_typestring_repeat_code = UtilityCode(
proto = """
""",
impl = """
""")
771

772
raise_buffer_fallback_code = load_buffer_utility("BufferFallbackError")
773
buffer_structs_code = load_buffer_utility("BufferFormatStructs")
774
acquire_utility_code = load_buffer_utility("BufferFormatCheck",
775 776 777 778 779
                                           context=context,
                                           requires=[buffer_structs_code])

# See utility code BufferFormatFromTypeInfo
_typeinfo_to_format_code = load_buffer_utility(
780
        "TypeInfoToFormat", context={}, requires=[buffer_structs_code])