Commit f8233ea5 authored by Stefan Behnel's avatar Stefan Behnel

extend semantics of 'basestring' typed variables to represent exactly...

extend semantics of 'basestring' typed variables to represent exactly bytes/str/unicode but no subtypes
parent 9e3a2d7a
...@@ -8,6 +8,11 @@ Cython Changelog ...@@ -8,6 +8,11 @@ Cython Changelog
Features added Features added
-------------- --------------
* Using ``cdef basestring stringvar`` and function arguments typed as
``basestring`` is now meaningful and allows assigning exactly
``bytes`` (Py2-only), ``str`` and ``unicode`` (Py2/Py3) objects,
but no subtypes of these types.
* Support for the ``__debug__`` builtin. * Support for the ``__debug__`` builtin.
* Assertions in Cython compiled modules are disabled if the running * Assertions in Cython compiled modules are disabled if the running
......
...@@ -408,7 +408,7 @@ def init_builtins(): ...@@ -408,7 +408,7 @@ def init_builtins():
'__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type), '__debug__', PyrexTypes.c_const_type(PyrexTypes.c_bint_type),
pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True) pos=None, cname='(!Py_OptimizeFlag)', is_cdef=True)
global list_type, tuple_type, dict_type, set_type, frozenset_type global list_type, tuple_type, dict_type, set_type, frozenset_type
global bytes_type, str_type, unicode_type global bytes_type, str_type, unicode_type, basestring_type
global float_type, bool_type, type_type, complex_type global float_type, bool_type, type_type, complex_type
type_type = builtin_scope.lookup('type').type type_type = builtin_scope.lookup('type').type
list_type = builtin_scope.lookup('list').type list_type = builtin_scope.lookup('list').type
...@@ -419,6 +419,7 @@ def init_builtins(): ...@@ -419,6 +419,7 @@ def init_builtins():
bytes_type = builtin_scope.lookup('bytes').type bytes_type = builtin_scope.lookup('bytes').type
str_type = builtin_scope.lookup('str').type str_type = builtin_scope.lookup('str').type
unicode_type = builtin_scope.lookup('unicode').type unicode_type = builtin_scope.lookup('unicode').type
basestring_type = builtin_scope.lookup('basestring').type
float_type = builtin_scope.lookup('float').type float_type = builtin_scope.lookup('float').type
bool_type = builtin_scope.lookup('bool').type bool_type = builtin_scope.lookup('bool').type
complex_type = builtin_scope.lookup('complex').type complex_type = builtin_scope.lookup('complex').type
......
...@@ -1160,7 +1160,7 @@ class BytesNode(ConstNode): ...@@ -1160,7 +1160,7 @@ class BytesNode(ConstNode):
node = BytesNode(self.pos, value=self.value, node = BytesNode(self.pos, value=self.value,
constant_result=self.constant_result) constant_result=self.constant_result)
if dst_type.is_pyobject: if dst_type.is_pyobject:
if dst_type in (py_object_type, Builtin.bytes_type): if dst_type in (py_object_type, Builtin.bytes_type, Builtin.basestring_type):
node.type = Builtin.bytes_type node.type = Builtin.bytes_type
else: else:
self.check_for_coercion_error(dst_type, env, fail=True) self.check_for_coercion_error(dst_type, env, fail=True)
...@@ -1250,9 +1250,8 @@ class UnicodeNode(ConstNode): ...@@ -1250,9 +1250,8 @@ class UnicodeNode(ConstNode):
"Unicode literals do not support coercion to C types other " "Unicode literals do not support coercion to C types other "
"than Py_UNICODE/Py_UCS4 (for characters) or Py_UNICODE* " "than Py_UNICODE/Py_UCS4 (for characters) or Py_UNICODE* "
"(for strings).") "(for strings).")
elif dst_type is not py_object_type: elif dst_type not in (py_object_type, Builtin.basestring_type):
if not self.check_for_coercion_error(dst_type, env): self.check_for_coercion_error(dst_type, env, fail=True)
self.fail_assignment(dst_type)
return self return self
def can_coerce_to_char_literal(self): def can_coerce_to_char_literal(self):
...@@ -1337,6 +1336,7 @@ class StringNode(PyConstNode): ...@@ -1337,6 +1336,7 @@ class StringNode(PyConstNode):
# return BytesNode(self.pos, value=self.value) # return BytesNode(self.pos, value=self.value)
if not dst_type.is_pyobject: if not dst_type.is_pyobject:
return BytesNode(self.pos, value=self.value).coerce_to(dst_type, env) return BytesNode(self.pos, value=self.value).coerce_to(dst_type, env)
if dst_type is not Builtin.basestring_type:
self.check_for_coercion_error(dst_type, env, fail=True) self.check_for_coercion_error(dst_type, env, fail=True)
return self return self
......
...@@ -962,6 +962,9 @@ class BuiltinObjectType(PyObjectType): ...@@ -962,6 +962,9 @@ class BuiltinObjectType(PyObjectType):
def assignable_from(self, src_type): def assignable_from(self, src_type):
if isinstance(src_type, BuiltinObjectType): if isinstance(src_type, BuiltinObjectType):
if self.name == 'basestring':
return src_type.name in ('bytes', 'str', 'unicode', 'basestring')
else:
return src_type.name == self.name return src_type.name == self.name
elif src_type.is_extension_type: elif src_type.is_extension_type:
# FIXME: This is an ugly special case that we currently # FIXME: This is an ugly special case that we currently
...@@ -1005,7 +1008,15 @@ class BuiltinObjectType(PyObjectType): ...@@ -1005,7 +1008,15 @@ class BuiltinObjectType(PyObjectType):
check = 'likely(%s(%s))' % (type_check, arg) check = 'likely(%s(%s))' % (type_check, arg)
if not notnone: if not notnone:
check += '||((%s) == Py_None)' % arg check += '||((%s) == Py_None)' % arg
error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg) if self.name == 'basestring':
name = '(PY_MAJOR_VERSION < 3 ? "basestring" : "str")'
space_for_name = 16
else:
name = '"%s"' % self.name
# avoid wasting too much space but limit number of different format strings
space_for_name = (len(self.name) // 16 + 1) * 16
error = '(PyErr_Format(PyExc_TypeError, "Expected %%.%ds, got %%.200s", %s, Py_TYPE(%s)->tp_name), 0)' % (
space_for_name, name, arg)
return check + '||' + error return check + '||' + error
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
......
...@@ -14,7 +14,10 @@ static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed ...@@ -14,7 +14,10 @@ static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed
} }
if (none_allowed && obj == Py_None) return 1; if (none_allowed && obj == Py_None) return 1;
else if (exact) { else if (exact) {
if (Py_TYPE(obj) == type) return 1; if (likely(Py_TYPE(obj) == type)) return 1;
#if PY_MAJOR_VERSION == 2
else if ((type == &PyBaseString_Type) && __Pyx_PyBaseString_CheckExact(obj)) return 1;
#endif
} }
else { else {
if (PyObject_TypeCheck(obj, type)) return 1; if (PyObject_TypeCheck(obj, type)) return 1;
......
...@@ -186,7 +186,7 @@ ...@@ -186,7 +186,7 @@
#else #else
#define __Pyx_PyBaseString_Check(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj) || \ #define __Pyx_PyBaseString_Check(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj) || \
PyString_Check(obj) || PyUnicode_Check(obj)) PyString_Check(obj) || PyUnicode_Check(obj))
#define __Pyx_PyBaseString_CheckExact(obj) (Py_TYPE(obj) == &PyBaseString_Type) #define __Pyx_PyBaseString_CheckExact(obj) (PyString_CheckExact(obj) || PyUnicode_CheckExact(obj))
#endif #endif
#if PY_VERSION_HEX < 0x02060000 #if PY_VERSION_HEX < 0x02060000
......
...@@ -37,3 +37,52 @@ def unicode_subtypes_basestring(): ...@@ -37,3 +37,52 @@ def unicode_subtypes_basestring():
True True
""" """
return issubclass(unicode, basestring) return issubclass(unicode, basestring)
def basestring_typed_variable(obj):
"""
>>> basestring_typed_variable(None) is None
True
>>> basestring_typed_variable(ustring) is ustring
True
>>> basestring_typed_variable(sstring) is sstring
True
>>> if IS_PY3: print(True)
... else: print(basestring_typed_variable(bstring) is bstring)
True
>>> class S(str): pass
>>> basestring_typed_variable(S()) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...got S...
"""
cdef basestring s
s = u'abc'
assert s
s = 'abc'
assert s
s = b'abc'
assert s
# make sure coercion also works in conditional expressions
s = u'abc' if obj else b'abc' if obj else 'abc'
assert s
s = obj
return s
def basestring_typed_argument(basestring obj):
"""
>>> basestring_typed_argument(None) is None
True
>>> basestring_typed_argument(ustring) is ustring
True
>>> basestring_typed_argument(sstring) is sstring
True
>>> if IS_PY3: print(True)
... else: print(basestring_typed_argument(bstring) is bstring)
True
>>> class S(str): pass
>>> basestring_typed_argument(S()) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...got S...
"""
return obj
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment