Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
C
cython
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
Analytics
Analytics
Repository
Value Stream
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
Gwenaël Samain
cython
Commits
8912ea26
Commit
8912ea26
authored
Oct 07, 2008
by
Robert Bradshaw
Browse files
Options
Browse Files
Download
Plain Diff
merge
parents
7d26e739
6fff2b59
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
392 additions
and
97 deletions
+392
-97
Cython/Compiler/Buffer.py
Cython/Compiler/Buffer.py
+111
-12
Cython/Compiler/Code.py
Cython/Compiler/Code.py
+22
-32
Cython/Compiler/ExprNodes.py
Cython/Compiler/ExprNodes.py
+80
-19
Cython/Compiler/PyrexTypes.py
Cython/Compiler/PyrexTypes.py
+6
-1
Cython/Includes/numpy.pxd
Cython/Includes/numpy.pxd
+17
-14
tests/run/bufaccess.pyx
tests/run/bufaccess.pyx
+119
-8
tests/run/numpy_test.pyx
tests/run/numpy_test.pyx
+37
-11
No files found.
Cython/Compiler/Buffer.py
View file @
8912ea26
...
@@ -416,7 +416,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
...
@@ -416,7 +416,7 @@ def put_buffer_lookup_code(entry, index_signeds, index_cnames, options, pos, cod
params
.
append
(
s
.
cname
)
params
.
append
(
s
.
cname
)
# Make sure the utility code is available
# Make sure the utility code is available
code
.
globalstate
.
use_
generated_code
(
funcgen
,
name
=
funcname
,
nd
=
nd
)
code
.
globalstate
.
use_
code_from
(
funcgen
,
name
=
funcname
,
nd
=
nd
)
ptr_type
=
entry
.
type
.
buffer_ptr_type
ptr_type
=
entry
.
type
.
buffer_ptr_type
ptrcode
=
"%s(%s, %s.buf, %s)"
%
(
funcname
,
ptrcode
=
"%s(%s, %s.buf, %s)"
%
(
funcname
,
...
@@ -507,14 +507,14 @@ def mangle_dtype_name(dtype):
...
@@ -507,14 +507,14 @@ def mangle_dtype_name(dtype):
def
get_ts_check_item
(
dtype
,
writer
):
def
get_ts_check_item
(
dtype
,
writer
):
# See if we can consume one (unnamed) dtype as next item
# See if we can consume one (unnamed) dtype as next item
# Put native
types and structs in seperate namespaces (as one could create a struct
named unsigned_int...)
# Put native
and custom types in seperate namespaces (as one could create a type
named unsigned_int...)
name
=
"__Pyx_
BufferTypestringCheck_i
tem_%s"
%
mangle_dtype_name
(
dtype
)
name
=
"__Pyx_
CheckTypestringI
tem_%s"
%
mangle_dtype_name
(
dtype
)
if
not
writer
.
globalstate
.
has_
utility_
code
(
name
):
if
not
writer
.
globalstate
.
has_code
(
name
):
char
=
dtype
.
typestring
char
=
dtype
.
typestring
if
char
is
not
None
:
if
char
is
not
None
:
assert
len
(
char
)
==
1
# Can use direct comparison
# Can use direct comparison
code
=
dedent
(
"""
\
code
=
dedent
(
"""
\
if (*ts == '1') ++ts;
if (*ts != '%s') {
if (*ts != '%s') {
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts);
PyErr_Format(PyExc_ValueError, "Buffer datatype mismatch (expected '%s', got '%%s')", ts);
return NULL;
return NULL;
...
@@ -526,7 +526,6 @@ def get_ts_check_item(dtype, writer):
...
@@ -526,7 +526,6 @@ def get_ts_check_item(dtype, writer):
ctype
=
dtype
.
declaration_code
(
""
)
ctype
=
dtype
.
declaration_code
(
""
)
code
=
dedent
(
"""
\
code
=
dedent
(
"""
\
int ok;
int ok;
if (*ts == '1') ++ts;
switch (*ts) {"""
,
2
)
switch (*ts) {"""
,
2
)
if
dtype
.
is_int
:
if
dtype
.
is_int
:
types
=
[
types
=
[
...
@@ -536,8 +535,7 @@ def get_ts_check_item(dtype, writer):
...
@@ -536,8 +535,7 @@ def get_ts_check_item(dtype, writer):
elif
dtype
.
is_float
:
elif
dtype
.
is_float
:
types
=
[(
'f'
,
'float'
),
(
'd'
,
'double'
),
(
'g'
,
'long double'
)]
types
=
[(
'f'
,
'float'
),
(
'd'
,
'double'
),
(
'g'
,
'long double'
)]
else
:
else
:
assert
dtype
.
is_error
assert
False
return
name
if
dtype
.
signed
==
0
:
if
dtype
.
signed
==
0
:
code
+=
""
.
join
([
"
\
n
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;"
%
code
+=
""
.
join
([
"
\
n
case '%s': ok = (sizeof(%s) == sizeof(%s) && (%s)-1 > 0); break;"
%
(
char
.
upper
(),
ctype
,
against
,
ctype
)
for
char
,
against
in
types
])
(
char
.
upper
(),
ctype
,
against
,
ctype
)
for
char
,
against
in
types
])
...
@@ -564,6 +562,82 @@ def get_ts_check_item(dtype, writer):
...
@@ -564,6 +562,82 @@ def get_ts_check_item(dtype, writer):
return
name
return
name
def
create_typestringchecker
(
protocode
,
defcode
,
name
,
dtype
):
if
dtype
.
is_error
:
return
simple
=
dtype
.
is_int
or
dtype
.
is_float
or
dtype
.
is_pyobject
or
dtype
.
is_extension_type
or
dtype
.
is_ptr
complex_possible
=
dtype
.
is_struct_or_union
and
dtype
.
can_be_complex
()
# Cannot add utility code recursively...
if
simple
:
itemchecker
=
get_ts_check_item
(
dtype
,
protocode
)
else
:
dtype_t
=
dtype
.
declaration_code
(
""
)
protocode
.
globalstate
.
use_utility_code
(
parse_typestring_repeat_code
)
fields
=
dtype
.
scope
.
var_entries
# divide fields into blocks of equal type (for repeat count)
field_blocks
=
[]
# of (n, type, checkerfunc)
n
=
0
prevtype
=
None
for
f
in
fields
:
if
n
and
f
.
type
!=
prevtype
:
field_blocks
.
append
((
n
,
prevtype
,
get_ts_check_item
(
prevtype
,
protocode
)))
n
=
0
prevtype
=
f
.
type
n
+=
1
field_blocks
.
append
((
n
,
f
.
type
,
get_ts_check_item
(
f
.
type
,
protocode
)))
protocode
.
putln
(
"static const char* %s(const char* ts); /*proto*/"
%
name
)
defcode
.
putln
(
"static const char* %s(const char* ts) {"
%
name
)
if
simple
:
defcode
.
putln
(
"ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;"
)
defcode
.
putln
(
"if (*ts == '1') ++ts;"
)
defcode
.
putln
(
"ts = %s(ts); if (!ts) return NULL;"
%
itemchecker
)
elif
complex_possible
:
# Could be a struct representing a complex number, so allow
# for parsing a "Zf" spec.
real_t
,
imag_t
=
[
x
.
type
for
x
in
fields
]
defcode
.
putln
(
"ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;"
)
defcode
.
putln
(
"if (*ts == '1') ++ts;"
)
defcode
.
putln
(
"if (*ts == 'Z') {"
)
if
len
(
field_blocks
)
==
2
:
# Different float type, sizeof check needed
defcode
.
putln
(
"if (sizeof(%s) != sizeof(%s)) {"
%
(
real_t
.
declaration_code
(
""
),
imag_t
.
declaration_code
(
""
)))
defcode
.
putln
(
'PyErr_SetString(PyExc_ValueError, "Cannot store complex number in
\
'
%s
\
'
as
\
'
%s
\
'
differs from
\
'
%s
\
'
in size.");'
%
(
dtype
.
declaration_code
(
""
,
for_display
=
True
),
real_t
.
declaration_code
(
""
,
for_display
=
True
),
imag_t
.
declaration_code
(
""
,
for_display
=
True
)))
defcode
.
putln
(
"return NULL;"
)
defcode
.
putln
(
"}"
)
check_real
,
check_imag
=
[
x
[
2
]
for
x
in
field_blocks
]
else
:
assert
len
(
field_blocks
)
==
1
check_real
=
check_imag
=
field_blocks
[
0
][
2
]
defcode
.
putln
(
"ts = %s(ts + 1); if (!ts) return NULL;"
%
check_real
)
defcode
.
putln
(
"} else {"
)
defcode
.
putln
(
"ts = %s(ts); if (!ts) return NULL;"
%
check_real
)
defcode
.
putln
(
"ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;"
)
defcode
.
putln
(
"ts = %s(ts); if (!ts) return NULL;"
%
check_imag
)
defcode
.
putln
(
"}"
)
else
:
defcode
.
putln
(
"int n, count;"
)
defcode
.
putln
(
"ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;"
)
for
n
,
type
,
checker
in
field_blocks
:
if
n
==
1
:
defcode
.
putln
(
"if (*ts == '1') ++ts;"
)
defcode
.
putln
(
"ts = %s(ts); if (!ts) return NULL;"
%
checker
)
else
:
defcode
.
putln
(
"n = %d;"
%
n
);
defcode
.
putln
(
"do {"
)
defcode
.
putln
(
"ts = __Pyx_ParseTypestringRepeat(ts, &count); n -= count;"
)
defcode
.
putln
(
"ts = %s(ts); if (!ts) return NULL;"
%
checker
)
defcode
.
putln
(
"} while (n > 0);"
);
defcode
.
putln
(
"ts = __Pyx_ConsumeWhitespace(ts); if (!ts) return NULL;"
)
defcode
.
putln
(
"return ts;"
)
defcode
.
putln
(
"}"
)
def
get_getbuffer_code
(
dtype
,
code
):
def
get_getbuffer_code
(
dtype
,
code
):
"""
"""
Generate a utility function for getting a buffer for the given dtype.
Generate a utility function for getting a buffer for the given dtype.
...
@@ -575,9 +649,14 @@ def get_getbuffer_code(dtype, code):
...
@@ -575,9 +649,14 @@ def get_getbuffer_code(dtype, code):
"""
"""
name
=
"__Pyx_GetBuffer_%s"
%
mangle_dtype_name
(
dtype
)
name
=
"__Pyx_GetBuffer_%s"
%
mangle_dtype_name
(
dtype
)
if
not
code
.
globalstate
.
has_
utility_
code
(
name
):
if
not
code
.
globalstate
.
has_code
(
name
):
code
.
globalstate
.
use_utility_code
(
acquire_utility_code
)
code
.
globalstate
.
use_utility_code
(
acquire_utility_code
)
itemchecker
=
get_ts_check_item
(
dtype
,
code
)
typestringchecker
=
"__Pyx_CheckTypestring_%s"
%
mangle_dtype_name
(
dtype
)
code
.
globalstate
.
use_code_from
(
create_typestringchecker
,
typestringchecker
,
dtype
=
dtype
)
dtype_name
=
str
(
dtype
)
dtype_cname
=
dtype
.
declaration_code
(
""
)
dtype_cname
=
dtype
.
declaration_code
(
""
)
utilcode
=
[
dedent
(
"""
utilcode
=
[
dedent
(
"""
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
static int %s(PyObject* obj, Py_buffer* buf, int flags, int nd, int cast); /*proto*/
...
@@ -598,13 +677,13 @@ def get_getbuffer_code(dtype, code):
...
@@ -598,13 +677,13 @@ def get_getbuffer_code(dtype, code):
ts = buf->format;
ts = buf->format;
ts = __Pyx_ConsumeWhitespace(ts);
ts = __Pyx_ConsumeWhitespace(ts);
if (!ts) goto fail;
if (!ts) goto fail;
ts = %(
item
checker)s(ts);
ts = %(
typestring
checker)s(ts);
if (!ts) goto fail;
if (!ts) goto fail;
ts = __Pyx_ConsumeWhitespace(ts);
ts = __Pyx_ConsumeWhitespace(ts);
if (!ts) goto fail;
if (!ts) goto fail;
if (*ts != 0) {
if (*ts != 0) {
PyErr_Format(PyExc_ValueError,
PyErr_Format(PyExc_ValueError,
"
Expected non-struct buffer data type
(expected end, got '%%s')", ts);
"
Buffer format string specifies more data than '%(dtype_name)s' can hold
(expected end, got '%%s')", ts);
goto fail;
goto fail;
}
}
} else {
} else {
...
@@ -781,6 +860,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
...
@@ -781,6 +860,26 @@ static void __Pyx_BufferNdimError(Py_buffer* buffer, int expected_ndim) {
"""
]
"""
]
parse_typestring_repeat_code
=
[
"""
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count); /*proto*/
"""
,
"""
static INLINE const char* __Pyx_ParseTypestringRepeat(const char* ts, int* out_count) {
int count;
if (*ts < '0' || *ts > '9') {
count = 1;
} else {
count = *ts++ - '0';
while (*ts >= '0' && *ts < '9') {
count *= 10;
count += *ts++ - '0';
}
}
*out_count = count;
return ts;
}
"""
]
raise_buffer_fallback_code
=
[
"""
raise_buffer_fallback_code
=
[
"""
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
static void __Pyx_RaiseBufferFallbackError(void); /*proto*/
"""
,
"""
"""
,
"""
...
...
Cython/Compiler/Code.py
View file @
8912ea26
...
@@ -174,6 +174,7 @@ class GlobalState(object):
...
@@ -174,6 +174,7 @@ class GlobalState(object):
self
.
used_utility_code
=
set
()
self
.
used_utility_code
=
set
()
self
.
declared_cnames
=
{}
self
.
declared_cnames
=
{}
self
.
pystring_table_needed
=
False
self
.
pystring_table_needed
=
False
self
.
in_utility_code_generation
=
False
self
.
emit_linenums
=
emit_linenums
self
.
emit_linenums
=
emit_linenums
def
initwriters
(
self
,
rootwriter
):
def
initwriters
(
self
,
rootwriter
):
...
@@ -189,13 +190,12 @@ class GlobalState(object):
...
@@ -189,13 +190,12 @@ class GlobalState(object):
self
.
init_cached_builtins_writer
.
putln
(
"static int __Pyx_InitCachedBuiltins(void) {"
)
self
.
init_cached_builtins_writer
.
putln
(
"static int __Pyx_InitCachedBuiltins(void) {"
)
self
.
initwriter
.
enter_cfunc_scope
()
self
.
initwriter
.
enter_cfunc_scope
()
self
.
initwriter
.
putln
(
""
).
putln
(
"static int __Pyx_InitGlobals(void) {"
)
self
.
initwriter
.
putln
(
""
)
self
.
initwriter
.
putln
(
"static int __Pyx_InitGlobals(void) {"
)
(
self
.
pystring_table
self
.
pystring_table
.
putln
(
""
)
.
putln
(
""
)
self
.
pystring_table
.
putln
(
"static __Pyx_StringTabEntry %s[] = {"
%
.
putln
(
"static __Pyx_StringTabEntry %s[] = {"
%
Naming
.
stringtab_cname
)
Naming
.
stringtab_cname
)
)
#
#
# Global constants, interned objects, etc.
# Global constants, interned objects, etc.
...
@@ -207,7 +207,8 @@ class GlobalState(object):
...
@@ -207,7 +207,8 @@ class GlobalState(object):
# This is called when it is known that no more global declarations will
# This is called when it is known that no more global declarations will
# declared (but can be called before or after insert_XXX).
# declared (but can be called before or after insert_XXX).
if
self
.
pystring_table_needed
:
if
self
.
pystring_table_needed
:
self
.
pystring_table
.
putln
(
"{0, 0, 0, 0, 0, 0}"
).
putln
(
"};"
)
self
.
pystring_table
.
putln
(
"{0, 0, 0, 0, 0, 0}"
)
self
.
pystring_table
.
putln
(
"};"
)
import
Nodes
import
Nodes
self
.
use_utility_code
(
Nodes
.
init_string_tab_utility_code
)
self
.
use_utility_code
(
Nodes
.
init_string_tab_utility_code
)
self
.
initwriter
.
putln
(
self
.
initwriter
.
putln
(
...
@@ -216,21 +217,19 @@ class GlobalState(object):
...
@@ -216,21 +217,19 @@ class GlobalState(object):
self
.
initwriter
.
error_goto
(
self
.
module_pos
)))
self
.
initwriter
.
error_goto
(
self
.
module_pos
)))
if
Options
.
cache_builtins
:
if
Options
.
cache_builtins
:
(
self
.
init_cached_builtins_writer
w
=
self
.
init_cached_builtins_writer
.
putln
(
"return 0;"
)
w
.
putln
(
"return 0;"
)
.
put_label
(
self
.
init_cached_builtins_writer
.
error_label
)
w
.
put_label
(
w
.
error_label
)
.
putln
(
"return -1;"
)
w
.
putln
(
"return -1;"
)
.
putln
(
"}"
)
w
.
putln
(
"}"
)
.
exit_cfunc_scope
()
w
.
exit_cfunc_scope
()
)
w
=
self
.
initwriter
(
self
.
initwriter
w
.
putln
(
"return 0;"
)
.
putln
(
"return 0;"
)
w
.
put_label
(
w
.
error_label
)
.
put_label
(
self
.
initwriter
.
error_label
)
w
.
putln
(
"return -1;"
)
.
putln
(
"return -1;"
)
w
.
putln
(
"}"
)
.
putln
(
"}"
)
w
.
exit_cfunc_scope
()
.
exit_cfunc_scope
()
)
def
insert_initcode_into
(
self
,
code
):
def
insert_initcode_into
(
self
,
code
):
if
self
.
pystring_table_needed
:
if
self
.
pystring_table_needed
:
...
@@ -351,10 +350,10 @@ class GlobalState(object):
...
@@ -351,10 +350,10 @@ class GlobalState(object):
self
.
utilprotowriter
.
put
(
proto
)
self
.
utilprotowriter
.
put
(
proto
)
self
.
utildefwriter
.
put
(
_def
)
self
.
utildefwriter
.
put
(
_def
)
def
has_
utility_
code
(
self
,
name
):
def
has_code
(
self
,
name
):
return
name
in
self
.
used_utility_code
return
name
in
self
.
used_utility_code
def
use_
generated_code
(
self
,
func
,
name
,
*
args
,
**
kw
):
def
use_
code_from
(
self
,
func
,
name
,
*
args
,
**
kw
):
"""
"""
Requests that the utility code that func can generate is used in the C
Requests that the utility code that func can generate is used in the C
file. func is called like this:
file. func is called like this:
...
@@ -525,7 +524,6 @@ class CCodeWriter(object):
...
@@ -525,7 +524,6 @@ class CCodeWriter(object):
self
.
put
(
code
)
self
.
put
(
code
)
self
.
write
(
"
\
n
"
);
self
.
write
(
"
\
n
"
);
self
.
bol
=
1
self
.
bol
=
1
return
self
def
emit_marker
(
self
):
def
emit_marker
(
self
):
self
.
write
(
"
\
n
"
);
self
.
write
(
"
\
n
"
);
...
@@ -533,7 +531,6 @@ class CCodeWriter(object):
...
@@ -533,7 +531,6 @@ class CCodeWriter(object):
self
.
write
(
"/* %s */
\
n
"
%
self
.
marker
[
1
])
self
.
write
(
"/* %s */
\
n
"
%
self
.
marker
[
1
])
self
.
last_marker_line
=
self
.
marker
[
0
]
self
.
last_marker_line
=
self
.
marker
[
0
]
self
.
marker
=
None
self
.
marker
=
None
return
self
def
put_safe
(
self
,
code
):
def
put_safe
(
self
,
code
):
# put code, but ignore {}
# put code, but ignore {}
...
@@ -556,25 +553,20 @@ class CCodeWriter(object):
...
@@ -556,25 +553,20 @@ class CCodeWriter(object):
self
.
level
+=
dl
self
.
level
+=
dl
elif
fix_indent
:
elif
fix_indent
:
self
.
level
+=
1
self
.
level
+=
1
return
self
def
increase_indent
(
self
):
def
increase_indent
(
self
):
self
.
level
=
self
.
level
+
1
self
.
level
=
self
.
level
+
1
return
self
def
decrease_indent
(
self
):
def
decrease_indent
(
self
):
self
.
level
=
self
.
level
-
1
self
.
level
=
self
.
level
-
1
return
self
def
begin_block
(
self
):
def
begin_block
(
self
):
self
.
putln
(
"{"
)
self
.
putln
(
"{"
)
self
.
increase_indent
()
self
.
increase_indent
()
return
self
def
end_block
(
self
):
def
end_block
(
self
):
self
.
decrease_indent
()
self
.
decrease_indent
()
self
.
putln
(
"}"
)
self
.
putln
(
"}"
)
return
self
def
indent
(
self
):
def
indent
(
self
):
self
.
write
(
" "
*
self
.
level
)
self
.
write
(
" "
*
self
.
level
)
...
@@ -603,12 +595,10 @@ class CCodeWriter(object):
...
@@ -603,12 +595,10 @@ class CCodeWriter(object):
def
put_label
(
self
,
lbl
):
def
put_label
(
self
,
lbl
):
if
lbl
in
self
.
funcstate
.
labels_used
:
if
lbl
in
self
.
funcstate
.
labels_used
:
self
.
putln
(
"%s:;"
%
lbl
)
self
.
putln
(
"%s:;"
%
lbl
)
return
self
def
put_goto
(
self
,
lbl
):
def
put_goto
(
self
,
lbl
):
self
.
funcstate
.
use_label
(
lbl
)
self
.
funcstate
.
use_label
(
lbl
)
self
.
putln
(
"goto %s;"
%
lbl
)
self
.
putln
(
"goto %s;"
%
lbl
)
return
self
def
put_var_declarations
(
self
,
entries
,
static
=
0
,
dll_linkage
=
None
,
def
put_var_declarations
(
self
,
entries
,
static
=
0
,
dll_linkage
=
None
,
definition
=
True
):
definition
=
True
):
...
...
Cython/Compiler/ExprNodes.py
View file @
8912ea26
...
@@ -169,6 +169,7 @@ class ExprNode(Node):
...
@@ -169,6 +169,7 @@ class ExprNode(Node):
saved_subexpr_nodes
=
None
saved_subexpr_nodes
=
None
is_temp
=
0
is_temp
=
0
is_target
=
0
def
get_child_attrs
(
self
):
def
get_child_attrs
(
self
):
return
self
.
subexprs
return
self
.
subexprs
...
@@ -207,10 +208,10 @@ class ExprNode(Node):
...
@@ -207,10 +208,10 @@ class ExprNode(Node):
return
self
.
saved_subexpr_nodes
return
self
.
saved_subexpr_nodes
def
result
(
self
):
def
result
(
self
):
if
self
.
is_temp
:
if
not
self
.
is_temp
or
self
.
is_target
:
return
self
.
result_code
else
:
return
self
.
calculate_result_code
()
return
self
.
calculate_result_code
()
else
:
# i.e. self.is_temp:
return
self
.
result_code
def
result_as
(
self
,
type
=
None
):
def
result_as
(
self
,
type
=
None
):
# Return the result code cast to the specified C type.
# Return the result code cast to the specified C type.
...
@@ -341,7 +342,7 @@ class ExprNode(Node):
...
@@ -341,7 +342,7 @@ class ExprNode(Node):
if
debug_temp_alloc
:
if
debug_temp_alloc
:
print
(
"%s Allocating target temps"
%
self
)
print
(
"%s Allocating target temps"
%
self
)
self
.
allocate_subexpr_temps
(
env
)
self
.
allocate_subexpr_temps
(
env
)
self
.
result_code
=
self
.
target_code
()
self
.
is_target
=
True
if
rhs
:
if
rhs
:
rhs
.
release_temp
(
env
)
rhs
.
release_temp
(
env
)
self
.
release_subexpr_temps
(
env
)
self
.
release_subexpr_temps
(
env
)
...
@@ -436,9 +437,13 @@ class ExprNode(Node):
...
@@ -436,9 +437,13 @@ class ExprNode(Node):
# its sub-expressions, and dispose of any
# its sub-expressions, and dispose of any
# temporary results of its sub-expressions.
# temporary results of its sub-expressions.
self
.
generate_subexpr_evaluation_code
(
code
)
self
.
generate_subexpr_evaluation_code
(
code
)
self
.
pre_generate_result_code
(
code
)
self
.
generate_result_code
(
code
)
self
.
generate_result_code
(
code
)
if
self
.
is_temp
:
if
self
.
is_temp
:
self
.
generate_subexpr_disposal_code
(
code
)
self
.
generate_subexpr_disposal_code
(
code
)
def
pre_generate_result_code
(
self
,
code
):
pass
def
generate_subexpr_evaluation_code
(
self
,
code
):
def
generate_subexpr_evaluation_code
(
self
,
code
):
for
node
in
self
.
subexpr_nodes
():
for
node
in
self
.
subexpr_nodes
():
...
@@ -569,6 +574,66 @@ class ExprNode(Node):
...
@@ -569,6 +574,66 @@ class ExprNode(Node):
return
None
return
None
class
NewTempExprNode
(
ExprNode
):
backwards_compatible_result
=
None
def
result
(
self
):
if
self
.
is_temp
:
return
self
.
temp_code
else
:
return
self
.
calculate_result_code
()
def
allocate_target_temps
(
self
,
env
,
rhs
):
self
.
allocate_subexpr_temps
(
env
)
rhs
.
release_temp
(
rhs
)
self
.
release_subexpr_temps
(
env
)
def
allocate_temps
(
self
,
env
,
result
=
None
):
self
.
allocate_subexpr_temps
(
env
)
self
.
backwards_compatible_result
=
result
if
self
.
is_temp
:
self
.
release_subexpr_temps
(
env
)
def
allocate_temp
(
self
,
env
,
result
=
None
):
assert
result
is
None
def
release_temp
(
self
,
env
):
pass
def
pre_generate_result_code
(
self
,
code
):
if
self
.
is_temp
:
type
=
self
.
type
if
not
type
.
is_void
:
if
type
.
is_pyobject
:
type
=
PyrexTypes
.
py_object_type
if
self
.
backwards_compatible_result
:
self
.
temp_code
=
self
.
backwards_compatible_result
else
:
self
.
temp_code
=
code
.
funcstate
.
allocate_temp
(
type
)
else
:
self
.
temp_code
=
None
def
generate_disposal_code
(
self
,
code
):
if
self
.
is_temp
:
if
self
.
type
.
is_pyobject
:
code
.
put_decref_clear
(
self
.
result
(),
self
.
ctype
())
if
not
self
.
backwards_compatible_result
:
code
.
funcstate
.
release_temp
(
self
.
temp_code
)
else
:
self
.
generate_subexpr_disposal_code
(
code
)
def
generate_post_assignment_code
(
self
,
code
):
if
self
.
is_temp
:
if
self
.
type
.
is_pyobject
:
code
.
putln
(
"%s = 0;"
%
self
.
temp_code
)
if
not
self
.
backwards_compatible_result
:
code
.
funcstate
.
release_temp
(
self
.
temp_code
)
else
:
self
.
generate_subexpr_disposal_code
(
code
)
class
AtomicExprNode
(
ExprNode
):
class
AtomicExprNode
(
ExprNode
):
# Abstract base class for expression nodes which have
# Abstract base class for expression nodes which have
# no sub-expressions.
# no sub-expressions.
...
@@ -1463,10 +1528,8 @@ class IndexNode(ExprNode):
...
@@ -1463,10 +1528,8 @@ class IndexNode(ExprNode):
self
.
type
=
self
.
base
.
type
.
dtype
self
.
type
=
self
.
base
.
type
.
dtype
self
.
is_buffer_access
=
True
self
.
is_buffer_access
=
True
self
.
buffer_type
=
self
.
base
.
entry
.
type
self
.
buffer_type
=
self
.
base
.
entry
.
type
if
getting
:
if
getting
and
self
.
type
.
is_pyobject
:
# we only need a temp because result_code isn't refactored to
# generation time, but this seems an ok shortcut to take
self
.
is_temp
=
True
self
.
is_temp
=
True
if
setting
:
if
setting
:
if
not
self
.
base
.
entry
.
type
.
writable
:
if
not
self
.
base
.
entry
.
type
.
writable
:
...
@@ -1515,10 +1578,10 @@ class IndexNode(ExprNode):
...
@@ -1515,10 +1578,10 @@ class IndexNode(ExprNode):
def
is_lvalue
(
self
):
def
is_lvalue
(
self
):
return
1
return
1
def
calculate_result_code
(
self
):
def
calculate_result_code
(
self
):
if
self
.
is_buffer_access
:
if
self
.
is_buffer_access
:
return
"
<not used>"
return
"
(*%s)"
%
self
.
buffer_ptr_code
else
:
else
:
return
"(%s[%s])"
%
(
return
"(%s[%s])"
%
(
self
.
base
.
result
(),
self
.
index
.
result
())
self
.
base
.
result
(),
self
.
index
.
result
())
...
@@ -1552,12 +1615,10 @@ class IndexNode(ExprNode):
...
@@ -1552,12 +1615,10 @@ class IndexNode(ExprNode):
if
self
.
is_buffer_access
:
if
self
.
is_buffer_access
:
if
code
.
globalstate
.
directives
[
'nonecheck'
]:
if
code
.
globalstate
.
directives
[
'nonecheck'
]:
self
.
put_nonecheck
(
code
)
self
.
put_nonecheck
(
code
)
ptrcode
=
self
.
buffer_lookup_code
(
code
)
self
.
buffer_ptr_code
=
self
.
buffer_lookup_code
(
code
)
code
.
putln
(
"%s = *%s;"
%
(
if
self
.
type
.
is_pyobject
:
self
.
result
(),
# is_temp is True, so must pull out value and incref it.
self
.
buffer_type
.
buffer_ptr_type
.
cast_code
(
ptrcode
)))
code
.
putln
(
"%s = *%s;"
%
(
self
.
result
(),
self
.
buffer_ptr_code
))
# Must incref the value we pulled out.
if
self
.
buffer_type
.
dtype
.
is_pyobject
:
code
.
putln
(
"Py_INCREF((PyObject*)%s);"
%
self
.
result
())
code
.
putln
(
"Py_INCREF((PyObject*)%s);"
%
self
.
result
())
elif
self
.
type
.
is_pyobject
:
elif
self
.
type
.
is_pyobject
:
if
self
.
index
.
type
.
is_int
:
if
self
.
index
.
type
.
is_int
:
...
@@ -3380,7 +3441,7 @@ def get_compile_time_binop(node):
...
@@ -3380,7 +3441,7 @@ def get_compile_time_binop(node):
%
node
.
operator
)
%
node
.
operator
)
return
func
return
func
class
BinopNode
(
ExprNode
):
class
BinopNode
(
NewTemp
ExprNode
):
# operator string
# operator string
# operand1 ExprNode
# operand1 ExprNode
# operand2 ExprNode
# operand2 ExprNode
...
@@ -4377,7 +4438,7 @@ class CloneNode(CoercionNode):
...
@@ -4377,7 +4438,7 @@ class CloneNode(CoercionNode):
if
hasattr
(
arg
,
'entry'
):
if
hasattr
(
arg
,
'entry'
):
self
.
entry
=
arg
.
entry
self
.
entry
=
arg
.
entry
def
calculate_result_code
(
self
):
def
result
(
self
):
return
self
.
arg
.
result
()
return
self
.
arg
.
result
()
def
analyse_types
(
self
,
env
):
def
analyse_types
(
self
,
env
):
...
@@ -4397,7 +4458,7 @@ class CloneNode(CoercionNode):
...
@@ -4397,7 +4458,7 @@ class CloneNode(CoercionNode):
pass
pass
def
allocate_temps
(
self
,
env
):
def
allocate_temps
(
self
,
env
):
self
.
result_code
=
self
.
calculate_result_code
()
pass
def
release_temp
(
self
,
env
):
def
release_temp
(
self
,
env
):
pass
pass
...
...
Cython/Compiler/PyrexTypes.py
View file @
8912ea26
...
@@ -101,6 +101,7 @@ class PyrexType(BaseType):
...
@@ -101,6 +101,7 @@ class PyrexType(BaseType):
default_value
=
""
default_value
=
""
parsetuple_format
=
""
parsetuple_format
=
""
pymemberdef_typecode
=
None
pymemberdef_typecode
=
None
typestring
=
None
def
resolve
(
self
):
def
resolve
(
self
):
# If a typedef, returns the base type.
# If a typedef, returns the base type.
...
@@ -140,7 +141,6 @@ class PyrexType(BaseType):
...
@@ -140,7 +141,6 @@ class PyrexType(BaseType):
# a struct whose attributes are not defined, etc.
# a struct whose attributes are not defined, etc.
return
1
return
1
class
CTypedefType
(
BaseType
):
class
CTypedefType
(
BaseType
):
#
#
# Pseudo-type defined with a ctypedef statement in a
# Pseudo-type defined with a ctypedef statement in a
...
@@ -965,6 +965,11 @@ class CStructOrUnionType(CType):
...
@@ -965,6 +965,11 @@ class CStructOrUnionType(CType):
def
attributes_known
(
self
):
def
attributes_known
(
self
):
return
self
.
is_complete
()
return
self
.
is_complete
()
def
can_be_complex
(
self
):
# Does the struct consist of exactly two floats?
fields
=
self
.
scope
.
var_entries
return
len
(
fields
)
==
2
and
fields
[
0
].
type
.
is_float
and
fields
[
1
].
type
.
is_float
class
CEnumType
(
CType
):
class
CEnumType
(
CType
):
# name string
# name string
...
...
Cython/Includes/numpy.pxd
View file @
8912ea26
...
@@ -69,20 +69,23 @@ cdef extern from "numpy/arrayobject.h":
...
@@ -69,20 +69,23 @@ cdef extern from "numpy/arrayobject.h":
# made available from this pxd file yet.
# made available from this pxd file yet.
cdef
int
t
=
PyArray_TYPE
(
self
)
cdef
int
t
=
PyArray_TYPE
(
self
)
cdef
char
*
f
=
NULL
cdef
char
*
f
=
NULL
if
t
==
NPY_BYTE
:
f
=
"b"
if
t
==
NPY_BYTE
:
f
=
"b"
elif
t
==
NPY_UBYTE
:
f
=
"B"
elif
t
==
NPY_UBYTE
:
f
=
"B"
elif
t
==
NPY_SHORT
:
f
=
"h"
elif
t
==
NPY_SHORT
:
f
=
"h"
elif
t
==
NPY_USHORT
:
f
=
"H"
elif
t
==
NPY_USHORT
:
f
=
"H"
elif
t
==
NPY_INT
:
f
=
"i"
elif
t
==
NPY_INT
:
f
=
"i"
elif
t
==
NPY_UINT
:
f
=
"I"
elif
t
==
NPY_UINT
:
f
=
"I"
elif
t
==
NPY_LONG
:
f
=
"l"
elif
t
==
NPY_LONG
:
f
=
"l"
elif
t
==
NPY_ULONG
:
f
=
"L"
elif
t
==
NPY_ULONG
:
f
=
"L"
elif
t
==
NPY_LONGLONG
:
f
=
"q"
elif
t
==
NPY_LONGLONG
:
f
=
"q"
elif
t
==
NPY_ULONGLONG
:
f
=
"Q"
elif
t
==
NPY_ULONGLONG
:
f
=
"Q"
elif
t
==
NPY_FLOAT
:
f
=
"f"
elif
t
==
NPY_FLOAT
:
f
=
"f"
elif
t
==
NPY_DOUBLE
:
f
=
"d"
elif
t
==
NPY_DOUBLE
:
f
=
"d"
elif
t
==
NPY_LONGDOUBLE
:
f
=
"g"
elif
t
==
NPY_LONGDOUBLE
:
f
=
"g"
elif
t
==
NPY_OBJECT
:
f
=
"O"
elif
t
==
NPY_CFLOAT
:
f
=
"Zf"
elif
t
==
NPY_CDOUBLE
:
f
=
"Zd"
elif
t
==
NPY_CLONGDOUBLE
:
f
=
"Zg"
elif
t
==
NPY_OBJECT
:
f
=
"O"
if
f
==
NULL
:
if
f
==
NULL
:
raise
ValueError
(
"only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)"
%
t
)
raise
ValueError
(
"only objects, int and float dtypes supported for ndarray buffer access so far (dtype is %d)"
%
t
)
...
...
tests/run/bufaccess.pyx
View file @
8912ea26
...
@@ -17,16 +17,15 @@ cimport cython
...
@@ -17,16 +17,15 @@ cimport cython
from
python_ref
cimport
PyObject
from
python_ref
cimport
PyObject
__test__
=
{}
__test__
=
{}
setup_string
=
u"""
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> E = ErrorBuffer("E")
"""
import
re
exclude
=
[]
#re.compile('object').search]
def
testcase
(
func
):
def
testcase
(
func
):
__test__
[
func
.
__name__
]
=
setup_string
+
func
.
__doc__
for
e
in
exclude
:
if
e
(
func
.
__name__
):
return
func
__test__
[
func
.
__name__
]
=
func
.
__doc__
return
func
return
func
def
testcas
(
a
):
def
testcas
(
a
):
...
@@ -53,6 +52,8 @@ def printbuf():
...
@@ -53,6 +52,8 @@ def printbuf():
@
testcase
@
testcase
def
acquire_release
(
o1
,
o2
):
def
acquire_release
(
o1
,
o2
):
"""
"""
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> acquire_release(A, B)
>>> acquire_release(A, B)
acquired A
acquired A
released A
released A
...
@@ -73,6 +74,7 @@ def acquire_raise(o):
...
@@ -73,6 +74,7 @@ def acquire_raise(o):
Apparently, doctest won't handle mixed exceptions and print
Apparently, doctest won't handle mixed exceptions and print
stats, so need to circumvent this.
stats, so need to circumvent this.
>>> A = IntMockBuffer("A", range(6))
>>> A.resetlog()
>>> A.resetlog()
>>> acquire_raise(A)
>>> acquire_raise(A)
Traceback (most recent call last):
Traceback (most recent call last):
...
@@ -218,6 +220,7 @@ def acquire_nonbuffer2():
...
@@ -218,6 +220,7 @@ def acquire_nonbuffer2():
@
testcase
@
testcase
def
as_argument
(
object
[
int
]
bufarg
,
int
n
):
def
as_argument
(
object
[
int
]
bufarg
,
int
n
):
"""
"""
>>> A = IntMockBuffer("A", range(6))
>>> as_argument(A, 6)
>>> as_argument(A, 6)
acquired A
acquired A
0 1 2 3 4 5 END
0 1 2 3 4 5 END
...
@@ -235,6 +238,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in
...
@@ -235,6 +238,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in
acquired default
acquired default
0 1 2 3 4 5 END
0 1 2 3 4 5 END
released default
released default
>>> A = IntMockBuffer("A", range(6))
>>> as_argument_defval(A, 6)
>>> as_argument_defval(A, 6)
acquired A
acquired A
0 1 2 3 4 5 END
0 1 2 3 4 5 END
...
@@ -248,6 +252,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in
...
@@ -248,6 +252,7 @@ def as_argument_defval(object[int] bufarg=IntMockBuffer('default', range(6)), in
@
testcase
@
testcase
def
cdef_assignment
(
obj
,
n
):
def
cdef_assignment
(
obj
,
n
):
"""
"""
>>> A = IntMockBuffer("A", range(6))
>>> cdef_assignment(A, 6)
>>> cdef_assignment(A, 6)
acquired A
acquired A
0 1 2 3 4 5 END
0 1 2 3 4 5 END
...
@@ -263,6 +268,8 @@ def cdef_assignment(obj, n):
...
@@ -263,6 +268,8 @@ def cdef_assignment(obj, n):
@
testcase
@
testcase
def
forin_assignment
(
objs
,
int
pick
):
def
forin_assignment
(
objs
,
int
pick
):
"""
"""
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> forin_assignment([A, B, A, A], 2)
>>> forin_assignment([A, B, A, A], 2)
acquired A
acquired A
2
2
...
@@ -284,6 +291,7 @@ def forin_assignment(objs, int pick):
...
@@ -284,6 +291,7 @@ def forin_assignment(objs, int pick):
@
testcase
@
testcase
def
cascaded_buffer_assignment
(
obj
):
def
cascaded_buffer_assignment
(
obj
):
"""
"""
>>> A = IntMockBuffer("A", range(6))
>>> cascaded_buffer_assignment(A)
>>> cascaded_buffer_assignment(A)
acquired A
acquired A
acquired A
acquired A
...
@@ -296,6 +304,8 @@ def cascaded_buffer_assignment(obj):
...
@@ -296,6 +304,8 @@ def cascaded_buffer_assignment(obj):
@
testcase
@
testcase
def
tuple_buffer_assignment1
(
a
,
b
):
def
tuple_buffer_assignment1
(
a
,
b
):
"""
"""
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> tuple_buffer_assignment1(A, B)
>>> tuple_buffer_assignment1(A, B)
acquired A
acquired A
acquired B
acquired B
...
@@ -308,6 +318,8 @@ def tuple_buffer_assignment1(a, b):
...
@@ -308,6 +318,8 @@ def tuple_buffer_assignment1(a, b):
@
testcase
@
testcase
def
tuple_buffer_assignment2
(
tup
):
def
tuple_buffer_assignment2
(
tup
):
"""
"""
>>> A = IntMockBuffer("A", range(6))
>>> B = IntMockBuffer("B", range(6))
>>> tuple_buffer_assignment2((A, B))
>>> tuple_buffer_assignment2((A, B))
acquired A
acquired A
acquired B
acquired B
...
@@ -358,12 +370,27 @@ def alignment_string(object[int] buf):
...
@@ -358,12 +370,27 @@ def alignment_string(object[int] buf):
"""
"""
print
buf
[
1
]
print
buf
[
1
]
@
testcase
def
wrong_string
(
object
[
int
]
buf
):
"""
>>> wrong_string(IntMockBuffer(None, [1,2], format="iasdf"))
Traceback (most recent call last):
...
ValueError: Buffer format string specifies more data than 'int' can hold (expected end, got 'asdf')
>>> wrong_string(IntMockBuffer(None, [1,2], format="$$"))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (expected 'i', got '$$')
"""
print
buf
[
1
]
#
#
# Getting items and index bounds checking
# Getting items and index bounds checking
#
#
@
testcase
@
testcase
def
get_int_2d
(
object
[
int
,
ndim
=
2
]
buf
,
int
i
,
int
j
):
def
get_int_2d
(
object
[
int
,
ndim
=
2
]
buf
,
int
i
,
int
j
):
"""
"""
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> get_int_2d(C, 1, 1)
>>> get_int_2d(C, 1, 1)
acquired C
acquired C
released C
released C
...
@@ -399,6 +426,7 @@ def get_int_2d(object[int, ndim=2] buf, int i, int j):
...
@@ -399,6 +426,7 @@ def get_int_2d(object[int, ndim=2] buf, int i, int j):
def
get_int_2d_uintindex
(
object
[
int
,
ndim
=
2
]
buf
,
unsigned
int
i
,
unsigned
int
j
):
def
get_int_2d_uintindex
(
object
[
int
,
ndim
=
2
]
buf
,
unsigned
int
i
,
unsigned
int
j
):
"""
"""
Unsigned indexing:
Unsigned indexing:
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> get_int_2d_uintindex(C, 0, 0)
>>> get_int_2d_uintindex(C, 0, 0)
acquired C
acquired C
released C
released C
...
@@ -418,6 +446,7 @@ def set_int_2d(object[int, ndim=2] buf, int i, int j, int value):
...
@@ -418,6 +446,7 @@ def set_int_2d(object[int, ndim=2] buf, int i, int j, int value):
Uses get_int_2d to read back the value afterwards. For pure
Uses get_int_2d to read back the value afterwards. For pure
unit test, one should support reading in MockBuffer instead.
unit test, one should support reading in MockBuffer instead.
>>> C = IntMockBuffer("C", range(6), (2,3))
>>> set_int_2d(C, 1, 1, 10)
>>> set_int_2d(C, 1, 1, 10)
acquired C
acquired C
released C
released C
...
@@ -1175,7 +1204,6 @@ cdef class DoubleMockBuffer(MockBuffer):
...
@@ -1175,7 +1204,6 @@ cdef class DoubleMockBuffer(MockBuffer):
cdef
get_itemsize
(
self
):
return
sizeof
(
double
)
cdef
get_itemsize
(
self
):
return
sizeof
(
double
)
cdef
get_default_format
(
self
):
return
b"d"
cdef
get_default_format
(
self
):
return
b"d"
cdef
extern
from
*
:
cdef
extern
from
*
:
void
*
addr_of_pyobject
"(void*)"
(
object
)
void
*
addr_of_pyobject
"(void*)"
(
object
)
...
@@ -1254,3 +1282,86 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf):
...
@@ -1254,3 +1282,86 @@ def bufdefaults1(IntStridedMockBuffer[int, ndim=1] buf):
pass
pass
#
# Structs
#
cdef
struct
MyStruct
:
char
a
char
b
long
long
int
c
int
d
int
e
cdef
class
MyStructMockBuffer
(
MockBuffer
):
cdef
int
write
(
self
,
char
*
buf
,
object
value
)
except
-
1
:
cdef
MyStruct
*
s
s
=
<
MyStruct
*>
buf
;
s
.
a
,
s
.
b
,
s
.
c
,
s
.
d
,
s
.
e
=
value
return
0
cdef
get_itemsize
(
self
):
return
sizeof
(
MyStruct
)
cdef
get_default_format
(
self
):
return
b"2bq2i"
@
testcase
def
basic_struct
(
object
[
MyStruct
]
buf
):
"""
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)]))
1 2 3 4 5
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="bbqii"))
1 2 3 4 5
>>> basic_struct(MyStructMockBuffer(None, [(1, 2, 3, 4, 5)], format="i"))
Traceback (most recent call last):
...
ValueError: Buffer datatype mismatch (expected 'b', got 'i')
"""
print
buf
[
0
].
a
,
buf
[
0
].
b
,
buf
[
0
].
c
,
buf
[
0
].
d
,
buf
[
0
].
e
cdef
struct
LongComplex
:
long
double
real
long
double
imag
cdef
struct
MixedComplex
:
long
double
real
float
imag
cdef
class
LongComplexMockBuffer
(
MockBuffer
):
cdef
int
write
(
self
,
char
*
buf
,
object
value
)
except
-
1
:
cdef
LongComplex
*
s
s
=
<
LongComplex
*>
buf
;
s
.
real
,
s
.
imag
=
value
return
0
cdef
get_itemsize
(
self
):
return
sizeof
(
LongComplex
)
cdef
get_default_format
(
self
):
return
b"Zg"
@
testcase
def
complex_struct_dtype
(
object
[
LongComplex
]
buf
):
"""
Note that the format string is "Zg" rather than "2g"...
>>> complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
0.0 -1.0
"""
print
buf
[
0
].
real
,
buf
[
0
].
imag
@
testcase
def
mixed_complex_struct_dtype
(
object
[
MixedComplex
]
buf
):
"""
Triggering a specific execution path for this case.
>>> mixed_complex_struct_dtype(LongComplexMockBuffer(None, [(0, -1)]))
Traceback (most recent call last):
...
ValueError: Cannot store complex number in 'MixedComplex' as 'long double' differs from 'float' in size.
"""
print
buf
[
0
].
real
,
buf
[
0
].
imag
@
testcase
def
complex_struct_inplace
(
object
[
LongComplex
]
buf
):
"""
>>> complex_struct_inplace(LongComplexMockBuffer(None, [(0, -1)]))
1.0 1.0
"""
buf
[
0
].
real
+=
1
buf
[
0
].
imag
+=
2
print
buf
[
0
].
real
,
buf
[
0
].
imag
tests/run/numpy_test.pyx
View file @
8912ea26
...
@@ -115,6 +115,9 @@ try:
...
@@ -115,6 +115,9 @@ try:
>>> test_dtype('d', inc1_double)
>>> test_dtype('d', inc1_double)
>>> test_dtype('g', inc1_longdouble)
>>> test_dtype('g', inc1_longdouble)
>>> test_dtype('O', inc1_object)
>>> test_dtype('O', inc1_object)
>>> test_dtype('F', inc1_cfloat) # numpy format codes differ from buffer ones here
>>> test_dtype('D', inc1_cdouble)
>>> test_dtype('G', inc1_clongdouble)
>>> test_dtype(np.int, inc1_int_t)
>>> test_dtype(np.int, inc1_int_t)
>>> test_dtype(np.long, inc1_long_t)
>>> test_dtype(np.long, inc1_long_t)
...
@@ -127,11 +130,6 @@ try:
...
@@ -127,11 +130,6 @@ try:
>>> test_dtype(np.float64, inc1_float64_t)
>>> test_dtype(np.float64, inc1_float64_t)
Unsupported types:
Unsupported types:
>>> test_dtype(np.complex, inc1_byte)
Traceback (most recent call last):
...
ValueError: only objects, int and float dtypes supported for ndarray buffer access so far (dtype is 15)
>>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
>>> a = np.zeros((10,), dtype=np.dtype('i4,i4'))
>>> inc1_byte(a)
>>> inc1_byte(a)
Traceback (most recent call last):
Traceback (most recent call last):
...
@@ -194,7 +192,19 @@ def test_f_contig(np.ndarray[int, ndim=2, mode='fortran'] arr):
...
@@ -194,7 +192,19 @@ def test_f_contig(np.ndarray[int, ndim=2, mode='fortran'] arr):
for
i
in
range
(
arr
.
shape
[
0
]):
for
i
in
range
(
arr
.
shape
[
0
]):
print
" "
.
join
([
str
(
arr
[
i
,
j
])
for
j
in
range
(
arr
.
shape
[
1
])])
print
" "
.
join
([
str
(
arr
[
i
,
j
])
for
j
in
range
(
arr
.
shape
[
1
])])
# Exhaustive dtype tests -- increments element [1] by 1 for all dtypes
cdef
struct
cfloat
:
float
real
float
imag
cdef
struct
cdouble
:
double
real
double
imag
cdef
struct
clongdouble
:
long
double
real
long
double
imag
# Exhaustive dtype tests -- increments element [1] by 1 (or 1+1j) for all dtypes
def
inc1_byte
(
np
.
ndarray
[
char
]
arr
):
arr
[
1
]
+=
1
def
inc1_byte
(
np
.
ndarray
[
char
]
arr
):
arr
[
1
]
+=
1
def
inc1_ubyte
(
np
.
ndarray
[
unsigned
char
]
arr
):
arr
[
1
]
+=
1
def
inc1_ubyte
(
np
.
ndarray
[
unsigned
char
]
arr
):
arr
[
1
]
+=
1
def
inc1_short
(
np
.
ndarray
[
short
]
arr
):
arr
[
1
]
+=
1
def
inc1_short
(
np
.
ndarray
[
short
]
arr
):
arr
[
1
]
+=
1
...
@@ -210,6 +220,19 @@ def inc1_float(np.ndarray[float] arr): arr[1] += 1
...
@@ -210,6 +220,19 @@ def inc1_float(np.ndarray[float] arr): arr[1] += 1
def
inc1_double
(
np
.
ndarray
[
double
]
arr
):
arr
[
1
]
+=
1
def
inc1_double
(
np
.
ndarray
[
double
]
arr
):
arr
[
1
]
+=
1
def
inc1_longdouble
(
np
.
ndarray
[
long
double
]
arr
):
arr
[
1
]
+=
1
def
inc1_longdouble
(
np
.
ndarray
[
long
double
]
arr
):
arr
[
1
]
+=
1
def
inc1_cfloat
(
np
.
ndarray
[
cfloat
]
arr
):
arr
[
1
].
real
+=
1
arr
[
1
].
imag
+=
1
def
inc1_cdouble
(
np
.
ndarray
[
cdouble
]
arr
):
arr
[
1
].
real
+=
1
arr
[
1
].
imag
+=
1
def
inc1_clongdouble
(
np
.
ndarray
[
clongdouble
]
arr
):
cdef
long
double
x
x
=
arr
[
1
].
real
+
1
arr
[
1
].
real
=
x
arr
[
1
].
imag
=
arr
[
1
].
imag
+
1
def
inc1_object
(
np
.
ndarray
[
object
]
arr
):
def
inc1_object
(
np
.
ndarray
[
object
]
arr
):
o
=
arr
[
1
]
o
=
arr
[
1
]
...
@@ -229,10 +252,14 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr): arr[1] += 1
...
@@ -229,10 +252,14 @@ def inc1_float64_t(np.ndarray[np.float64_t] arr): arr[1] += 1
def
test_dtype
(
dtype
,
inc1
):
def
test_dtype
(
dtype
,
inc1
):
a
=
np
.
array
([
0
,
10
],
dtype
=
dtype
)
if
dtype
in
(
'F'
,
'D'
,
'G'
):
inc1
(
a
)
a
=
np
.
array
([
0
,
10
+
10j
],
dtype
=
dtype
)
if
a
[
1
]
!=
11
:
print
"failed!"
inc1
(
a
)
if
a
[
1
]
!=
(
11
+
11j
):
print
"failed!"
,
a
[
1
]
else
:
a
=
np
.
array
([
0
,
10
],
dtype
=
dtype
)
inc1
(
a
)
if
a
[
1
]
!=
11
:
print
"failed!"
def
test_good_cast
():
def
test_good_cast
():
# Check that a signed int can round-trip through casted unsigned int access
# Check that a signed int can round-trip through casted unsigned int access
...
@@ -243,4 +270,3 @@ def test_good_cast():
...
@@ -243,4 +270,3 @@ def test_good_cast():
def
test_bad_cast
():
def
test_bad_cast
():
# This should raise an exception
# This should raise an exception
cdef
np
.
ndarray
[
long
,
cast
=
True
]
arr
=
np
.
array
([
1
],
dtype
=
'b'
)
cdef
np
.
ndarray
[
long
,
cast
=
True
]
arr
=
np
.
array
([
1
],
dtype
=
'b'
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment