Buffer.py 34.7 KB
Newer Older
1
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform
2 3 4
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
5
from Cython.Compiler.StringEncoding import EncodedString
6
from Cython.Compiler.Errors import CompileError
7
from Cython.Utils import UtilityCode
8
import Interpreter
9
import PyrexTypes
10

Stefan Behnel's avatar
Stefan Behnel committed
11 12 13 14 15
try:
    set
except NameError:
    from sets import Set as set

16 17
import textwrap

18 19 20 21 22 23
# Code cleanup ideas:
# - One could be more smart about casting in some places
# - Start using CCodeWriters to generate utility functions
# - Create a struct type per ndim rather than keeping loose local vars


24 25 26 27 28 29
def dedent(text, reindent=0):
    text = textwrap.dedent(text)
    if reindent > 0:
        indent = " " * reindent
        text = '\n'.join([indent + x for x in text.split('\n')])
    return text
30 31 32 33 34 35 36 37 38 39 40

class IntroduceBufferAuxiliaryVars(CythonTransform):

    #
    # Entry point
    #

    buffers_exists = False

    def __call__(self, node):
        assert isinstance(node, ModuleNode)
41
        self.max_ndim = 0
42 43 44
        result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
        if self.buffers_exists:
            use_py2_buffer_functions(node.scope)
45
            use_empty_bufstruct_code(node.scope, self.max_ndim)
46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        return result


    #
    # Basic operations for transforms
    #
    def handle_scope(self, node, scope):
        # For all buffers, insert extra variables in the scope.
        # The variables are also accessible from the buffer_info
        # on the buffer entry
        bufvars = [entry for name, entry
                   in scope.entries.iteritems()
                   if entry.type.is_buffer]
        if len(bufvars) > 0:
            self.buffers_exists = True


        if isinstance(node, ModuleNode) and len(bufvars) > 0:
            # for now...note that pos is wrong 
            raise CompileError(node.pos, "Buffer vars not allowed in module scope")
        for entry in bufvars:
67 68 69
            if entry.type.dtype.is_ptr:
                raise CompileError(node.pos, "Buffers with pointer types not yet supported.")
            
70 71
            name = entry.name
            buftype = entry.type
72 73
            if buftype.ndim > self.max_ndim:
                self.max_ndim = buftype.ndim
74 75 76 77 78

            # Declare auxiliary vars
            cname = scope.mangle(Naming.bufstruct_prefix, name)
            bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
                                        type=PyrexTypes.c_py_buffer_type, pos=node.pos)
79 80
            if entry.is_arg:
                bufinfo.used = True # otherwise, NameNode will mark whether it is used
81

82
            def var(prefix, idx, initval):
83 84 85 86
                cname = scope.mangle(prefix, "%d_%s" % (idx, name))
                result = scope.declare_var("$%s" % cname, PyrexTypes.c_py_ssize_t_type,
                                         node.pos, cname=cname, is_cdef=True)

87
                result.init = initval
88 89 90 91
                if entry.is_arg:
                    result.used = True
                return result
            
92

93 94
            stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
            shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]            
95 96 97
            mode = entry.type.mode
            if mode == 'full':
                suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
98
            else:
99 100
                suboffsetvars = None

101
            entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, suboffsetvars)
102 103 104 105 106 107 108 109 110 111 112 113 114 115
            
        scope.buffer_entries = bufvars
        self.scope = scope

    def visit_ModuleNode(self, node):
        self.handle_scope(node, node.scope)
        self.visitchildren(node)
        return node

    def visit_FuncDefNode(self, node):
        self.handle_scope(node, node.local_scope)
        self.visitchildren(node)
        return node

116 117 118
#
# Analysis
#
119 120
buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
121
buffer_positional_options_count = 1 # anything beyond this needs keyword argument
122 123 124 125 126

ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
127
ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
128
ERR_BUF_NDIM = 'ndim must be a non-negative integer'
129
ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
130
ERR_BUF_BOOL = '"%s" must be a boolean'
131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147

def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
    """
    Must be called during type analysis, as analyse is called
    on the dtype argument.

    posargs and dictargs should consist of a list and a dict
    of tuples (value, pos). Defaults should be a dict of values.

    Returns a dict containing all the options a buffer can have and
    its value (with the positions stripped).
    """
    if defaults is None:
        defaults = buffer_defaults
    
    posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env)
    
148
    if len(posargs) > buffer_positional_options_count:
149 150 151
        raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)

    options = {}
Stefan Behnel's avatar
Stefan Behnel committed
152
    for name, (value, pos) in dictargs.iteritems():
153 154
        if not name in buffer_options:
            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
Stefan Behnel's avatar
Stefan Behnel committed
155
        options[name.encode("ASCII")] = value
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
    
    for name, (value, pos) in zip(buffer_options, posargs):
        if not name in buffer_options:
            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
        if name in options:
            raise CompileError(pos, ERR_BUF_DUP % name)
        options[name] = value

    # Check that they are all there and copy defaults
    for name in buffer_options:
        if not name in options:
            try:
                options[name] = defaults[name]
            except KeyError:
                if need_complete:
                    raise CompileError(globalpos, ERR_BUF_MISSING % name)

173 174 175 176 177 178
    dtype = options.get("dtype")
    if dtype and dtype.is_extension_type:
        raise CompileError(globalpos, ERR_BUF_DTYPE)

    ndim = options.get("ndim")
    if ndim and (not isinstance(ndim, int) or ndim < 0):
179 180
        raise CompileError(globalpos, ERR_BUF_NDIM)

181
    mode = options.get("mode")
182
    if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
183 184
        raise CompileError(globalpos, ERR_BUF_MODE)

185 186 187 188 189 190 191
    def assert_bool(name):
        x = options.get(name)
        if not isinstance(x, bool):
            raise CompileError(globalpos, ERR_BUF_BOOL % name)

    assert_bool('negative_indices')
    assert_bool('cast')
192

193 194 195 196 197 198
    return options
    

#
# Code generation
#
199 200


201
def get_flags(buffer_aux, buffer_type):
202
    flags = 'PyBUF_FORMAT'
203 204
    mode = buffer_type.mode
    if mode == 'full':
205
        flags += '| PyBUF_INDIRECT'
206
    elif mode == 'strided':
207
        flags += '| PyBUF_STRIDES'
208 209 210 211
    elif mode == 'c':
        flags += '| PyBUF_C_CONTIGUOUS'
    elif mode == 'fortran':
        flags += '| PyBUF_F_CONTIGUOUS'
212 213
    else:
        assert False
214 215 216
    if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
    return flags
        
217 218 219 220 221
def used_buffer_aux_vars(entry):
    buffer_aux = entry.buffer_aux
    buffer_aux.buffer_info_var.used = True
    for s in buffer_aux.shapevars: s.used = True
    for s in buffer_aux.stridevars: s.used = True
222 223
    if buffer_aux.suboffsetvars:
        for s in buffer_aux.suboffsetvars: s.used = True
224

225 226 227
def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
    # Generate code to copy the needed struct info into local
    # variables.
228 229
    bufstruct = buffer_aux.buffer_info_var.cname

230 231 232 233
    varspec = [("strides", buffer_aux.stridevars),
               ("shape", buffer_aux.shapevars)]
    if mode == 'full':
        varspec.append(("suboffsets", buffer_aux.suboffsetvars))
234

235
    for field, vars in varspec:
236 237 238
        code.putln(" ".join(["%s = %s.%s[%d];" %
                             (s.cname, bufstruct, field, idx)
                             for idx, s in enumerate(vars)]))
239 240

def put_acquire_arg_buffer(entry, code, pos):
241
    code.globalstate.use_utility_code(acquire_utility_code)
242
    buffer_aux = entry.buffer_aux
243
    getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
244

245
    # Acquire any new buffer
246
    code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d, %d) == -1" % (
247 248 249 250
        getbuffer_cname,
        entry.cname,
        entry.buffer_aux.buffer_info_var.cname,
        get_flags(buffer_aux, entry.type),
251 252
        entry.type.ndim,
        int(entry.type.cast)), pos))
253
    # An exception raised in arg parsing cannot be catched, so no
254
    # need to care about the buffer then.
255
    put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
256

257 258 259 260 261 262
#def put_release_buffer_normal(entry, code):
#    code.putln("if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s);" % (
#        entry.cname,
#        entry.cname,
#        entry.buffer_aux.buffer_info_var.cname))

263
def get_release_buffer_code(entry):
264
    return "__Pyx_SafeReleaseBuffer(&%s)" % entry.buffer_aux.buffer_info_var.cname
265

266
def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
267
                         is_initialized, pos, code):
268 269 270 271 272 273 274 275 276 277 278 279 280
    """
    Generate code for reassigning a buffer variables. This only deals with getting
    the buffer auxiliary structure and variables set up correctly, the assignment
    itself and refcounting is the responsibility of the caller.

    However, the assignment operation may throw an exception so that the reassignment
    never happens.
    
    Depending on the circumstances there are two possible outcomes:
    - Old buffer released, new acquired, rhs assigned to lhs
    - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
      (which may or may not succeed).
    """
281

282
    code.globalstate.use_utility_code(acquire_utility_code)
283
    bufstruct = buffer_aux.buffer_info_var.cname
284
    flags = get_flags(buffer_aux, buffer_type)
285

286
    getbuffer = "%s((PyObject*)%%s, &%s, %s, %d, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
287
                                          # note: object is filled in later (%%s)
288 289
                                          bufstruct,
                                          flags,
290 291
                                          buffer_type.ndim,
                                          int(buffer_type.cast))
292

293 294
    if is_initialized:
        # Release any existing buffer
295
        code.putln('__Pyx_SafeReleaseBuffer(&%s);' % bufstruct)
296
        # Acquire
297
        retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
298
        code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
299
        code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
300 301 302 303 304
        # If acquisition failed, attempt to reacquire the old buffer
        # before raising the exception. A failure of reacquisition
        # will cause the reacquisition exception to be reported, one
        # can consider working around this later.
        code.begin_block()
305
        type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type)
306 307
                           for i in range(3)]
        code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
308
        code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
309
        code.begin_block()
310
        code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
311
        code.globalstate.use_utility_code(raise_buffer_fallback_code)
312
        code.putln('__Pyx_RaiseBufferFallbackError();')
313
        code.putln('} else {')
314 315
        code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
        for t in (type, value, tb):
316
            code.funcstate.release_temp(t)
317 318
        code.end_block()
        # Unpack indices
319
        code.end_block()
320
        put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
321
        code.putln(code.error_goto_if_neg(retcode_cname, pos))
322
        code.funcstate.release_temp(retcode_cname)
323
    else:
324 325 326 327
        # Our entry had no previous value, so set to None when acquisition fails.
        # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
        # so it suffices to set the buf field to NULL.
        code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
328 329 330 331
        code.putln('%s = %s; Py_INCREF(Py_None); %s.buf = NULL;' %
                   (lhs_cname,
                    PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
                    bufstruct))
332 333 334
        code.putln(code.error_goto(pos))
        code.put('} else {')
        # Unpack indices
335
        put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
336
        code.putln('}')
337

338

339 340 341 342 343 344
def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
    """
    Generates code to process indices and calculate an offset into
    a buffer. Returns a C string which gives a pointer which can be
    read from or written to at will (it is an expression so caller should
    store it in a temporary if it is used more than once).
345 346 347 348 349

    As the bounds checking can have any number of combinations of unsigned
    arguments, smart optimizations etc. we insert it directly in the function
    body. The lookup however is delegated to a inline function that is instantiated
    once per ndim (lookup with suboffsets tend to get quite complicated).
350

351
    """
352 353
    bufaux = entry.buffer_aux
    bufstruct = bufaux.buffer_info_var.cname
354
    negative_indices = entry.type.negative_indices
355

356 357 358 359 360 361
    if options['boundscheck']:
        # Check bounds and fix negative indices.
        # We allocate a temporary which is initialized to -1, meaning OK (!).
        # If an error occurs, the temp is set to the dimension index the
        # error is occuring at.
        tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
362
        code.putln("%s = -1;" % tmp_cname)
363 364 365 366 367
        for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
                                                         bufaux.shapevars)):
            if signed != 0:
                # not unsigned, deal with negative index
                code.putln("if (%s < 0) {" % cname)
368 369 370 371 372 373
                if negative_indices:
                    code.putln("%s += %s;" % (cname, shape.cname))
                    code.putln("if (%s) %s = %d;" % (
                        code.unlikely("%s < 0" % cname), tmp_cname, dim))
                else:
                    code.putln("%s = %d;" % (tmp_cname, dim))
374
                code.put("} else ")
375 376 377
            # check bounds in positive direction
            code.putln("if (%s) %s = %d;" % (
                code.unlikely("%s >= %s" % (cname, shape.cname)),
378 379
                tmp_cname, dim))
        code.globalstate.use_utility_code(raise_indexerror_code)
380 381
        code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
        code.begin_block()
382
        code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
383
        code.putln(code.error_goto(pos))
384
        code.end_block()
385
        code.funcstate.release_temp(tmp_cname)
386
    elif negative_indices:
387 388 389 390 391 392
        # Only fix negative indices.
        for signed, cname, shape in zip(index_signeds, index_cnames,
                                        bufaux.shapevars):
            if signed != 0:
                code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape.cname))
        
393
    # Create buffer lookup and return it
394 395
    # This is done via utility macros/inline functions, which vary
    # according to the access mode used.
396
    params = []
397
    nd = entry.type.ndim
398 399
    mode = entry.type.mode
    if mode == 'full':
400 401 402 403
        for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
            params.append(i)
            params.append(s.cname)
            params.append(o.cname)
404 405
        funcname = "__Pyx_BufPtrFull%dd" % nd
        funcgen = buf_lookup_full_code
406
    else:
407 408 409 410 411 412 413 414 415 416 417
        if mode == 'strided':
            funcname = "__Pyx_BufPtrStrided%dd" % nd
            funcgen = buf_lookup_strided_code
        elif mode == 'c':
            funcname = "__Pyx_BufPtrCContig%dd" % nd
            funcgen = buf_lookup_c_code
        elif mode == 'fortran':
            funcname = "__Pyx_BufPtrFortranContig%dd" % nd
            funcgen = buf_lookup_fortran_code
        else:
            assert False
418 419 420
        for i, s in zip(index_cnames, bufaux.stridevars):
            params.append(i)
            params.append(s.cname)
421
        
422
    # Make sure the utility code is available
423
    code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
424

425 426 427 428 429 430
    ptr_type = entry.type.buffer_ptr_type
    ptrcode = "%s(%s, %s.buf, %s)" % (funcname,
                                      ptr_type.declaration_code(""),
                                      bufstruct,
                                      ", ".join(params))
    return ptrcode
431

432 433 434 435 436 437

def use_empty_bufstruct_code(env, max_ndim):
    code = dedent("""
        Py_ssize_t __Pyx_zeros[] = {%s};
        Py_ssize_t __Pyx_minusones[] = {%s};
    """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
438
    env.use_utility_code(UtilityCode(proto=code), "empty_bufstruct_code")
439

440

441
def buf_lookup_full_code(proto, defin, name, nd):
442
    """
443
    Generates a buffer lookup function for the right number
444 445
    of dimensions. The function gives back a void* at the right location.
    """
446
    # _i_ndex, _s_tride, sub_o_ffset
447 448 449 450 451
    macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
    proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))

    funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
    proto.putln("static INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
452
    defin.putln(dedent("""
453
        static INLINE void* %s_imp(void* buf, %s) {
454
          char* ptr = (char*)buf;
455
        """) % (name, funcargs) + "".join([dedent("""\
456 457 458
          ptr += s%d * i%d;
          if (o%d >= 0) ptr = *((char**)ptr) + o%d; 
        """) % (i, i, i, i) for i in range(nd)]
459
        ) + "\nreturn ptr;\n}")
460

461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493
def buf_lookup_strided_code(proto, defin, name, nd):
    """
    Generates a buffer lookup function for the right number
    of dimensions. The function gives back a void* at the right location.
    """
    # _i_ndex, _s_tride
    args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
    offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
    proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))

def buf_lookup_c_code(proto, defin, name, nd):
    """
    Similar to strided lookup, but can assume that the last dimension
    doesn't need a multiplication as long as.
    Still we keep the same signature for now.
    """
    if nd == 1:
        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
    else:
        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))

def buf_lookup_fortran_code(proto, defin, name, nd):
    """
    Like C lookup, but the first index is optimized instead.
    """
    if nd == 1:
        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
    else:
        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
494

495 496 497 498 499 500
#
# Utils for creating type string checkers
#
def mangle_dtype_name(dtype):
    # Use prefixes to seperate user defined types from builtins
    # (consider "typedef float unsigned_int")
501 502 503 504
    if dtype.is_pyobject:
        return "object"
    elif dtype.is_ptr:
        return "ptr"
505
    else:
506
        if dtype.is_typedef or dtype.is_struct_or_union:
507 508 509 510
            prefix = "nn_"
        else:
            prefix = ""
        return prefix + dtype.declaration_code("").replace(" ", "_")
511

512 513 514 515 516 517 518 519 520 521 522
def get_typestringchecker(code, dtype):
    """
    Returns the name of a typestring checker with the given type; emitting
    it to code if needed.
    """
    name = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
    code.globalstate.use_code_from(create_typestringchecker,
                                   name,
                                   dtype=dtype)
    return name

523
def create_typestringchecker(protocode, defcode, name, dtype):
524 525 526

    def put_assert(cond, msg):
        defcode.putln("if (!(%s)) {" % cond)
527
        defcode.putln('PyErr_Format(PyExc_ValueError, "Buffer dtype mismatch (%s)", __Pyx_DescribeTokenInFormatString(ts));' % msg)
528 529 530
        defcode.putln("return NULL;")
        defcode.putln("}")
    
531
    if dtype.is_error: return
532
    simple = dtype.is_simple_buffer_dtype()
533 534
    complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
    # Cannot add utility code recursively...
535
    if not simple:
536
        dtype_t = dtype.declaration_code("")
537 538
        protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
        fields = dtype.scope.var_entries
539 540 541 542 543 544 545

        # divide fields into blocks of equal type (for repeat count)
        field_blocks = [] # of (n, type, checkerfunc)
        n = 0
        prevtype = None
        for f in fields:
            if n and f.type != prevtype:
546
                field_blocks.append((n, prevtype, get_typestringchecker(protocode, prevtype)))
547 548 549
                n = 0
            prevtype = f.type
            n += 1
550
        field_blocks.append((n, f.type, get_typestringchecker(protocode, f.type)))
551 552 553 554
        
    protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
    defcode.putln("static const char* %s(const char* ts) {" % name)
    if simple:
555
        defcode.putln("int ok;")
556 557
        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
        defcode.putln("if (*ts == '1') ++ts;")
558 559
        if dtype.is_pyobject:
            defcode.putln("ok = (*ts == 'O');")
560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586
        else:
            # Cannot trust declared size; but rely on int vs float and
            # signed/unsigned to be correctly declared. Use a switch statement
            # on all possible format codes to validate that the size is ok.
            # (Note that many codes may map to same size, e.g. 'i' and 'l'
            # may both be four bytes).
            ctype = dtype.declaration_code("")
            defcode.putln("switch (*ts) {")
            if dtype.is_int:
                types = [
                    ('b', 'char'), ('h', 'short'), ('i', 'int'),
                    ('l', 'long'), ('q', 'long long')
                ]
            elif dtype.is_float:
                types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
            else:
                assert False
            if dtype.signed == 0:
                for char, against in types:
                    defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(unsigned %s) && (%s)-1 > 0); break;" %
                                  (char.upper(), ctype, against, ctype))
            else:
                for char, against in types:
                    defcode.putln("case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 < 0); break;" %
                                  (char, ctype, against, ctype))
            defcode.putln("default: ok = 0;")
            defcode.putln("}")
587
        put_assert("ok", "expected %s, got %%s" % dtype)
588
        defcode.putln("++ts;")
589 590 591 592
    elif complex_possible:
        # Could be a struct representing a complex number, so allow
        # for parsing a "Zf" spec.
        real_t, imag_t = [x.type for x in fields]
593
        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
594 595 596 597 598 599 600 601
        defcode.putln("if (*ts == '1') ++ts;")
        defcode.putln("if (*ts == 'Z') {")
        if len(field_blocks) == 2:
            # Different float type, sizeof check needed
            defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
                real_t.declaration_code(""),
                imag_t.declaration_code("")))
            defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
602
                dtype, real_t, imag_t))
603
            defcode.putln("return NULL;")
604
            defcode.putln("}")
605 606 607 608 609 610 611 612 613 614 615 616 617
            check_real, check_imag = [x[2] for x in field_blocks]
        else:
            assert len(field_blocks) == 1
            check_real = check_imag = field_blocks[0][2]
        defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
        defcode.putln("} else {")
        defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
        defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
        defcode.putln("}")
    else:
        defcode.putln("int n, count;")
        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
618

619 620
        next_types = [x[1] for x in field_blocks[1:]] + ["end"]
        for (n, type, checker), next_type in zip(field_blocks, next_types):
621 622 623 624 625 626
            if n == 1:
                defcode.putln("if (*ts == '1') ++ts;")
            else:
                defcode.putln("n = %d;" % n);
                defcode.putln("do {")
                defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
627
                put_assert("n >= 0", "expected %s, got %%s" % next_type)
628 629 630

            simple = type.is_simple_buffer_dtype()
            if not simple:
631
                put_assert("*ts == 'T' && *(ts+1) == '{'", "expected %s, got %%s" % type)
632 633 634
                defcode.putln("ts += 2;")
            defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
            if not simple:
635
                put_assert("*ts == '}'", "expected end of %s struct, got %%s" % type)
636 637 638
                defcode.putln("++ts;")

            if n > 1:
639 640
                defcode.putln("} while (n > 0);");
        defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
641 642 643 644

    defcode.putln("return ts;")
    defcode.putln("}")

645
def get_getbuffer_code(dtype, code):
646 647 648 649 650 651 652 653 654 655
    """
    Generate a utility function for getting a buffer for the given dtype.
    The function will:
    - Call PyObject_GetBuffer
    - Check that ndim matched the expected value
    - Check that the format string is right
    - Set suboffsets to all -1 if it is returned as NULL.
    """

    name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
656
    if not code.globalstate.has_code(name):
657
        code.globalstate.use_utility_code(acquire_utility_code)
658
        typestringchecker = get_typestringchecker(code, dtype)
659
        dtype_name = str(dtype)
660
        dtype_cname = dtype.declaration_code("")
661
        utilcode = UtilityCode(proto = dedent("""
662
        static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
663
        """) % name, impl = dedent("""
664
        static int %(name)s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast) {
665 666 667 668 669 670
          const char* ts;
          if (obj == Py_None) {
            __Pyx_ZeroBuffer(buf);
            return 0;
          }
          buf->buf = NULL;
671
          if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
672 673 674 675
          if (buf->ndim != nd) {
            __Pyx_BufferNdimError(buf, nd);
            goto fail;
          }
676 677 678 679
          if (!cast) {
            ts = buf->format;
            ts = __Pyx_ConsumeWhitespace(ts);
            if (!ts) goto fail;
Dag Sverre Seljebotn's avatar
merge  
Dag Sverre Seljebotn committed
680
            ts = %(typestringchecker)s(ts);
681 682 683 684 685
            if (!ts) goto fail;
            ts = __Pyx_ConsumeWhitespace(ts);
            if (!ts) goto fail;
            if (*ts != 0) {
              PyErr_Format(PyExc_ValueError,
686 687
                "Buffer dtype mismatch (expected end, got %%s)",
                __Pyx_DescribeTokenInFormatString(ts));
688 689 690 691 692 693 694 695
              goto fail;
            }
          } else {
            if (buf->itemsize != sizeof(%(dtype_cname)s)) {
              PyErr_SetString(PyExc_ValueError,
                "Attempted cast of buffer to datatype of different size.");
              goto fail;
            }
696 697 698 699 700 701
          }
          if (buf->suboffsets == NULL) buf->suboffsets = __Pyx_minusones;
          return 0;
        fail:;
          __Pyx_ZeroBuffer(buf);
          return -1;
702
        }""") % locals())
703
        code.globalstate.use_utility_code(utilcode, name)
704 705 706
    return name

def use_py2_buffer_functions(env):
707 708 709
    # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
    # For >= 2.6 we do double mode -- use the new buffer interface on objects
    # which has the right tp_flags set, but emulation otherwise.
710 711 712 713
    codename = "PyObject_GetBuffer" # just a representative unique key

    # Search all types for __getbuffer__ overloads
    types = []
714
    visited_scopes = set()
715
    def find_buffer_types(scope):
716 717 718
        if scope in visited_scopes:
            return
        visited_scopes.add(scope)
719 720 721 722 723 724 725 726 727 728 729 730 731 732
        for m in scope.cimported_modules:
            find_buffer_types(m)
        for e in scope.type_entries:
            t = e.type
            if t.is_extension_type:
                release = get = None
                for x in t.scope.pyfunc_entries:
                    if x.name == u"__getbuffer__": get = x.func_cname
                    elif x.name == u"__releasebuffer__": release = x.func_cname
                if get:
                    types.append((t.typeptr_cname, get, release))

    find_buffer_types(env)

733
    code = dedent("""
734
        #if PY_MAJOR_VERSION < 3
735
        static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
736 737 738 739
          #if PY_VERSION_HEX >= 0x02060000
          if (Py_TYPE(obj)->tp_flags & Py_TPFLAGS_HAVE_NEWBUFFER)
              return PyObject_GetBuffer(obj, view, flags);
          #endif
740
    """)
741 742 743 744 745 746
    if len(types) > 0:
        clause = "if"
        for t, get, release in types:
            code += "  %s (PyObject_TypeCheck(obj, %s)) return %s(obj, view, flags);\n" % (clause, t, get)
            clause = "else if"
        code += "  else {\n"
747 748 749 750
    code += dedent("""\
        PyErr_Format(PyExc_TypeError, "'%100s' does not have the buffer interface", Py_TYPE(obj)->tp_name);
        return -1;
    """, 2)
751
    if len(types) > 0: code += "  }"
752 753
    code += dedent("""
        }
754

755 756 757
        static void __Pyx_ReleaseBuffer(Py_buffer *view) {
          PyObject* obj = view->obj;
          if (obj) {
758
    """)
759 760 761 762 763 764
    if len(types) > 0:
        clause = "if"
        for t, get, release in types:
            if release:
                code += "%s (PyObject_TypeCheck(obj, %s)) %s(obj, view);" % (clause, t, release)
                clause = "else if"
765
    code += dedent("""
766 767 768
            Py_DECREF(obj);
            view->obj = NULL;
          }
769 770 771 772 773
        }

        #endif
    """)
                   
774 775
    env.use_utility_code(UtilityCode(
            proto = dedent("""\
776
        #if PY_MAJOR_VERSION < 3
777
        static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
778
        static void __Pyx_ReleaseBuffer(Py_buffer *view);
779 780
        #else
        #define __Pyx_GetBuffer PyObject_GetBuffer
781
        #define __Pyx_ReleaseBuffer PyBuffer_Release
782
        #endif
783
    """), impl = code), codename)
784 785 786 787 788 789 790 791

#
# Static utility code
#


# Utility function to set the right exception
# The caller should immediately goto_error
792 793
raise_indexerror_code = UtilityCode(
proto = """\
794
static void __Pyx_RaiseBufferIndexError(int axis); /*proto*/
795 796
""",
impl = """\
797
static void __Pyx_RaiseBufferIndexError(int axis) {
798 799
  PyErr_Format(PyExc_IndexError,
     "Out of bounds on buffer access (axis %d)", axis);
800
}
801

802
""")
803 804 805 806 807 808 809

#
# Buffer type checking. Utility code for checking that acquired
# buffers match our assumptions. We only need to check ndim and
# the format string; the access mode/flags is checked by the
# exporter.
#
810 811
acquire_utility_code = UtilityCode(
proto = """\
812
static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info);
813 814 815
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
816
static const char* __Pyx_DescribeTokenInFormatString(const char* ts); /*proto*/
817 818
""",
impl = """
819
static INLINE void __Pyx_SafeReleaseBuffer(Py_buffer* info) {
820
  if (info->buf == NULL) return;
821
  if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
822
  __Pyx_ReleaseBuffer(info);
823 824
}

825 826
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
  buf->buf = NULL;
827
  buf->obj = NULL;
828 829 830 831 832 833 834 835
  buf->strides = __Pyx_zeros;
  buf->shape = __Pyx_zeros;
  buf->suboffsets = __Pyx_minusones;
}

static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts) {
  while (1) {
    switch (*ts) {
836
      case '@':
837 838 839 840
      case 10:
      case 13:
      case ' ':
        ++ts;
841 842 843 844 845 846 847
        break;
      case '=':
      case '<':
      case '>':
      case '!':
        PyErr_SetString(PyExc_ValueError, "Buffer acquisition error: Only native byte order, size and alignment supported.");
        return NULL;               
848 849 850 851 852 853 854 855 856 857 858 859
      default:
        return ts;
    }
  }
}

static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
  PyErr_Format(PyExc_ValueError,
               "Buffer has wrong number of dimensions (expected %d, got %d)",
               expected_ndim, buffer->ndim);
}

860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887
static const char* __Pyx_DescribeTokenInFormatString(const char* ts) {
  switch (*ts) {
    case 'b': return "char";
    case 'B': return "unsigned char";
    case 'h': return "short";
    case 'H': return "unsigned short";
    case 'i': return "int";
    case 'I': return "unsigned int";
    case 'l': return "long";
    case 'L': return "unsigned long";
    case 'q': return "long long";
    case 'Q': return "unsigned long long";
    case 'f': return "float";
    case 'd': return "double";
    case 'g': return "long double";
    case 'Z': switch (*(ts+1)) {
        case 'f': return "complex float";
        case 'd': return "complex double";
        case 'g': return "complex long double";
        default: return "unparseable format string";
    }
    case 'T': return "a struct";
    case 'O': return "Python object";
    case 'P': return "a pointer";
    default: return "unparseable format string";
  }
}

888
""")
889

890

891 892
parse_typestring_repeat_code = UtilityCode(
proto = """
893
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
894 895
""",
impl = """
896 897 898 899 900 901 902 903 904 905 906 907 908 909
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
    int count;
    if (*ts < '0' || *ts > '9') {
        count = 1;
    } else {
        count = *ts++ - '0';
        while (*ts >= '0' && *ts < '9') {
            count *= 10;
            count += *ts++ - '0';
        }
    }
    *out_count = count;
    return ts;
}
910
""")
911

912 913
raise_buffer_fallback_code = UtilityCode(
proto = """
914
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
915 916
""",
impl = """
917 918 919 920 921
static void __Pyx_RaiseBufferFallbackError(void) {
  PyErr_Format(PyExc_ValueError,
     "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
}

922
""")