Commit 8912ea26 authored by Robert Bradshaw's avatar Robert Bradshaw

merge

parents 7d26e739 6fff2b59
...@@ -416,7 +416,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod ...@@ -416,7 +416,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
params.append(s.cname) params.append(s.cname)
# Make sure the utility code is available # Make sure the utility code is available
code.globalstate.use_generated_code(funcgen, name=funcname, nd=nd) code.globalstate.use_code_from(funcgen, name=funcname, nd=nd)
ptr_type = entry.type.buffer_ptr_type ptr_type = entry.type.buffer_ptr_type
ptrcode = "%s(%s, %s.buf, %s)" % (funcname, ptrcode = "%s(%s, %s.buf, %s)" % (funcname,
...@@ -507,14 +507,14 @@ def mangle_dtype_name(dtype): ...@@ -507,14 +507,14 @@ def mangle_dtype_name(dtype):
def get_ts_check_item(dtype, writer): def get_ts_check_item(dtype, writer):
# See if we can consume one (unnamed) dtype as next item # See if we can consume one (unnamed) dtype as next item
# Put native types and structs in seperate namespaces (as one could create a struct named unsigned_int...) # Put native and custom types in seperate namespaces (as one could create a type named unsigned_int...)
name = "__Pyx_BufferTypestringCheck_item_%s" % mangle_dtype_name(dtype) name = "__Pyx_CheckTypestringItem_%s" % mangle_dtype_name(dtype)
if not writer.globalstate.has_utility_code(name): if not writer.globalstate.has_code(name):
char = dtype.typestring char = dtype.typestring
if char is not None: if char is not None:
assert len(char) == 1
# Can use direct comparison # Can use direct comparison
code = dedent("""\ code = dedent("""\
if (*ts == '1') ++ts;
if (*ts != '%s') { if (*ts != '%s') {
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts); PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts);
return NULL; return NULL;
...@@ -526,7 +526,6 @@ def get_ts_check_item(dtype, writer): ...@@ -526,7 +526,6 @@ def get_ts_check_item(dtype, writer):
ctype = dtype.declaration_code("") ctype = dtype.declaration_code("")
code = dedent("""\ code = dedent("""\
int ok; int ok;
if (*ts == '1') ++ts;
switch (*ts) {""", 2) switch (*ts) {""", 2)
if dtype.is_int: if dtype.is_int:
types = [ types = [
...@@ -536,8 +535,7 @@ def get_ts_check_item(dtype, writer): ...@@ -536,8 +535,7 @@ def get_ts_check_item(dtype, writer):
elif dtype.is_float: elif dtype.is_float:
types = [('f', 'float'), ('d', 'double'), ('g', 'long double')] types = [('f', 'float'), ('d', 'double'), ('g', 'long double')]
else: else:
assert dtype.is_error assert False
return name
if dtype.signed == 0: if dtype.signed == 0:
code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" % code += "".join(["\n case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;" %
(char.upper(), ctype, against, ctype) for char, against in types]) (char.upper(), ctype, against, ctype) for char, against in types])
...@@ -564,6 +562,82 @@ def get_ts_check_item(dtype, writer): ...@@ -564,6 +562,82 @@ def get_ts_check_item(dtype, writer):
return name return name
def create_typestringchecker(protocode, defcode, name, dtype):
if dtype.is_error: return
simple = dtype.is_int or dtype.is_float or dtype.is_pyobject or dtype.is_extension_type or dtype.is_ptr
complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
# Cannot add utility code recursively...
if simple:
itemchecker = get_ts_check_item(dtype, protocode)
else:
dtype_t = dtype.declaration_code("")
protocode.globalstate.use_utility_code(parse_typestring_repeat_code)
fields = dtype.scope.var_entries
# divide fields into blocks of equal type (for repeat count)
field_blocks = [] # of (n, type, checkerfunc)
n = 0
prevtype = None
for f in fields:
if n and f.type != prevtype:
field_blocks.append((n, prevtype, get_ts_check_item(prevtype, protocode)))
n = 0
prevtype = f.type
n += 1
field_blocks.append((n, f.type, get_ts_check_item(f.type, protocode)))
protocode.putln("static const char* %s(const char* ts); /*proto*/" % name)
defcode.putln("static const char* %s(const char* ts) {" % name)
if simple:
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
defcode.putln("if (*ts == '1') ++ts;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % itemchecker)
elif complex_possible:
# Could be a struct representing a complex number, so allow
# for parsing a "Zf" spec.
real_t, imag_t = [x.type for x in fields]
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
defcode.putln("if (*ts == '1') ++ts;")
defcode.putln("if (*ts == 'Z') {")
if len(field_blocks) == 2:
# Different float type, sizeof check needed
defcode.putln("if (sizeof(%s) != sizeof(%s)) {" % (
real_t.declaration_code(""),
imag_t.declaration_code("")))
defcode.putln('PyErr_SetString(PyExc_ValueError, "Cannot store complex number in \'%s\' as \'%s\' differs from \'%s\' in size.");' % (
dtype.declaration_code("", for_display=True),
real_t.declaration_code("", for_display=True),
imag_t.declaration_code("", for_display=True)))
defcode.putln("return NULL;")
defcode.putln("}")
check_real, check_imag = [x[2] for x in field_blocks]
else:
assert len(field_blocks) == 1
check_real = check_imag = field_blocks[0][2]
defcode.putln("ts = %s(ts + 1); if (!ts) return NULL;" % check_real)
defcode.putln("} else {")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_real)
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % check_imag)
defcode.putln("}")
else:
defcode.putln("int n, count;")
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
for n, type, checker in field_blocks:
if n == 1:
defcode.putln("if (*ts == '1') ++ts;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
else:
defcode.putln("n = %d;" % n);
defcode.putln("do {")
defcode.putln("ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;")
defcode.putln("ts = %s(ts); if (!ts) return NULL;" % checker)
defcode.putln("} while (n > 0);");
defcode.putln("ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;")
defcode.putln("return ts;")
defcode.putln("}")
def get_getbuffer_code(dtype, code): def get_getbuffer_code(dtype, code):
""" """
Generate a utility function for getting a buffer for the given dtype. Generate a utility function for getting a buffer for the given dtype.
...@@ -575,9 +649,14 @@ def get_getbuffer_code(dtype, code): ...@@ -575,9 +649,14 @@ def get_getbuffer_code(dtype, code):
""" """
name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype) name = "__Pyx_GetBuffer_%s" % mangle_dtype_name(dtype)
if not code.globalstate.has_utility_code(name): if not code.globalstate.has_code(name):
code.globalstate.use_utility_code(acquire_utility_code) code.globalstate.use_utility_code(acquire_utility_code)
itemchecker = get_ts_check_item(dtype, code) typestringchecker = "__Pyx_CheckTypestring_%s" % mangle_dtype_name(dtype)
code.globalstate.use_code_from(create_typestringchecker,
typestringchecker,
dtype=dtype)
dtype_name = str(dtype)
dtype_cname = dtype.declaration_code("") dtype_cname = dtype.declaration_code("")
utilcode = [dedent(""" utilcode = [dedent("""
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/ static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
...@@ -598,13 +677,13 @@ def get_getbuffer_code(dtype, code): ...@@ -598,13 +677,13 @@ def get_getbuffer_code(dtype, code):
ts = buf->format; ts = buf->format;
ts = __Pyx_ConsumeWhitespace(ts); ts = __Pyx_ConsumeWhitespace(ts);
if (!ts) goto fail; if (!ts) goto fail;
ts = %(itemchecker)s(ts); ts = %(typestringchecker)s(ts);
if (!ts) goto fail; if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts); ts = __Pyx_ConsumeWhitespace(ts);
if (!ts) goto fail; if (!ts) goto fail;
if (*ts != 0) { if (*ts != 0) {
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Expected non-struct buffer data type (expected end, got '%%s')", ts); "Buffer format string specifies more data than '%(dtype_name)s' can hold (expected end, got '%%s')", ts);
goto fail; goto fail;
} }
} else { } else {
...@@ -781,6 +860,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) { ...@@ -781,6 +860,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
"""] """]
parse_typestring_repeat_code = ["""
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
""","""
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
int count;
if (*ts < '0' || *ts > '9') {
count = 1;
} else {
count = *ts++ - '0';
while (*ts >= '0' && *ts < '9') {
count *= 10;
count += *ts++ - '0';
}
}
*out_count = count;
return ts;
}
"""]
raise_buffer_fallback_code = [""" raise_buffer_fallback_code = ["""
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/ static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
""",""" ""","""
......
...@@ -174,6 +174,7 @@ class GlobalState(object): ...@@ -174,6 +174,7 @@ class GlobalState(object):
self.used_utility_code = set() self.used_utility_code = set()
self.declared_cnames = {} self.declared_cnames = {}
self.pystring_table_needed = False self.pystring_table_needed = False
self.in_utility_code_generation = False
self.emit_linenums = emit_linenums self.emit_linenums = emit_linenums
def initwriters(self, rootwriter): def initwriters(self, rootwriter):
...@@ -189,13 +190,12 @@ class GlobalState(object): ...@@ -189,13 +190,12 @@ class GlobalState(object):
self.init_cached_builtins_writer.putln("static int __Pyx_InitCachedBuiltins(void) {") self.init_cached_builtins_writer.putln("static int __Pyx_InitCachedBuiltins(void) {")
self.initwriter.enter_cfunc_scope() self.initwriter.enter_cfunc_scope()
self.initwriter.putln("").putln("static int __Pyx_InitGlobals(void) {") self.initwriter.putln("")
self.initwriter.putln("static int __Pyx_InitGlobals(void) {")
(self.pystring_table self.pystring_table.putln("")
.putln("") self.pystring_table.putln("static __Pyx_StringTabEntry %s[] = {" %
.putln("static __Pyx_StringTabEntry %s[] = {" %
Naming.stringtab_cname) Naming.stringtab_cname)
)
# #
# Global constants, interned objects, etc. # Global constants, interned objects, etc.
...@@ -207,7 +207,8 @@ class GlobalState(object): ...@@ -207,7 +207,8 @@ class GlobalState(object):
# This is called when it is known that no more global declarations will # This is called when it is known that no more global declarations will
# declared (but can be called before or after insert_XXX). # declared (but can be called before or after insert_XXX).
if self.pystring_table_needed: if self.pystring_table_needed:
self.pystring_table.putln("{0, 0, 0, 0, 0, 0}").putln("};") self.pystring_table.putln("{0, 0, 0, 0, 0, 0}")
self.pystring_table.putln("};")
import Nodes import Nodes
self.use_utility_code(Nodes.init_string_tab_utility_code) self.use_utility_code(Nodes.init_string_tab_utility_code)
self.initwriter.putln( self.initwriter.putln(
...@@ -216,21 +217,19 @@ class GlobalState(object): ...@@ -216,21 +217,19 @@ class GlobalState(object):
self.initwriter.error_goto(self.module_pos))) self.initwriter.error_goto(self.module_pos)))
if Options.cache_builtins: if Options.cache_builtins:
(self.init_cached_builtins_writer w = self.init_cached_builtins_writer
.putln("return 0;") w.putln("return 0;")
.put_label(self.init_cached_builtins_writer.error_label) w.put_label(w.error_label)
.putln("return -1;") w.putln("return -1;")
.putln("}") w.putln("}")
.exit_cfunc_scope() w.exit_cfunc_scope()
)
w = self.initwriter
(self.initwriter w.putln("return 0;")
.putln("return 0;") w.put_label(w.error_label)
.put_label(self.initwriter.error_label) w.putln("return -1;")
.putln("return -1;") w.putln("}")
.putln("}") w.exit_cfunc_scope()
.exit_cfunc_scope()
)
def insert_initcode_into(self, code): def insert_initcode_into(self, code):
if self.pystring_table_needed: if self.pystring_table_needed:
...@@ -351,10 +350,10 @@ class GlobalState(object): ...@@ -351,10 +350,10 @@ class GlobalState(object):
self.utilprotowriter.put(proto) self.utilprotowriter.put(proto)
self.utildefwriter.put(_def) self.utildefwriter.put(_def)
def has_utility_code(self, name): def has_code(self, name):
return name in self.used_utility_code return name in self.used_utility_code
def use_generated_code(self, func, name, *args, **kw): def use_code_from(self, func, name, *args, **kw):
""" """
Requests that the utility code that func can generate is used in the C Requests that the utility code that func can generate is used in the C
file. func is called like this: file. func is called like this:
...@@ -525,7 +524,6 @@ class CCodeWriter(object): ...@@ -525,7 +524,6 @@ class CCodeWriter(object):
self.put(code) self.put(code)
self.write("\n"); self.write("\n");
self.bol = 1 self.bol = 1
return self
def emit_marker(self): def emit_marker(self):
self.write("\n"); self.write("\n");
...@@ -533,7 +531,6 @@ class CCodeWriter(object): ...@@ -533,7 +531,6 @@ class CCodeWriter(object):
self.write("/* %s */\n" % self.marker[1]) self.write("/* %s */\n" % self.marker[1])
self.last_marker_line = self.marker[0] self.last_marker_line = self.marker[0]
self.marker = None self.marker = None
return self
def put_safe(self, code): def put_safe(self, code):
# put code, but ignore {} # put code, but ignore {}
...@@ -556,25 +553,20 @@ class CCodeWriter(object): ...@@ -556,25 +553,20 @@ class CCodeWriter(object):
self.level += dl self.level += dl
elif fix_indent: elif fix_indent:
self.level += 1 self.level += 1
return self
def increase_indent(self): def increase_indent(self):
self.level = self.level + 1 self.level = self.level + 1
return self
def decrease_indent(self): def decrease_indent(self):
self.level = self.level - 1 self.level = self.level - 1
return self
def begin_block(self): def begin_block(self):
self.putln("{") self.putln("{")
self.increase_indent() self.increase_indent()
return self
def end_block(self): def end_block(self):
self.decrease_indent() self.decrease_indent()
self.putln("}") self.putln("}")
return self
def indent(self): def indent(self):
self.write(" " * self.level) self.write(" " * self.level)
...@@ -603,12 +595,10 @@ class CCodeWriter(object): ...@@ -603,12 +595,10 @@ class CCodeWriter(object):
def put_label(self, lbl): def put_label(self, lbl):
if lbl in self.funcstate.labels_used: if lbl in self.funcstate.labels_used:
self.putln("%s:;" % lbl) self.putln("%s:;" % lbl)
return self
def put_goto(self, lbl): def put_goto(self, lbl):
self.funcstate.use_label(lbl) self.funcstate.use_label(lbl)
self.putln("goto %s;" % lbl) self.putln("goto %s;" % lbl)
return self
def put_var_declarations(self, entries, static = 0, dll_linkage = None, def put_var_declarations(self, entries, static = 0, dll_linkage = None,
definition = True): definition = True):
......
...@@ -169,6 +169,7 @@ class ExprNode(Node): ...@@ -169,6 +169,7 @@ class ExprNode(Node):
saved_subexpr_nodes = None saved_subexpr_nodes = None
is_temp = 0 is_temp = 0
is_target = 0
def get_child_attrs(self): def get_child_attrs(self):
return self.subexprs return self.subexprs
...@@ -207,10 +208,10 @@ class ExprNode(Node): ...@@ -207,10 +208,10 @@ class ExprNode(Node):
return self.saved_subexpr_nodes return self.saved_subexpr_nodes
def result(self): def result(self):
if self.is_temp: if not self.is_temp or self.is_target:
return self.result_code
else:
return self.calculate_result_code() return self.calculate_result_code()
else: # i.e. self.is_temp:
return self.result_code
def result_as(self, type = None): def result_as(self, type = None):
# Return the result code cast to the specified C type. # Return the result code cast to the specified C type.
...@@ -341,7 +342,7 @@ class ExprNode(Node): ...@@ -341,7 +342,7 @@ class ExprNode(Node):
if debug_temp_alloc: if debug_temp_alloc:
print("%s Allocating target temps" % self) print("%s Allocating target temps" % self)
self.allocate_subexpr_temps(env) self.allocate_subexpr_temps(env)
self.result_code = self.target_code() self.is_target = True
if rhs: if rhs:
rhs.release_temp(env) rhs.release_temp(env)
self.release_subexpr_temps(env) self.release_subexpr_temps(env)
...@@ -436,9 +437,13 @@ class ExprNode(Node): ...@@ -436,9 +437,13 @@ class ExprNode(Node):
# its sub-expressions, and dispose of any # its sub-expressions, and dispose of any
# temporary results of its sub-expressions. # temporary results of its sub-expressions.
self.generate_subexpr_evaluation_code(code) self.generate_subexpr_evaluation_code(code)
self.pre_generate_result_code(code)
self.generate_result_code(code) self.generate_result_code(code)
if self.is_temp: if self.is_temp:
self.generate_subexpr_disposal_code(code) self.generate_subexpr_disposal_code(code)
def pre_generate_result_code(self, code):
pass
def generate_subexpr_evaluation_code(self, code): def generate_subexpr_evaluation_code(self, code):
for node in self.subexpr_nodes(): for node in self.subexpr_nodes():
...@@ -569,6 +574,66 @@ class ExprNode(Node): ...@@ -569,6 +574,66 @@ class ExprNode(Node):
return None return None
class NewTempExprNode(ExprNode):
backwards_compatible_result = None
def result(self):
if self.is_temp:
return self.temp_code
else:
return self.calculate_result_code()
def allocate_target_temps(self, env, rhs):
self.allocate_subexpr_temps(env)
rhs.release_temp(rhs)
self.release_subexpr_temps(env)
def allocate_temps(self, env, result = None):
self.allocate_subexpr_temps(env)
self.backwards_compatible_result = result
if self.is_temp:
self.release_subexpr_temps(env)
def allocate_temp(self, env, result = None):
assert result is None
def release_temp(self, env):
pass
def pre_generate_result_code(self, code):
if self.is_temp:
type = self.type
if not type.is_void:
if type.is_pyobject:
type = PyrexTypes.py_object_type
if self.backwards_compatible_result:
self.temp_code = self.backwards_compatible_result
else:
self.temp_code = code.funcstate.allocate_temp(type)
else:
self.temp_code = None
def generate_disposal_code(self, code):
if self.is_temp:
if self.type.is_pyobject:
code.put_decref_clear(self.result(), self.ctype())
if not self.backwards_compatible_result:
code.funcstate.release_temp(self.temp_code)
else:
self.generate_subexpr_disposal_code(code)
def generate_post_assignment_code(self, code):
if self.is_temp:
if self.type.is_pyobject:
code.putln("%s = 0;" % self.temp_code)
if not self.backwards_compatible_result:
code.funcstate.release_temp(self.temp_code)
else:
self.generate_subexpr_disposal_code(code)
class AtomicExprNode(ExprNode): class AtomicExprNode(ExprNode):
# Abstract base class for expression nodes which have # Abstract base class for expression nodes which have
# no sub-expressions. # no sub-expressions.
...@@ -1463,10 +1528,8 @@ class IndexNode(ExprNode): ...@@ -1463,10 +1528,8 @@ class IndexNode(ExprNode):
self.type = self.base.type.dtype self.type = self.base.type.dtype
self.is_buffer_access = True self.is_buffer_access = True
self.buffer_type = self.base.entry.type self.buffer_type = self.base.entry.type
if getting: if getting and self.type.is_pyobject:
# we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take
self.is_temp = True self.is_temp = True
if setting: if setting:
if not self.base.entry.type.writable: if not self.base.entry.type.writable:
...@@ -1515,10 +1578,10 @@ class IndexNode(ExprNode): ...@@ -1515,10 +1578,10 @@ class IndexNode(ExprNode):
def is_lvalue(self): def is_lvalue(self):
return 1 return 1
def calculate_result_code(self): def calculate_result_code(self):
if self.is_buffer_access: if self.is_buffer_access:
return "<not used>" return "(*%s)" % self.buffer_ptr_code
else: else:
return "(%s[%s])" % ( return "(%s[%s])" % (
self.base.result(), self.index.result()) self.base.result(), self.index.result())
...@@ -1552,12 +1615,10 @@ class IndexNode(ExprNode): ...@@ -1552,12 +1615,10 @@ class IndexNode(ExprNode):
if self.is_buffer_access: if self.is_buffer_access:
if code.globalstate.directives['nonecheck']: if code.globalstate.directives['nonecheck']:
self.put_nonecheck(code) self.put_nonecheck(code)
ptrcode = self.buffer_lookup_code(code) self.buffer_ptr_code = self.buffer_lookup_code(code)
code.putln("%s = *%s;" % ( if self.type.is_pyobject:
self.result(), # is_temp is True, so must pull out value and incref it.
self.buffer_type.buffer_ptr_type.cast_code(ptrcode))) code.putln("%s = *%s;" % (self.result(), self.buffer_ptr_code))
# Must incref the value we pulled out.
if self.buffer_type.dtype.is_pyobject:
code.putln("Py_INCREF((PyObject*)%s);" % self.result()) code.putln("Py_INCREF((PyObject*)%s);" % self.result())
elif self.type.is_pyobject: elif self.type.is_pyobject:
if self.index.type.is_int: if self.index.type.is_int:
...@@ -3380,7 +3441,7 @@ def get_compile_time_binop(node): ...@@ -3380,7 +3441,7 @@ def get_compile_time_binop(node):
% node.operator) % node.operator)
return func return func
class BinopNode(ExprNode): class BinopNode(NewTempExprNode):
# operator string # operator string
# operand1 ExprNode # operand1 ExprNode
# operand2 ExprNode # operand2 ExprNode
...@@ -4377,7 +4438,7 @@ class CloneNode(CoercionNode): ...@@ -4377,7 +4438,7 @@ class CloneNode(CoercionNode):
if hasattr(arg, 'entry'): if hasattr(arg, 'entry'):
self.entry = arg.entry self.entry = arg.entry
def calculate_result_code(self): def result(self):
return self.arg.result() return self.arg.result()
def analyse_types(self, env): def analyse_types(self, env):
...@@ -4397,7 +4458,7 @@ class CloneNode(CoercionNode): ...@@ -4397,7 +4458,7 @@ class CloneNode(CoercionNode):
pass pass
def allocate_temps(self, env): def allocate_temps(self, env):
self.result_code = self.calculate_result_code() pass
def release_temp(self, env): def release_temp(self, env):
pass pass
......
...@@ -101,6 +101,7 @@ class PyrexType(BaseType): ...@@ -101,6 +101,7 @@ class PyrexType(BaseType):
default_value = "" default_value = ""
parsetuple_format = "" parsetuple_format = ""
pymemberdef_typecode = None pymemberdef_typecode = None
typestring = None
def resolve(self): def resolve(self):
# If a typedef, returns the base type. # If a typedef, returns the base type.
...@@ -140,7 +141,6 @@ class PyrexType(BaseType): ...@@ -140,7 +141,6 @@ class PyrexType(BaseType):
# a struct whose attributes are not defined, etc. # a struct whose attributes are not defined, etc.
return 1 return 1
class CTypedefType(BaseType): class CTypedefType(BaseType):
# #
# Pseudo-type defined with a ctypedef statement in a # Pseudo-type defined with a ctypedef statement in a
...@@ -965,6 +965,11 @@ class CStructOrUnionType(CType): ...@@ -965,6 +965,11 @@ class CStructOrUnionType(CType):
def attributes_known(self): def attributes_known(self):
return self.is_complete() return self.is_complete()
def can_be_complex(self):
# Does the struct consist of exactly two floats?
fields = self.scope.var_entries
return len(fields) == 2 and fields[0].type.is_float and fields[1].type.is_float
class CEnumType(CType): class CEnumType(CType):
# name string # name string
......
...@@ -69,20 +69,23 @@ cdef extern from "numpy/arrayobject.h": ...@@ -69,20 +69,23 @@ cdef extern from "numpy/arrayobject.h":
# made available from this pxd file yet. # made available from this pxd file yet.
cdef int t = PyArray_TYPE(self) cdef int t = PyArray_TYPE(self)
cdef char* f = NULL cdef char* f = NULL
if t == NPY_BYTE: f = "b" if t == NPY_BYTE: f = "b"
elif t == NPY_UBYTE: f = "B" elif t == NPY_UBYTE: f = "B"
elif t == NPY_SHORT: f = "h" elif t == NPY_SHORT: f = "h"
elif t == NPY_USHORT: f = "H" elif t == NPY_USHORT: f = "H"
elif t == NPY_INT: f = "i" elif t == NPY_INT: f = "i"
elif t == NPY_UINT: f = "I" elif t == NPY_UINT: f = "I"
elif t == NPY_LONG: f = "l" elif t == NPY_LONG: f = "l"
elif t == NPY_ULONG: f = "L" elif t == NPY_ULONG: f = "L"
elif t == NPY_LONGLONG: f = "q" elif t == NPY_LONGLONG: f = "q"
elif t == NPY_ULONGLONG: f = "Q" elif t == NPY_ULONGLONG: f = "Q"
elif t == NPY_FLOAT: f = "f" elif t == NPY_FLOAT: f = "f"
elif t == NPY_DOUBLE: f = "d" elif t == NPY_DOUBLE: f = "d"
elif t == NPY_LONGDOUBLE: f = "g" elif t == NPY_LONGDOUBLE: f = "g"
elif t == NPY_OBJECT: f = "O" elif t == NPY_CFLOAT: f = "Zf"
elif t == NPY_CDOUBLE: f = "Zd"
elif t == NPY_CLONGDOUBLE: f = "Zg"
elif t == NPY_OBJECT: f = "O"
if f == NULL: if f == NULL:
raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t) raise ValueError("only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)" % t)
......
...@@ -17,16 +17,15 @@ cimport cython ...@@ -17,16 +17,15 @@ cimport cython
from python_ref cimport PyObject from python_ref cimport PyObject
__test__ = {} __test__ = {}
setup_string = u"""
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> E = ErrorBuffer("E")
""" import re
exclude = []#re.compile('object').search]
def testcase(func): def testcase(func):
__test__[func.__name__] = setup_string + func.__doc__ for e in exclude:
if e(func.__name__):
return func
__test__[func.__name__] = func.__doc__
return func return func
def testcas(a): def testcas(a):
...@@ -53,6 +52,8 @@ def printbuf(): ...@@ -53,6 +52,8 @@ def printbuf():
@testcase @testcase
def acquire_release(o1, o2): def acquire_release(o1, o2):
""" """
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> acquire_release(A, B) >>> acquire_release(A, B)
acquired A acquired A
released A released A
...@@ -73,6 +74,7 @@ def acquire_raise(o): ...@@ -73,6 +74,7 @@ def acquire_raise(o):
Apparently, doctest won't handle mixed exceptions and print Apparently, doctest won't handle mixed exceptions and print
stats, so need to circumvent this. stats, so need to circumvent this.
>>> A = IntMockBuffer("A", range(6))
>>> A.resetlog() >>> A.resetlog()
>>> acquire_raise(A) >>> acquire_raise(A)
Traceback (most recent call last): Traceback (most recent call last):
...@@ -218,6 +220,7 @@ def acquire_nonbuffer2(): ...@@ -218,6 +220,7 @@ def acquire_nonbuffer2():
@testcase @testcase
def as_argument(object[int] bufarg, int n): def as_argument(object[int] bufarg, int n):
""" """
>>> A = IntMockBuffer("A", range(6))
>>> as_argument(A, 6) >>> as_argument(A, 6)
acquired A acquired A
0 1 2 3 4 5 END 0 1 2 3 4 5 END
...@@ -235,6 +238,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in ...@@ -235,6 +238,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in
acquired default acquired default
0 1 2 3 4 5 END 0 1 2 3 4 5 END
released default released default
>>> A = IntMockBuffer("A", range(6))
>>> as_argument_defval(A, 6) >>> as_argument_defval(A, 6)
acquired A acquired A
0 1 2 3 4 5 END 0 1 2 3 4 5 END
...@@ -248,6 +252,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in ...@@ -248,6 +252,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in
@testcase @testcase
def cdef_assignment(obj, n): def cdef_assignment(obj, n):
""" """
>>> A = IntMockBuffer("A", range(6))
>>> cdef_assignment(A, 6) >>> cdef_assignment(A, 6)
acquired A acquired A
0 1 2 3 4 5 END 0 1 2 3 4 5 END
...@@ -263,6 +268,8 @@ def cdef_assignment(obj, n): ...@@ -263,6 +268,8 @@ def cdef_assignment(obj, n):
@testcase @testcase
def forin_assignment(objs, int pick): def forin_assignment(objs, int pick):
""" """
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> forin_assignment([A, B, A, A], 2) >>> forin_assignment([A, B, A, A], 2)
acquired A acquired A
2 2
...@@ -284,6 +291,7 @@ def forin_assignment(objs, int pick): ...@@ -284,6 +291,7 @@ def forin_assignment(objs, int pick):
@testcase @testcase
def cascaded_buffer_assignment(obj): def cascaded_buffer_assignment(obj):
""" """
>>> A = IntMockBuffer("A", range(6))
>>> cascaded_buffer_assignment(A) >>> cascaded_buffer_assignment(A)
acquired A acquired A
acquired A acquired A
...@@ -296,6 +304,8 @@ def cascaded_buffer_assignment(obj): ...@@ -296,6 +304,8 @@ def cascaded_buffer_assignment(obj):
@testcase @testcase
def tuple_buffer_assignment1(a, b): def tuple_buffer_assignment1(a, b):
""" """
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> tuple_buffer_assignment1(A, B) >>> tuple_buffer_assignment1(A, B)
acquired A acquired A
acquired B acquired B
...@@ -308,6 +318,8 @@ def tuple_buffer_assignment1(a, b): ...@@ -308,6 +318,8 @@ def tuple_buffer_assignment1(a, b):
@testcase @testcase
def tuple_buffer_assignment2(tup): def tuple_buffer_assignment2(tup):
""" """
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> tuple_buffer_assignment2((A, B)) >>> tuple_buffer_assignment2((A, B))
acquired A acquired A
acquired B acquired B
...@@ -358,12 +370,27 @@ def alignment_string(object[int] buf): ...@@ -358,12 +370,27 @@ def alignment_string(object[int] buf):
""" """
print buf[1] print buf[1]
@testcase
def wrong_string(object[int] buf):
"""
>>> wrong_string(IntMockBuffer(None, [1,2], format="iasdf"))
Traceback (most recent call last):
...
ValueError: Buffer format string specifies more data than 'int' can hold (expected end, got 'asdf')
>>> wrong_string(IntMockBuffer(None, [1,2], format="$$"))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (expected 'i', got '$$')
"""
print buf[1]
# #
# Getting items and index bounds checking # Getting items and index bounds checking
# #
@testcase @testcase
def get_int_2d(object[int, ndim=2] buf, int i, int j): def get_int_2d(object[int, ndim=2] buf, int i, int j):
""" """
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> get_int_2d(C, 1, 1) >>> get_int_2d(C, 1, 1)
acquired C acquired C
released C released C
...@@ -399,6 +426,7 @@ def get_int_2d(object[int, ndim=2] buf, int i, int j): ...@@ -399,6 +426,7 @@ def get_int_2d(object[int, ndim=2] buf, int i, int j):
def get_int_2d_uintindex(object[int, ndim=2] buf, unsigned int i, unsigned int j): def get_int_2d_uintindex(object[int, ndim=2] buf, unsigned int i, unsigned int j):
""" """
Unsigned indexing: Unsigned indexing:
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> get_int_2d_uintindex(C, 0, 0) >>> get_int_2d_uintindex(C, 0, 0)
acquired C acquired C
released C released C
...@@ -418,6 +446,7 @@ def set_int_2d(object[int, ndim=2] buf, int i, int j, int value): ...@@ -418,6 +446,7 @@ def set_int_2d(object[int, ndim=2] buf, int i, int j, int value):
Uses get_int_2d to read back the value afterwards. For pure Uses get_int_2d to read back the value afterwards. For pure
unit test, one should support reading in MockBuffer instead. unit test, one should support reading in MockBuffer instead.
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> set_int_2d(C, 1, 1, 10) >>> set_int_2d(C, 1, 1, 10)
acquired C acquired C
released C released C
...@@ -1175,7 +1204,6 @@ cdef class DoubleMockBuffer(MockBuffer): ...@@ -1175,7 +1204,6 @@ cdef class DoubleMockBuffer(MockBuffer):
cdef get_itemsize(self): return sizeof(double) cdef get_itemsize(self): return sizeof(double)
cdef get_default_format(self): return b"d" cdef get_default_format(self): return b"d"
cdef extern from *: cdef extern from *:
void* addr_of_pyobject "(void*)"(object) void* addr_of_pyobject "(void*)"(object)
...@@ -1254,3 +1282,86 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf): ...@@ -1254,3 +1282,86 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf):
pass pass
#
# Structs
#
cdef struct MyStruct:
char a
char b
long long int c
int d
int e
cdef class MyStructMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
cdef MyStruct* s
s = <MyStruct*>buf;
s.a, s.b, s.c, s.d, s.e = value
return 0
cdef get_itemsize(self): return sizeof(MyStruct)
cdef get_default_format(self): return b"2bq2i"
@testcase
def basic_struct(object[MyStruct] buf):
"""
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
1 2 3 4 5
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="bbqii"))
1 2 3 4 5
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="i"))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (expected 'b', got 'i')
"""
print buf[0].a, buf[0].b, buf[0].c, buf[0].d, buf[0].e
cdef struct LongComplex:
long double real
long double imag
cdef struct MixedComplex:
long double real
float imag
cdef class LongComplexMockBuffer(MockBuffer):
cdef int write(self, char* buf, object value) except -1:
cdef LongComplex* s
s = <LongComplex*>buf;
s.real, s.imag = value
return 0
cdef get_itemsize(self): return sizeof(LongComplex)
cdef get_default_format(self): return b"Zg"
@testcase
def complex_struct_dtype(object[LongComplex] buf):
"""
Note that the format string is "Zg" rather than "2g"...
>>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
0.0 -1.0
"""
print buf[0].real, buf[0].imag
@testcase
def mixed_complex_struct_dtype(object[MixedComplex] buf):
"""
Triggering a specific execution path for this case.
>>> mixed_complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
Traceback (most recent call last):
...
ValueError: Cannot store complex number in 'MixedComplex' as 'long double' differs from 'float' in size.
"""
print buf[0].real, buf[0].imag
@testcase
def complex_struct_inplace(object[LongComplex] buf):
"""
>>> complex_struct_inplace(LongComplexMockBuffer(None, [(0, -1)]))
1.0 1.0
"""
buf[0].real += 1
buf[0].imag += 2
print buf[0].real, buf[0].imag
...@@ -115,6 +115,9 @@ try: ...@@ -115,6 +115,9 @@ try:
>>> test_dtype('d', inc1_double) >>> test_dtype('d', inc1_double)
>>> test_dtype('g', inc1_longdouble) >>> test_dtype('g', inc1_longdouble)
>>> test_dtype('O', inc1_object) >>> test_dtype('O', inc1_object)
>>> test_dtype('F', inc1_cfloat) # numpy format codes differ from buffer ones here
>>> test_dtype('D', inc1_cdouble)
>>> test_dtype('G', inc1_clongdouble)
>>> test_dtype(np.int, inc1_int_t) >>> test_dtype(np.int, inc1_int_t)
>>> test_dtype(np.long, inc1_long_t) >>> test_dtype(np.long, inc1_long_t)
...@@ -127,11 +130,6 @@ try: ...@@ -127,11 +130,6 @@ try:
>>> test_dtype(np.float64, inc1_float64_t) >>> test_dtype(np.float64, inc1_float64_t)
Unsupported types: Unsupported types:
>>> test_dtype(np.complex, inc1_byte)
Traceback (most recent call last):
...
ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 15)
>>> a = np.zeros((10,), dtype=np.dtype('i4,i4')) >>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
>>> inc1_byte(a) >>> inc1_byte(a)
Traceback (most recent call last): Traceback (most recent call last):
...@@ -194,7 +192,19 @@ def test_f_contig(np.ndarray[int, ndim=2, mode='fortran'] arr): ...@@ -194,7 +192,19 @@ def test_f_contig(np.ndarray[int, ndim=2, mode='fortran'] arr):
for i in range(arr.shape[0]): for i in range(arr.shape[0]):
print " ".join([str(arr[i, j]) for j in range(arr.shape[1])]) print " ".join([str(arr[i, j]) for j in range(arr.shape[1])])
# Exhaustive dtype tests -- increments element [1] by 1 for all dtypes cdef struct cfloat:
float real
float imag
cdef struct cdouble:
double real
double imag
cdef struct clongdouble:
long double real
long double imag
# Exhaustive dtype tests -- increments element [1] by 1 (or 1+1j) for all dtypes
def inc1_byte(np.ndarray[char] arr): arr[1] += 1 def inc1_byte(np.ndarray[char] arr): arr[1] += 1
def inc1_ubyte(np.ndarray[unsigned char] arr): arr[1] += 1 def inc1_ubyte(np.ndarray[unsigned char] arr): arr[1] += 1
def inc1_short(np.ndarray[short] arr): arr[1] += 1 def inc1_short(np.ndarray[short] arr): arr[1] += 1
...@@ -210,6 +220,19 @@ def inc1_float(np.ndarray[float] arr): arr[1] += 1 ...@@ -210,6 +220,19 @@ def inc1_float(np.ndarray[float] arr): arr[1] += 1
def inc1_double(np.ndarray[double] arr): arr[1] += 1 def inc1_double(np.ndarray[double] arr): arr[1] += 1
def inc1_longdouble(np.ndarray[long double] arr): arr[1] += 1 def inc1_longdouble(np.ndarray[long double] arr): arr[1] += 1
def inc1_cfloat(np.ndarray[cfloat] arr):
arr[1].real += 1
arr[1].imag += 1
def inc1_cdouble(np.ndarray[cdouble] arr):
arr[1].real += 1
arr[1].imag += 1
def inc1_clongdouble(np.ndarray[clongdouble] arr):
cdef long double x
x = arr[1].real + 1
arr[1].real = x
arr[1].imag = arr[1].imag + 1
def inc1_object(np.ndarray[object] arr): def inc1_object(np.ndarray[object] arr):
o = arr[1] o = arr[1]
...@@ -229,10 +252,14 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr): arr[1] += 1 ...@@ -229,10 +252,14 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr): arr[1] += 1
def test_dtype(dtype, inc1): def test_dtype(dtype, inc1):
a = np.array([0, 10], dtype=dtype) if dtype in ('F', 'D', 'G'):
inc1(a) a = np.array([0, 10+10j], dtype=dtype)
if a[1] != 11: print "failed!" inc1(a)
if a[1] != (11 + 11j): print "failed!", a[1]
else:
a = np.array([0, 10], dtype=dtype)
inc1(a)
if a[1] != 11: print "failed!"
def test_good_cast(): def test_good_cast():
# Check that a signed int can round-trip through casted unsigned int access # Check that a signed int can round-trip through casted unsigned int access
...@@ -243,4 +270,3 @@ def test_good_cast(): ...@@ -243,4 +270,3 @@ def test_good_cast():
def test_bad_cast(): def test_bad_cast():
# This should raise an exception # This should raise an exception
cdef np.ndarray[long, cast=True] arr = np.array([1], dtype='b') cdef np.ndarray[long, cast=True] arr = np.array([1], dtype='b')
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment