Commit cab342a2 authored by Lisandro Dalcin's avatar Lisandro Dalcin

Extension type cast should reject None (ticket #417)

--HG--
extra : rebase_source : 37bb9de5574e1f7b4f288192eaa3c70a2ae350ca
parent 97d7ad39
...@@ -4245,7 +4245,7 @@ class TypecastNode(ExprNode): ...@@ -4245,7 +4245,7 @@ class TypecastNode(ExprNode):
warning(self.pos, "No conversion from %s to %s, python object pointer used." % (self.type, self.operand.type)) warning(self.pos, "No conversion from %s to %s, python object pointer used." % (self.type, self.operand.type))
elif from_py and to_py: elif from_py and to_py:
if self.typecheck and self.type.is_extension_type: if self.typecheck and self.type.is_extension_type:
self.operand = PyTypeTestNode(self.operand, self.type, env) self.operand = PyTypeTestNode(self.operand, self.type, env, notnone=True)
def nogil_check(self, env): def nogil_check(self, env):
if self.type and self.type.is_pyobject and self.is_temp: if self.type and self.type.is_pyobject and self.is_temp:
...@@ -5563,13 +5563,14 @@ class PyTypeTestNode(CoercionNode): ...@@ -5563,13 +5563,14 @@ class PyTypeTestNode(CoercionNode):
# object is an instance of a particular extension type. # object is an instance of a particular extension type.
# This node borrows the result of its argument node. # This node borrows the result of its argument node.
def __init__(self, arg, dst_type, env): def __init__(self, arg, dst_type, env, notnone=False):
# The arg is know to be a Python object, and # The arg is know to be a Python object, and
# the dst_type is known to be an extension type. # the dst_type is known to be an extension type.
assert dst_type.is_extension_type or dst_type.is_builtin_type, "PyTypeTest on non extension type" assert dst_type.is_extension_type or dst_type.is_builtin_type, "PyTypeTest on non extension type"
CoercionNode.__init__(self, arg) CoercionNode.__init__(self, arg)
self.type = dst_type self.type = dst_type
self.result_ctype = arg.ctype() self.result_ctype = arg.ctype()
self.notnone = notnone
nogil_check = Node.gil_error nogil_check = Node.gil_error
gil_message = "Python type test" gil_message = "Python type test"
...@@ -5596,7 +5597,7 @@ class PyTypeTestNode(CoercionNode): ...@@ -5596,7 +5597,7 @@ class PyTypeTestNode(CoercionNode):
code.globalstate.use_utility_code(type_test_utility_code) code.globalstate.use_utility_code(type_test_utility_code)
code.putln( code.putln(
"if (!(%s)) %s" % ( "if (!(%s)) %s" % (
self.type.type_test_code(self.arg.py_result()), self.type.type_test_code(self.arg.py_result(), self.notnone),
code.error_goto(self.pos))) code.error_goto(self.pos)))
else: else:
error(self.pos, "Cannot test type of extern C class " error(self.pos, "Cannot test type of extern C class "
...@@ -6008,18 +6009,18 @@ bad: ...@@ -6008,18 +6009,18 @@ bad:
type_test_utility_code = UtilityCode( type_test_utility_code = UtilityCode(
proto = """ proto = """
static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/ static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/
""", """,
impl = """ impl = """
static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
if (!type) { if (unlikely(!type)) {
PyErr_Format(PyExc_SystemError, "Missing type object"); PyErr_Format(PyExc_SystemError, "Missing type object");
return 0; return 0;
} }
if (obj == Py_None || PyObject_TypeCheck(obj, type)) if (likely(PyObject_TypeCheck(obj, type)))
return 1; return 1;
PyErr_Format(PyExc_TypeError, "Cannot convert %s to %s", PyErr_Format(PyExc_TypeError, "Cannot convert %.200s to %.200s",
Py_TYPE(obj)->tp_name, type->tp_name); Py_TYPE(obj)->tp_name, type->tp_name);
return 0; return 0;
} }
""") """)
......
...@@ -408,19 +408,24 @@ class BuiltinObjectType(PyObjectType): ...@@ -408,19 +408,24 @@ class BuiltinObjectType(PyObjectType):
def subtype_of(self, type): def subtype_of(self, type):
return type.is_pyobject and self.assignable_from(type) return type.is_pyobject and self.assignable_from(type)
def type_test_code(self, arg): def type_test_code(self, arg, notnone=False):
type_name = self.name type_name = self.name
if type_name == 'str': if type_name == 'str':
check = 'PyString_CheckExact' type_check = 'PyString_CheckExact'
elif type_name == 'set': elif type_name == 'set':
check = 'PyAnySet_CheckExact' type_check = 'PyAnySet_CheckExact'
elif type_name == 'frozenset': elif type_name == 'frozenset':
check = 'PyFrozenSet_CheckExact' type_check = 'PyFrozenSet_CheckExact'
elif type_name == 'bool': elif type_name == 'bool':
check = 'PyBool_Check' type_check = 'PyBool_Check'
else: else:
check = 'Py%s_CheckExact' % type_name.capitalize() type_check = 'Py%s_CheckExact' % type_name.capitalize()
return 'likely(%s(%s)) || (%s) == Py_None || (PyErr_Format(PyExc_TypeError, "Expected %s, got %%s", Py_TYPE(%s)->tp_name), 0)' % (check, arg, arg, self.name, arg)
check = 'likely(%s(%s))' % (type_check, arg)
if not notnone:
check = check + ('||((%s) == Py_None)' % arg)
error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg)
return check + '||' + error
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
...@@ -504,9 +509,16 @@ class PyExtensionType(PyObjectType): ...@@ -504,9 +509,16 @@ class PyExtensionType(PyObjectType):
else: else:
return "%s *%s" % (base, entity_code) return "%s *%s" % (base, entity_code)
def type_test_code(self, py_arg): def type_test_code(self, py_arg, notnone=False):
return "__Pyx_TypeTest(%s, %s)" % (py_arg, self.typeptr_cname)
none_check = "((%s) == Py_None)" % py_arg
type_check = "likely(__Pyx_TypeTest(%s, %s))" % (
py_arg, self.typeptr_cname)
if notnone:
return type_check
else:
return "likely(%s || %s)" % (none_check, type_check)
def attributes_known(self): def attributes_known(self):
return self.scope is not None return self.scope is not None
......
#cython: autotestdict=True
cdef class Foo:
pass
cdef class SubFoo(Foo):
pass
cdef class Bar:
pass
def foo1(arg):
"""
>>> foo1(Foo())
>>> foo1(SubFoo())
>>> foo1(None)
>>> foo1(123)
>>> foo1(Bar())
"""
cdef Foo val = <Foo>arg
def foo2(arg):
"""
>>> foo2(Foo())
>>> foo2(SubFoo())
>>> foo2(None)
>>> foo2(123)
Traceback (most recent call last):
...
TypeError: Cannot convert int to typetest_T417.Foo
>>> foo2(Bar())
Traceback (most recent call last):
...
TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo
"""
cdef Foo val = arg
def foo3(arg):
"""
>>> foo3(Foo())
>>> foo3(SubFoo())
>>> foo3(None)
Traceback (most recent call last):
...
TypeError: Cannot convert NoneType to typetest_T417.Foo
>>> foo3(123)
Traceback (most recent call last):
...
TypeError: Cannot convert int to typetest_T417.Foo
>>> foo2(Bar())
Traceback (most recent call last):
...
TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo
"""
cdef val = <Foo?>arg
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment