Commit f968f684 authored by Stefan Behnel's avatar Stefan Behnel

merge of 0.9.8.1 beta2

parents 857e7852 e751a61a
...@@ -17,8 +17,8 @@ special_chars = [(u'<', u'\xF0', u'&lt;'), ...@@ -17,8 +17,8 @@ special_chars = [(u'<', u'\xF0', u'&lt;'),
class AnnotationCCodeWriter(CCodeWriter): class AnnotationCCodeWriter(CCodeWriter):
def __init__(self, create_from=None, buffer=None): def __init__(self, create_from=None, buffer=None, copy_formatting=True):
CCodeWriter.__init__(self, create_from, buffer) CCodeWriter.__init__(self, create_from, buffer, copy_formatting=True)
self.annotation_buffer = StringIO() self.annotation_buffer = StringIO()
if create_from is None: if create_from is None:
self.annotations = [] self.annotations = []
...@@ -30,8 +30,8 @@ class AnnotationCCodeWriter(CCodeWriter): ...@@ -30,8 +30,8 @@ class AnnotationCCodeWriter(CCodeWriter):
self.annotations = create_from.annotations self.annotations = create_from.annotations
self.code = create_from.code self.code = create_from.code
def create_new(self, create_from, buffer): def create_new(self, create_from, buffer, copy_formatting):
return AnnotationCCodeWriter(create_from, buffer) return AnnotationCCodeWriter(create_from, buffer, copy_formatting)
def write(self, s): def write(self, s):
CCodeWriter.write(self, s) CCodeWriter.write(self, s)
......
...@@ -5,6 +5,7 @@ from Cython.Compiler.ExprNodes import * ...@@ -5,6 +5,7 @@ from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError from Cython.Compiler.Errors import CompileError
import Interpreter
import PyrexTypes import PyrexTypes
try: try:
...@@ -13,6 +14,13 @@ except NameError: ...@@ -13,6 +14,13 @@ except NameError:
from sets import Set as set from sets import Set as set
import textwrap import textwrap
# Code cleanup ideas:
# - One could be more smart about casting in some places
# - Start using CCodeWriters to generate utility functions
# - Create a struct type per ndim rather than keeping loose local vars
def dedent(text, reindent=0): def dedent(text, reindent=0):
text = textwrap.dedent(text) text = textwrap.dedent(text)
if reindent > 0: if reindent > 0:
...@@ -35,7 +43,6 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -35,7 +43,6 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
if self.buffers_exists: if self.buffers_exists:
use_py2_buffer_functions(node.scope) use_py2_buffer_functions(node.scope)
use_empty_bufstruct_code(node.scope, self.max_ndim) use_empty_bufstruct_code(node.scope, self.max_ndim)
node.scope.use_utility_code(access_utility_code)
return result return result
...@@ -62,15 +69,12 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -62,15 +69,12 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
if buftype.ndim > self.max_ndim: if buftype.ndim > self.max_ndim:
self.max_ndim = buftype.ndim self.max_ndim = buftype.ndim
# Get or make a type string checker
tschecker = buffer_type_checker(buftype.dtype, scope)
# Declare auxiliary vars # Declare auxiliary vars
cname = scope.mangle(Naming.bufstruct_prefix, name) cname = scope.mangle(Naming.bufstruct_prefix, name)
bufinfo = scope.declare_var(name="$%s" % cname, cname=cname, bufinfo = scope.declare_var(name="$%s" % cname, cname=cname,
type=PyrexTypes.c_py_buffer_type, pos=node.pos) type=PyrexTypes.c_py_buffer_type, pos=node.pos)
if entry.is_arg:
bufinfo.used = True bufinfo.used = True # otherwise, NameNode will mark whether it is used
def var(prefix, idx, initval): def var(prefix, idx, initval):
cname = scope.mangle(prefix, "%d_%s" % (idx, name)) cname = scope.mangle(prefix, "%d_%s" % (idx, name))
...@@ -85,17 +89,13 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -85,17 +89,13 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)] stridevars = [var(Naming.bufstride_prefix, i, "0") for i in range(entry.type.ndim)]
shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)] shapevars = [var(Naming.bufshape_prefix, i, "0") for i in range(entry.type.ndim)]
entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, tschecker)
mode = entry.type.mode mode = entry.type.mode
if mode == 'full': if mode == 'full':
suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)] suboffsetvars = [var(Naming.bufsuboffset_prefix, i, "-1") for i in range(entry.type.ndim)]
entry.buffer_aux.lookup = get_buf_lookup_full(scope, entry.type.ndim)
elif mode == 'strided': elif mode == 'strided':
suboffsetvars = None suboffsetvars = None
entry.buffer_aux.lookup = get_buf_lookup_strided(scope, entry.type.ndim)
entry.buffer_aux.suboffsetvars = suboffsetvars entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, shapevars, suboffsetvars)
entry.buffer_aux.get_buffer_cname = tschecker
scope.buffer_entries = bufvars scope.buffer_entries = bufvars
self.scope = scope self.scope = scope
...@@ -110,7 +110,79 @@ class IntroduceBufferAuxiliaryVars(CythonTransform): ...@@ -110,7 +110,79 @@ class IntroduceBufferAuxiliaryVars(CythonTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
#
# Analysis
#
buffer_options = ("dtype", "ndim", "mode") # ordered!
buffer_defaults = {"ndim": 1, "mode": "full"}
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_MODE = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)'
ERR_BUF_NDIM = 'ndim must be a non-negative integer'
ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
"""
Must be called during type analysis, as analyse is called
on the dtype argument.
posargs and dictargs should consist of a list and a dict
of tuples (value, pos). Defaults should be a dict of values.
Returns a dict containing all the options a buffer can have and
its value (with the positions stripped).
"""
if defaults is None:
defaults = buffer_defaults
posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env)
if len(posargs) > len(buffer_options):
raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)
options = {}
for name, (value, pos) in dictargs.iteritems():
if not name in buffer_options:
raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
options[name] = value
for name, (value, pos) in zip(buffer_options, posargs):
if not name in buffer_options:
raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
if name in options:
raise CompileError(pos, ERR_BUF_DUP % name)
options[name] = value
# Check that they are all there and copy defaults
for name in buffer_options:
if not name in options:
try:
options[name] = defaults[name]
except KeyError:
if need_complete:
raise CompileError(globalpos, ERR_BUF_MISSING % name)
dtype = options.get("dtype")
if dtype and dtype.is_extension_type:
raise CompileError(globalpos, ERR_BUF_DTYPE)
ndim = options.get("ndim")
if ndim and (not isinstance(ndim, int) or ndim < 0):
raise CompileError(globalpos, ERR_BUF_NDIM)
mode = options.get("mode")
if mode and not (mode in ('full', 'strided')):
raise CompileError(globalpos, ERR_BUF_MODE)
return options
#
# Code generation
#
def get_flags(buffer_aux, buffer_type): def get_flags(buffer_aux, buffer_type):
...@@ -129,7 +201,8 @@ def used_buffer_aux_vars(entry): ...@@ -129,7 +201,8 @@ def used_buffer_aux_vars(entry):
buffer_aux.buffer_info_var.used = True buffer_aux.buffer_info_var.used = True
for s in buffer_aux.shapevars: s.used = True for s in buffer_aux.shapevars: s.used = True
for s in buffer_aux.stridevars: s.used = True for s in buffer_aux.stridevars: s.used = True
for s in buffer_aux.suboffsetvars: s.used = True if buffer_aux.suboffsetvars:
for s in buffer_aux.suboffsetvars: s.used = True
def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code): def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
# Generate code to copy the needed struct info into local # Generate code to copy the needed struct info into local
...@@ -146,24 +219,19 @@ def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code): ...@@ -146,24 +219,19 @@ def put_unpack_buffer_aux_into_scope(buffer_aux, mode, code):
(s.cname, bufstruct, field, idx) (s.cname, bufstruct, field, idx)
for idx, s in enumerate(vars)])) for idx, s in enumerate(vars)]))
def getbuffer_cond_code(obj_cname, buffer_aux, flags, ndim):
bufstruct = buffer_aux.buffer_info_var.cname
return "%s(%s, &%s, %s, %d) == -1" % (
buffer_aux.get_buffer_cname, obj_cname, bufstruct, flags, ndim)
def put_acquire_arg_buffer(entry, code, pos): def put_acquire_arg_buffer(entry, code, pos):
code.globalstate.use_utility_code(acquire_utility_code)
buffer_aux = entry.buffer_aux buffer_aux = entry.buffer_aux
cname = entry.cname getbuffer_cname = get_getbuffer_code(entry.type.dtype, code)
bufstruct = buffer_aux.buffer_info_var.cname
flags = get_flags(buffer_aux, entry.type)
# Acquire any new buffer # Acquire any new buffer
code.putln(code.error_goto_if(getbuffer_cond_code(cname, code.putln(code.error_goto_if("%s((PyObject*)%s, &%s, %s, %d) == -1" % (
buffer_aux, getbuffer_cname,
flags, entry.cname,
entry.type.ndim), entry.buffer_aux.buffer_info_var.cname,
pos)) get_flags(buffer_aux, entry.type),
entry.type.ndim), pos))
# An exception raised in arg parsing cannot be catched, so no # An exception raised in arg parsing cannot be catched, so no
# need to do care about the buffer then. # need to care about the buffer then.
put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code) put_unpack_buffer_aux_into_scope(buffer_aux, entry.type.mode, code)
#def put_release_buffer_normal(entry, code): #def put_release_buffer_normal(entry, code):
...@@ -173,8 +241,7 @@ def put_acquire_arg_buffer(entry, code, pos): ...@@ -173,8 +241,7 @@ def put_acquire_arg_buffer(entry, code, pos):
# entry.buffer_aux.buffer_info_var.cname)) # entry.buffer_aux.buffer_info_var.cname))
def get_release_buffer_code(entry): def get_release_buffer_code(entry):
return "if (%s != Py_None) PyObject_ReleaseBuffer(%s, &%s)" % ( return "__Pyx_SafeReleaseBuffer((PyObject*)%s, &%s)" % (
entry.cname,
entry.cname, entry.cname,
entry.buffer_aux.buffer_info_var.cname) entry.buffer_aux.buffer_info_var.cname)
...@@ -193,25 +260,23 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, ...@@ -193,25 +260,23 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
- Old buffer released, new acquired which fails, reaqcuire old lhs buffer - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
(which may or may not succeed). (which may or may not succeed).
""" """
code.globalstate.use_utility_code(acquire_utility_code)
bufstruct = buffer_aux.buffer_info_var.cname bufstruct = buffer_aux.buffer_info_var.cname
flags = get_flags(buffer_aux, buffer_type) flags = get_flags(buffer_aux, buffer_type)
getbuffer = "%s(%%s, &%s, %s, %d)" % (buffer_aux.get_buffer_cname, getbuffer = "%s((PyObject*)%%s, &%s, %s, %d)" % (get_getbuffer_code(buffer_type.dtype, code),
# note: object is filled in later # note: object is filled in later (%%s)
bufstruct, bufstruct,
flags, flags,
buffer_type.ndim) buffer_type.ndim)
if is_initialized: if is_initialized:
# Release any existing buffer # Release any existing buffer
code.put('if (%s != Py_None) ' % lhs_cname) code.putln('__Pyx_SafeReleaseBuffer((PyObject*)%s, &%s);' % (
code.begin_block();
code.putln('PyObject_ReleaseBuffer(%s, &%s);' % (
lhs_cname, bufstruct)) lhs_cname, bufstruct))
code.end_block()
# Acquire # Acquire
retcode_cname = code.func.allocate_temp(PyrexTypes.c_int_type) retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname)) code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname))) code.putln('if (%s) ' % (code.unlikely("%s < 0" % retcode_cname)))
# If acquisition failed, attempt to reacquire the old buffer # If acquisition failed, attempt to reacquire the old buffer
...@@ -219,29 +284,33 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, ...@@ -219,29 +284,33 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
# will cause the reacquisition exception to be reported, one # will cause the reacquisition exception to be reported, one
# can consider working around this later. # can consider working around this later.
code.begin_block() code.begin_block()
type, value, tb = [code.func.allocate_temp(PyrexTypes.py_object_type) type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type)
for i in range(3)] for i in range(3)]
code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb)) code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname))) code.put('if (%s) ' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
code.begin_block() code.begin_block()
code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb)) code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb))
code.globalstate.use_utility_code(raise_buffer_fallback_code)
code.putln('__Pyx_RaiseBufferFallbackError();') code.putln('__Pyx_RaiseBufferFallbackError();')
code.putln('} else {') code.putln('} else {')
code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb)) code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
for t in (type, value, tb): for t in (type, value, tb):
code.func.release_temp(t) code.funcstate.release_temp(t)
code.end_block() code.end_block()
# Unpack indices # Unpack indices
code.end_block() code.end_block()
put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code) put_unpack_buffer_aux_into_scope(buffer_aux, buffer_type.mode, code)
code.putln(code.error_goto_if_neg(retcode_cname, pos)) code.putln(code.error_goto_if_neg(retcode_cname, pos))
code.func.release_temp(retcode_cname) code.funcstate.release_temp(retcode_cname)
else: else:
# Our entry had no previous value, so set to None when acquisition fails. # Our entry had no previous value, so set to None when acquisition fails.
# In this case, auxiliary vars should be set up right in initialization to a zero-buffer, # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
# so it suffices to set the buf field to NULL. # so it suffices to set the buf field to NULL.
code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname))) code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
code.putln('%s = Py_None; Py_INCREF(Py_None); %s.buf = NULL;' % (lhs_cname, bufstruct)) code.putln('%s = %s; Py_INCREF(Py_None); %s.buf = NULL;' %
(lhs_cname,
PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
bufstruct))
code.putln(code.error_goto(pos)) code.putln(code.error_goto(pos))
code.put('} else {') code.put('} else {')
# Unpack indices # Unpack indices
...@@ -249,65 +318,79 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type, ...@@ -249,65 +318,79 @@ def put_assign_to_buffer(lhs_cname, rhs_cname, buffer_aux, buffer_type,
code.putln('}') code.putln('}')
def put_access(entry, index_signeds, index_cnames, pos, code): def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, code):
"""Returns a c string which can be used to access the buffer """
for reading or writing. Generates code to process indices and calculate an offset into
a buffer. Returns a C string which gives a pointer which can be
read from or written to at will (it is an expression so caller should
store it in a temporary if it is used more than once).
As the bounds checking can have any number of combinations of unsigned As the bounds checking can have any number of combinations of unsigned
arguments, smart optimizations etc. we insert it directly in the function arguments, smart optimizations etc. we insert it directly in the function
body. The lookup however is delegated to a inline function that is instantiated body. The lookup however is delegated to a inline function that is instantiated
once per ndim (lookup with suboffsets tend to get quite complicated). once per ndim (lookup with suboffsets tend to get quite complicated).
""" """
bufaux = entry.buffer_aux bufaux = entry.buffer_aux
bufstruct = bufaux.buffer_info_var.cname bufstruct = bufaux.buffer_info_var.cname
# Check bounds and fix negative indices
boundscheck = True if options['boundscheck']:
nonegs = True # Check bounds and fix negative indices.
tmp_cname = code.func.allocate_temp(PyrexTypes.c_int_type) # We allocate a temporary which is initialized to -1, meaning OK (!).
if boundscheck: # If an error occurs, the temp is set to the dimension index the
# error is occuring at.
tmp_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type)
code.putln("%s = -1;" % tmp_cname) code.putln("%s = -1;" % tmp_cname)
for idx, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames, for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames,
bufaux.shapevars)): bufaux.shapevars)):
if signed != 0: if signed != 0:
nonegs = False # not unsigned, deal with negative index
# not unsigned, deal with negative index code.putln("if (%s < 0) {" % cname)
code.putln("if (%s < 0) {" % cname) code.putln("%s += %s;" % (cname, shape.cname))
code.putln("%s += %s;" % (cname, shape.cname))
if boundscheck:
code.putln("if (%s) %s = %d;" % ( code.putln("if (%s) %s = %d;" % (
code.unlikely("%s < 0" % cname), tmp_cname, idx)) code.unlikely("%s < 0" % cname), tmp_cname, dim))
code.put("} else ") code.put("} else ")
else:
if idx > 0: code.put("else ")
if boundscheck:
# check bounds in positive direction # check bounds in positive direction
code.putln("if (%s) %s = %d;" % ( code.putln("if (%s) %s = %d;" % (
code.unlikely("%s >= %s" % (cname, shape.cname)), code.unlikely("%s >= %s" % (cname, shape.cname)),
tmp_cname, idx)) tmp_cname, dim))
if boundscheck: code.globalstate.use_utility_code(raise_indexerror_code)
code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname)) code.put("if (%s) " % code.unlikely("%s != -1" % tmp_cname))
code.begin_block() code.begin_block()
code.putln('__Pyx_BufferIndexError(%s);' % tmp_cname) code.putln('__Pyx_RaiseBufferIndexError(%s);' % tmp_cname)
code.putln(code.error_goto(pos)) code.putln(code.error_goto(pos))
code.end_block() code.end_block()
code.func.release_temp(tmp_cname) code.funcstate.release_temp(tmp_cname)
else:
# Only fix negative indices.
for signed, cname, shape in zip(index_signeds, index_cnames,
bufaux.shapevars):
if signed != 0:
code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape.cname))
# Create buffer lookup and return it # Create buffer lookup and return it
params = [] params = []
nd = entry.type.ndim
if entry.type.mode == 'full': if entry.type.mode == 'full':
for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars): for i, s, o in zip(index_cnames, bufaux.stridevars, bufaux.suboffsetvars):
params.append(i) params.append(i)
params.append(s.cname) params.append(s.cname)
params.append(o.cname) params.append(o.cname)
funcname = "__Pyx_BufPtrFull%dd" % nd
funcgen = buf_lookup_full_code
else: else:
for i, s in zip(index_cnames, bufaux.stridevars): for i, s in zip(index_cnames, bufaux.stridevars):
params.append(i) params.append(i)
params.append(s.cname) params.append(s.cname)
ptrcode = "%s(%s.buf, %s)" % (bufaux.lookup, bufstruct, funcname = "__Pyx_BufPtrStrided%dd" % nd
", ".join(params)) funcgen = buf_lookup_strided_code
valuecode = "*%s" % entry.type.buffer_ptr_type.cast_code(ptrcode)
return valuecode # Make sure the utility code is available
code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd)
ptrcode = "%s(%s.buf, %s)" % (funcname, bufstruct, ", ".join(params))
return entry.type.buffer_ptr_type.cast_code(ptrcode)
def use_empty_bufstruct_code(env, max_ndim): def use_empty_bufstruct_code(env, max_ndim):
...@@ -315,54 +398,35 @@ def use_empty_bufstruct_code(env, max_ndim): ...@@ -315,54 +398,35 @@ def use_empty_bufstruct_code(env, max_ndim):
Py_ssize_t __Pyx_zeros[] = {%s}; Py_ssize_t __Pyx_zeros[] = {%s};
Py_ssize_t __Pyx_minusones[] = {%s}; Py_ssize_t __Pyx_minusones[] = {%s};
""") % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim)) """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
env.use_utility_code([code, ""]) env.use_utility_code([code, ""], "empty_bufstruct_code")
def get_buf_lookup_strided(env, nd): def buf_lookup_strided_code(proto, defin, name, nd):
""" """
Generates and registers as utility a buffer lookup function for the right number Generates a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location. of dimensions. The function gives back a void* at the right location.
""" """
name = "__Pyx_BufPtrStrided_%dd" % nd # _i_ndex, _s_tride
if not env.has_utility_code(name): args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
# _i_ndex, _s_tride offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)]) proto.putln("#define %s(buf, %s) ((char*)buf + %s)" % (name, args, offset))
offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
proto = dedent("""\
#define %s(buf, %s) ((char*)buf + %s)
""") % (name, args, offset)
env.use_utility_code([proto, ""], name=name)
return name
def get_buf_lookup_full(env, nd): def buf_lookup_full_code(proto, defin, name, nd):
""" """
Generates and registers as utility a buffer lookup function for the right number Generates a buffer lookup function for the right number
of dimensions. The function gives back a void* at the right location. of dimensions. The function gives back a void* at the right location.
""" """
name = "__Pyx_BufPtrFull_%dd" % nd # _i_ndex, _s_tride, sub_o_ffset
if not env.has_utility_code(name): args = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
# _i_ndex, _s_tride, sub_o_ffset proto.putln("static INLINE void* %s(void* buf, %s);" % (name, args))
args = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)]) defin.putln(dedent("""
proto = dedent("""\
static INLINE void* %s(void* buf, %s);
""") % (name, args)
func = dedent("""
static INLINE void* %s(void* buf, %s) { static INLINE void* %s(void* buf, %s) {
char* ptr = (char*)buf; char* ptr = (char*)buf;
""") % (name, args) + "".join([dedent("""\ """) % (name, args) + "".join([dedent("""\
ptr += s%d * i%d; ptr += s%d * i%d;
if (o%d >= 0) ptr = *((char**)ptr) + o%d; if (o%d >= 0) ptr = *((char**)ptr) + o%d;
""") % (i, i, i, i) for i in range(nd)] """) % (i, i, i, i) for i in range(nd)]
) + "\nreturn ptr;\n}" ) + "\nreturn ptr;\n}")
env.use_utility_code([proto, func], name=name)
return name
# #
...@@ -371,21 +435,27 @@ def get_buf_lookup_full(env, nd): ...@@ -371,21 +435,27 @@ def get_buf_lookup_full(env, nd):
def mangle_dtype_name(dtype): def mangle_dtype_name(dtype):
# Use prefixes to seperate user defined types from builtins # Use prefixes to seperate user defined types from builtins
# (consider "typedef float unsigned_int") # (consider "typedef float unsigned_int")
if dtype.typestring is None: if dtype.is_pyobject:
prefix = "nn_" return "object"
elif dtype.is_ptr:
return "ptr"
else: else:
prefix = "" if dtype.typestring is None:
return prefix + dtype.declaration_code("").replace(" ", "_") prefix = "nn_"
else:
prefix = ""
return prefix + dtype.declaration_code("").replace(" ", "_")
def get_ts_check_item(dtype, env): def get_ts_check_item(dtype, writer):
# See if we can consume one (unnamed) dtype as next item # See if we can consume one (unnamed) dtype as next item
# Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...) # Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...)
name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype) name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype)
if not env.has_utility_code(name): if not writer.globalstate.has_utility_code(name):
char = dtype.typestring char = dtype.typestring
if char is not None: if char is not None:
# Can use direct comparison # Can use direct comparison
code = dedent("""\ code = dedent("""\
if (*ts == '1') ++ts;
if (*ts != '%s') { if (*ts != '%s') {
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts); PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (rejecting on '%%s')", ts);
return NULL; return NULL;
...@@ -397,6 +467,7 @@ def get_ts_check_item(dtype, env): ...@@ -397,6 +467,7 @@ def get_ts_check_item(dtype, env):
ctype = dtype.declaration_code("") ctype = dtype.declaration_code("")
code = dedent("""\ code = dedent("""\
int ok; int ok;
if (*ts == '1') ++ts;
switch (*ts) {""", 2) switch (*ts) {""", 2)
if dtype.is_int: if dtype.is_int:
types = [ types = [
...@@ -417,7 +488,7 @@ def get_ts_check_item(dtype, env): ...@@ -417,7 +488,7 @@ def get_ts_check_item(dtype, env):
return NULL; return NULL;
} else return ts + 1; } else return ts + 1;
""", 2) """, 2)
env.use_utility_code([dedent("""\ writer.globalstate.use_utility_code([dedent("""\
static const char* %s(const char* ts); /*proto*/ static const char* %s(const char* ts); /*proto*/
""") % name, dedent(""" """) % name, dedent("""
static const char* %s(const char* ts) { static const char* %s(const char* ts) {
...@@ -427,7 +498,7 @@ def get_ts_check_item(dtype, env): ...@@ -427,7 +498,7 @@ def get_ts_check_item(dtype, env):
return name return name
def get_getbuffer_code(dtype, env): def get_getbuffer_code(dtype, code):
""" """
Generate a utility function for getting a buffer for the given dtype. Generate a utility function for getting a buffer for the given dtype.
The function will: The function will:
...@@ -438,9 +509,9 @@ def get_getbuffer_code(dtype, env): ...@@ -438,9 +509,9 @@ def get_getbuffer_code(dtype, env):
""" """
name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype) name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
if not env.has_utility_code(name): if not code.globalstate.has_utility_code(name):
env.use_utility_code(acquire_utility_code) code.globalstate.use_utility_code(acquire_utility_code)
itemchecker = get_ts_check_item(dtype, env) itemchecker = get_ts_check_item(dtype, code)
utilcode = [dedent(""" utilcode = [dedent("""
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/ static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd); /*proto*/
""") % name, dedent(""" """) % name, dedent("""
...@@ -451,7 +522,7 @@ def get_getbuffer_code(dtype, env): ...@@ -451,7 +522,7 @@ def get_getbuffer_code(dtype, env):
return 0; return 0;
} }
buf->buf = NULL; buf->buf = NULL;
if (PyObject_GetBuffer(obj, buf, flags) == -1) goto fail; if (__Pyx_GetBuffer(obj, buf, flags) == -1) goto fail;
if (buf->ndim != nd) { if (buf->ndim != nd) {
__Pyx_BufferNdimError(buf, nd); __Pyx_BufferNdimError(buf, nd);
goto fail; goto fail;
...@@ -475,72 +546,21 @@ def get_getbuffer_code(dtype, env): ...@@ -475,72 +546,21 @@ def get_getbuffer_code(dtype, env):
__Pyx_ZeroBuffer(buf); __Pyx_ZeroBuffer(buf);
return -1; return -1;
}""") % locals()] }""") % locals()]
env.use_utility_code(utilcode, name) code.globalstate.use_utility_code(utilcode, name)
return name return name
def buffer_type_checker(dtype, env): def buffer_type_checker(dtype, code):
# Creates a type checker function for the given type. # Creates a type checker function for the given type.
if dtype.is_struct_or_union: if dtype.is_struct_or_union:
assert False assert False
elif dtype.is_int or dtype.is_float: elif dtype.is_int or dtype.is_float:
# This includes simple typedef-ed types # This includes simple typedef-ed types
funcname = get_getbuffer_code(dtype, env) funcname = get_getbuffer_code(dtype, code)
else: else:
assert False assert False
return funcname return funcname
def use_py2_buffer_functions(env): def use_py2_buffer_functions(env):
# will be refactored
try:
env.entries[u'numpy']
env.use_utility_code(["","""
static int numpy_getbuffer(PyObject *obj, Py_buffer *view, int flags) {
/* This function is always called after a type-check; safe to cast */
PyArrayObject *arr = (PyArrayObject*)obj;
PyArray_Descr *type = (PyArray_Descr*)arr->descr;
int typenum = PyArray_TYPE(obj);
if (!PyTypeNum_ISNUMBER(typenum)) {
PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
return -1;
}
/*
NumPy format codes doesn't completely match buffer codes;
seems safest to retranslate.
01234567890123456789012345*/
const char* base_codes = "?bBhHiIlLqQfdgfdgO";
char* format = (char*)malloc(4);
char* fp = format;
*fp++ = type->byteorder;
if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
*fp++ = base_codes[typenum];
*fp = 0;
view->buf = arr->data;
view->readonly = !PyArray_ISWRITEABLE(obj);
view->ndim = PyArray_NDIM(arr);
view->strides = PyArray_STRIDES(arr);
view->shape = PyArray_DIMS(arr);
view->suboffsets = NULL;
view->format = format;
view->itemsize = type->elsize;
view->internal = 0;
return 0;
}
static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
free((char*)view->format);
view->format = NULL;
}
"""])
except KeyError:
pass
codename = "PyObject_GetBuffer" # just a representative unique key codename = "PyObject_GetBuffer" # just a representative unique key
# Search all types for __getbuffer__ overloads # Search all types for __getbuffer__ overloads
...@@ -560,16 +580,9 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -560,16 +580,9 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
find_buffer_types(env) find_buffer_types(env)
# For now, hard-code numpy imported as "numpy"
try:
ndarrtype = env.entries[u'numpy'].as_module.entries['ndarray'].type
types.append((ndarrtype.typeptr_cname, "numpy_getbuffer", "numpy_releasebuffer"))
except KeyError:
pass
code = dedent(""" code = dedent("""
#if PY_VERSION_HEX < 0x02060000 #if (PY_MAJOR_VERSION < 3) && !(Py_TPFLAGS_DEFAULT & Py_TPFLAGS_HAVE_NEWBUFFER)
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags) { static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags) {
""") """)
if len(types) > 0: if len(types) > 0:
clause = "if" clause = "if"
...@@ -585,7 +598,7 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -585,7 +598,7 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
code += dedent(""" code += dedent("""
} }
static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view) { static void __Pyx_ReleaseBuffer(PyObject *obj, Py_buffer *view) {
""") """)
if len(types) > 0: if len(types) > 0:
clause = "if" clause = "if"
...@@ -600,11 +613,14 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -600,11 +613,14 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
""") """)
env.use_utility_code([dedent("""\ env.use_utility_code([dedent("""\
#if PY_VERSION_HEX < 0x02060000 #if (PY_MAJOR_VERSION < 3) && !(Py_TPFLAGS_DEFAULT & Py_TPFLAGS_HAVE_NEWBUFFER)
static int PyObject_GetBuffer(PyObject *obj, Py_buffer *view, int flags); static int __Pyx_GetBuffer(PyObject *obj, Py_buffer *view, int flags);
static void PyObject_ReleaseBuffer(PyObject *obj, Py_buffer *view); static void __Pyx_ReleaseBuffer(PyObject *obj, Py_buffer *view);
#else
#define __Pyx_GetBuffer PyObject_GetBuffer
#define __Pyx_ReleaseBuffer PyObject_ReleaseBuffer
#endif #endif
""") ,code], codename) """), code], codename)
# #
# Static utility code # Static utility code
...@@ -613,11 +629,11 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) { ...@@ -613,11 +629,11 @@ static void numpy_releasebuffer(PyObject *obj, Py_buffer *view) {
# Utility function to set the right exception # Utility function to set the right exception
# The caller should immediately goto_error # The caller should immediately goto_error
access_utility_code = [ raise_indexerror_code = [
"""\ """\
static void __Pyx_BufferIndexError(int axis); /*proto*/ static void __Pyx_RaiseBufferIndexError(int axis); /*proto*/
""","""\ ""","""\
static void __Pyx_BufferIndexError(int axis) { static void __Pyx_RaiseBufferIndexError(int axis) {
PyErr_Format(PyExc_IndexError, PyErr_Format(PyExc_IndexError,
"Out of bounds on buffer access (axis %d)", axis); "Out of bounds on buffer access (axis %d)", axis);
} }
...@@ -631,12 +647,18 @@ static void __Pyx_BufferIndexError(int axis) { ...@@ -631,12 +647,18 @@ static void __Pyx_BufferIndexError(int axis) {
# exporter. # exporter.
# #
acquire_utility_code = ["""\ acquire_utility_code = ["""\
static INLINE void __Pyx_SafeReleaseBuffer(PyObject* obj, Py_buffer* info);
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/ static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf); /*proto*/
static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/ static INLINE const char* __Pyx_ConsumeWhitespace(const char* ts); /*proto*/
static INLINE const char* __Pyx_BufferTypestringCheckEndian(const char* ts); /*proto*/ static INLINE const char* __Pyx_BufferTypestringCheckEndian(const char* ts); /*proto*/
static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim); /*proto*/
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
""", """ """, """
static INLINE void __Pyx_SafeReleaseBuffer(PyObject* obj, Py_buffer* info) {
if (info->buf == NULL) return;
if (info->suboffsets == __Pyx_minusones) info->suboffsets = NULL;
__Pyx_ReleaseBuffer(obj, info);
}
static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) { static INLINE void __Pyx_ZeroBuffer(Py_buffer* buf) {
buf->buf = NULL; buf->buf = NULL;
buf->strides = __Pyx_zeros; buf->strides = __Pyx_zeros;
...@@ -688,9 +710,15 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) { ...@@ -688,9 +710,15 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
expected_ndim, buffer->ndim); expected_ndim, buffer->ndim);
} }
"""]
raise_buffer_fallback_code = ["""
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
""","""
static void __Pyx_RaiseBufferFallbackError(void) { static void __Pyx_RaiseBufferFallbackError(void) {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!"); "Buffer acquisition failed on assignment; and then reacquiring the old buffer failed too!");
} }
"""] """]
...@@ -180,9 +180,10 @@ def init_builtins(): ...@@ -180,9 +180,10 @@ def init_builtins():
init_builtin_funcs() init_builtin_funcs()
init_builtin_types() init_builtin_types()
init_builtin_structs() init_builtin_structs()
global list_type, tuple_type, dict_type global list_type, tuple_type, dict_type, unicode_type
list_type = builtin_scope.lookup('list').type list_type = builtin_scope.lookup('list').type
tuple_type = builtin_scope.lookup('tuple').type tuple_type = builtin_scope.lookup('tuple').type
dict_type = builtin_scope.lookup('dict').type dict_type = builtin_scope.lookup('dict').type
unicode_type = builtin_scope.lookup('unicode').type
init_builtins() init_builtins()
...@@ -38,6 +38,7 @@ Options: ...@@ -38,6 +38,7 @@ Options:
-a, --annotate Produce a colorized HTML version of the source. -a, --annotate Produce a colorized HTML version of the source.
--convert-range Convert for loops using range() function to for...from loops. --convert-range Convert for loops using range() function to for...from loops.
--cplus Output a c++ rather than c file. --cplus Output a c++ rather than c file.
-O, --option <name>=<value>[,<name=value,...] Overrides an optimization/code generation option
""" """
#The following experimental options are supported only on MacOSX: #The following experimental options are supported only on MacOSX:
# -C, --compile Compile generated .c file to .o file # -C, --compile Compile generated .c file to .o file
...@@ -45,37 +46,12 @@ Options: ...@@ -45,37 +46,12 @@ Options:
# -+, --cplus Use C++ compiler for compiling and linking # -+, --cplus Use C++ compiler for compiling and linking
# Additional .o files to link may be supplied when using -X.""" # Additional .o files to link may be supplied when using -X."""
#The following options are very experimental and is used for plugging in code
#into different transform stages.
# -T phase:factory At the phase given, hand off the tree to the transform returned
# when calling factory without arguments. Factory should be fully
# specified (ie Module.SubModule.factory) and the containing module
# will be imported. This option can be repeated to add more transforms,
# transforms for the same phase will be used in the order they are given.
def bad_usage(): def bad_usage():
sys.stderr.write(usage) sys.stderr.write(usage)
sys.exit(1) sys.exit(1)
def parse_command_line(args): def parse_command_line(args):
def parse_add_transform(transforms, param):
from Main import PHASES
def import_symbol(fqn):
modsplitpt = fqn.rfind(".")
if modsplitpt == -1: bad_usage()
modulename = fqn[:modsplitpt]
symbolname = fqn[modsplitpt+1:]
module = __import__(modulename, globals(), locals(), [symbolname])
return getattr(module, symbolname)
stagename, factoryname = param.split(":")
if not stagename in PHASES:
bad_usage()
factory = import_symbol(factoryname)
transform = factory()
transforms[stagename].append(transform)
from Cython.Compiler.Main import \ from Cython.Compiler.Main import \
CompilationOptions, default_options CompilationOptions, default_options
...@@ -138,9 +114,12 @@ def parse_command_line(args): ...@@ -138,9 +114,12 @@ def parse_command_line(args):
Options.annotate = True Options.annotate = True
elif option == "--convert-range": elif option == "--convert-range":
Options.convert_range = True Options.convert_range = True
elif option.startswith("-T"): elif option in ("-O", "--option"):
parse_add_transform(options.transforms, get_param(option)) try:
# Note: this can occur multiple times, each time appends options.pragma_overrides = Options.parse_option_list(pop_arg())
except ValueError, e:
sys.stderr.write("Error in option string: %s\n" % e.message)
sys.exit(1)
else: else:
bad_usage() bad_usage()
else: else:
......
...@@ -15,7 +15,16 @@ try: ...@@ -15,7 +15,16 @@ try:
except NameError: except NameError:
from sets import Set as set from sets import Set as set
class FunctionContext(object): class FunctionState(object):
# return_label string function return point label
# error_label string error catch point label
# continue_label string loop continue point label
# break_label string loop break point label
# return_from_error_cleanup_label string
# label_counter integer counter for naming labels
# in_try_finally boolean inside try of try...finally
# exc_vars (string * 3) exception variables for reraise, or None
# Not used for now, perhaps later # Not used for now, perhaps later
def __init__(self, names_taken=set()): def __init__(self, names_taken=set()):
self.names_taken = names_taken self.names_taken = names_taken
...@@ -28,6 +37,9 @@ class FunctionContext(object): ...@@ -28,6 +37,9 @@ class FunctionContext(object):
self.continue_label = None self.continue_label = None
self.break_label = None self.break_label = None
self.in_try_finally = 0
self.exc_vars = None
self.temps_allocated = [] # of (name, type) self.temps_allocated = [] # of (name, type)
self.temps_free = {} # type -> list of free vars self.temps_free = {} # type -> list of free vars
self.temps_used_type = {} # name -> type self.temps_used_type = {} # name -> type
...@@ -121,14 +133,259 @@ class FunctionContext(object): ...@@ -121,14 +133,259 @@ class FunctionContext(object):
freelist = self.temps_free.get(type) freelist = self.temps_free.get(type)
if freelist is None: if freelist is None:
freelist = [] freelist = []
self.temps_free[type] = freelist self.temps_free[type] = freelist
freelist.append(name) freelist.append(name)
class GlobalState(object):
# filename_table {string : int} for finding filename table indexes
# filename_list [string] filenames in filename table order
# input_file_contents dict contents (=list of lines) of any file that was used as input
# to create this output C code. This is
# used to annotate the comments.
#
# used_utility_code set(string|int) Ids of used utility code (to avoid reinsertion)
# utilprotowriter CCodeWriter
# utildefwriter CCodeWriter
#
# declared_cnames {string:Entry} used in a transition phase to merge pxd-declared
# constants etc. into the pyx-declared ones (i.e,
# check if constants are already added).
# In time, hopefully the literals etc. will be
# supplied directly instead.
# interned_strings
# consts
# py_string_decls
# interned_nums
# cached_builtins
def __init__(self, rootwriter):
self.filename_table = {}
self.filename_list = []
self.input_file_contents = {}
self.used_utility_code = set()
self.declared_cnames = {}
self.pystring_table_needed = False
def initwriters(self, rootwriter):
self.utilprotowriter = rootwriter.new_writer()
self.utildefwriter = rootwriter.new_writer()
self.decls_writer = rootwriter.new_writer()
self.pystring_table = rootwriter.new_writer()
self.init_cached_builtins_writer = rootwriter.new_writer()
self.initwriter = rootwriter.new_writer()
if Options.cache_builtins:
self.init_cached_builtins_writer.enter_cfunc_scope()
self.init_cached_builtins_writer.putln("static int __Pyx_InitCachedBuiltins(void) {")
self.initwriter.enter_cfunc_scope()
self.initwriter.putln("").putln("static int __Pyx_InitGlobals(void) {")
(self.pystring_table
.putln("")
.putln("static __Pyx_StringTabEntry %s[] = {" %
Naming.stringtab_cname)
)
#
# Global constants, interned objects, etc.
#
def insert_global_var_declarations_into(self, code):
code.insert(self.decls_writer)
def close_global_decls(self):
# This is called when it is known that no more global declarations will
# declared (but can be called before or after insert_XXX).
if self.pystring_table_needed:
self.pystring_table.putln("{0, 0, 0, 0, 0, 0}").putln("};")
import Nodes
self.use_utility_code(Nodes.init_string_tab_utility_code)
self.initwriter.putln(
"if (__Pyx_InitStrings(%s) < 0) %s;" % (
Naming.stringtab_cname,
self.initwriter.error_goto(self.module_pos)))
if Options.cache_builtins:
(self.init_cached_builtins_writer
.putln("return 0;")
.put_label(self.init_cached_builtins_writer.error_label)
.putln("return -1;")
.putln("}")
.exit_cfunc_scope()
)
(self.initwriter
.putln("return 0;")
.put_label(self.initwriter.error_label)
.putln("return -1;")
.putln("}")
.exit_cfunc_scope()
)
def insert_initcode_into(self, code):
if self.pystring_table_needed:
code.insert(self.pystring_table)
if Options.cache_builtins:
code.insert(self.init_cached_builtins_writer)
code.insert(self.initwriter)
def put_pyobject_decl(self, entry):
self.decls_writer.putln("static PyObject *%s;" % entry.cname)
# The functions below are there in a transition phase only
# and will be deprecated. They are called from Nodes.BlockNode.
# The copy&paste duplication is intentional in order to be able
# to see quickly how BlockNode worked, until this is replaced.
def should_declare(self, cname, entry):
if cname in self.declared_cnames:
other = self.declared_cnames[cname]
assert entry.type == other.type
assert entry.init == other.init
return False
else:
self.declared_cnames[cname] = entry
return True
def add_const_definition(self, entry):
if self.should_declare(entry.cname, entry):
self.decls_writer.put_var_declaration(entry, static = 1)
def add_interned_string_decl(self, entry):
if self.should_declare(entry.cname, entry):
self.decls_writer.put_var_declaration(entry, static = 1)
self.add_py_string_decl(entry)
def add_py_string_decl(self, entry):
if self.should_declare(entry.pystring_cname, entry):
self.decls_writer.putln("static PyObject *%s;" % entry.pystring_cname)
self.pystring_table_needed = True
self.pystring_table.putln("{&%s, %s, sizeof(%s), %d, %d, %d}," % (
entry.pystring_cname,
entry.cname,
entry.cname,
entry.type.is_unicode,
entry.is_interned,
entry.is_identifier
))
def add_interned_num_decl(self, entry):
if self.should_declare(entry.cname, entry):
if entry.init[-1] == "L":
self.initwriter.putln('%s = PyLong_FromString("%s", 0, 0); %s;' % (
entry.cname,
entry.init,
self.initwriter.error_goto_if_null(entry.cname, self.module_pos)))
else:
self.initwriter.putln("%s = PyInt_FromLong(%s); %s;" % (
entry.cname,
entry.init,
self.initwriter.error_goto_if_null(entry.cname, self.module_pos)))
self.put_pyobject_decl(entry)
def add_cached_builtin_decl(self, entry):
if Options.cache_builtins:
if self.should_declare(entry.cname, entry):
self.put_pyobject_decl(entry)
self.init_cached_builtins_writer.putln('%s = __Pyx_GetName(%s, %s); if (!%s) %s' % (
entry.cname,
Naming.builtins_cname,
entry.interned_cname,
entry.cname,
self.init_cached_builtins_writer.error_goto(entry.pos)))
#
# File name state
#
def lookup_filename(self, filename):
try:
index = self.filename_table[filename]
except KeyError:
index = len(self.filename_list)
self.filename_list.append(filename)
self.filename_table[filename] = index
return index
def commented_file_contents(self, source_desc):
try:
return self.input_file_contents[source_desc]
except KeyError:
F = [u' * ' + line.rstrip().replace(
u'*/', u'*[inserted by cython to avoid comment closer]/'
).encode('ASCII', 'replace') # + Py2 auto-decode to unicode
for line in source_desc.get_lines()]
if len(F) == 0: F.append(u'')
self.input_file_contents[source_desc] = F
return F
#
# Utility code state
#
def use_utility_code(self, codetup, name=None):
"""
Adds the given utility code to the C file if needed.
codetup should unpack into one prototype code part and one
definition code part, both strings inserted directly in C.
If name is provided, it is used as an identifier to avoid inserting
code twice. Otherwise, id(codetup) is used as such an identifier.
"""
if name is None: name = id(codetup)
if self.check_utility_code_needed_and_register(name):
proto, _def = codetup
self.utilprotowriter.put(proto)
self.utildefwriter.put(_def)
def has_utility_code(self, name):
return name in self.used_utility_code
def use_generated_code(self, func, name, *args, **kw):
"""
Requests that the utility code that func can generate is used in the C
file. func is called like this:
func(proto, definition, name, *args, **kw)
where proto and definition are two CCodeWriter instances; the
former should have the prototype written to it and the other the definition.
The call might happen at some later point (if compiling multiple modules
into a cache for instance), and will only happen once per utility code.
name is used to identify the utility code, so that it isn't regenerated
when the same code is requested again.
"""
if self.check_utility_code_needed_and_register(name):
func(self.utilprotowriter, self.utildefwriter,
name, *args, **kw)
def check_utility_code_needed_and_register(self, name):
if name in self.used_utility_code:
return False
else:
self.used_utility_code.add(name)
return True
def put_utility_code_protos(self, writer):
writer.insert(self.utilprotowriter)
def put_utility_code_defs(self, writer):
writer.insert(self.utildefwriter)
def funccontext_property(name): def funccontext_property(name):
def get(self): def get(self):
return getattr(self.func, name) return getattr(self.funcstate, name)
def set(self, value): def set(self, value):
setattr(self.func, name, value) setattr(self.funcstate, name, value)
return property(get, set) return property(get, set)
class CCodeWriter(object): class CCodeWriter(object):
...@@ -154,52 +411,37 @@ class CCodeWriter(object): ...@@ -154,52 +411,37 @@ class CCodeWriter(object):
# level int indentation level # level int indentation level
# bol bool beginning of line? # bol bool beginning of line?
# marker string comment to emit before next line # marker string comment to emit before next line
# return_label string function return point label # funcstate FunctionState contains state local to a C function used for code
# error_label string error catch point label # generation (labels and temps state etc.)
# continue_label string loop continue point label # globalstate GlobalState contains state global for a C file (input file info,
# break_label string loop break point label # utility code, declared constants etc.)
# return_from_error_cleanup_label string
# label_counter integer counter for naming labels
# in_try_finally boolean inside try of try...finally
# filename_table {string : int} for finding filename table indexes
# filename_list [string] filenames in filename table order
# exc_vars (string * 3) exception variables for reraise, or None
# input_file_contents dict contents (=list of lines) of any file that was used as input
# to create this output C code. This is
# used to annotate the comments.
# func FunctionContext contains labels and temps context info
in_try_finally = 0 def __init__(self, create_from=None, buffer=None, copy_formatting=False):
def __init__(self, create_from=None, buffer=None):
if buffer is None: buffer = StringIOTree() if buffer is None: buffer = StringIOTree()
self.buffer = buffer self.buffer = buffer
self.marker = None self.marker = None
self.last_marker_line = 0 self.last_marker_line = 0
self.func = None
self.funcstate = None
self.level = 0
self.bol = 1
if create_from is None: if create_from is None:
# Root CCodeWriter # Root CCodeWriter
self.level = 0 self.globalstate = GlobalState(self)
self.bol = 1 self.globalstate.initwriters(self)
self.filename_table = {} # ^^^ need seperate step because this will reference self.globalstate
self.filename_list = []
self.exc_vars = None
self.input_file_contents = {}
else: else:
# Use same global state
self.globalstate = create_from.globalstate
# Clone formatting state # Clone formatting state
c = create_from if copy_formatting:
self.level = c.level self.level = create_from.level
self.bol = c.bol self.bol = create_from.bol
# Note: NOT copying but sharing instance
self.filename_table = c.filename_table def create_new(self, create_from, buffer, copy_formatting):
self.filename_list = []
self.input_file_contents = c.input_file_contents
# Leave other state alone
def create_new(self, create_from, buffer):
# polymorphic constructor -- very slightly more versatile # polymorphic constructor -- very slightly more versatile
# than using __class__ # than using __class__
return CCodeWriter(create_from, buffer) return CCodeWriter(create_from, buffer, copy_formatting)
def copyto(self, f): def copyto(self, f):
self.buffer.copyto(f) self.buffer.copyto(f)
...@@ -211,9 +453,25 @@ class CCodeWriter(object): ...@@ -211,9 +453,25 @@ class CCodeWriter(object):
self.buffer.write(s) self.buffer.write(s)
def insertion_point(self): def insertion_point(self):
other = self.create_new(create_from=self, buffer=self.buffer.insertion_point()) other = self.create_new(create_from=self, buffer=self.buffer.insertion_point(), copy_formatting=True)
return other return other
def new_writer(self):
"""
Creates a new CCodeWriter connected to the same global state, which
can later be inserted using insert.
"""
return CCodeWriter(create_from=self)
def insert(self, writer):
"""
Inserts the contents of another code writer (created with
the same global state) in the current location.
It is ok to write to the inserted writer also after insertion.
"""
assert writer.globalstate is self.globalstate
self.buffer.insert(writer.buffer)
# Properties delegated to function scope # Properties delegated to function scope
label_counter = funccontext_property("label_counter") label_counter = funccontext_property("label_counter")
...@@ -222,26 +480,26 @@ class CCodeWriter(object): ...@@ -222,26 +480,26 @@ class CCodeWriter(object):
labels_used = funccontext_property("labels_used") labels_used = funccontext_property("labels_used")
continue_label = funccontext_property("continue_label") continue_label = funccontext_property("continue_label")
break_label = funccontext_property("break_label") break_label = funccontext_property("break_label")
return_from_error_cleanup_label = funccontext_property("return_from_error_cleanup_label")
# Functions delegated to function scope # Functions delegated to function scope
def new_label(self, name=None): return self.func.new_label(name) def new_label(self, name=None): return self.funcstate.new_label(name)
def new_error_label(self): return self.func.new_error_label() def new_error_label(self): return self.funcstate.new_error_label()
def get_loop_labels(self): return self.func.get_loop_labels() def get_loop_labels(self): return self.funcstate.get_loop_labels()
def set_loop_labels(self, labels): return self.func.set_loop_labels(labels) def set_loop_labels(self, labels): return self.funcstate.set_loop_labels(labels)
def new_loop_labels(self): return self.func.new_loop_labels() def new_loop_labels(self): return self.funcstate.new_loop_labels()
def get_all_labels(self): return self.func.get_all_labels() def get_all_labels(self): return self.funcstate.get_all_labels()
def set_all_labels(self, labels): return self.func.set_all_labels(labels) def set_all_labels(self, labels): return self.funcstate.set_all_labels(labels)
def all_new_labels(self): return self.func.all_new_labels() def all_new_labels(self): return self.funcstate.all_new_labels()
def use_label(self, lbl): return self.func.use_label(lbl) def use_label(self, lbl): return self.funcstate.use_label(lbl)
def label_used(self, lbl): return self.func.label_used(lbl) def label_used(self, lbl): return self.funcstate.label_used(lbl)
def enter_cfunc_scope(self): def enter_cfunc_scope(self):
self.func = FunctionContext() self.funcstate = FunctionState()
def exit_cfunc_scope(self): def exit_cfunc_scope(self):
self.func = None self.funcstate = None
def putln(self, code = ""): def putln(self, code = ""):
if self.marker and self.bol: if self.marker and self.bol:
...@@ -250,6 +508,7 @@ class CCodeWriter(object): ...@@ -250,6 +508,7 @@ class CCodeWriter(object):
self.put(code) self.put(code)
self.write("\n"); self.write("\n");
self.bol = 1 self.bol = 1
return self
def emit_marker(self): def emit_marker(self):
self.write("\n"); self.write("\n");
...@@ -257,6 +516,7 @@ class CCodeWriter(object): ...@@ -257,6 +516,7 @@ class CCodeWriter(object):
self.write("/* %s */\n" % self.marker[1]) self.write("/* %s */\n" % self.marker[1])
self.last_marker_line = self.marker[0] self.last_marker_line = self.marker[0]
self.marker = None self.marker = None
return self
def put_safe(self, code): def put_safe(self, code):
# put code, but ignore {} # put code, but ignore {}
...@@ -279,20 +539,25 @@ class CCodeWriter(object): ...@@ -279,20 +539,25 @@ class CCodeWriter(object):
self.level += dl self.level += dl
elif fix_indent: elif fix_indent:
self.level += 1 self.level += 1
return self
def increase_indent(self): def increase_indent(self):
self.level = self.level + 1 self.level = self.level + 1
return self
def decrease_indent(self): def decrease_indent(self):
self.level = self.level - 1 self.level = self.level - 1
return self
def begin_block(self): def begin_block(self):
self.putln("{") self.putln("{")
self.increase_indent() self.increase_indent()
return self
def end_block(self): def end_block(self):
self.decrease_indent() self.decrease_indent()
self.putln("}") self.putln("}")
return self
def indent(self): def indent(self):
self.write(" " * self.level) self.write(" " * self.level)
...@@ -300,17 +565,6 @@ class CCodeWriter(object): ...@@ -300,17 +565,6 @@ class CCodeWriter(object):
def get_py_version_hex(self, pyversion): def get_py_version_hex(self, pyversion):
return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4] return "0x%02X%02X%02X%02X" % (tuple(pyversion) + (0,0,0,0))[:4]
def commented_file_contents(self, source_desc):
try:
return self.input_file_contents[source_desc]
except KeyError:
F = [u' * ' + line.rstrip().replace(
u'*/', u'*[inserted by cython to avoid comment closer]/'
).encode('ASCII', 'replace') # + Py2 auto-decode to unicode
for line in source_desc.get_lines()]
self.input_file_contents[source_desc] = F
return F
def mark_pos(self, pos): def mark_pos(self, pos):
if pos is None: if pos is None:
return return
...@@ -318,8 +572,7 @@ class CCodeWriter(object): ...@@ -318,8 +572,7 @@ class CCodeWriter(object):
if self.last_marker_line == line: if self.last_marker_line == line:
return return
assert isinstance(source_desc, SourceDescriptor) assert isinstance(source_desc, SourceDescriptor)
contents = self.commented_file_contents(source_desc) contents = self.globalstate.commented_file_contents(source_desc)
lines = contents[max(0,line-3):line] # line numbers start at 1 lines = contents[max(0,line-3):line] # line numbers start at 1
lines[-1] += u' # <<<<<<<<<<<<<<' lines[-1] += u' # <<<<<<<<<<<<<<'
lines += contents[line:line+2] lines += contents[line:line+2]
...@@ -330,12 +583,14 @@ class CCodeWriter(object): ...@@ -330,12 +583,14 @@ class CCodeWriter(object):
def put_label(self, lbl): def put_label(self, lbl):
if lbl in self.func.labels_used: if lbl in self.funcstate.labels_used:
self.putln("%s:;" % lbl) self.putln("%s:;" % lbl)
return self
def put_goto(self, lbl): def put_goto(self, lbl):
self.func.use_label(lbl) self.funcstate.use_label(lbl)
self.putln("goto %s;" % lbl) self.putln("goto %s;" % lbl)
return self
def put_var_declarations(self, entries, static = 0, dll_linkage = None, def put_var_declarations(self, entries, static = 0, dll_linkage = None,
definition = True): definition = True):
...@@ -494,8 +749,8 @@ class CCodeWriter(object): ...@@ -494,8 +749,8 @@ class CCodeWriter(object):
return cond return cond
def error_goto(self, pos): def error_goto(self, pos):
lbl = self.func.error_label lbl = self.funcstate.error_label
self.func.use_label(lbl) self.funcstate.use_label(lbl)
if Options.c_line_in_traceback: if Options.c_line_in_traceback:
cinfo = " %s = %s;" % (Naming.clineno_cname, Naming.line_c_macro) cinfo = " %s = %s;" % (Naming.clineno_cname, Naming.line_c_macro)
else: else:
...@@ -522,13 +777,7 @@ class CCodeWriter(object): ...@@ -522,13 +777,7 @@ class CCodeWriter(object):
return self.error_goto_if("PyErr_Occurred()", pos) return self.error_goto_if("PyErr_Occurred()", pos)
def lookup_filename(self, filename): def lookup_filename(self, filename):
try: return self.globalstate.lookup_filename(filename)
index = self.filename_table[filename]
except KeyError:
index = len(self.filename_list)
self.filename_list.append(filename)
self.filename_table[filename] = index
return index
class PyrexCodeWriter: class PyrexCodeWriter:
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
class ExtractPxdCode(CythonTransform):
"""
Finds nodes in a pxd file that should generate code, and
returns them in a StatListNode.
The result is a tuple (StatListNode, ModuleScope), i.e.
everything that is needed from the pxd after it is processed.
A purer approach would be to seperately compile the pxd code,
but the result would have to be slightly more sophisticated
than pure strings (functions + wanted interned strings +
wanted utility code + wanted cached objects) so for now this
approach is taken.
"""
def __call__(self, root):
self.funcs = []
self.visitchildren(root)
return (StatListNode(root.pos, stats=self.funcs), root.scope)
def visit_FuncDefNode(self, node):
self.funcs.append(node)
# Do not visit children, nested funcdefnodes will
# also be moved by this action...
return node
from Symtab import ModuleScope
from PyrexTypes import *
shape_func_type = CFuncType(
c_ptr_type(c_py_ssize_t_type),
[CFuncTypeArg("buffer", py_object_type, None)])
class CythonScope(ModuleScope):
def __init__(self, context):
ModuleScope.__init__(self, u'cython', None, context)
self.pxd_file_loaded = True
self.shape_entry = self.declare_cfunction('shape',
shape_func_type,
pos=None,
visibility='public',
cname='<error>')
def create_cython_scope(context):
return CythonScope(context)
...@@ -104,7 +104,7 @@ def report_error(err): ...@@ -104,7 +104,7 @@ def report_error(err):
def error(position, message): def error(position, message):
#print "Errors.error:", repr(position), repr(message) ### #print "Errors.error:", repr(position), repr(message) ###
err = CompileError(position, message) err = CompileError(position, message)
# if position is not None: raise Exception(err) # debug #if position is not None: raise Exception(err) # debug
report_error(err) report_error(err)
return err return err
......
...@@ -10,7 +10,7 @@ import Naming ...@@ -10,7 +10,7 @@ import Naming
from Nodes import Node from Nodes import Node
import PyrexTypes import PyrexTypes
from PyrexTypes import py_object_type, c_long_type, typecast, error_type from PyrexTypes import py_object_type, c_long_type, typecast, error_type
from Builtin import list_type, tuple_type, dict_type from Builtin import list_type, tuple_type, dict_type, unicode_type
import Symtab import Symtab
import Options import Options
from Annotate import AnnotationItem from Annotate import AnnotationItem
...@@ -708,8 +708,7 @@ class StringNode(ConstNode): ...@@ -708,8 +708,7 @@ class StringNode(ConstNode):
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type.is_int: if dst_type.is_int:
if not self.type.is_pyobject and len(self.entry.init) == 1: if not self.type.is_pyobject and len(self.entry.init) == 1:
# we use the *encoded* value here return CharNode(self.pos, value=self.value)
return CharNode(self.pos, value=self.entry.init)
else: else:
error(self.pos, "Only coerce single-character ascii strings can be used as ints.") error(self.pos, "Only coerce single-character ascii strings can be used as ints.")
return self return self
...@@ -741,7 +740,7 @@ class StringNode(ConstNode): ...@@ -741,7 +740,7 @@ class StringNode(ConstNode):
class UnicodeNode(PyConstNode): class UnicodeNode(PyConstNode):
# entry Symtab.Entry # entry Symtab.Entry
type = PyrexTypes.c_unicode_type type = unicode_type
def analyse_types(self, env): def analyse_types(self, env):
self.entry = env.add_string_const(self.value) self.entry = env.add_string_const(self.value)
...@@ -759,6 +758,9 @@ class UnicodeNode(PyConstNode): ...@@ -759,6 +758,9 @@ class UnicodeNode(PyConstNode):
# We still need to perform normal coerce_to processing on the # We still need to perform normal coerce_to processing on the
# result, because we might be coercing to an extension type, # result, because we might be coercing to an extension type,
# in which case a type test node will be needed. # in which case a type test node will be needed.
def compile_time_value(self, env):
return self.value
class IdentifierStringNode(ConstNode): class IdentifierStringNode(ConstNode):
...@@ -913,9 +915,6 @@ class NameNode(AtomicExprNode): ...@@ -913,9 +915,6 @@ class NameNode(AtomicExprNode):
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
self.entry.used = 1 self.entry.used = 1
if self.entry.type.is_buffer: if self.entry.type.is_buffer:
# Have an rhs temp just in case. All rhs I could
# think of had a single symbol result_code but better
# safe than sorry. Feel free to change this.
import Buffer import Buffer
Buffer.used_buffer_aux_vars(self.entry) Buffer.used_buffer_aux_vars(self.entry)
...@@ -992,6 +991,9 @@ class NameNode(AtomicExprNode): ...@@ -992,6 +991,9 @@ class NameNode(AtomicExprNode):
entry = self.entry entry = self.entry
if entry: if entry:
entry.used = 1 entry.used = 1
if entry.type.is_buffer:
import Buffer
Buffer.used_buffer_aux_vars(entry)
if entry.utility_code: if entry.utility_code:
env.use_utility_code(entry.utility_code) env.use_utility_code(entry.utility_code)
...@@ -1093,7 +1095,7 @@ class NameNode(AtomicExprNode): ...@@ -1093,7 +1095,7 @@ class NameNode(AtomicExprNode):
rhs.generate_post_assignment_code(code) rhs.generate_post_assignment_code(code)
def generate_acquire_buffer(self, rhs, code): def generate_acquire_buffer(self, rhs, code):
rhstmp = code.func.allocate_temp(self.entry.type) rhstmp = code.funcstate.allocate_temp(self.entry.type)
buffer_aux = self.entry.buffer_aux buffer_aux = self.entry.buffer_aux
bufstruct = buffer_aux.buffer_info_var.cname bufstruct = buffer_aux.buffer_info_var.cname
code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype()))) code.putln('%s = %s;' % (rhstmp, rhs.result_as(self.ctype())))
...@@ -1103,7 +1105,7 @@ class NameNode(AtomicExprNode): ...@@ -1103,7 +1105,7 @@ class NameNode(AtomicExprNode):
is_initialized=not self.skip_assignment_decref, is_initialized=not self.skip_assignment_decref,
pos=self.pos, code=code) pos=self.pos, code=code)
code.putln("%s = 0;" % rhstmp) code.putln("%s = 0;" % rhstmp)
code.func.release_temp(rhstmp) code.funcstate.release_temp(rhstmp)
def generate_deletion_code(self, code): def generate_deletion_code(self, code):
if self.entry is None: if self.entry is None:
...@@ -1370,6 +1372,9 @@ class IndexNode(ExprNode): ...@@ -1370,6 +1372,9 @@ class IndexNode(ExprNode):
self.is_buffer_access = False self.is_buffer_access = False
self.base.analyse_types(env) self.base.analyse_types(env)
# Handle the case where base is a literal char* (and we expect a string, not an int)
if isinstance(self.base, StringNode):
self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False skip_child_analysis = False
buffer_access = False buffer_access = False
...@@ -1392,6 +1397,7 @@ class IndexNode(ExprNode): ...@@ -1392,6 +1397,7 @@ class IndexNode(ExprNode):
self.index = None self.index = None
self.type = self.base.type.dtype self.type = self.base.type.dtype
self.is_buffer_access = True self.is_buffer_access = True
self.buffer_type = self.base.entry.type
if getting: if getting:
# we only need a temp because result_code isn't refactored to # we only need a temp because result_code isn't refactored to
...@@ -1479,8 +1485,13 @@ class IndexNode(ExprNode): ...@@ -1479,8 +1485,13 @@ class IndexNode(ExprNode):
def generate_result_code(self, code): def generate_result_code(self, code):
if self.is_buffer_access: if self.is_buffer_access:
valuecode = self.buffer_access_code(code) ptrcode = self.buffer_lookup_code(code)
code.putln("%s = %s;" % (self.result_code, valuecode)) code.putln("%s = *%s;" % (
self.result_code,
self.buffer_type.buffer_ptr_type.cast_code(ptrcode)))
# Must incref the value we pulled out.
if self.buffer_type.dtype.is_pyobject:
code.putln("Py_INCREF((PyObject*)%s);" % self.result_code)
elif self.type.is_pyobject: elif self.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
function = "__Pyx_GetItemInt" function = "__Pyx_GetItemInt"
...@@ -1518,8 +1529,26 @@ class IndexNode(ExprNode): ...@@ -1518,8 +1529,26 @@ class IndexNode(ExprNode):
def generate_assignment_code(self, rhs, code): def generate_assignment_code(self, rhs, code):
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
if self.is_buffer_access: if self.is_buffer_access:
valuecode = self.buffer_access_code(code) ptrexpr = self.buffer_lookup_code(code)
code.putln("%s = %s;" % (valuecode, rhs.result_code)) if self.buffer_type.dtype.is_pyobject:
# Must manage refcounts. Decref what is already there
# and incref what we put in.
ptr = code.funcstate.allocate_temp(self.buffer_type.buffer_ptr_type)
if rhs.is_temp:
rhs_code = code.funcstate.allocate_temp(rhs.type)
else:
rhs_code = rhs.result_code
code.putln("%s = %s;" % (ptr, ptrexpr))
code.putln("Py_DECREF(*%s); Py_INCREF(%s);" % (
ptr, rhs_code
))
code.putln("*%s = %s;" % (ptr, rhs_code))
if rhs.is_temp:
code.funcstate.release_temp(rhs_code)
code.funcstate.release_temp(ptr)
else:
# Simple case
code.putln("*%s = %s;" % (ptrexpr, rhs.result_code))
elif self.type.is_pyobject: elif self.type.is_pyobject:
self.generate_setitem_code(rhs.py_result(), code) self.generate_setitem_code(rhs.py_result(), code)
else: else:
...@@ -1546,19 +1575,18 @@ class IndexNode(ExprNode): ...@@ -1546,19 +1575,18 @@ class IndexNode(ExprNode):
code.error_goto(self.pos))) code.error_goto(self.pos)))
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def buffer_access_code(self, code): def buffer_lookup_code(self, code):
# Assign indices to temps # Assign indices to temps
index_temps = [code.func.allocate_temp(i.type) for i in self.indices] index_temps = [code.funcstate.allocate_temp(i.type) for i in self.indices]
for temp, index in zip(index_temps, self.indices): for temp, index in zip(index_temps, self.indices):
code.putln("%s = %s;" % (temp, index.result_code)) code.putln("%s = %s;" % (temp, index.result_code))
# Generate buffer access code using these temps # Generate buffer access code using these temps
import Buffer import Buffer
valuecode = Buffer.put_access(entry=self.base.entry, return Buffer.put_buffer_lookup_code(entry=self.base.entry,
index_signeds=[i.type.signed for i in self.indices], index_signeds=[i.type.signed for i in self.indices],
index_cnames=index_temps, index_cnames=index_temps,
pos=self.pos, code=code) options=self.options,
return valuecode pos=self.pos, code=code)
class SliceIndexNode(ExprNode): class SliceIndexNode(ExprNode):
# 2-element slice indexing # 2-element slice indexing
......
"""
This module deals with interpreting the parse tree as Python
would have done, in the compiler.
For now this only covers parse tree to value conversion of
compile-time values.
"""
from Nodes import *
from ExprNodes import *
from Visitor import BasicVisitor
from Errors import CompileError
class EmptyScope:
def lookup(self, name):
return None
empty_scope = EmptyScope()
def interpret_compiletime_options(optlist, optdict, type_env=None):
"""
Tries to interpret a list of compile time option nodes.
The result will be a tuple (optlist, optdict) but where
all expression nodes have been interpreted. The result is
in the form of tuples (value, pos).
optlist is a list of nodes, while optdict is a DictNode (the
result optdict is a dict)
If type_env is set, all type nodes will be analysed and the resulting
type set. Otherwise only interpretateable ExprNodes
are allowed, other nodes raises errors.
A CompileError will be raised if there are problems.
"""
def interpret(node):
if isinstance(node, CBaseTypeNode):
if type_env:
return (node.analyse(type_env), node.pos)
else:
raise CompileError(node.pos, "Type not allowed here.")
else:
return (node.compile_time_value(empty_scope), node.pos)
if optlist:
optlist = [interpret(x) for x in optlist]
if optdict:
assert isinstance(optdict, DictNode)
new_optdict = {}
for item in optdict.key_value_pairs:
new_optdict[item.key.value] = interpret(item.value)
optdict = new_optdict
return (optlist, new_optdict)
...@@ -63,10 +63,11 @@ def make_lexicon(): ...@@ -63,10 +63,11 @@ def make_lexicon():
three_oct = octdigit + octdigit + octdigit three_oct = octdigit + octdigit + octdigit
two_hex = hexdigit + hexdigit two_hex = hexdigit + hexdigit
four_hex = two_hex + two_hex four_hex = two_hex + two_hex
escapeseq = Str("\\") + (two_oct | three_oct | two_hex | escapeseq = Str("\\") + (two_oct | three_oct |
Str('u') + four_hex | Str('x') + two_hex | Str('u') + four_hex | Str('x') + two_hex |
Str('U') + four_hex + four_hex | AnyChar) Str('U') + four_hex + four_hex | AnyChar)
deco = Str("@") deco = Str("@")
bra = Any("([{") bra = Any("([{")
ket = Any(")]}") ket = Any(")]}")
...@@ -75,9 +76,12 @@ def make_lexicon(): ...@@ -75,9 +76,12 @@ def make_lexicon():
"+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=", "+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=",
"<<=", ">>=", "**=", "//=") "<<=", ">>=", "**=", "//=")
spaces = Rep1(Any(" \t\f")) spaces = Rep1(Any(" \t\f"))
comment = Str("#") + Rep(AnyBut("\n"))
escaped_newline = Str("\\\n") escaped_newline = Str("\\\n")
lineterm = Eol + Opt(Str("\n")) lineterm = Eol + Opt(Str("\n"))
comment_start = Str("#")
comment = comment_start + Rep(AnyBut("\n"))
option_comment = comment_start + Str("cython:") + Rep(AnyBut("\n"))
return Lexicon([ return Lexicon([
(name, 'IDENT'), (name, 'IDENT'),
...@@ -94,11 +98,13 @@ def make_lexicon(): ...@@ -94,11 +98,13 @@ def make_lexicon():
#(stringlit, 'STRING'), #(stringlit, 'STRING'),
(beginstring, Method('begin_string_action')), (beginstring, Method('begin_string_action')),
(option_comment, Method('option_comment')),
(comment, IGNORE), (comment, IGNORE),
(spaces, IGNORE), (spaces, IGNORE),
(escaped_newline, IGNORE), (escaped_newline, IGNORE),
State('INDENT', [ State('INDENT', [
(option_comment + lineterm, Method('option_comment')),
(Opt(spaces) + Opt(comment) + lineterm, IGNORE), (Opt(spaces) + Opt(comment) + lineterm, IGNORE),
(indentation, Method('indentation_action')), (indentation, Method('indentation_action')),
(Eof, Method('eof_action')) (Eof, Method('eof_action'))
......
...@@ -22,11 +22,32 @@ from Scanning import PyrexScanner, FileSourceDescriptor ...@@ -22,11 +22,32 @@ from Scanning import PyrexScanner, FileSourceDescriptor
from Errors import PyrexError, CompileError, error from Errors import PyrexError, CompileError, error
from Symtab import BuiltinScope, ModuleScope from Symtab import BuiltinScope, ModuleScope
from Cython import Utils from Cython import Utils
from Cython.Utils import open_new_file, replace_suffix
import CythonScope
module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$") module_name_pattern = re.compile(r"[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
verbose = 0 verbose = 0
def dumptree(t):
# For quick debugging in pipelines
print t.dump()
return t
class CompilationData:
# Bundles the information that is passed from transform to transform.
# (For now, this is only)
# While Context contains every pxd ever loaded, path information etc.,
# this only contains the data related to a single compilation pass
#
# pyx ModuleNode Main code tree of this compilation.
# pxds {string : ModuleNode} Trees for the pxds used in the pyx.
# codewriter CCodeWriter Where to output final code.
# options CompilationOptions
# result CompilationResult
pass
class Context: class Context:
# This class encapsulates the context needed for compiling # This class encapsulates the context needed for compiling
# one or more Cython implementation files along with their # one or more Cython implementation files along with their
...@@ -38,22 +59,119 @@ class Context: ...@@ -38,22 +59,119 @@ class Context:
# include_directories [string] # include_directories [string]
# future_directives [object] # future_directives [object]
def __init__(self, include_directories): def __init__(self, include_directories, pragma_overrides):
#self.modules = {"__builtin__" : BuiltinScope()} #self.modules = {"__builtin__" : BuiltinScope()}
import Builtin import Builtin, CythonScope
self.modules = {"__builtin__" : Builtin.builtin_scope} self.modules = {"__builtin__" : Builtin.builtin_scope}
self.pxds = {} self.modules["cython"] = CythonScope.create_cython_scope(self)
self.pyxs = {}
self.include_directories = include_directories self.include_directories = include_directories
self.future_directives = set() self.future_directives = set()
self.pragma_overrides = pragma_overrides
import os.path self.pxds = {} # full name -> node tree
standard_include_path = os.path.abspath( standard_include_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), '..', 'Includes')) os.path.join(os.path.dirname(__file__), '..', 'Includes'))
self.include_directories = include_directories + [standard_include_path] self.include_directories = include_directories + [standard_include_path]
def create_pipeline(self, pxd):
from Visitor import PrintTree
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import ResolveOptions
from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes
if pxd:
_check_c_classes = None
_specific_post_parse = PxdPostParse(self)
else:
_check_c_classes = check_c_classes
_specific_post_parse = None
return [
NormalizeTree(self),
PostParse(self),
_specific_post_parse,
ResolveOptions(self, self.pragma_overrides),
FlattenInListTransform(),
WithTransform(self),
DecoratorTransform(self),
AnalyseDeclarationsTransform(self),
IntroduceBufferAuxiliaryVars(self),
_check_c_classes,
AnalyseExpressionsTransform(self),
SwitchTransform(),
OptimizeRefcounting(self),
# SpecialFunctions(self),
# CreateClosureClasses(context),
]
def create_pyx_pipeline(self, options, result):
def generate_pyx_code(module_node):
module_node.process_implementation(options, result)
result.compilation_source = module_node.compilation_source
return result
def inject_pxd_code(module_node):
from textwrap import dedent
stats = module_node.body.stats
for name, (statlistnode, scope) in self.pxds.iteritems():
# Copy over function nodes to the module
# (this seems strange -- I believe the right concept is to split
# ModuleNode into a ModuleNode and a CodeGenerator, and tell that
# CodeGenerator to generate code both from the pyx and pxd ModuleNodes.
stats.append(statlistnode)
# Until utility code is moved to code generation phase everywhere,
# we need to copy it over to the main scope
module_node.scope.utility_code_list.extend(scope.utility_code_list)
return module_node
return ([
create_parse(self),
] + self.create_pipeline(pxd=False) + [
inject_pxd_code,
generate_pyx_code,
])
def create_pxd_pipeline(self, scope, module_name):
def parse_pxd(source_desc):
tree = self.parse(source_desc, scope, pxd=True,
full_module_name=module_name)
tree.scope = scope
tree.is_pxd = True
return tree
from CodeGeneration import ExtractPxdCode
# The pxd pipeline ends up with a CCodeWriter containing the
# code of the pxd, as well as a pxd scope.
return [parse_pxd] + self.create_pipeline(pxd=True) + [
ExtractPxdCode(self),
]
def process_pxd(self, source_desc, scope, module_name):
pipeline = self.create_pxd_pipeline(scope, module_name)
result = self.run_pipeline(pipeline, source_desc)
return result
def nonfatal_error(self, exc):
return Errors.report_error(exc)
def run_pipeline(self, pipeline, source):
err = None
data = source
try:
for phase in pipeline:
if phase is not None:
data = phase(data)
except CompileError, err:
# err is set
Errors.report_error(err)
return (err, data)
def find_module(self, module_name, def find_module(self, module_name,
relative_to = None, pos = None, need_pxd = 1): relative_to = None, pos = None, need_pxd = 1):
# Finds and returns the module scope corresponding to # Finds and returns the module scope corresponding to
...@@ -67,6 +185,7 @@ class Context: ...@@ -67,6 +185,7 @@ class Context:
if debug_find_module: if debug_find_module:
print("Context.find_module: module_name = %s, relative_to = %s, pos = %s, need_pxd = %s" % ( print("Context.find_module: module_name = %s, relative_to = %s, pos = %s, need_pxd = %s" % (
module_name, relative_to, pos, need_pxd)) module_name, relative_to, pos, need_pxd))
scope = None scope = None
pxd_pathname = None pxd_pathname = None
if not module_name_pattern.match(module_name): if not module_name_pattern.match(module_name):
...@@ -108,9 +227,11 @@ class Context: ...@@ -108,9 +227,11 @@ class Context:
if debug_find_module: if debug_find_module:
print("Context.find_module: Parsing %s" % pxd_pathname) print("Context.find_module: Parsing %s" % pxd_pathname)
source_desc = FileSourceDescriptor(pxd_pathname) source_desc = FileSourceDescriptor(pxd_pathname)
pxd_tree = self.parse(source_desc, scope, pxd = 1, err, result = self.process_pxd(source_desc, scope, module_name)
full_module_name = module_name) if err:
pxd_tree.analyse_declarations(scope) raise err
(pxd_codenodes, pxd_scope) = result
self.pxds[module_name] = (pxd_codenodes, pxd_scope)
except CompileError: except CompileError:
pass pass
return scope return scope
...@@ -308,15 +429,15 @@ class Context: ...@@ -308,15 +429,15 @@ class Context:
else: else:
Errors.open_listing_file(None) Errors.open_listing_file(None)
def teardown_errors(self, errors_occurred, options, result): def teardown_errors(self, err, options, result):
source_desc = result.compilation_source.source_desc source_desc = result.compilation_source.source_desc
if not isinstance(source_desc, FileSourceDescriptor): if not isinstance(source_desc, FileSourceDescriptor):
raise RuntimeError("Only file sources for code supported") raise RuntimeError("Only file sources for code supported")
Errors.close_listing_file() Errors.close_listing_file()
result.num_errors = Errors.num_errors result.num_errors = Errors.num_errors
if result.num_errors > 0: if result.num_errors > 0:
errors_occurred = True err = True
if errors_occurred and result.c_file: if err and result.c_file:
try: try:
Utils.castrate_file(result.c_file, os.stat(source_desc.filename)) Utils.castrate_file(result.c_file, os.stat(source_desc.filename))
except EnvironmentError: except EnvironmentError:
...@@ -332,20 +453,6 @@ class Context: ...@@ -332,20 +453,6 @@ class Context:
verbose_flag = options.show_version, verbose_flag = options.show_version,
cplus = options.cplus) cplus = options.cplus)
def nonfatal_error(self, exc):
return Errors.report_error(exc)
def run_pipeline(self, pipeline, source):
errors_occurred = False
data = source
try:
for phase in pipeline:
data = phase(data)
except CompileError, err:
errors_occurred = True
Errors.report_error(err)
return (errors_occurred, data)
def create_parse(context): def create_parse(context):
def parse(compsrc): def parse(compsrc):
source_desc = compsrc.source_desc source_desc = compsrc.source_desc
...@@ -355,45 +462,10 @@ def create_parse(context): ...@@ -355,45 +462,10 @@ def create_parse(context):
tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name) tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
tree.compilation_source = compsrc tree.compilation_source = compsrc
tree.scope = scope tree.scope = scope
tree.is_pxd = False
return tree return tree
return parse return parse
def create_generate_code(context, options, result):
def generate_code(module_node):
scope = module_node.scope
module_node.process_implementation(options, result)
result.compilation_source = module_node.compilation_source
return result
return generate_code
def create_default_pipeline(context, options, result):
from Visitor import PrintTree
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from Optimize import FlattenInListTransform, SwitchTransform, OptimizeRefcounting
from Buffer import IntroduceBufferAuxiliaryVars
from ModuleNode import check_c_classes
def printit(x): print x.dump()
return [
create_parse(context),
# printit,
NormalizeTree(context),
PostParse(context),
FlattenInListTransform(),
WithTransform(context),
DecoratorTransform(context),
AnalyseDeclarationsTransform(context),
IntroduceBufferAuxiliaryVars(context),
check_c_classes,
AnalyseExpressionsTransform(context),
# BufferTransform(context),
SwitchTransform(),
OptimizeRefcounting(context),
# CreateClosureClasses(context),
create_generate_code(context, options, result)
]
def create_default_resultobj(compilation_source, options): def create_default_resultobj(compilation_source, options):
result = CompilationResult() result = CompilationResult()
result.main_source_file = compilation_source.source_desc.filename result.main_source_file = compilation_source.source_desc.filename
...@@ -409,7 +481,10 @@ def create_default_resultobj(compilation_source, options): ...@@ -409,7 +481,10 @@ def create_default_resultobj(compilation_source, options):
result.c_file = Utils.replace_suffix(source_desc.filename, c_suffix) result.c_file = Utils.replace_suffix(source_desc.filename, c_suffix)
return result return result
def run_pipeline(source, context, options, full_module_name = None): def run_pipeline(source, options, full_module_name = None):
# Set up context
context = Context(options.include_path, options.pragma_overrides)
# Set up source object # Set up source object
cwd = os.getcwd() cwd = os.getcwd()
source_desc = FileSourceDescriptor(os.path.join(cwd, source)) source_desc = FileSourceDescriptor(os.path.join(cwd, source))
...@@ -420,11 +495,11 @@ def run_pipeline(source, context, options, full_module_name = None): ...@@ -420,11 +495,11 @@ def run_pipeline(source, context, options, full_module_name = None):
result = create_default_resultobj(source, options) result = create_default_resultobj(source, options)
# Get pipeline # Get pipeline
pipeline = create_default_pipeline(context, options, result) pipeline = context.create_pyx_pipeline(options, result)
context.setup_errors(options) context.setup_errors(options)
errors_occurred, enddata = context.run_pipeline(pipeline, source) err, enddata = context.run_pipeline(pipeline, source)
context.teardown_errors(errors_occurred, options, result) context.teardown_errors(err, options, result)
return result return result
#------------------------------------------------------------------------ #------------------------------------------------------------------------
...@@ -458,6 +533,7 @@ class CompilationOptions: ...@@ -458,6 +533,7 @@ class CompilationOptions:
defaults to true when recursive is true. defaults to true when recursive is true.
verbose boolean Always print source names being compiled verbose boolean Always print source names being compiled
quiet boolean Don't print source names in recursive mode quiet boolean Don't print source names in recursive mode
pragma_overrides dict Overrides for pragma options (see Options.py)
Following options are experimental and only used on MacOSX: Following options are experimental and only used on MacOSX:
...@@ -533,10 +609,7 @@ def compile_single(source, options, full_module_name = None): ...@@ -533,10 +609,7 @@ def compile_single(source, options, full_module_name = None):
Always compiles a single file; does not perform timestamp checking or Always compiles a single file; does not perform timestamp checking or
recursion. recursion.
""" """
context = Context(options.include_path) return run_pipeline(source, options, full_module_name)
return run_pipeline(source, context, options, full_module_name)
# context = Context(options.include_path)
# return context.compile(source, options, full_module_name)
def compile_multiple(sources, options): def compile_multiple(sources, options):
...@@ -559,12 +632,11 @@ def compile_multiple(sources, options): ...@@ -559,12 +632,11 @@ def compile_multiple(sources, options):
if source not in processed: if source not in processed:
# Compiling multiple sources in one context doesn't quite # Compiling multiple sources in one context doesn't quite
# work properly yet. # work properly yet.
context = Context(options.include_path) # to be removed later
if not timestamps or context.c_file_out_of_date(source): if not timestamps or context.c_file_out_of_date(source):
if verbose: if verbose:
sys.stderr.write("Compiling %s\n" % source) sys.stderr.write("Compiling %s\n" % source)
result = run_pipeline(source, context, options) result = run_pipeline(source, options)
results.add(source, result) results.add(source, result)
processed.add(source) processed.add(source)
if recursive: if recursive:
...@@ -646,10 +718,11 @@ default_options = dict( ...@@ -646,10 +718,11 @@ default_options = dict(
generate_pxi = 0, generate_pxi = 0,
working_path = "", working_path = "",
recursive = 0, recursive = 0,
transforms = None, # deprecated
timestamps = None, timestamps = None,
verbose = 0, verbose = 0,
quiet = 0) quiet = 0,
pragma_overrides = {}
)
if sys.platform == "mac": if sys.platform == "mac":
from Cython.Mac.MacSystem import c_compile, c_link, CCompilerError from Cython.Mac.MacSystem import c_compile, c_link, CCompilerError
default_options['use_listing_file'] = 1 default_options['use_listing_file'] = 1
......
...@@ -23,7 +23,7 @@ import Version ...@@ -23,7 +23,7 @@ import Version
from Errors import error, warning from Errors import error, warning
from PyrexTypes import py_object_type from PyrexTypes import py_object_type
from Cython.Utils import open_new_file, replace_suffix, escape_byte_string from Cython.Utils import open_new_file, replace_suffix, escape_byte_string, EncodedString
def check_c_classes(module_node): def check_c_classes(module_node):
...@@ -45,9 +45,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -45,9 +45,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def analyse_declarations(self, env): def analyse_declarations(self, env):
if Options.embed_pos_in_docstring: if Options.embed_pos_in_docstring:
env.doc = 'File: %s (starting at line %s)'%Nodes.relative_position(self.pos) env.doc = EncodedString(u'File: %s (starting at line %s)' % Nodes.relative_position(self.pos))
if not self.doc is None: if not self.doc is None:
env.doc = env.doc + '\\n' + self.doc env.doc = EncodedString(env.doc + u'\n' + self.doc)
env.doc.encoding = self.doc.encoding
else: else:
env.doc = self.doc env.doc = self.doc
self.body.analyse_declarations(env) self.body.analyse_declarations(env)
...@@ -242,16 +243,20 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -242,16 +243,20 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
h_code = code.insertion_point() h_code = code.insertion_point()
self.generate_module_preamble(env, modules, h_code) self.generate_module_preamble(env, modules, h_code)
code.globalstate.module_pos = self.pos
code.putln("") code.putln("")
code.putln("/* Implementation of %s */" % env.qualified_name) code.putln("/* Implementation of %s */" % env.qualified_name)
self.generate_const_definitions(env, code) self.generate_const_definitions(env, code)
self.generate_interned_num_decls(env, code) self.generate_interned_num_decls(env, code)
self.generate_interned_string_decls(env, code) self.generate_interned_string_decls(env, code)
self.generate_py_string_decls(env, code) self.generate_py_string_decls(env, code)
code.globalstate.insert_global_var_declarations_into(code)
self.generate_cached_builtins_decls(env, code) self.generate_cached_builtins_decls(env, code)
self.body.generate_function_definitions(env, code, options.transforms) self.body.generate_function_definitions(env, code)
code.mark_pos(None) code.mark_pos(None)
self.generate_py_string_table(env, code)
self.generate_typeobj_definitions(env, code) self.generate_typeobj_definitions(env, code)
self.generate_method_table(env, code) self.generate_method_table(env, code)
self.generate_filename_init_prototype(code) self.generate_filename_init_prototype(code)
...@@ -267,6 +272,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -267,6 +272,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_declarations_for_modules(env, modules, h_code) self.generate_declarations_for_modules(env, modules, h_code)
h_code.write('\n') h_code.write('\n')
code.globalstate.close_global_decls()
f = open_new_file(result.c_file) f = open_new_file(result.c_file)
code.copyto(f) code.copyto(f)
f.close() f.close()
...@@ -531,8 +538,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -531,8 +538,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_filename_table(self, code): def generate_filename_table(self, code):
code.putln("") code.putln("")
code.putln("static const char *%s[] = {" % Naming.filenames_cname) code.putln("static const char *%s[] = {" % Naming.filenames_cname)
if code.filename_list: if code.globalstate.filename_list:
for source_desc in code.filename_list: for source_desc in code.globalstate.filename_list:
filename = os.path.basename(source_desc.get_filenametable_entry()) filename = os.path.basename(source_desc.get_filenametable_entry())
escaped_filename = filename.replace("\\", "\\\\").replace('"', r'\"') escaped_filename = filename.replace("\\", "\\\\").replace('"', r'\"')
code.putln('"%s",' % code.putln('"%s",' %
...@@ -1451,28 +1458,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1451,28 +1458,6 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln( code.putln(
"};") "};")
def generate_py_string_table(self, env, code):
entries = env.all_pystring_entries
if entries:
code.putln("")
code.putln(
"static __Pyx_StringTabEntry %s[] = {" %
Naming.stringtab_cname)
for entry in entries:
code.putln(
"{&%s, %s, sizeof(%s), %d, %d, %d}," % (
entry.pystring_cname,
entry.cname,
entry.cname,
entry.type.is_unicode,
entry.is_interned,
entry.is_identifier
))
code.putln(
"{0, 0, 0, 0, 0, 0}")
code.putln(
"};")
def generate_filename_init_prototype(self, code): def generate_filename_init_prototype(self, code):
code.putln(""); code.putln("");
code.putln("static void %s(void); /*proto*/" % Naming.fileinit_cname) code.putln("static void %s(void); /*proto*/" % Naming.fileinit_cname)
...@@ -1540,6 +1525,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1540,6 +1525,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.exit_cfunc_scope() # done with labels code.exit_cfunc_scope() # done with labels
def generate_module_init_func(self, imported_modules, env, code): def generate_module_init_func(self, imported_modules, env, code):
# Insert code stream of __Pyx_InitGlobals
code.globalstate.insert_initcode_into(code)
code.enter_cfunc_scope() code.enter_cfunc_scope()
code.putln("") code.putln("")
header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name header2 = "PyMODINIT_FUNC init%s(void)" % env.module_name
...@@ -1559,19 +1547,17 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1559,19 +1547,17 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
env.generate_library_function_declarations(code) env.generate_library_function_declarations(code)
self.generate_filename_init_call(code) self.generate_filename_init_call(code)
code.putln("/*--- Initialize various global constants etc. ---*/")
code.putln(code.error_goto_if_neg("__Pyx_InitGlobals()", self.pos))
code.putln("/*--- Module creation code ---*/") code.putln("/*--- Module creation code ---*/")
self.generate_module_creation_code(env, code) self.generate_module_creation_code(env, code)
code.putln("/*--- Intern code ---*/")
self.generate_intern_code(env, code)
code.putln("/*--- String init code ---*/")
self.generate_string_init_code(env, code)
if Options.cache_builtins: if Options.cache_builtins:
code.putln("/*--- Builtin init code ---*/") code.putln("/*--- Builtin init code ---*/")
self.generate_builtin_init_code(env, code) code.putln(code.error_goto_if_neg("__Pyx_InitCachedBuiltins()",
self.pos))
code.putln("%s = 0;" % Naming.skip_dispatch_cname); code.putln("%s = 0;" % Naming.skip_dispatch_cname);
code.putln("/*--- Global init code ---*/") code.putln("/*--- Global init code ---*/")
...@@ -1615,7 +1601,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1615,7 +1601,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln('}') code.putln('}')
tempdecl_code.put_var_declarations(env.temp_entries) tempdecl_code.put_var_declarations(env.temp_entries)
tempdecl_code.put_temp_declarations(code.func) tempdecl_code.put_temp_declarations(code.funcstate)
code.exit_cfunc_scope() code.exit_cfunc_scope()
...@@ -1727,41 +1713,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1727,41 +1713,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
"if (!%s) %s;" % ( "if (!%s) %s;" % (
Naming.preimport_cname, Naming.preimport_cname,
code.error_goto(self.pos))); code.error_goto(self.pos)));
def generate_intern_code(self, env, code):
for entry in env.pynum_entries:
if entry.init[-1] == "L":
code.putln('%s = PyLong_FromString("%s", 0, 0); %s;' % (
entry.cname,
entry.init,
code.error_goto_if_null(entry.cname, self.pos)))
else:
code.putln("%s = PyInt_FromLong(%s); %s;" % (
entry.cname,
entry.init,
code.error_goto_if_null(entry.cname, self.pos)))
def generate_string_init_code(self, env, code):
if env.all_pystring_entries:
env.use_utility_code(Nodes.init_string_tab_utility_code)
code.putln(
"if (__Pyx_InitStrings(%s) < 0) %s;" % (
Naming.stringtab_cname,
code.error_goto(self.pos)))
def generate_builtin_init_code(self, env, code):
# Lookup and cache builtin objects.
if Options.cache_builtins:
for entry in env.cached_builtins:
#assert entry.interned_cname is not None
code.putln(
'%s = __Pyx_GetName(%s, %s); if (!%s) %s' % (
entry.cname,
Naming.builtins_cname,
entry.interned_cname,
entry.cname,
code.error_goto(entry.pos)))
def generate_global_init_code(self, env, code): def generate_global_init_code(self, env, code):
# Generate code to initialise global PyObject * # Generate code to initialise global PyObject *
# variables to None. # variables to None.
...@@ -1961,6 +1913,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1961,6 +1913,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
type.typeptr_cname, type.typeobj_cname)) type.typeptr_cname, type.typeobj_cname))
def generate_utility_functions(self, env, code, h_code): def generate_utility_functions(self, env, code, h_code):
for codetup, name in env.utility_code_list:
code.globalstate.use_utility_code(codetup, name)
code.globalstate.put_utility_code_protos(h_code)
code.putln("") code.putln("")
code.putln("/* Runtime support code */") code.putln("/* Runtime support code */")
code.putln("") code.putln("")
...@@ -1968,9 +1924,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1968,9 +1924,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("%s = %s;" % code.putln("%s = %s;" %
(Naming.filetable_cname, Naming.filenames_cname)) (Naming.filetable_cname, Naming.filenames_cname))
code.putln("}") code.putln("}")
for utility_code in env.utility_code_used: code.globalstate.put_utility_code_defs(code)
h_code.put(utility_code[0])
code.put(utility_code[1])
code.put(PyrexTypes.type_conversion_functions) code.put(PyrexTypes.type_conversion_functions)
code.putln("") code.putln("")
......
...@@ -63,7 +63,7 @@ def embed_position(pos, docstring): ...@@ -63,7 +63,7 @@ def embed_position(pos, docstring):
# reuse the string encoding of the original docstring # reuse the string encoding of the original docstring
doc = EncodedString(pos_line) doc = EncodedString(pos_line)
else: else:
doc = EncodedString(pos_line + u'\\n' + docstring) doc = EncodedString(pos_line + u'\n' + docstring)
doc.encoding = encoding doc.encoding = encoding
return doc return doc
...@@ -200,21 +200,15 @@ class BlockNode: ...@@ -200,21 +200,15 @@ class BlockNode:
def generate_const_definitions(self, env, code): def generate_const_definitions(self, env, code):
if env.const_entries: if env.const_entries:
code.putln("")
for entry in env.const_entries: for entry in env.const_entries:
if not entry.is_interned: if not entry.is_interned:
code.put_var_declaration(entry, static = 1) code.globalstate.add_const_definition(entry)
def generate_interned_string_decls(self, env, code): def generate_interned_string_decls(self, env, code):
entries = env.global_scope().new_interned_string_entries entries = env.global_scope().new_interned_string_entries
if entries: if entries:
code.putln("")
for entry in entries: for entry in entries:
code.put_var_declaration(entry, static = 1) code.globalstate.add_interned_string_decl(entry)
code.putln("")
for entry in entries:
code.putln(
"static PyObject *%s;" % entry.pystring_cname)
del entries[:] del entries[:]
def generate_py_string_decls(self, env, code): def generate_py_string_decls(self, env, code):
...@@ -222,11 +216,9 @@ class BlockNode: ...@@ -222,11 +216,9 @@ class BlockNode:
return # earlier error return # earlier error
entries = env.pystring_entries entries = env.pystring_entries
if entries: if entries:
code.putln("")
for entry in entries: for entry in entries:
if not entry.is_interned: if not entry.is_interned:
code.putln( code.globalstate.add_py_string_decl(entry)
"static PyObject *%s;" % entry.pystring_cname)
def generate_interned_num_decls(self, env, code): def generate_interned_num_decls(self, env, code):
# Flush accumulated interned nums from the global scope # Flush accumulated interned nums from the global scope
...@@ -234,18 +226,14 @@ class BlockNode: ...@@ -234,18 +226,14 @@ class BlockNode:
genv = env.global_scope() genv = env.global_scope()
entries = genv.interned_nums entries = genv.interned_nums
if entries: if entries:
code.putln("")
for entry in entries: for entry in entries:
code.putln( code.globalstate.add_interned_num_decl(entry)
"static PyObject *%s;" % entry.cname)
del entries[:] del entries[:]
def generate_cached_builtins_decls(self, env, code): def generate_cached_builtins_decls(self, env, code):
entries = env.global_scope().undeclared_cached_builtins entries = env.global_scope().undeclared_cached_builtins
if len(entries) > 0:
code.putln("")
for entry in entries: for entry in entries:
code.putln("static PyObject *%s;" % entry.cname) code.globalstate.add_cached_builtin_decl(entry)
del entries[:] del entries[:]
...@@ -273,10 +261,10 @@ class StatListNode(Node): ...@@ -273,10 +261,10 @@ class StatListNode(Node):
for stat in self.stats: for stat in self.stats:
stat.analyse_expressions(env) stat.analyse_expressions(env)
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code):
#print "StatListNode.generate_function_definitions" ### #print "StatListNode.generate_function_definitions" ###
for stat in self.stats: for stat in self.stats:
stat.generate_function_definitions(env, code, transforms) stat.generate_function_definitions(env, code)
def generate_execution_code(self, code): def generate_execution_code(self, code):
#print "StatListNode.generate_execution_code" ### #print "StatListNode.generate_execution_code" ###
...@@ -302,7 +290,7 @@ class StatNode(Node): ...@@ -302,7 +290,7 @@ class StatNode(Node):
# Emit C code for executable statements. # Emit C code for executable statements.
# #
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code):
pass pass
def generate_execution_code(self, code): def generate_execution_code(self, code):
...@@ -586,29 +574,33 @@ class CSimpleBaseTypeNode(CBaseTypeNode): ...@@ -586,29 +574,33 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
else: else:
return PyrexTypes.error_type return PyrexTypes.error_type
class CBufferAccessTypeNode(Node): class CBufferAccessTypeNode(CBaseTypeNode):
# After parsing: # After parsing:
# positional_args [ExprNode] List of positional arguments # positional_args [ExprNode] List of positional arguments
# keyword_args DictNode Keyword arguments # keyword_args DictNode Keyword arguments
# base_type_node CBaseTypeNode # base_type_node CBaseTypeNode
# After PostParse:
# dtype_node CBaseTypeNode
# ndim int
# After analysis: # After analysis:
# type PyrexType.PyrexType # type PyrexType.BufferType ...containing the right options
child_attrs = ["base_type_node", "positional_args", "keyword_args", child_attrs = ["base_type_node", "positional_args",
"dtype_node"] "keyword_args", "dtype_node"]
dtype_node = None dtype_node = None
def analyse(self, env): def analyse(self, env):
base_type = self.base_type_node.analyse(env) base_type = self.base_type_node.analyse(env)
dtype = self.dtype_node.analyse(env) import Buffer
self.type = PyrexTypes.BufferType(base_type, dtype=dtype, ndim=self.ndim,
mode=self.mode) options = Buffer.analyse_buffer_options(
self.pos,
env,
self.positional_args,
self.keyword_args,
base_type.buffer_defaults)
self.type = PyrexTypes.BufferType(base_type, **options)
return self.type return self.type
class CComplexBaseTypeNode(CBaseTypeNode): class CComplexBaseTypeNode(CBaseTypeNode):
...@@ -641,7 +633,6 @@ class CVarDefNode(StatNode): ...@@ -641,7 +633,6 @@ class CVarDefNode(StatNode):
dest_scope = env dest_scope = env
self.dest_scope = dest_scope self.dest_scope = dest_scope
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
if (dest_scope.is_c_class_scope if (dest_scope.is_c_class_scope
and self.visibility == 'public' and self.visibility == 'public'
and base_type.is_pyobject and base_type.is_pyobject
...@@ -853,7 +844,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -853,7 +844,7 @@ class FuncDefNode(StatNode, BlockNode):
self.local_scope = lenv self.local_scope = lenv
return lenv return lenv
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code):
import Buffer import Buffer
lenv = self.local_scope lenv = self.local_scope
...@@ -907,14 +898,18 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -907,14 +898,18 @@ class FuncDefNode(StatNode, BlockNode):
for entry in lenv.arg_entries: for entry in lenv.arg_entries:
if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg': if entry.type.is_pyobject and lenv.control_flow.get_state((entry.name, 'source')) != 'arg':
code.put_var_incref(entry) code.put_var_incref(entry)
if entry.type.is_buffer:
Buffer.put_acquire_arg_buffer(entry, code, self.pos)
# ----- Initialise local variables # ----- Initialise local variables
for entry in lenv.var_entries: for entry in lenv.var_entries:
if entry.type.is_pyobject and entry.init_to_none and entry.used: if entry.type.is_pyobject and entry.init_to_none and entry.used:
code.put_init_var_to_py_none(entry) code.put_init_var_to_py_none(entry)
if entry.type.is_buffer and entry.buffer_aux.buffer_info_var.used:
code.putln("%s.buf = NULL;" % entry.buffer_aux.buffer_info_var.cname)
# ----- Check and convert arguments # ----- Check and convert arguments
self.generate_argument_type_tests(code) self.generate_argument_type_tests(code)
# ----- Acquire buffer arguments
for entry in lenv.arg_entries:
if entry.type.is_buffer:
Buffer.put_acquire_arg_buffer(entry, code, self.pos)
# ----- Function body # ----- Function body
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
# ----- Default return value # ----- Default return value
...@@ -974,7 +969,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -974,7 +969,8 @@ class FuncDefNode(StatNode, BlockNode):
# goto statement in error cleanup above # goto statement in error cleanup above
code.put_label(code.return_label) code.put_label(code.return_label)
for entry in lenv.buffer_entries: for entry in lenv.buffer_entries:
code.putln("%s;" % Buffer.get_release_buffer_code(entry)) if entry.used:
code.putln("%s;" % Buffer.get_release_buffer_code(entry))
# ----- Return cleanup for both error and no-error return # ----- Return cleanup for both error and no-error return
code.put_label(code.return_from_error_cleanup_label) code.put_label(code.return_from_error_cleanup_label)
if not Options.init_local_none: if not Options.init_local_none:
...@@ -996,11 +992,11 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -996,11 +992,11 @@ class FuncDefNode(StatNode, BlockNode):
code.putln("}") code.putln("}")
# ----- Go back and insert temp variable declarations # ----- Go back and insert temp variable declarations
tempvardecl_code.put_var_declarations(lenv.temp_entries) tempvardecl_code.put_var_declarations(lenv.temp_entries)
tempvardecl_code.put_temp_declarations(code.func) tempvardecl_code.put_temp_declarations(code.funcstate)
# ----- Python version # ----- Python version
code.exit_cfunc_scope() code.exit_cfunc_scope()
if self.py_func: if self.py_func:
self.py_func.generate_function_definitions(env, code, transforms) self.py_func.generate_function_definitions(env, code)
self.generate_optarg_wrapper_function(env, code) self.generate_optarg_wrapper_function(env, code)
def put_stararg_decrefs(self, code): def put_stararg_decrefs(self, code):
...@@ -2031,10 +2027,9 @@ class PyClassDefNode(ClassDefNode): ...@@ -2031,10 +2027,9 @@ class PyClassDefNode(ClassDefNode):
#self.classobj.release_temp(env) #self.classobj.release_temp(env)
#self.target.release_target_temp(env) #self.target.release_target_temp(env)
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code):
self.generate_py_string_decls(self.scope, code) self.generate_py_string_decls(self.scope, code)
self.body.generate_function_definitions( self.body.generate_function_definitions(self.scope, code)
self.scope, code, transforms)
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.dict.generate_evaluation_code(code) self.dict.generate_evaluation_code(code)
...@@ -2062,13 +2057,26 @@ class CClassDefNode(ClassDefNode): ...@@ -2062,13 +2057,26 @@ class CClassDefNode(ClassDefNode):
# body StatNode or None # body StatNode or None
# entry Symtab.Entry # entry Symtab.Entry
# base_type PyExtensionType or None # base_type PyExtensionType or None
# buffer_defaults_node DictNode or None Declares defaults for a buffer
# buffer_defaults_pos
child_attrs = ["body"] child_attrs = ["body"]
buffer_defaults_node = None
buffer_defaults_pos = None
def analyse_declarations(self, env): def analyse_declarations(self, env):
#print "CClassDefNode.analyse_declarations:", self.class_name #print "CClassDefNode.analyse_declarations:", self.class_name
#print "...visibility =", self.visibility #print "...visibility =", self.visibility
#print "...module_name =", self.module_name #print "...module_name =", self.module_name
import Buffer
if self.buffer_defaults_node:
buffer_defaults = Buffer.analyse_buffer_options(self.buffer_defaults_pos,
env, [], self.buffer_defaults_node,
need_complete=False)
else:
buffer_defaults = None
if env.in_cinclude and not self.objstruct_name: if env.in_cinclude and not self.objstruct_name:
error(self.pos, "Object struct name specification required for " error(self.pos, "Object struct name specification required for "
"C class defined in 'extern from' block") "C class defined in 'extern from' block")
...@@ -2120,7 +2128,8 @@ class CClassDefNode(ClassDefNode): ...@@ -2120,7 +2128,8 @@ class CClassDefNode(ClassDefNode):
typeobj_cname = self.typeobj_name, typeobj_cname = self.typeobj_name,
visibility = self.visibility, visibility = self.visibility,
typedef_flag = self.typedef_flag, typedef_flag = self.typedef_flag,
api = self.api) api = self.api,
buffer_defaults = buffer_defaults)
if home_scope is not env and self.visibility == 'extern': if home_scope is not env and self.visibility == 'extern':
env.add_imported_entry(self.class_name, self.entry, pos) env.add_imported_entry(self.class_name, self.entry, pos)
scope = self.entry.type.scope scope = self.entry.type.scope
...@@ -2150,11 +2159,11 @@ class CClassDefNode(ClassDefNode): ...@@ -2150,11 +2159,11 @@ class CClassDefNode(ClassDefNode):
scope = self.entry.type.scope scope = self.entry.type.scope
self.body.analyse_expressions(scope) self.body.analyse_expressions(scope)
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code):
self.generate_py_string_decls(self.entry.type.scope, code) self.generate_py_string_decls(self.entry.type.scope, code)
if self.body: if self.body:
self.body.generate_function_definitions( self.body.generate_function_definitions(
self.entry.type.scope, code, transforms) self.entry.type.scope, code)
def generate_execution_code(self, code): def generate_execution_code(self, code):
# This is needed to generate evaluation code for # This is needed to generate evaluation code for
...@@ -2188,8 +2197,8 @@ class PropertyNode(StatNode): ...@@ -2188,8 +2197,8 @@ class PropertyNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
def generate_function_definitions(self, env, code, transforms): def generate_function_definitions(self, env, code):
self.body.generate_function_definitions(env, code, transforms) self.body.generate_function_definitions(env, code)
def generate_execution_code(self, code): def generate_execution_code(self, code):
pass pass
...@@ -2675,7 +2684,7 @@ class ContinueStatNode(StatNode): ...@@ -2675,7 +2684,7 @@ class ContinueStatNode(StatNode):
pass pass
def generate_execution_code(self, code): def generate_execution_code(self, code):
if code.in_try_finally: if code.funcstate.in_try_finally:
error(self.pos, "continue statement inside try of try...finally") error(self.pos, "continue statement inside try of try...finally")
elif not code.continue_label: elif not code.continue_label:
error(self.pos, "continue statement not inside loop") error(self.pos, "continue statement not inside loop")
...@@ -2838,7 +2847,7 @@ class ReraiseStatNode(StatNode): ...@@ -2838,7 +2847,7 @@ class ReraiseStatNode(StatNode):
gil_message = "Raising exception" gil_message = "Raising exception"
def generate_execution_code(self, code): def generate_execution_code(self, code):
vars = code.exc_vars vars = code.funcstate.exc_vars
if vars: if vars:
code.putln("__Pyx_Raise(%s, %s, %s);" % tuple(vars)) code.putln("__Pyx_Raise(%s, %s, %s);" % tuple(vars))
code.putln(code.error_goto(self.pos)) code.putln(code.error_goto(self.pos))
...@@ -3519,10 +3528,10 @@ class ExceptClauseNode(Node): ...@@ -3519,10 +3528,10 @@ class ExceptClauseNode(Node):
self.excinfo_tuple.generate_evaluation_code(code) self.excinfo_tuple.generate_evaluation_code(code)
self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code) self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code)
old_exc_vars = code.exc_vars old_exc_vars = code.funcstate.exc_vars
code.exc_vars = self.exc_vars code.funcstate.exc_vars = self.exc_vars
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
code.exc_vars = old_exc_vars code.funcstate.exc_vars = old_exc_vars
for var in self.exc_vars: for var in self.exc_vars:
code.putln("Py_DECREF(%s); %s = 0;" % (var, var)) code.putln("Py_DECREF(%s); %s = 0;" % (var, var))
code.put_goto(end_label) code.put_goto(end_label)
...@@ -3597,11 +3606,11 @@ class TryFinallyStatNode(StatNode): ...@@ -3597,11 +3606,11 @@ class TryFinallyStatNode(StatNode):
code.putln( code.putln(
"/*try:*/ {") "/*try:*/ {")
if self.disallow_continue_in_try_finally: if self.disallow_continue_in_try_finally:
was_in_try_finally = code.in_try_finally was_in_try_finally = code.funcstate.in_try_finally
code.in_try_finally = 1 code.funcstate.in_try_finally = 1
self.body.generate_execution_code(code) self.body.generate_execution_code(code)
if self.disallow_continue_in_try_finally: if self.disallow_continue_in_try_finally:
code.in_try_finally = was_in_try_finally code.funcstate.in_try_finally = was_in_try_finally
code.putln( code.putln(
"}") "}")
code.putln( code.putln(
...@@ -3926,6 +3935,7 @@ class FromImportStatNode(StatNode): ...@@ -3926,6 +3935,7 @@ class FromImportStatNode(StatNode):
target.generate_assignment_code(self.item, code) target.generate_assignment_code(self.item, code)
self.module.generate_disposal_code(code) self.module.generate_disposal_code(code)
#------------------------------------------------------------------------------------ #------------------------------------------------------------------------------------
# #
# Runtime support code # Runtime support code
......
# #
# Pyrex - Compilation-wide options # Cython - Compilation-wide options and pragma declarations
# #
cache_builtins = 1 # Perform lookups on builtin names only once cache_builtins = 1 # Perform lookups on builtin names only once
...@@ -52,3 +52,58 @@ optimize_simple_methods = 1 ...@@ -52,3 +52,58 @@ optimize_simple_methods = 1
# Append the c file and line number to the traceback for exceptions. # Append the c file and line number to the traceback for exceptions.
c_line_in_traceback = 1 c_line_in_traceback = 1
# Declare pragmas
option_types = {
'boundscheck' : bool
}
option_defaults = {
'boundscheck' : True
}
def parse_option_list(s):
"""
Parses a comma-seperated list of pragma options. Whitespace
is not considered.
>>> parse_option_list(' ')
{}
>>> (parse_option_list('boundscheck=True') ==
... {'boundscheck': True})
True
>>> parse_option_list(' asdf')
Traceback (most recent call last):
...
ValueError: Expected "=" in option "asdf"
>>> parse_option_list('boundscheck=hey')
Traceback (most recent call last):
...
ValueError: Must pass a boolean value for option "boundscheck"
>>> parse_option_list('unknown=True')
Traceback (most recent call last):
...
ValueError: Unknown option: "unknown"
"""
result = {}
for item in s.split(','):
item = item.strip()
if not item: continue
if not '=' in item: raise ValueError('Expected "=" in option "%s"' % item)
name, value = item.strip().split('=')
try:
type = option_types[name]
except KeyError:
raise ValueError('Unknown option: "%s"' % name)
if type is bool:
value = value.lower()
if value in ('true', 'yes'):
value = True
elif value in ('false', 'no'):
value = False
else: raise ValueError('Must pass a boolean value for option "%s"' % name)
result[name] = value
else:
assert False
return result
...@@ -9,6 +9,7 @@ try: ...@@ -9,6 +9,7 @@ try:
set set
except NameError: except NameError:
from sets import Set as set from sets import Set as set
import copy
class NormalizeTree(CythonTransform): class NormalizeTree(CythonTransform):
""" """
...@@ -79,15 +80,10 @@ class NormalizeTree(CythonTransform): ...@@ -79,15 +80,10 @@ class NormalizeTree(CythonTransform):
class PostParseError(CompileError): pass class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them # error strings checked by unit tests, so define them
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes' ERR_CDEF_INCLASS = 'Cannot assign default value to cdef class attributes'
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables' ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
ERR_BUF_MODEHELP = 'Only allowed buffer modes are "full" or "strided" (as a compile-time string)' ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
class PostParse(CythonTransform): class PostParse(CythonTransform):
""" """
Basic interpretation of the parse tree, as well as validity Basic interpretation of the parse tree, as well as validity
...@@ -99,10 +95,24 @@ class PostParse(CythonTransform): ...@@ -99,10 +95,24 @@ class PostParse(CythonTransform):
- Default values to cdef assignments are turned into single - Default values to cdef assignments are turned into single
assignments following the declaration (everywhere but in class assignments following the declaration (everywhere but in class
bodies, where they raise a compile error) bodies, where they raise a compile error)
- CBufferAccessTypeNode has its options interpreted:
- Interpret some node structures into Python runtime values.
Some nodes take compile-time arguments (currently:
CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}),
which should be interpreted. This happens in a general way
and other steps should be taken to ensure validity.
Type arguments cannot be interpreted in this way.
- For __cythonbufferdefaults__ the arguments are checked for
validity.
CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute, Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the option combination is valid. so on. Also it is checked that the option combination is valid.
- __cythonbufferdefaults__ attributes are parsed and put into the
type information.
Note: Currently Parsing.py does a lot of interpretation and Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform reorganization that can be refactored into this transform
...@@ -112,6 +122,12 @@ class PostParse(CythonTransform): ...@@ -112,6 +122,12 @@ class PostParse(CythonTransform):
# Track our context. # Track our context.
scope_type = None # can be either of 'module', 'function', 'class' scope_type = None # can be either of 'module', 'function', 'class'
def __init__(self, context):
super(PostParse, self).__init__(context)
self.specialattribute_handlers = {
'__cythonbufferdefaults__' : self.handle_bufferdefaults
}
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.scope_type = 'module' self.scope_type = 'module'
self.visitchildren(node) self.visitchildren(node)
...@@ -120,8 +136,10 @@ class PostParse(CythonTransform): ...@@ -120,8 +136,10 @@ class PostParse(CythonTransform):
def visit_ClassDefNode(self, node): def visit_ClassDefNode(self, node):
prev = self.scope_type prev = self.scope_type
self.scope_type = 'class' self.scope_type = 'class'
self.classnode = node
self.visitchildren(node) self.visitchildren(node)
self.scope_type = prev self.scope_type = prev
del self.classnode
return node return node
def visit_FuncDefNode(self, node): def visit_FuncDefNode(self, node):
...@@ -132,6 +150,12 @@ class PostParse(CythonTransform): ...@@ -132,6 +150,12 @@ class PostParse(CythonTransform):
return node return node
# cdef variables # cdef variables
def handle_bufferdefaults(self, decl):
if not isinstance(decl.default, DictNode):
raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
self.classnode.buffer_defaults_node = decl.default
self.classnode.buffer_defaults_pos = decl.pos
def visit_CVarDefNode(self, node): def visit_CVarDefNode(self, node):
# This assumes only plain names and pointers are assignable on # This assumes only plain names and pointers are assignable on
# declaration. Also, it makes use of the fact that a cdef decl # declaration. Also, it makes use of the fact that a cdef decl
...@@ -139,82 +163,217 @@ class PostParse(CythonTransform): ...@@ -139,82 +163,217 @@ class PostParse(CythonTransform):
# "i = 3; cdef int i = i" and can simply move the nodes around. # "i = 3; cdef int i = i" and can simply move the nodes around.
try: try:
self.visitchildren(node) self.visitchildren(node)
stats = [node]
newdecls = []
for decl in node.declarators:
declbase = decl
while isinstance(declbase, CPtrDeclaratorNode):
declbase = declbase.base
if isinstance(declbase, CNameDeclaratorNode):
if declbase.default is not None:
if self.scope_type == 'class':
if isinstance(self.classnode, CClassDefNode):
handler = self.specialattribute_handlers.get(decl.name)
if handler:
if decl is not declbase:
raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
handler(decl)
continue # Remove declaration
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
stats.append(SingleAssignmentNode(node.pos,
lhs=NameNode(node.pos, name=declbase.name),
rhs=declbase.default, first=True))
declbase.default = None
newdecls.append(decl)
node.declarators = newdecls
return stats
except PostParseError, e: except PostParseError, e:
# An error in a cdef clause is ok, simply remove the declaration # An error in a cdef clause is ok, simply remove the declaration
# and try to move on to report more errors # and try to move on to report more errors
self.context.nonfatal_error(e) self.context.nonfatal_error(e)
return None return None
stats = [node]
for decl in node.declarators:
while isinstance(decl, CPtrDeclaratorNode):
decl = decl.base
if isinstance(decl, CNameDeclaratorNode):
if decl.default is not None:
if self.scope_type == 'class':
raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
stats.append(SingleAssignmentNode(node.pos,
lhs=NameNode(node.pos, name=decl.name),
rhs=decl.default, first=True))
decl.default = None
return stats
# buffer access
buffer_options = ("dtype", "ndim", "mode") # ordered!
def visit_CBufferAccessTypeNode(self, node): def visit_CBufferAccessTypeNode(self, node):
if not self.scope_type == 'function': if not self.scope_type == 'function':
raise PostParseError(node.pos, ERR_BUF_LOCALONLY) raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
return node
class PxdPostParse(CythonTransform):
"""
Basic interpretation/validity checking that should only be
done on pxd trees.
A lot of this checking currently happens in the parser; but
what is listed below happens here.
- "def" functions are let through only if they fill the
getbuffer/releasebuffer slots
"""
ERR_FUNCDEF_NOT_ALLOWED = 'function definition not allowed here'
def __call__(self, node):
self.scope_type = 'pxd'
return super(PxdPostParse, self).__call__(node)
def visit_CClassDefNode(self, node):
old = self.scope_type
self.scope_type = 'cclass'
self.visitchildren(node)
self.scope_type = old
return node
def visit_FuncDefNode(self, node):
# FuncDefNode always come with an implementation (without
# an imp they are CVarDefNodes..)
ok = False
if (isinstance(node, DefNode) and self.scope_type == 'cclass'
and node.name in ('__getbuffer__', '__releasebuffer__')):
ok = True
if not ok:
self.context.nonfatal_error(PostParseError(node.pos,
self.ERR_FUNCDEF_NOT_ALLOWED))
return None
else:
return node
class ResolveOptions(CythonTransform):
"""
After parsing, options can be stored in a number of places:
- #cython-comments at the top of the file (stored in ModuleNode)
- Command-line arguments overriding these
- @cython.optionname decorators
- with cython.optionname: statements
This transform is responsible for annotating each node with an
"options" attribute linking it to a dict containing the exact
options that are in effect for that node. Any corresponding decorators
or with statements are removed in the process.
Note that we have to run this prior to analysis, and so some minor
duplication of functionality has to occur: We manually track cimports
to correctly intercept @cython... and with cython...
"""
def __init__(self, context, compilation_option_overrides):
super(ResolveOptions, self).__init__(context)
self.compilation_option_overrides = compilation_option_overrides
self.cython_module_names = set()
self.option_names = {}
def visit_ModuleNode(self, node):
options = copy.copy(Options.option_defaults)
options.update(node.option_comments)
options.update(self.compilation_option_overrides)
self.options = options
node.options = options
self.visitchildren(node)
return node
# Track cimports of the cython module.
def visit_CImportStatNode(self, node):
if node.module_name == u"cython":
if node.as_name:
modname = node.as_name
else:
modname = u"cython"
self.cython_module_names.add(modname)
return node
def visit_FromCImportStatNode(self, node):
if node.module_name == u"cython":
newimp = []
for pos, name, as_name, kind in node.imported_names:
if name in Options.option_types:
self.option_names[as_name] = name
if kind is not None:
self.context.nonfatal_error(PostParseError(pos,
"Compiler option imports must be plain imports"))
return None
else:
newimp.append((pos, name, as_name, kind))
node.imported_names = newimpo
return node
def visit_Node(self, node):
node.options = self.options
self.visitchildren(node)
return node
def try_to_parse_option(self, node):
# If node is the contents of an option (in a with statement or
# decorator), returns (optionname, value).
# Otherwise, returns None
optname = None
if isinstance(node, SimpleCallNode):
if (isinstance(node.function, AttributeNode) and
isinstance(node.function.obj, NameNode) and
node.function.obj.name in self.cython_module_names):
optname = node.function.attribute
elif (isinstance(node.function, NameNode) and
node.function.name in self.option_names):
optname = self.option_names[node.function.name]
if optname:
optiontype = Options.option_types.get(optname)
if optiontype:
args = node.args
if optiontype is bool:
if len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(dec.function.pos,
'The %s option takes one compile-time boolean argument' % optname)
return (optname, args[0].value)
else:
assert False
return None
def visit_with_options(self, node, options):
oldoptions = self.options
newoptions = copy.copy(oldoptions)
newoptions.update(options)
self.options = newoptions
node = self.visit_Node(node)
self.options = oldoptions
return node
# Handle decorators
def visit_DefNode(self, node):
options = []
options = {} if node.decorators:
# Fetch positional arguments # Split the decorators into two lists -- real decorators and options
if len(node.positional_args) > len(self.buffer_options): realdecs = []
raise PostParseError(node.pos, ERR_BUF_TOO_MANY) for dec in node.decorators:
for arg, unicode_name in zip(node.positional_args, self.buffer_options): option = self.try_to_parse_option(dec.decorator)
name = str(unicode_name) if option is not None:
options[name] = arg options.append(option)
# Fetch named arguments else:
for item in node.keyword_args.key_value_pairs: realdecs.append(dec)
name = str(item.key.value) node.decorators = realdecs
if not name in self.buffer_options:
raise PostParseError(item.key.pos, ERR_BUF_OPTION_UNKNOWN % name) if options:
if name in options.keys(): optdict = {}
raise PostParseError(item.key.pos, ERR_BUF_DUP % key) options.reverse() # Decorators coming first take precedence
options[name] = item.value for option in options:
name, value = option
# get dtype optdict[name] = value
dtype = options.get("dtype") return self.visit_with_options(node, optdict)
if dtype is None:
raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype_node = dtype
# get ndim
if "ndim" in options:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
# so nothing more to do here
raise PostParseError(ndimnode.pos, ERR_BUF_INT % 'ndim')
ndim_value = int(ndimnode.value)
if ndim_value < 0:
raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
node.ndim = int(ndimnode.value)
else: else:
node.ndim = 1 return self.visit_Node(node)
if "mode" in options: # Handle with statements
modenode = options["mode"] def visit_WithStatNode(self, node):
if not isinstance(modenode, StringNode): option = self.try_to_parse_option(node.manager)
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP) if option is not None:
mode = modenode.value if node.target is not None:
if not mode in ('full', 'strided'): raise PostParseError(node.pos, "Compiler option with statements cannot contain 'as'")
raise PostParseError(modenode.pos, ERR_BUF_MODEHELP) name, value = option
node.mode = mode self.visit_with_options(node.body, {name:value})
return node.body.stats
else: else:
node.mode = 'full' return self.visit_Node(node)
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
return node
class WithTransform(CythonTransform): class WithTransform(CythonTransform):
......
...@@ -13,6 +13,7 @@ from ModuleNode import ModuleNode ...@@ -13,6 +13,7 @@ from ModuleNode import ModuleNode
from Errors import error, warning, InternalError from Errors import error, warning, InternalError
from Cython import Utils from Cython import Utils
import Future import Future
import Options
class Ctx(object): class Ctx(object):
# Parsing context # Parsing context
...@@ -602,7 +603,7 @@ def p_string_literal(s): ...@@ -602,7 +603,7 @@ def p_string_literal(s):
else: else:
c = systr[1] c = systr[1]
if c in "01234567": if c in "01234567":
chars.append(chr(int(systr[1:]))) chars.append(chr(int(systr[1:], 8)))
elif c in "'\"\\": elif c in "'\"\\":
chars.append(c) chars.append(c)
elif c in "abfnrtv": elif c in "abfnrtv":
...@@ -621,7 +622,7 @@ def p_string_literal(s): ...@@ -621,7 +622,7 @@ def p_string_literal(s):
strval = systr strval = systr
chars.append(strval) chars.append(strval)
else: else:
chars.append(r'\\' + systr[1:]) chars.append('\\' + systr[1:])
elif sy == 'NEWLINE': elif sy == 'NEWLINE':
chars.append('\n') chars.append('\n')
elif sy == 'END_STRING': elif sy == 'END_STRING':
...@@ -1412,7 +1413,7 @@ def p_statement(s, ctx, first_statement = 0): ...@@ -1412,7 +1413,7 @@ def p_statement(s, ctx, first_statement = 0):
if ctx.api: if ctx.api:
error(s.pos, "'api' not allowed with this statement") error(s.pos, "'api' not allowed with this statement")
elif s.sy == 'def': elif s.sy == 'def':
if ctx.level not in ('module', 'class', 'c_class', 'property'): if ctx.level not in ('module', 'class', 'c_class', 'c_class_pxd', 'property'):
s.error('def statement not allowed here') s.error('def statement not allowed here')
s.level = ctx.level s.level = ctx.level
return p_def_statement(s) return p_def_statement(s)
...@@ -1626,8 +1627,13 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1626,8 +1627,13 @@ def p_c_simple_base_type(s, self_flag, nonempty):
longness = longness, is_self_arg = self_flag) longness = longness, is_self_arg = self_flag)
# Treat trailing [] on type as buffer access # Treat trailing [] on type as buffer access if it appears in a context
if s.sy == '[': # where declarator names are required (so that it cannot mean int[] or
# sizeof(int[SIZE]))...
#
# (This means that buffers cannot occur where there can be empty declarators,
# which is an ok restriction to make.)
if nonempty and s.sy == '[':
return p_buffer_access(s, type_node) return p_buffer_access(s, type_node)
else: else:
return type_node return type_node
...@@ -1636,10 +1642,6 @@ def p_buffer_access(s, base_type_node): ...@@ -1636,10 +1642,6 @@ def p_buffer_access(s, base_type_node):
# s.sy == '[' # s.sy == '['
pos = s.position() pos = s.position()
s.next() s.next()
if s.sy == ']' or s.sy == 'INT':
# not buffer, could be [] on C type nameless array arguments
s.put_back('[', '[')
return base_type_node
positional_args, keyword_args = ( positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',)) p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
) )
...@@ -1887,7 +1889,7 @@ def p_c_arg_decl(s, ctx, in_pyfunc, cmethod_flag = 0, nonempty = 0, kw_only = 0) ...@@ -1887,7 +1889,7 @@ def p_c_arg_decl(s, ctx, in_pyfunc, cmethod_flag = 0, nonempty = 0, kw_only = 0)
if 'pxd' in s.level: if 'pxd' in s.level:
if s.sy not in ['*', '?']: if s.sy not in ['*', '?']:
error(pos, "default values cannot be specified in pxd files, use ? or *") error(pos, "default values cannot be specified in pxd files, use ? or *")
default = 1 default = ExprNodes.BoolNode(1)
s.next() s.next()
else: else:
default = p_simple_expr(s) default = p_simple_expr(s)
...@@ -2324,6 +2326,17 @@ def p_code(s, level=None): ...@@ -2324,6 +2326,17 @@ def p_code(s, level=None):
repr(s.sy), repr(s.systring))) repr(s.sy), repr(s.systring)))
return body return body
def p_option_comments(s):
result = {}
while s.sy == 'option_comment':
opts = s.systring[len("#cython:"):]
try:
result.update(Options.parse_option_list(opts))
except ValueError, e:
s.error(e.message, fatal=False)
s.next()
return result
def p_module(s, pxd, full_module_name): def p_module(s, pxd, full_module_name):
s.add_type_name("object") s.add_type_name("object")
s.add_type_name("Py_buffer") s.add_type_name("Py_buffer")
...@@ -2333,11 +2346,16 @@ def p_module(s, pxd, full_module_name): ...@@ -2333,11 +2346,16 @@ def p_module(s, pxd, full_module_name):
level = 'module_pxd' level = 'module_pxd'
else: else:
level = 'module' level = 'module'
option_comments = p_option_comments(s)
s.parse_option_comments = False
body = p_statement_list(s, Ctx(level = level), first_statement = 1) body = p_statement_list(s, Ctx(level = level), first_statement = 1)
if s.sy != 'EOF': if s.sy != 'EOF':
s.error("Syntax error in statement [%s,%s]" % ( s.error("Syntax error in statement [%s,%s]" % (
repr(s.sy), repr(s.systring))) repr(s.sy), repr(s.systring)))
return ModuleNode(pos, doc = doc, body = body, full_module_name = full_module_name) return ModuleNode(pos, doc = doc, body = body,
full_module_name = full_module_name,
option_comments = option_comments)
#---------------------------------------------- #----------------------------------------------
# #
......
...@@ -149,6 +149,7 @@ class CTypedefType(BaseType): ...@@ -149,6 +149,7 @@ class CTypedefType(BaseType):
# typedef_base_type PyrexType # typedef_base_type PyrexType
is_typedef = 1 is_typedef = 1
typestring = None # Because typedefs are not known exactly
def __init__(self, cname, base_type): def __init__(self, cname, base_type):
self.typedef_cname = cname self.typedef_cname = cname
...@@ -223,11 +224,14 @@ class PyObjectType(PyrexType): ...@@ -223,11 +224,14 @@ class PyObjectType(PyrexType):
# #
# Base class for all Python object types (reference-counted). # Base class for all Python object types (reference-counted).
# #
# buffer_defaults dict or None Default options for bu
is_pyobject = 1 is_pyobject = 1
default_value = "0" default_value = "0"
parsetuple_format = "O" parsetuple_format = "O"
pymemberdef_typecode = "T_OBJECT" pymemberdef_typecode = "T_OBJECT"
buffer_defaults = None
typestring = "O"
def __str__(self): def __str__(self):
return "Python object" return "Python object"
...@@ -270,6 +274,7 @@ class BuiltinObjectType(PyObjectType): ...@@ -270,6 +274,7 @@ class BuiltinObjectType(PyObjectType):
return "<%s>"% self.cname return "<%s>"% self.cname
def assignable_from(self, src_type): def assignable_from(self, src_type):
if isinstance(src_type, BuiltinObjectType): if isinstance(src_type, BuiltinObjectType):
return src_type.name == self.name return src_type.name == self.name
else: else:
...@@ -998,6 +1003,19 @@ class CStringType: ...@@ -998,6 +1003,19 @@ class CStringType:
return '"%s"' % Utils.escape_byte_string(value) return '"%s"' % Utils.escape_byte_string(value)
class CUTF8CharArrayType(CStringType, CArrayType):
# C 'char []' type.
parsetuple_format = "s"
pymemberdef_typecode = "T_STRING_INPLACE"
is_unicode = 1
to_py_function = "PyUnicode_DecodeUTF8"
exception_value = "NULL"
def __init__(self, size):
CArrayType.__init__(self, c_char_type, size)
class CCharArrayType(CStringType, CArrayType): class CCharArrayType(CStringType, CArrayType):
# C 'char []' type. # C 'char []' type.
...@@ -1018,29 +1036,6 @@ class CCharPtrType(CStringType, CPtrType): ...@@ -1018,29 +1036,6 @@ class CCharPtrType(CStringType, CPtrType):
CPtrType.__init__(self, c_char_type) CPtrType.__init__(self, c_char_type)
class UnicodeType(BuiltinObjectType):
# The Python unicode type.
is_string = 1
is_unicode = 1
parsetuple_format = "U"
def __init__(self):
BuiltinObjectType.__init__(self, "unicode", "PyUnicodeObject")
def literal_code(self, value):
assert isinstance(value, str)
return '"%s"' % Utils.escape_byte_string(value)
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
if pyrex or for_display:
return self.base_declaration_code(self.name, entity_code)
else:
return "%s %s[]" % (public_decl("char", dll_linkage), entity_code)
class ErrorType(PyrexType): class ErrorType(PyrexType):
# Used to prevent propagation of error messages. # Used to prevent propagation of error messages.
...@@ -1049,6 +1044,7 @@ class ErrorType(PyrexType): ...@@ -1049,6 +1044,7 @@ class ErrorType(PyrexType):
exception_check = 0 exception_check = 0
to_py_function = "dummy" to_py_function = "dummy"
from_py_function = "dummy" from_py_function = "dummy"
typestring = None
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
...@@ -1105,8 +1101,8 @@ c_longdouble_type = CFloatType(8, typestring="g") ...@@ -1105,8 +1101,8 @@ c_longdouble_type = CFloatType(8, typestring="g")
c_null_ptr_type = CNullPtrType(c_void_type) c_null_ptr_type = CNullPtrType(c_void_type)
c_char_array_type = CCharArrayType(None) c_char_array_type = CCharArrayType(None)
c_unicode_type = UnicodeType()
c_char_ptr_type = CCharPtrType() c_char_ptr_type = CCharPtrType()
c_utf8_char_array_type = CUTF8CharArrayType(None)
c_char_ptr_ptr_type = CPtrType(c_char_ptr_type) c_char_ptr_ptr_type = CPtrType(c_char_ptr_type)
c_py_ssize_t_ptr_type = CPtrType(c_py_ssize_t_type) c_py_ssize_t_ptr_type = CPtrType(c_py_ssize_t_type)
c_int_ptr_type = CPtrType(c_int_type) c_int_ptr_type = CPtrType(c_int_type)
......
...@@ -306,6 +306,7 @@ class PyrexScanner(Scanner): ...@@ -306,6 +306,7 @@ class PyrexScanner(Scanner):
self.compile_time_env = initial_compile_time_env() self.compile_time_env = initial_compile_time_env()
self.compile_time_eval = 1 self.compile_time_eval = 1
self.compile_time_expr = 0 self.compile_time_expr = 0
self.parse_option_comments = True
self.source_encoding = source_encoding self.source_encoding = source_encoding
self.trace = trace_scanner self.trace = trace_scanner
self.indentation_stack = [0] self.indentation_stack = [0]
...@@ -314,6 +315,13 @@ class PyrexScanner(Scanner): ...@@ -314,6 +315,13 @@ class PyrexScanner(Scanner):
self.begin('INDENT') self.begin('INDENT')
self.sy = '' self.sy = ''
self.next() self.next()
def option_comment(self, text):
# #cython:-comments should be treated as literals until
# parse_option_comments is set to False, at which point
# they should be ignored.
if self.parse_option_comments:
self.produce('option_comment', text)
def current_level(self): def current_level(self):
return self.indentation_stack[-1] return self.indentation_stack[-1]
...@@ -432,12 +440,13 @@ class PyrexScanner(Scanner): ...@@ -432,12 +440,13 @@ class PyrexScanner(Scanner):
def looking_at_type_name(self): def looking_at_type_name(self):
return self.sy == 'IDENT' and self.systring in self.type_names return self.sy == 'IDENT' and self.systring in self.type_names
def error(self, message, pos = None): def error(self, message, pos = None, fatal = True):
if pos is None: if pos is None:
pos = self.position() pos = self.position()
if self.sy == 'INDENT': if self.sy == 'INDENT':
error(pos, "Possible inconsistent indentation") err = error(pos, "Possible inconsistent indentation")
raise error(pos, message) err = error(pos, message)
if fatal: raise err
def expect(self, what, message = None): def expect(self, what, message = None):
if self.sy == what: if self.sy == what:
......
...@@ -26,11 +26,12 @@ nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match ...@@ -26,11 +26,12 @@ nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
class BufferAux: class BufferAux:
writable_needed = False writable_needed = False
def __init__(self, buffer_info_var, stridevars, shapevars, tschecker): def __init__(self, buffer_info_var, stridevars, shapevars,
suboffsetvars):
self.buffer_info_var = buffer_info_var self.buffer_info_var = buffer_info_var
self.stridevars = stridevars self.stridevars = stridevars
self.shapevars = shapevars self.shapevars = shapevars
self.tschecker = tschecker self.suboffsetvars = suboffsetvars
def __repr__(self): def __repr__(self):
return "<BufferAux %r>" % self.__dict__ return "<BufferAux %r>" % self.__dict__
...@@ -504,7 +505,7 @@ class Scope: ...@@ -504,7 +505,7 @@ class Scope:
else: else:
cname = self.new_const_cname() cname = self.new_const_cname()
if value.is_unicode: if value.is_unicode:
c_type = PyrexTypes.c_unicode_type c_type = PyrexTypes.c_utf8_char_array_type
value = value.utf8encode() value = value.utf8encode()
else: else:
c_type = PyrexTypes.c_char_array_type c_type = PyrexTypes.c_char_array_type
...@@ -629,9 +630,6 @@ class Scope: ...@@ -629,9 +630,6 @@ class Scope:
def use_utility_code(self, new_code, name=None): def use_utility_code(self, new_code, name=None):
self.global_scope().use_utility_code(new_code, name) self.global_scope().use_utility_code(new_code, name)
def has_utility_code(self, name):
return self.global_scope().has_utility_code(name)
def generate_library_function_declarations(self, code): def generate_library_function_declarations(self, code):
# Generate extern decls for C library funcs used. # Generate extern decls for C library funcs used.
...@@ -743,6 +741,8 @@ class BuiltinScope(Scope): ...@@ -743,6 +741,8 @@ class BuiltinScope(Scope):
"True": ["Py_True", py_object_type], "True": ["Py_True", py_object_type],
} }
const_counter = 1 # As a temporary solution for compiling code in pxds
class ModuleScope(Scope): class ModuleScope(Scope):
# module_name string Python name of the module # module_name string Python name of the module
# module_cname string C name of Python module object # module_cname string C name of Python module object
...@@ -750,9 +750,8 @@ class ModuleScope(Scope): ...@@ -750,9 +750,8 @@ class ModuleScope(Scope):
# method_table_cname string C name of method table # method_table_cname string C name of method table
# doc string Module doc string # doc string Module doc string
# doc_cname string C name of module doc string # doc_cname string C name of module doc string
# const_counter integer Counter for naming constants # const_counter integer Counter for naming constants (PS: MOVED TO GLOBAL)
# utility_code_used [string] Utility code to be included # utility_code_list [((string, string), string)] Queuing utility codes for forwarding to Code.py
# utility_code_names set(string) (Optional) names for named (often generated) utility code
# default_entries [Entry] Function argument default entries # default_entries [Entry] Function argument default entries
# python_include_files [string] Standard Python headers to be included # python_include_files [string] Standard Python headers to be included
# include_files [string] Other C headers to be included # include_files [string] Other C headers to be included
...@@ -785,9 +784,7 @@ class ModuleScope(Scope): ...@@ -785,9 +784,7 @@ class ModuleScope(Scope):
self.method_table_cname = Naming.methtable_cname self.method_table_cname = Naming.methtable_cname
self.doc = "" self.doc = ""
self.doc_cname = Naming.moddoc_cname self.doc_cname = Naming.moddoc_cname
self.const_counter = 1 self.utility_code_list = []
self.utility_code_used = []
self.utility_code_names = set()
self.default_entries = [] self.default_entries = []
self.module_entries = {} self.module_entries = {}
self.python_include_files = ["Python.h", "structmember.h"] self.python_include_files = ["Python.h", "structmember.h"]
...@@ -940,35 +937,20 @@ class ModuleScope(Scope): ...@@ -940,35 +937,20 @@ class ModuleScope(Scope):
return entry return entry
def new_const_cname(self): def new_const_cname(self):
global const_counter
# Create a new globally-unique name for a constant. # Create a new globally-unique name for a constant.
prefix='' prefix=''
n = self.const_counter n = const_counter
self.const_counter = n + 1 const_counter = n + 1
return "%s%s%d" % (Naming.const_prefix, prefix, n) return "%s%s%d" % (Naming.const_prefix, prefix, n)
def use_utility_code(self, new_code, name=None): def use_utility_code(self, new_code, name=None):
# Add string to list of utility code to be included, self.utility_code_list.append((new_code, name))
# if not already there (tested using the provided name,
# or 'is' if name=None -- if the utility code is dynamically
# generated, use the name, otherwise it is not needed).
if name is not None:
if name in self.utility_code_names:
return
for old_code in self.utility_code_used:
if old_code is new_code:
return
self.utility_code_used.append(new_code)
self.utility_code_names.add(name)
def has_utility_code(self, name):
# Checks if utility code (that is registered by name) has
# previously been registered. This is useful if the utility code
# is dynamically generated to avoid re-generation.
return name in self.utility_code_names
def declare_c_class(self, name, pos, defining = 0, implementing = 0, def declare_c_class(self, name, pos, defining = 0, implementing = 0,
module_name = None, base_type = None, objstruct_cname = None, module_name = None, base_type = None, objstruct_cname = None,
typeobj_cname = None, visibility = 'private', typedef_flag = 0, api = 0): typeobj_cname = None, visibility = 'private', typedef_flag = 0, api = 0,
buffer_defaults = None):
# #
# Look for previous declaration as a type # Look for previous declaration as a type
# #
...@@ -992,6 +974,7 @@ class ModuleScope(Scope): ...@@ -992,6 +974,7 @@ class ModuleScope(Scope):
if not entry: if not entry:
type = PyrexTypes.PyExtensionType(name, typedef_flag, base_type) type = PyrexTypes.PyExtensionType(name, typedef_flag, base_type)
type.pos = pos type.pos = pos
type.buffer_defaults = buffer_defaults
if visibility == 'extern': if visibility == 'extern':
type.module_name = module_name type.module_name = module_name
else: else:
......
...@@ -2,6 +2,7 @@ from Cython.TestUtils import CythonTest ...@@ -2,6 +2,7 @@ from Cython.TestUtils import CythonTest
import Cython.Compiler.Errors as Errors import Cython.Compiler.Errors as Errors
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Buffer import *
class TestBufferParsing(CythonTest): class TestBufferParsing(CythonTest):
...@@ -45,6 +46,8 @@ class TestBufferParsing(CythonTest): ...@@ -45,6 +46,8 @@ class TestBufferParsing(CythonTest):
# See also tests/error/e_bufaccess.pyx and tets/run/bufaccess.pyx # See also tests/error/e_bufaccess.pyx and tets/run/bufaccess.pyx
# THESE TESTS ARE NOW DISABLED, the code they test was pretty much
# refactored away
class TestBufferOptions(CythonTest): class TestBufferOptions(CythonTest):
# Tests the full parsing of the options within the brackets # Tests the full parsing of the options within the brackets
...@@ -74,24 +77,24 @@ class TestBufferOptions(CythonTest): ...@@ -74,24 +77,24 @@ class TestBufferOptions(CythonTest):
# e = self.should_fail(lambda: self.parse_opts(opts)) # e = self.should_fail(lambda: self.parse_opts(opts))
self.assertEqual(expected_err, self.error.message_only) self.assertEqual(expected_err, self.error.message_only)
def test_basic(self): def __test_basic(self):
buf = self.parse_opts(u"unsigned short int, 3") buf = self.parse_opts(u"unsigned short int, 3")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode)) self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1) self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim) self.assertEqual(3, buf.ndim)
def test_dict(self): def __test_dict(self):
buf = self.parse_opts(u"ndim=3, dtype=unsigned short int") buf = self.parse_opts(u"ndim=3, dtype=unsigned short int")
self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode)) self.assert_(isinstance(buf.dtype_node, CSimpleBaseTypeNode))
self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1) self.assert_(buf.dtype_node.signed == 0 and buf.dtype_node.longness == -1)
self.assertEqual(3, buf.ndim) self.assertEqual(3, buf.ndim)
def test_ndim(self): def __test_ndim(self):
self.parse_opts(u"int, 2") self.parse_opts(u"int, 2")
self.non_parse(ERR_BUF_INT % 'ndim', u"int, 'a'") self.non_parse(ERR_BUF_NDIM, u"int, 'a'")
self.non_parse(ERR_BUF_NONNEG % 'ndim', u"int, -34") self.non_parse(ERR_BUF_NDIM, u"int, -34")
def test_use_DEF(self): def __test_use_DEF(self):
t = self.fragment(u""" t = self.fragment(u"""
DEF ndim = 3 DEF ndim = 3
def f(): def f():
......
...@@ -20,7 +20,7 @@ Support for parsing strings into code trees. ...@@ -20,7 +20,7 @@ Support for parsing strings into code trees.
class StringParseContext(Main.Context): class StringParseContext(Main.Context):
def __init__(self, include_directories, name): def __init__(self, include_directories, name):
Main.Context.__init__(self, include_directories) Main.Context.__init__(self, include_directories, {})
self.module_name = name self.module_name = name
def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1): def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
......
...@@ -28,6 +28,10 @@ class BasicVisitor(object): ...@@ -28,6 +28,10 @@ class BasicVisitor(object):
if m is not None: if m is not None:
break break
else: else:
print type(self), type(obj)
print self.access_path
print self.access_path[-1][0].pos
print self.access_path[-1][0].__dict__
raise RuntimeError("Visitor does not accept object: %s" % obj) raise RuntimeError("Visitor does not accept object: %s" % obj)
self.dispatch_table[mname] = m self.dispatch_table[mname] = m
return m(obj) return m(obj)
......
...@@ -2,16 +2,80 @@ cdef extern from "Python.h": ...@@ -2,16 +2,80 @@ cdef extern from "Python.h":
ctypedef int Py_intptr_t ctypedef int Py_intptr_t
cdef extern from "numpy/arrayobject.h": cdef extern from "numpy/arrayobject.h":
ctypedef Py_intptr_t npy_intp
ctypedef struct PyArray_Descr:
int elsize
ctypedef class numpy.ndarray [object PyArrayObject]: ctypedef class numpy.ndarray [object PyArrayObject]:
cdef char *data cdef:
cdef int nd char *data
cdef Py_intptr_t *dimensions int nd
cdef Py_intptr_t *strides npy_intp *dimensions
cdef object base npy_intp *strides
# descr not implemented yet here... object base
cdef int flags # descr not implemented yet here...
cdef int itemsize int flags
cdef object weakreflist int itemsize
object weakreflist
PyArray_Descr* descr
def __getbuffer__(ndarray self, Py_buffer* info, int flags):
if sizeof(npy_intp) != sizeof(Py_ssize_t):
raise RuntimeError("Py_intptr_t and Py_ssize_t differs in size, numpy.pxd does not support this")
cdef int typenum = PyArray_TYPE(self)
info.buf = <void*>self.data
info.ndim = 2
info.strides = <Py_ssize_t*>self.strides
info.shape = <Py_ssize_t*>self.dimensions
info.suboffsets = NULL
info.format = "i"
info.itemsize = self.descr.elsize
info.readonly = not PyArray_ISWRITEABLE(self)
# PS TODO TODO!: Py_ssize_t vs Py_intptr_t
## PyArrayObject *arr = (PyArrayObject*)obj;
## PyArray_Descr *type = (PyArray_Descr*)arr->descr;
## int typenum = PyArray_TYPE(obj);
## if (!PyTypeNum_ISNUMBER(typenum)) {
## PyErr_Format(PyExc_TypeError, "Only numeric NumPy types currently supported.");
## return -1;
## }
## /*
## NumPy format codes doesn't completely match buffer codes;
## seems safest to retranslate.
## 01234567890123456789012345*/
## const char* base_codes = "?bBhHiIlLqQfdgfdgO";
## char* format = (char*)malloc(4);
## char* fp = format;
## *fp++ = type->byteorder;
## if (PyTypeNum_ISCOMPLEX(typenum)) *fp++ = 'Z';
## *fp++ = base_codes[typenum];
## *fp = 0;
## view->buf = arr->data;
## view->readonly = !PyArray_ISWRITEABLE(obj);
## view->ndim = PyArray_NDIM(arr);
## view->strides = PyArray_STRIDES(arr);
## view->shape = PyArray_DIMS(arr);
## view->suboffsets = NULL;
## view->format = format;
## view->itemsize = type->elsize;
## view->internal = 0;
## return 0;
## print "hello" + str(43) + "asdf" + "three"
## pass
cdef int PyArray_TYPE(ndarray arr)
cdef int PyArray_ISWRITEABLE(ndarray arr)
ctypedef unsigned int npy_uint8 ctypedef unsigned int npy_uint8
ctypedef unsigned int npy_uint16 ctypedef unsigned int npy_uint16
...@@ -27,4 +91,5 @@ cdef extern from "numpy/arrayobject.h": ...@@ -27,4 +91,5 @@ cdef extern from "numpy/arrayobject.h":
ctypedef float npy_float96 ctypedef float npy_float96
ctypedef float npy_float128 ctypedef float npy_float128
ctypedef npy_int64 Tint64
ctypedef npy_int64 int64
...@@ -118,32 +118,28 @@ ...@@ -118,32 +118,28 @@
# just to be sure you understand what is going on. # just to be sure you understand what is going on.
# #
################################################################# #################################################################
cdef extern from "Python.h":
ctypedef void PyObject
ctypedef void PyTypeObject
ctypedef struct FILE
include 'python_ref.pxi' from python_ref cimport *
include 'python_exc.pxi' from python_exc cimport *
include 'python_module.pxi' from python_module cimport *
include 'python_mem.pxi' from python_mem cimport *
include 'python_tuple.pxi' from python_tuple cimport *
include 'python_list.pxi' from python_list cimport *
include 'python_object.pxi' from python_object cimport *
include 'python_sequence.pxi' from python_sequence cimport *
include 'python_mapping.pxi' from python_mapping cimport *
include 'python_iterator.pxi' from python_iterator cimport *
include 'python_type.pxi' from python_type cimport *
include 'python_number.pxi' from python_number cimport *
include 'python_int.pxi' from python_int cimport *
include 'python_bool.pxi' from python_bool cimport *
include 'python_long.pxi' from python_long cimport *
include 'python_float.pxi' from python_float cimport *
include 'python_complex.pxi' from python_complex cimport *
include 'python_string.pxi' from python_string cimport *
include 'python_dict.pxi' from python_dict cimport *
include 'python_instance.pxi' from python_instance cimport *
include 'python_function.pxi' from python_function cimport *
include 'python_method.pxi' from python_method cimport *
include 'python_set.pxi' from python_set cimport *
cdef extern from "Python.h": cdef extern from "Python.h":
ctypedef void PyObject
ctypedef void PyTypeObject ctypedef void PyTypeObject
ctypedef struct PyObject:
Py_ssize_t ob_refcnt
PyTypeObject *ob_type
ctypedef struct FILE ctypedef struct FILE
......
...@@ -7,8 +7,7 @@ class StringIOTree(object): ...@@ -7,8 +7,7 @@ class StringIOTree(object):
def __init__(self, stream=None): def __init__(self, stream=None):
self.prepended_children = [] self.prepended_children = []
if stream is None: stream = StringIO() self.stream = stream # if set to None, it will be constructed on first write
self.stream = stream
def getvalue(self): def getvalue(self):
return ("".join([x.getvalue() for x in self.prepended_children]) + return ("".join([x.getvalue() for x in self.prepended_children]) +
...@@ -19,20 +18,44 @@ class StringIOTree(object): ...@@ -19,20 +18,44 @@ class StringIOTree(object):
needs to happen.""" needs to happen."""
for child in self.prepended_children: for child in self.prepended_children:
child.copyto(target) child.copyto(target)
target.write(self.stream.getvalue()) if self.stream:
target.write(self.stream.getvalue())
def write(self, what): def write(self, what):
if not self.stream:
self.stream = StringIO()
self.stream.write(what) self.stream.write(what)
def commit(self):
# Save what we have written until now so that the buffer
# itself is empty -- this makes it ready for insertion
if self.stream:
self.prepended_children.append(StringIOTree(self.stream))
self.stream = None
def insert(self, iotree):
"""
Insert a StringIOTree (and all of its contents) at this location.
Further writing to self appears after what is inserted.
"""
self.commit()
self.prepended_children.append(iotree)
def insertion_point(self): def insertion_point(self):
"""
Returns a new StringIOTree, which is left behind at the current position
(it what is written to the result will appear right before whatever is
next written to self).
Calling getvalue() or copyto() on the result will only return the
contents written to it.
"""
# Save what we have written until now # Save what we have written until now
# (would it be more efficient to check with len(self.stream.getvalue())? # This is so that getvalue on the result doesn't include it.
# leaving it out for now) self.commit()
self.prepended_children.append(StringIOTree(self.stream))
# Construct the new forked object to return # Construct the new forked object to return
other = StringIOTree() other = StringIOTree()
self.prepended_children.append(other) self.prepended_children.append(other)
self.stream = StringIO()
return other return other
__doc__ = r""" __doc__ = r"""
...@@ -57,13 +80,11 @@ EXAMPLE: ...@@ -57,13 +80,11 @@ EXAMPLE:
>>> c.write('beta\n') >>> c.write('beta\n')
>>> b.getvalue().split() >>> b.getvalue().split()
['second', 'alpha', 'beta', 'gamma'] ['second', 'alpha', 'beta', 'gamma']
>>> i = StringIOTree()
>>> d.insert(i)
>>> i.write('inserted\n')
>>> out = StringIO() >>> out = StringIO()
>>> a.copyto(out) >>> a.copyto(out)
>>> out.getvalue().split() >>> out.getvalue().split()
['first', 'second', 'alpha', 'beta', 'gamma', 'third'] ['first', 'second', 'alpha', 'inserted', 'beta', 'gamma', 'third']
""" """
\ No newline at end of file
if __name__ == "__main__":
import doctest
doctest.testmod()
...@@ -46,7 +46,7 @@ class ErrorWriter(object): ...@@ -46,7 +46,7 @@ class ErrorWriter(object):
class TestBuilder(object): class TestBuilder(object):
def __init__(self, rootdir, workdir, selectors, annotate, def __init__(self, rootdir, workdir, selectors, annotate,
cleanup_workdir, cleanup_sharedlibs, with_pyregr): cleanup_workdir, cleanup_sharedlibs, with_pyregr, cythononly):
self.rootdir = rootdir self.rootdir = rootdir
self.workdir = workdir self.workdir = workdir
self.selectors = selectors self.selectors = selectors
...@@ -54,6 +54,7 @@ class TestBuilder(object): ...@@ -54,6 +54,7 @@ class TestBuilder(object):
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs self.cleanup_sharedlibs = cleanup_sharedlibs
self.with_pyregr = with_pyregr self.with_pyregr = with_pyregr
self.cythononly = cythononly
def build_suite(self): def build_suite(self):
suite = unittest.TestSuite() suite = unittest.TestSuite()
...@@ -102,21 +103,23 @@ class TestBuilder(object): ...@@ -102,21 +103,23 @@ class TestBuilder(object):
path, workdir, module, path, workdir, module,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir, cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs) cleanup_sharedlibs=self.cleanup_sharedlibs,
cythononly=self.cythononly)
else: else:
test = CythonCompileTestCase( test = CythonCompileTestCase(
path, workdir, module, path, workdir, module,
expect_errors=expect_errors, expect_errors=expect_errors,
annotate=self.annotate, annotate=self.annotate,
cleanup_workdir=self.cleanup_workdir, cleanup_workdir=self.cleanup_workdir,
cleanup_sharedlibs=self.cleanup_sharedlibs) cleanup_sharedlibs=self.cleanup_sharedlibs,
cythononly=self.cythononly)
suite.addTest(test) suite.addTest(test)
return suite return suite
class CythonCompileTestCase(unittest.TestCase): class CythonCompileTestCase(unittest.TestCase):
def __init__(self, directory, workdir, module, def __init__(self, directory, workdir, module,
expect_errors=False, annotate=False, cleanup_workdir=True, expect_errors=False, annotate=False, cleanup_workdir=True,
cleanup_sharedlibs=True): cleanup_sharedlibs=True, cythononly=False):
self.directory = directory self.directory = directory
self.workdir = workdir self.workdir = workdir
self.module = module self.module = module
...@@ -124,6 +127,7 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -124,6 +127,7 @@ class CythonCompileTestCase(unittest.TestCase):
self.annotate = annotate self.annotate = annotate
self.cleanup_workdir = cleanup_workdir self.cleanup_workdir = cleanup_workdir
self.cleanup_sharedlibs = cleanup_sharedlibs self.cleanup_sharedlibs = cleanup_sharedlibs
self.cythononly = cythononly
unittest.TestCase.__init__(self) unittest.TestCase.__init__(self)
def shortDescription(self): def shortDescription(self):
...@@ -247,7 +251,8 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -247,7 +251,8 @@ class CythonCompileTestCase(unittest.TestCase):
unexpected_error = errors[len(expected_errors)] unexpected_error = errors[len(expected_errors)]
self.assertEquals(None, unexpected_error) self.assertEquals(None, unexpected_error)
else: else:
self.run_distutils(module, workdir, incdir) if not self.cythononly:
self.run_distutils(module, workdir, incdir)
class CythonRunTestCase(CythonCompileTestCase): class CythonRunTestCase(CythonCompileTestCase):
def shortDescription(self): def shortDescription(self):
...@@ -259,8 +264,9 @@ class CythonRunTestCase(CythonCompileTestCase): ...@@ -259,8 +264,9 @@ class CythonRunTestCase(CythonCompileTestCase):
result.startTest(self) result.startTest(self)
try: try:
self.runCompileTest() self.runCompileTest()
sys.stderr.write('running doctests in %s ...\n' % self.module) if not self.cythononly:
doctest.DocTestSuite(self.module).run(result) sys.stderr.write('running doctests in %s ...\n' % self.module)
doctest.DocTestSuite(self.module).run(result)
except Exception: except Exception:
result.addError(self, sys.exc_info()) result.addError(self, sys.exc_info())
result.stopTest(self) result.stopTest(self)
...@@ -372,7 +378,10 @@ if __name__ == '__main__': ...@@ -372,7 +378,10 @@ if __name__ == '__main__':
help="do not run the file based tests") help="do not run the file based tests")
parser.add_option("--no-pyregr", dest="pyregr", parser.add_option("--no-pyregr", dest="pyregr",
action="store_false", default=True, action="store_false", default=True,
help="do not run the regression tests of CPython in tests/pyregr/") help="do not run the regression tests of CPython in tests/pyregr/")
parser.add_option("--cython-only", dest="cythononly",
action="store_true", default=False,
help="only compile pyx to c, do not run C compiler or run the tests")
parser.add_option("--sys-pyregr", dest="system_pyregr", parser.add_option("--sys-pyregr", dest="system_pyregr",
action="store_true", default=False, action="store_true", default=False,
help="run the regression tests of the CPython installation") help="run the regression tests of the CPython installation")
...@@ -445,7 +454,7 @@ if __name__ == '__main__': ...@@ -445,7 +454,7 @@ if __name__ == '__main__':
if options.filetests: if options.filetests:
filetests = TestBuilder(ROOTDIR, WORKDIR, selectors, filetests = TestBuilder(ROOTDIR, WORKDIR, selectors,
options.annotate_source, options.cleanup_workdir, options.annotate_source, options.cleanup_workdir,
options.cleanup_sharedlibs, options.pyregr) options.cleanup_sharedlibs, options.pyregr, options.cythononly)
test_suite.addTest(filetests.build_suite()) test_suite.addTest(filetests.build_suite())
if options.system_pyregr: if options.system_pyregr:
......
cdef extern from *: cdef extern from *:
cdef void foo(int[]) cdef void foo(int[])
...@@ -17,3 +18,8 @@ cdef struct OtherStruct: ...@@ -17,3 +18,8 @@ cdef struct OtherStruct:
a = sizeof(int[23][34]) a = sizeof(int[23][34])
b = sizeof(OtherStruct[43]) b = sizeof(OtherStruct[43])
DEF COUNT = 4
c = sizeof(int[COUNT])
d = sizeof(OtherStruct[COUNT])
#cython: boundscheck=False
print 3
cimport python_dict as asadf, python_exc, cython as cy
@cy.boundscheck(False)
def f(object[int, 2] buf):
print buf[3, 2]
@cy.boundscheck(True)
def g(object[int, 2] buf):
# Please leave this comment,
#cython: this should have no special meaning
# even if the above line doesn't follow indentation.
print buf[3, 2]
def h(object[int, 2] buf):
print buf[3, 2]
with cy.boundscheck(True):
print buf[3,2]
from cython cimport boundscheck as bc
def i(object[int] buf):
with bc(True):
print buf[3]
...@@ -15,11 +15,13 @@ _ERRORS = u""" ...@@ -15,11 +15,13 @@ _ERRORS = u"""
1:11: Buffer types only allowed as function local variables 1:11: Buffer types only allowed as function local variables
3:15: Buffer types only allowed as function local variables 3:15: Buffer types only allowed as function local variables
6:27: "fakeoption" is not a buffer option 6:27: "fakeoption" is not a buffer option
7:22: "ndim" must be non-negative
8:15: "dtype" missing
9:21: "ndim" must be an integer
10:15: Too many buffer options
11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
""" """
#TODO:
#7:22: "ndim" must be non-negative
#8:15: "dtype" missing
#9:21: "ndim" must be an integer
#10:15: Too many buffer options
#11:24: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
#12:28: Only allowed buffer modes are "full" or "strided" (as a compile-time string)
#"""
cimport e_bufaccess_pxd # was needed to provoke a bug involving ErrorType
def f():
cdef object[e_bufaccess_pxd.T] buf
_ERRORS = u"""
3:17: Syntax error in ctypedef statement
4:31: 'T' is not a type identifier
4:31: 'T' is not declared
"""
# See e_bufaccess2.pyx
ctypedef nothing T
#cython: nonexistant
#cython: some=9
# The one below should NOT raise an error
#cython: boundscheck=True
# However this one should
#cython: boundscheck=sadf
print 3
#cython: boundscheck=True
_ERRORS = u"""
2:0: Expected "=" in option "nonexistant"
3:0: Unknown option: "some"
10:0: Must pass a boolean value for option "boundscheck"
"""
cimport e_pxdimpl_imported
_ERRORS = """
6:4: function definition not allowed here
18:4: function definition not allowed here
23:8: function definition not allowed here
"""
cdef class A:
cdef int test(self)
# Should give error:
def somefunc(self):
pass
# While this should *not* be an error...:
def __getbuffer__(self, Py_buffer* info, int flags):
pass
# This neither:
def __releasebuffer__(self, Py_buffer* info):
pass
# Terminate with an error to be sure the compiler is
# not terminating prior to previous errors
def terminate(self):
pass
cdef extern from "foo.h":
cdef class pxdimpl.B [object MyB]:
def otherfunc(self):
pass
/* See bufaccess.pyx */
typedef short htypedef_short;
...@@ -7,14 +7,17 @@ ...@@ -7,14 +7,17 @@
# what we want to test is what is passed into the flags argument. # what we want to test is what is passed into the flags argument.
# #
from __future__ import unicode_literals
cimport stdlib cimport stdlib
cimport python_buffer cimport python_buffer
# Add all test_X function docstrings as unit tests cimport stdio
cimport cython
from python_ref cimport PyObject
__test__ = {} __test__ = {}
setup_string = """ setup_string = u"""
>>> A = IntMockBuffer("A", range(6)) >>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6)) >>> B = IntMockBuffer("B", range(6))
>>> C = IntMockBuffer("C", range(6), (2,3)) >>> C = IntMockBuffer("C", range(6), (2,3))
...@@ -34,6 +37,19 @@ def testcas(a): ...@@ -34,6 +37,19 @@ def testcas(a):
# Buffer acquire and release tests # Buffer acquire and release tests
# #
def nousage():
"""
The challenge here is just compilation.
"""
cdef object[int, 2] buf
def printbuf():
"""
Just compilation.
"""
cdef object[int, 2] buf
print buf
@testcase @testcase
def acquire_release(o1, o2): def acquire_release(o1, o2):
""" """
...@@ -65,6 +81,8 @@ def acquire_raise(o): ...@@ -65,6 +81,8 @@ def acquire_raise(o):
>>> A.printlog() >>> A.printlog()
acquired A acquired A
released A released A
<BLANKLINE>
""" """
cdef object[int] buf cdef object[int] buf
buf = o buf = o
...@@ -203,37 +221,37 @@ def as_argument(object[int] bufarg, int n): ...@@ -203,37 +221,37 @@ def as_argument(object[int] bufarg, int n):
""" """
>>> as_argument(A, 6) >>> as_argument(A, 6)
acquired A acquired A
0 1 2 3 4 5 0 1 2 3 4 5 END
released A released A
""" """
cdef int i cdef int i
for i in range(n): for i in range(n):
print bufarg[i], print bufarg[i],
print print 'END'
@testcase @testcase
def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), int n=6): def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), int n=6):
""" """
>>> as_argument_defval() >>> as_argument_defval()
acquired default acquired default
0 1 2 3 4 5 0 1 2 3 4 5 END
released default released default
>>> as_argument_defval(A, 6) >>> as_argument_defval(A, 6)
acquired A acquired A
0 1 2 3 4 5 0 1 2 3 4 5 END
released A released A
""" """
cdef int i cdef int i
for i in range(n): for i in range(n):
print bufarg[i], print bufarg[i],
print print 'END'
@testcase @testcase
def cdef_assignment(obj, n): def cdef_assignment(obj, n):
""" """
>>> cdef_assignment(A, 6) >>> cdef_assignment(A, 6)
acquired A acquired A
0 1 2 3 4 5 0 1 2 3 4 5 END
released A released A
""" """
...@@ -241,18 +259,23 @@ def cdef_assignment(obj, n): ...@@ -241,18 +259,23 @@ def cdef_assignment(obj, n):
cdef int i cdef int i
for i in range(n): for i in range(n):
print buf[i], print buf[i],
print print 'END'
@testcase @testcase
def forin_assignment(objs, int pick): def forin_assignment(objs, int pick):
""" """
>>> as_argument_defval() >>> forin_assignment([A, B, A, A], 2)
acquired default
0 1 2 3 4 5
released default
>>> as_argument_defval(A, 6)
acquired A acquired A
0 1 2 3 4 5 2
released A
acquired B
2
released B
acquired A
2
released A
acquired A
2
released A released A
""" """
cdef object[int] buf cdef object[int] buf
...@@ -341,7 +364,6 @@ def get_int_2d(object[int, 2] buf, int i, int j): ...@@ -341,7 +364,6 @@ def get_int_2d(object[int, 2] buf, int i, int j):
Traceback (most recent call last): Traceback (most recent call last):
... ...
IndexError: Out of bounds on buffer access (axis 1) IndexError: Out of bounds on buffer access (axis 1)
""" """
return buf[i, j] return buf[i, j]
...@@ -458,7 +480,7 @@ def readonly(obj): ...@@ -458,7 +480,7 @@ def readonly(obj):
acquired R acquired R
25 25
released R released R
>>> R.recieved_flags >>> [str(x) for x in R.recieved_flags] # Works in both py2 and py3
['FORMAT', 'INDIRECT', 'ND', 'STRIDES'] ['FORMAT', 'INDIRECT', 'ND', 'STRIDES']
""" """
cdef object[unsigned short int, 3] buf = obj cdef object[unsigned short int, 3] buf = obj
...@@ -471,7 +493,7 @@ def writable(obj): ...@@ -471,7 +493,7 @@ def writable(obj):
>>> writable(R) >>> writable(R)
acquired R acquired R
released R released R
>>> R.recieved_flags >>> [str(x) for x in R.recieved_flags] # Py2/3
['FORMAT', 'INDIRECT', 'ND', 'STRIDES', 'WRITABLE'] ['FORMAT', 'INDIRECT', 'ND', 'STRIDES', 'WRITABLE']
""" """
cdef object[unsigned short int, 3] buf = obj cdef object[unsigned short int, 3] buf = obj
...@@ -485,12 +507,81 @@ def strided(object[int, 1, 'strided'] buf): ...@@ -485,12 +507,81 @@ def strided(object[int, 1, 'strided'] buf):
acquired A acquired A
released A released A
2 2
>>> A.recieved_flags >>> [str(x) for x in A.recieved_flags] # Py2/3
['FORMAT', 'ND', 'STRIDES'] ['FORMAT', 'ND', 'STRIDES']
Check that the suboffsets were patched back prior to release.
>>> A.release_ok
True
""" """
return buf[2] return buf[2]
#
# Test compiler options for bounds checking. We create an array with a
# safe "boundary" (memory
# allocated outside of what it published) and then check whether we get back
# what we stored in the memory or an error.
@testcase
def safe_get(object[int] buf, int idx):
"""
>>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
Validate our testing buffer...
>>> safe_get(A, 0)
5
>>> safe_get(A, 2)
7
>>> safe_get(A, -3)
5
Access outside it. This is already done above for bounds check
testing but we include it to tell the story right.
>>> safe_get(A, -4)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
>>> safe_get(A, 3)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
"""
return buf[idx]
@testcase
@cython.boundscheck(False)
@cython.boundscheck(True)
def unsafe_get(object[int] buf, int idx):
"""
Access outside of the area the buffer publishes.
>>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
>>> unsafe_get(A, -4)
4
>>> unsafe_get(A, -5)
3
>>> unsafe_get(A, 3)
8
"""
return buf[idx]
@testcase
def mixed_get(object[int] buf, int unsafe_idx, int safe_idx):
"""
>>> A = IntMockBuffer(None, range(10), shape=(3,), offset=5)
>>> mixed_get(A, -4, 0)
(4, 5)
>>> mixed_get(A, 0, -4)
Traceback (most recent call last):
...
IndexError: Out of bounds on buffer access (axis 0)
"""
with cython.boundscheck(False):
one = buf[unsafe_idx]
with cython.boundscheck(True):
two = buf[safe_idx]
return (one, two)
# #
# Coercions # Coercions
# #
...@@ -519,21 +610,21 @@ def printbuf_int_2d(o, shape): ...@@ -519,21 +610,21 @@ def printbuf_int_2d(o, shape):
>>> printbuf_int_2d(IntMockBuffer("A", range(6), (2,3)), (2,3)) >>> printbuf_int_2d(IntMockBuffer("A", range(6), (2,3)), (2,3))
acquired A acquired A
0 1 2 0 1 2 END
3 4 5 3 4 5 END
released A released A
>>> printbuf_int_2d(IntMockBuffer("A", range(100), (3,3), strides=(20,5)), (3,3)) >>> printbuf_int_2d(IntMockBuffer("A", range(100), (3,3), strides=(20,5)), (3,3))
acquired A acquired A
0 5 10 0 5 10 END
20 25 30 20 25 30 END
40 45 50 40 45 50 END
released A released A
Indirect: Indirect:
>>> printbuf_int_2d(IntMockBuffer("A", [[1,2],[3,4]]), (2,2)) >>> printbuf_int_2d(IntMockBuffer("A", [[1,2],[3,4]]), (2,2))
acquired A acquired A
1 2 1 2 END
3 4 3 4 END
released A released A
""" """
# should make shape builtin # should make shape builtin
...@@ -543,14 +634,14 @@ def printbuf_int_2d(o, shape): ...@@ -543,14 +634,14 @@ def printbuf_int_2d(o, shape):
for i in range(shape[0]): for i in range(shape[0]):
for j in range(shape[1]): for j in range(shape[1]):
print buf[i, j], print buf[i, j],
print print 'END'
@testcase @testcase
def printbuf_float(o, shape): def printbuf_float(o, shape):
""" """
>>> printbuf_float(FloatMockBuffer("F", [1.0, 1.25, 0.75, 1.0]), (4,)) >>> printbuf_float(FloatMockBuffer("F", [1.0, 1.25, 0.75, 1.0]), (4,))
acquired F acquired F
1.0 1.25 0.75 1.0 1.0 1.25 0.75 1.0 END
released F released F
""" """
...@@ -560,7 +651,131 @@ def printbuf_float(o, shape): ...@@ -560,7 +651,131 @@ def printbuf_float(o, shape):
cdef int i, j cdef int i, j
for i in range(shape[0]): for i in range(shape[0]):
print buf[i], print buf[i],
print print "END"
#
# Typedefs
#
ctypedef int cytypedef_int
cdef extern from "bufaccess.h":
ctypedef cytypedef_int htypedef_short # Defined as short, but Cython doesn't know this!
ctypedef htypedef_short cytypedef2
@testcase
def printbuf_cytypedef_int(object[cytypedef_int] buf, shape):
"""
>>> printbuf_cytypedef_int(IntMockBuffer("A", range(3)), (3,))
acquired A
0 1 2 END
released A
>>> printbuf_cytypedef_int(ShortMockBuffer("B", range(3)), (3,))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (rejecting on 'h')
"""
cdef int i
for i in range(shape[0]):
print buf[i],
print 'END'
@testcase
def printbuf_htypedef_short(object[htypedef_short] buf, shape):
"""
>>> printbuf_htypedef_short(ShortMockBuffer("A", range(3)), (3,))
acquired A
0 1 2 END
released A
>>> printbuf_htypedef_short(IntMockBuffer("B", range(3)), (3,))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (rejecting on 'i')
"""
cdef int i
for i in range(shape[0]):
print buf[i],
print 'END'
@testcase
def printbuf_cytypedef2(object[cytypedef2] buf, shape):
"""
>>> printbuf_cytypedef2(ShortMockBuffer("A", range(3)), (3,))
acquired A
0 1 2 END
released A
>>> printbuf_cytypedef2(IntMockBuffer("B", range(3)), (3,))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (rejecting on 'i')
"""
cdef int i
for i in range(shape[0]):
print buf[i],
print 'END'
#
# Object access
#
from python_ref cimport Py_INCREF, Py_DECREF
def addref(*args):
for item in args: Py_INCREF(item)
def decref(*args):
for item in args: Py_DECREF(item)
def get_refcount(x):
return (<PyObject*>x).ob_refcnt
@testcase
def printbuf_object(object[object] buf, shape):
"""
Only play with unique objects, interned numbers etc. will have
unpredictable refcounts.
ObjectMockBuffer doesn't do anything about increfing/decrefing,
we to the "buffer implementor" refcounting directly in the
testcase.
>>> a, b, c = "globally_unique_string_23234123", {4:23}, [34,3]
>>> get_refcount(a), get_refcount(b), get_refcount(c)
(2, 2, 2)
>>> A = ObjectMockBuffer(None, [a, b, c])
>>> printbuf_object(A, (3,))
'globally_unique_string_23234123' 2
{4: 23} 2
[34, 3] 2
"""
cdef int i
for i in range(shape[0]):
print repr(buf[i]), (<PyObject*>buf[i]).ob_refcnt
@testcase
def assign_to_object(object[object] buf, int idx, obj):
"""
See comments on printbuf_object above.
>>> a, b = [1, 2, 3], [4, 5, 6]
>>> get_refcount(a), get_refcount(b)
(2, 2)
>>> addref(a)
>>> A = ObjectMockBuffer(None, [1, a]) # 1, ...,otherwise it thinks nested lists...
>>> get_refcount(a), get_refcount(b)
(3, 2)
>>> assign_to_object(A, 1, b)
>>> get_refcount(a), get_refcount(b)
(2, 3)
>>> decref(b)
"""
buf[idx] = obj
#
# Testcase support code (more tests below!, because of scope rules)
#
available_flags = ( available_flags = (
...@@ -571,10 +786,8 @@ available_flags = ( ...@@ -571,10 +786,8 @@ available_flags = (
('WRITABLE', python_buffer.PyBUF_WRITABLE) ('WRITABLE', python_buffer.PyBUF_WRITABLE)
) )
cimport stdio
cdef class MockBuffer: cdef class MockBuffer:
cdef object format cdef object format, offset
cdef void* buffer cdef void* buffer
cdef int len, itemsize, ndim cdef int len, itemsize, ndim
cdef Py_ssize_t* strides cdef Py_ssize_t* strides
...@@ -582,12 +795,16 @@ cdef class MockBuffer: ...@@ -582,12 +795,16 @@ cdef class MockBuffer:
cdef Py_ssize_t* suboffsets cdef Py_ssize_t* suboffsets
cdef object label, log cdef object label, log
cdef readonly object recieved_flags cdef readonly object recieved_flags, release_ok
cdef public object fail cdef public object fail
def __init__(self, label, data, shape=None, strides=None, format=None): def __init__(self, label, data, shape=None, strides=None, format=None, offset=0):
# It is important not to store references to data after the constructor
# as refcounting is checked on object buffers.
self.label = label self.label = label
self.release_ok = True
self.log = "" self.log = ""
self.offset = offset
self.itemsize = self.get_itemsize() self.itemsize = self.get_itemsize()
if format is None: format = self.get_default_format() if format is None: format = self.get_default_format()
if shape is None: shape = (len(data),) if shape is None: shape = (len(data),)
...@@ -680,7 +897,7 @@ cdef class MockBuffer: ...@@ -680,7 +897,7 @@ cdef class MockBuffer:
if (value & flags) == value: if (value & flags) == value:
self.recieved_flags.append(name) self.recieved_flags.append(name)
buffer.buf = self.buffer buffer.buf = <void*>(<char*>self.buffer + (<int>self.offset * self.itemsize))
buffer.len = self.len buffer.len = self.len
buffer.readonly = 0 buffer.readonly = 0
buffer.format = <char*>self.format buffer.format = <char*>self.format
...@@ -690,17 +907,21 @@ cdef class MockBuffer: ...@@ -690,17 +907,21 @@ cdef class MockBuffer:
buffer.suboffsets = self.suboffsets buffer.suboffsets = self.suboffsets
buffer.itemsize = self.itemsize buffer.itemsize = self.itemsize
buffer.internal = NULL buffer.internal = NULL
msg = "acquired %s" % self.label if self.label:
print msg msg = "acquired %s" % self.label
self.log += msg + "\n" print msg
self.log += msg + "\n"
def __releasebuffer__(MockBuffer self, Py_buffer* buffer): def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
msg = "released %s" % self.label if buffer.suboffsets != self.suboffsets:
print msg self.release_ok = False
self.log += msg + "\n" if self.label:
msg = "released %s" % self.label
print msg
self.log += msg + "\n"
def printlog(self): def printlog(self):
print self.log, print self.log
def resetlog(self): def resetlog(self):
self.log = "" self.log = ""
...@@ -716,21 +937,43 @@ cdef class FloatMockBuffer(MockBuffer): ...@@ -716,21 +937,43 @@ cdef class FloatMockBuffer(MockBuffer):
(<float*>buf)[0] = <float>value (<float*>buf)[0] = <float>value
return 0 return 0
cdef get_itemsize(self): return sizeof(float) cdef get_itemsize(self): return sizeof(float)
cdef get_default_format(self): return "=f" cdef get_default_format(self): return b"=f"
cdef class IntMockBuffer(MockBuffer): cdef class IntMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1: cdef int write(self, char* buf, object value) except -1:
(<int*>buf)[0] = <int>value (<int*>buf)[0] = <int>value
return 0 return 0
cdef get_itemsize(self): return sizeof(int) cdef get_itemsize(self): return sizeof(int)
cdef get_default_format(self): return "=i" cdef get_default_format(self): return b"=i"
cdef class ShortMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<short*>buf)[0] = <short>value
return 0
cdef get_itemsize(self): return sizeof(short)
cdef get_default_format(self): return b"h" # Try without endian specifier
cdef class UnsignedShortMockBuffer(MockBuffer): cdef class UnsignedShortMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1: cdef int write(self, char* buf, object value) except -1:
(<unsigned short*>buf)[0] = <unsigned short>value (<unsigned short*>buf)[0] = <unsigned short>value
return 0 return 0
cdef get_itemsize(self): return sizeof(unsigned short) cdef get_itemsize(self): return sizeof(unsigned short)
cdef get_default_format(self): return "=H" cdef get_default_format(self): return b"=1H" # Try with repeat count
cdef extern from *:
void* addr_of_pyobject "(void*)"(object)
cdef class ObjectMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
(<void**>buf)[0] = addr_of_pyobject(value)
return 0
cdef get_itemsize(self): return sizeof(void*)
cdef get_default_format(self): return b"=O"
cdef class IntStridedMockBuffer(IntMockBuffer):
cdef __cythonbufferdefaults__ = {"mode" : "strided"}
cdef class ErrorBuffer: cdef class ErrorBuffer:
cdef object label cdef object label
...@@ -744,3 +987,54 @@ cdef class ErrorBuffer: ...@@ -744,3 +987,54 @@ cdef class ErrorBuffer:
def __releasebuffer__(MockBuffer self, Py_buffer* buffer): def __releasebuffer__(MockBuffer self, Py_buffer* buffer):
raise Exception("releasing %s" % self.label) raise Exception("releasing %s" % self.label)
#
# Typed buffers
#
@testcase
def typedbuffer1(obj):
"""
>>> typedbuffer1(IntMockBuffer("A", range(10)))
acquired A
released A
>>> typedbuffer1(None)
>>> typedbuffer1(4)
Traceback (most recent call last):
...
TypeError: Cannot convert int to bufaccess.IntMockBuffer
"""
cdef IntMockBuffer[int, 1] buf = obj
@testcase
def typedbuffer2(IntMockBuffer[int, 1] obj):
"""
>>> typedbuffer2(IntMockBuffer("A", range(10)))
acquired A
released A
>>> typedbuffer2(None)
>>> typedbuffer2(4)
Traceback (most recent call last):
...
TypeError: Argument 'obj' has incorrect type (expected bufaccess.IntMockBuffer, got int)
"""
pass
#
# Test __cythonbufferdefaults__
#
@testcase
def bufdefaults1(IntStridedMockBuffer[int, 1] buf):
"""
For IntStridedMockBuffer, mode should be
"strided" by defaults which should show
up in the flags.
>>> A = IntStridedMockBuffer("A", range(10))
>>> bufdefaults1(A)
acquired A
released A
>>> [str(x) for x in A.recieved_flags]
['FORMAT', 'ND', 'STRIDES']
"""
pass
cdef class A:
cpdef foo(self, bint a=*, b=*)
__doc__ = """
>>> a = A()
>>> a.foo()
(True, 'yo')
>>> a.foo(False)
(False, 'yo')
>>> a.foo(10, 'yes')
(True, 'yes')
"""
cdef class A:
cpdef foo(self, bint a=True, b="yo"):
return a, b
...@@ -4,8 +4,8 @@ __doc__ = u""" ...@@ -4,8 +4,8 @@ __doc__ = u"""
>>> test_unicode_ascii(2) >>> test_unicode_ascii(2)
u'c' u'c'
>>> test_unicode(2) >>> test_unicode(2) == u'\u00e4'
u'\u00e4' True
>>> test_int_list(2) >>> test_int_list(2)
3 3
......
# cannot be named "numpy" in order to no clash with the numpy module!
cimport numpy
try:
import numpy
__doc__ = """
>>> basic()
[[0 1 2 3 4]
[5 6 7 8 9]]
2 0 9 5
"""
except:
__doc__ = ""
def basic():
cdef object[int, 2] buf = numpy.arange(10).reshape((2, 5))
print buf
print buf[0, 2], buf[0, 0], buf[1, 4], buf[1, 0]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment