Commit cab342a2 authored by Lisandro Dalcin's avatar Lisandro Dalcin

Extension type cast should reject None (ticket #417)

--HG--
extra : rebase_source : 37bb9de5574e1f7b4f288192eaa3c70a2ae350ca
parent 97d7ad39
......@@ -4245,7 +4245,7 @@ class TypecastNode(ExprNode):
warning(self.pos, "No conversion from %s to %s, python object pointer used." % (self.type, self.operand.type))
elif from_py and to_py:
if self.typecheck and self.type.is_extension_type:
self.operand = PyTypeTestNode(self.operand, self.type, env)
self.operand = PyTypeTestNode(self.operand, self.type, env, notnone=True)
def nogil_check(self, env):
if self.type and self.type.is_pyobject and self.is_temp:
......@@ -5563,13 +5563,14 @@ class PyTypeTestNode(CoercionNode):
# object is an instance of a particular extension type.
# This node borrows the result of its argument node.
def __init__(self, arg, dst_type, env):
def __init__(self, arg, dst_type, env, notnone=False):
# The arg is know to be a Python object, and
# the dst_type is known to be an extension type.
assert dst_type.is_extension_type or dst_type.is_builtin_type, "PyTypeTest on non extension type"
CoercionNode.__init__(self, arg)
self.type = dst_type
self.result_ctype = arg.ctype()
self.notnone = notnone
nogil_check = Node.gil_error
gil_message = "Python type test"
......@@ -5596,7 +5597,7 @@ class PyTypeTestNode(CoercionNode):
code.globalstate.use_utility_code(type_test_utility_code)
code.putln(
"if (!(%s)) %s" % (
self.type.type_test_code(self.arg.py_result()),
self.type.type_test_code(self.arg.py_result(), self.notnone),
code.error_goto(self.pos)))
else:
error(self.pos, "Cannot test type of extern C class "
......@@ -6008,18 +6009,18 @@ bad:
type_test_utility_code = UtilityCode(
proto = """
static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/
static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/
""",
impl = """
static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
if (!type) {
static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
if (unlikely(!type)) {
PyErr_Format(PyExc_SystemError, "Missing type object");
return 0;
}
if (obj == Py_None || PyObject_TypeCheck(obj, type))
if (likely(PyObject_TypeCheck(obj, type)))
return 1;
PyErr_Format(PyExc_TypeError, "Cannot convert %s to %s",
Py_TYPE(obj)->tp_name, type->tp_name);
PyErr_Format(PyExc_TypeError, "Cannot convert %.200s to %.200s",
Py_TYPE(obj)->tp_name, type->tp_name);
return 0;
}
""")
......
......@@ -408,19 +408,24 @@ class BuiltinObjectType(PyObjectType):
def subtype_of(self, type):
return type.is_pyobject and self.assignable_from(type)
def type_test_code(self, arg):
def type_test_code(self, arg, notnone=False):
type_name = self.name
if type_name == 'str':
check = 'PyString_CheckExact'
type_check = 'PyString_CheckExact'
elif type_name == 'set':
check = 'PyAnySet_CheckExact'
type_check = 'PyAnySet_CheckExact'
elif type_name == 'frozenset':
check = 'PyFrozenSet_CheckExact'
type_check = 'PyFrozenSet_CheckExact'
elif type_name == 'bool':
check = 'PyBool_Check'
type_check = 'PyBool_Check'
else:
check = 'Py%s_CheckExact' % type_name.capitalize()
return 'likely(%s(%s)) || (%s) == Py_None || (PyErr_Format(PyExc_TypeError, "Expected %s, got %%s", Py_TYPE(%s)->tp_name), 0)' % (check, arg, arg, self.name, arg)
type_check = 'Py%s_CheckExact' % type_name.capitalize()
check = 'likely(%s(%s))' % (type_check, arg)
if not notnone:
check = check + ('||((%s) == Py_None)' % arg)
error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg)
return check + '||' + error
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
......@@ -504,9 +509,16 @@ class PyExtensionType(PyObjectType):
else:
return "%s *%s" % (base, entity_code)
def type_test_code(self, py_arg):
return "__Pyx_TypeTest(%s, %s)" % (py_arg, self.typeptr_cname)
def type_test_code(self, py_arg, notnone=False):
none_check = "((%s) == Py_None)" % py_arg
type_check = "likely(__Pyx_TypeTest(%s, %s))" % (
py_arg, self.typeptr_cname)
if notnone:
return type_check
else:
return "likely(%s || %s)" % (none_check, type_check)
def attributes_known(self):
return self.scope is not None
......
#cython: autotestdict=True
cdef class Foo:
pass
cdef class SubFoo(Foo):
pass
cdef class Bar:
pass
def foo1(arg):
"""
>>> foo1(Foo())
>>> foo1(SubFoo())
>>> foo1(None)
>>> foo1(123)
>>> foo1(Bar())
"""
cdef Foo val = <Foo>arg
def foo2(arg):
"""
>>> foo2(Foo())
>>> foo2(SubFoo())
>>> foo2(None)
>>> foo2(123)
Traceback (most recent call last):
...
TypeError: Cannot convert int to typetest_T417.Foo
>>> foo2(Bar())
Traceback (most recent call last):
...
TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo
"""
cdef Foo val = arg
def foo3(arg):
"""
>>> foo3(Foo())
>>> foo3(SubFoo())
>>> foo3(None)
Traceback (most recent call last):
...
TypeError: Cannot convert NoneType to typetest_T417.Foo
>>> foo3(123)
Traceback (most recent call last):
...
TypeError: Cannot convert int to typetest_T417.Foo
>>> foo2(Bar())
Traceback (most recent call last):
...
TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo
"""
cdef val = <Foo?>arg
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment