Commit 4355a479 authored by Guido van Rossum's avatar Guido van Rossum

Make all of test_bytes pass (except pickling, which is too badly busted).

parent 6c1e6741
...@@ -102,35 +102,35 @@ class BytesTest(unittest.TestCase): ...@@ -102,35 +102,35 @@ class BytesTest(unittest.TestCase):
self.failIf(b3 <= b2) self.failIf(b3 <= b2)
def test_compare_to_str(self): def test_compare_to_str(self):
self.assertEqual(b"abc" == "abc", True) self.assertEqual(b"abc" == str8("abc"), True)
self.assertEqual(b"ab" != "abc", True) self.assertEqual(b"ab" != str8("abc"), True)
self.assertEqual(b"ab" <= "abc", True) self.assertEqual(b"ab" <= str8("abc"), True)
self.assertEqual(b"ab" < "abc", True) self.assertEqual(b"ab" < str8("abc"), True)
self.assertEqual(b"abc" >= "ab", True) self.assertEqual(b"abc" >= str8("ab"), True)
self.assertEqual(b"abc" > "ab", True) self.assertEqual(b"abc" > str8("ab"), True)
self.assertEqual(b"abc" != "abc", False) self.assertEqual(b"abc" != str8("abc"), False)
self.assertEqual(b"ab" == "abc", False) self.assertEqual(b"ab" == str8("abc"), False)
self.assertEqual(b"ab" > "abc", False) self.assertEqual(b"ab" > str8("abc"), False)
self.assertEqual(b"ab" >= "abc", False) self.assertEqual(b"ab" >= str8("abc"), False)
self.assertEqual(b"abc" < "ab", False) self.assertEqual(b"abc" < str8("ab"), False)
self.assertEqual(b"abc" <= "ab", False) self.assertEqual(b"abc" <= str8("ab"), False)
self.assertEqual("abc" == b"abc", True) self.assertEqual(str8("abc") == b"abc", True)
self.assertEqual("ab" != b"abc", True) self.assertEqual(str8("ab") != b"abc", True)
self.assertEqual("ab" <= b"abc", True) self.assertEqual(str8("ab") <= b"abc", True)
self.assertEqual("ab" < b"abc", True) self.assertEqual(str8("ab") < b"abc", True)
self.assertEqual("abc" >= b"ab", True) self.assertEqual(str8("abc") >= b"ab", True)
self.assertEqual("abc" > b"ab", True) self.assertEqual(str8("abc") > b"ab", True)
self.assertEqual("abc" != b"abc", False) self.assertEqual(str8("abc") != b"abc", False)
self.assertEqual("ab" == b"abc", False) self.assertEqual(str8("ab") == b"abc", False)
self.assertEqual("ab" > b"abc", False) self.assertEqual(str8("ab") > b"abc", False)
self.assertEqual("ab" >= b"abc", False) self.assertEqual(str8("ab") >= b"abc", False)
self.assertEqual("abc" < b"ab", False) self.assertEqual(str8("abc") < b"ab", False)
self.assertEqual("abc" <= b"ab", False) self.assertEqual(str8("abc") <= b"ab", False)
# But they should never compare equal to Unicode! # Bytes should never compare equal to Unicode!
# Test this for all expected byte orders and Unicode character sizes # Test this for all expected byte orders and Unicode character sizes
self.assertEqual(b"\0a\0b\0c" == "abc", False) self.assertEqual(b"\0a\0b\0c" == "abc", False)
self.assertEqual(b"\0\0\0a\0\0\0b\0\0\0c" == "abc", False) self.assertEqual(b"\0\0\0a\0\0\0b\0\0\0c" == "abc", False)
...@@ -326,7 +326,7 @@ class BytesTest(unittest.TestCase): ...@@ -326,7 +326,7 @@ class BytesTest(unittest.TestCase):
sample = "Hello world\n\u1234\u5678\u9abc\udef0" sample = "Hello world\n\u1234\u5678\u9abc\udef0"
for enc in ("utf8", "utf16"): for enc in ("utf8", "utf16"):
b = bytes(sample, enc) b = bytes(sample, enc)
self.assertEqual(b, bytes(map(ord, sample.encode(enc)))) self.assertEqual(b, bytes(sample.encode(enc)))
self.assertRaises(UnicodeEncodeError, bytes, sample, "latin1") self.assertRaises(UnicodeEncodeError, bytes, sample, "latin1")
b = bytes(sample, "latin1", "ignore") b = bytes(sample, "latin1", "ignore")
self.assertEqual(b, bytes(sample[:-4])) self.assertEqual(b, bytes(sample[:-4]))
...@@ -342,7 +342,7 @@ class BytesTest(unittest.TestCase): ...@@ -342,7 +342,7 @@ class BytesTest(unittest.TestCase):
self.assertEqual(b.decode("utf8", "ignore"), "Hello world\n") self.assertEqual(b.decode("utf8", "ignore"), "Hello world\n")
def test_from_buffer(self): def test_from_buffer(self):
sample = "Hello world\n\x80\x81\xfe\xff" sample = str8("Hello world\n\x80\x81\xfe\xff")
buf = buffer(sample) buf = buffer(sample)
b = bytes(buf) b = bytes(buf)
self.assertEqual(b, bytes(map(ord, sample))) self.assertEqual(b, bytes(map(ord, sample)))
...@@ -364,8 +364,8 @@ class BytesTest(unittest.TestCase): ...@@ -364,8 +364,8 @@ class BytesTest(unittest.TestCase):
b1 = bytes("abc") b1 = bytes("abc")
b2 = bytes("def") b2 = bytes("def")
self.assertEqual(b1 + b2, bytes("abcdef")) self.assertEqual(b1 + b2, bytes("abcdef"))
self.assertEqual(b1 + "def", bytes("abcdef")) self.assertEqual(b1 + str8("def"), bytes("abcdef"))
self.assertEqual("def" + b1, bytes("defabc")) self.assertEqual(str8("def") + b1, bytes("defabc"))
self.assertRaises(TypeError, lambda: b1 + "def") self.assertRaises(TypeError, lambda: b1 + "def")
self.assertRaises(TypeError, lambda: "abc" + b2) self.assertRaises(TypeError, lambda: "abc" + b2)
...@@ -388,7 +388,7 @@ class BytesTest(unittest.TestCase): ...@@ -388,7 +388,7 @@ class BytesTest(unittest.TestCase):
self.assertEqual(b, bytes("abcdef")) self.assertEqual(b, bytes("abcdef"))
self.assertEqual(b, b1) self.assertEqual(b, b1)
self.failUnless(b is b1) self.failUnless(b is b1)
b += "xyz" b += str8("xyz")
self.assertEqual(b, b"abcdefxyz") self.assertEqual(b, b"abcdefxyz")
try: try:
b += "" b += ""
...@@ -456,8 +456,8 @@ class BytesTest(unittest.TestCase): ...@@ -456,8 +456,8 @@ class BytesTest(unittest.TestCase):
b = bytes([0x1a, 0x2b, 0x30]) b = bytes([0x1a, 0x2b, 0x30])
self.assertEquals(bytes.fromhex('1a2B30'), b) self.assertEquals(bytes.fromhex('1a2B30'), b)
self.assertEquals(bytes.fromhex(' 1A 2B 30 '), b) self.assertEquals(bytes.fromhex(' 1A 2B 30 '), b)
self.assertEquals(bytes.fromhex(buffer('')), bytes()) self.assertEquals(bytes.fromhex(buffer(b'')), bytes())
self.assertEquals(bytes.fromhex(buffer('0000')), bytes([0, 0])) self.assertEquals(bytes.fromhex(buffer(b'0000')), bytes([0, 0]))
self.assertRaises(ValueError, bytes.fromhex, 'a') self.assertRaises(ValueError, bytes.fromhex, 'a')
self.assertRaises(ValueError, bytes.fromhex, 'rt') self.assertRaises(ValueError, bytes.fromhex, 'rt')
self.assertRaises(ValueError, bytes.fromhex, '1a b cd') self.assertRaises(ValueError, bytes.fromhex, '1a b cd')
...@@ -717,5 +717,5 @@ def test_main(): ...@@ -717,5 +717,5 @@ def test_main():
if __name__ == "__main__": if __name__ == "__main__":
test_main() ##test_main()
##unittest.main() unittest.main()
...@@ -218,6 +218,7 @@ bytes_iconcat(PyBytesObject *self, PyObject *other) ...@@ -218,6 +218,7 @@ bytes_iconcat(PyBytesObject *self, PyObject *other)
Py_ssize_t mysize; Py_ssize_t mysize;
Py_ssize_t size; Py_ssize_t size;
/* XXX What if other == self? */
osize = _getbuffer(other, &optr); osize = _getbuffer(other, &optr);
if (osize < 0) { if (osize < 0) {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
...@@ -698,33 +699,24 @@ bytes_init(PyBytesObject *self, PyObject *args, PyObject *kwds) ...@@ -698,33 +699,24 @@ bytes_init(PyBytesObject *self, PyObject *args, PyObject *kwds)
if (PyUnicode_Check(arg)) { if (PyUnicode_Check(arg)) {
/* Encode via the codec registry */ /* Encode via the codec registry */
PyObject *encoded; PyObject *encoded, *new;
char *bytes;
Py_ssize_t size;
if (encoding == NULL) if (encoding == NULL)
encoding = PyUnicode_GetDefaultEncoding(); encoding = PyUnicode_GetDefaultEncoding();
encoded = PyCodec_Encode(arg, encoding, errors); encoded = PyCodec_Encode(arg, encoding, errors);
if (encoded == NULL) if (encoded == NULL)
return -1; return -1;
if (!PyString_Check(encoded)) { if (!PyBytes_Check(encoded) && !PyString_Check(encoded)) {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"encoder did not return a string object (type=%.400s)", "encoder did not return a str8 or bytes object (type=%.400s)",
encoded->ob_type->tp_name); encoded->ob_type->tp_name);
Py_DECREF(encoded); Py_DECREF(encoded);
return -1; return -1;
} }
bytes = PyString_AS_STRING(encoded); new = bytes_iconcat(self, encoded);
size = PyString_GET_SIZE(encoded);
if (size < self->ob_alloc) {
self->ob_size = size;
self->ob_bytes[self->ob_size] = '\0'; /* Trailing null byte */
}
else if (PyBytes_Resize((PyObject *)self, size) < 0) {
Py_DECREF(encoded); Py_DECREF(encoded);
if (new == NULL)
return -1; return -1;
} Py_DECREF(new);
memcpy(self->ob_bytes, bytes, size);
Py_DECREF(encoded);
return 0; return 0;
} }
...@@ -2689,7 +2681,7 @@ bytes_fromhex(PyObject *cls, PyObject *args) ...@@ -2689,7 +2681,7 @@ bytes_fromhex(PyObject *cls, PyObject *args)
return NULL; return NULL;
buf = PyBytes_AS_STRING(newbytes); buf = PyBytes_AS_STRING(newbytes);
for (i = j = 0; ; i += 2) { for (i = j = 0; i < len; i += 2) {
/* skip over spaces in the input */ /* skip over spaces in the input */
while (Py_CHARMASK(hex[i]) == ' ') while (Py_CHARMASK(hex[i]) == ' ')
i++; i++;
......
...@@ -5634,6 +5634,12 @@ unicode_encode(PyUnicodeObject *self, PyObject *args) ...@@ -5634,6 +5634,12 @@ unicode_encode(PyUnicodeObject *self, PyObject *args)
if (v == NULL) if (v == NULL)
goto onError; goto onError;
if (!PyBytes_Check(v)) { if (!PyBytes_Check(v)) {
if (PyString_Check(v)) {
/* Old codec, turn it into bytes */
PyObject *b = PyBytes_FromObject(v);
Py_DECREF(v);
return b;
}
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"encoder did not return a bytes object " "encoder did not return a bytes object "
"(type=%.400s)", "(type=%.400s)",
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment