Commit 4626458c authored by Guido van Rossum's avatar Guido van Rossum

SF patch# 1767398 by Adam Hupp.

Fix csv to read/write bytes from/to binary files.
Fix the unit tests to test this and to use with TemporaryFile().
parent dd766d53
...@@ -6,7 +6,7 @@ import sys ...@@ -6,7 +6,7 @@ import sys
import os import os
import unittest import unittest
from StringIO import StringIO from StringIO import StringIO
import tempfile from tempfile import TemporaryFile
import csv import csv
import gc import gc
from test import test_support from test import test_support
...@@ -117,17 +117,12 @@ class Test_Csv(unittest.TestCase): ...@@ -117,17 +117,12 @@ class Test_Csv(unittest.TestCase):
def _write_test(self, fields, expect, **kwargs): def _write_test(self, fields, expect, **kwargs):
fd, name = tempfile.mkstemp() with TemporaryFile("w+b") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, **kwargs) writer = csv.writer(fileobj, **kwargs)
writer.writerow(fields) writer.writerow(fields)
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), self.assertEqual(str(fileobj.read()),
expect + writer.dialect.lineterminator) expect + writer.dialect.lineterminator)
finally:
fileobj.close()
os.unlink(name)
def test_write_arg_valid(self): def test_write_arg_valid(self):
self.assertRaises(csv.Error, self._write_test, None, '') self.assertRaises(csv.Error, self._write_test, None, '')
...@@ -192,17 +187,13 @@ class Test_Csv(unittest.TestCase): ...@@ -192,17 +187,13 @@ class Test_Csv(unittest.TestCase):
raise IOError raise IOError
writer = csv.writer(BrokenFile()) writer = csv.writer(BrokenFile())
self.assertRaises(IOError, writer.writerows, [['a']]) self.assertRaises(IOError, writer.writerows, [['a']])
fd, name = tempfile.mkstemp()
fileobj = os.fdopen(fd, "w+b") with TemporaryFile("w+b") as fileobj:
try:
writer = csv.writer(fileobj) writer = csv.writer(fileobj)
self.assertRaises(TypeError, writer.writerows, None) self.assertRaises(TypeError, writer.writerows, None)
writer.writerows([['a','b'],['c','d']]) writer.writerows([['a','b'],['c','d']])
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), "a,b\r\nc,d\r\n") self.assertEqual(fileobj.read(), b"a,b\r\nc,d\r\n")
finally:
fileobj.close()
os.unlink(name)
def _read_test(self, input, expect, **kwargs): def _read_test(self, input, expect, **kwargs):
reader = csv.reader(input, **kwargs) reader = csv.reader(input, **kwargs)
...@@ -333,17 +324,19 @@ class TestDialectRegistry(unittest.TestCase): ...@@ -333,17 +324,19 @@ class TestDialectRegistry(unittest.TestCase):
quoting = csv.QUOTE_NONE quoting = csv.QUOTE_NONE
escapechar = "\\" escapechar = "\\"
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write("abc def\nc1ccccc1 benzene\n") fileobj.write("abc def\nc1ccccc1 benzene\n")
fileobj.seek(0) fileobj.seek(0)
reader = csv.reader(fileobj, dialect=space()) reader = csv.reader(fileobj, dialect=space())
self.assertEqual(next(reader), ["abc", "def"]) self.assertEqual(next(reader), ["abc", "def"])
self.assertEqual(next(reader), ["c1ccccc1", "benzene"]) self.assertEqual(next(reader), ["c1ccccc1", "benzene"])
finally:
fileobj.close() def compare_dialect_123(self, expected, *writeargs, **kwwriteargs):
os.unlink(name) with TemporaryFile("w+b") as fileobj:
writer = csv.writer(fileobj, *writeargs, **kwwriteargs)
writer.writerow([1,2,3])
fileobj.seek(0)
self.assertEqual(str(fileobj.read()), expected)
def test_dialect_apply(self): def test_dialect_apply(self):
class testA(csv.excel): class testA(csv.excel):
...@@ -352,63 +345,19 @@ class TestDialectRegistry(unittest.TestCase): ...@@ -352,63 +345,19 @@ class TestDialectRegistry(unittest.TestCase):
delimiter = ":" delimiter = ":"
class testC(csv.excel): class testC(csv.excel):
delimiter = "|" delimiter = "|"
class testUni(csv.excel):
delimiter = "\u039B"
csv.register_dialect('testC', testC) csv.register_dialect('testC', testC)
try: try:
fd, name = tempfile.mkstemp() self.compare_dialect_123("1,2,3\r\n")
fileobj = os.fdopen(fd, "w+b") self.compare_dialect_123("1\t2\t3\r\n", testA)
try: self.compare_dialect_123("1:2:3\r\n", dialect=testB())
writer = csv.writer(fileobj) self.compare_dialect_123("1|2|3\r\n", dialect='testC')
writer.writerow([1,2,3]) self.compare_dialect_123("1;2;3\r\n", dialect=testA,
fileobj.seek(0) delimiter=';')
self.assertEqual(fileobj.read(), "1,2,3\r\n") self.compare_dialect_123("1\u039B2\u039B3\r\n",
finally: dialect=testUni)
fileobj.close()
os.unlink(name)
fd, name = tempfile.mkstemp()
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, testA)
writer.writerow([1,2,3])
fileobj.seek(0)
self.assertEqual(fileobj.read(), "1\t2\t3\r\n")
finally:
fileobj.close()
os.unlink(name)
fd, name = tempfile.mkstemp()
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, dialect=testB())
writer.writerow([1,2,3])
fileobj.seek(0)
self.assertEqual(fileobj.read(), "1:2:3\r\n")
finally:
fileobj.close()
os.unlink(name)
fd, name = tempfile.mkstemp()
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, dialect='testC')
writer.writerow([1,2,3])
fileobj.seek(0)
self.assertEqual(fileobj.read(), "1|2|3\r\n")
finally:
fileobj.close()
os.unlink(name)
fd, name = tempfile.mkstemp()
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, dialect=testA, delimiter=';')
writer.writerow([1,2,3])
fileobj.seek(0)
self.assertEqual(fileobj.read(), "1;2;3\r\n")
finally:
fileobj.close()
os.unlink(name)
finally: finally:
csv.unregister_dialect('testC') csv.unregister_dialect('testC')
...@@ -423,29 +372,19 @@ class TestDialectRegistry(unittest.TestCase): ...@@ -423,29 +372,19 @@ class TestDialectRegistry(unittest.TestCase):
class TestCsvBase(unittest.TestCase): class TestCsvBase(unittest.TestCase):
def readerAssertEqual(self, input, expected_result): def readerAssertEqual(self, input, expected_result):
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write(input) fileobj.write(input)
fileobj.seek(0) fileobj.seek(0)
reader = csv.reader(fileobj, dialect = self.dialect) reader = csv.reader(fileobj, dialect = self.dialect)
fields = list(reader) fields = list(reader)
self.assertEqual(fields, expected_result) self.assertEqual(fields, expected_result)
finally:
fileobj.close()
os.unlink(name)
def writerAssertEqual(self, input, expected_result): def writerAssertEqual(self, input, expected_result):
fd, name = tempfile.mkstemp() with TemporaryFile("w+b") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, dialect = self.dialect) writer = csv.writer(fileobj, dialect = self.dialect)
writer.writerows(input) writer.writerows(input)
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), expected_result) self.assertEqual(str(fileobj.read()), expected_result)
finally:
fileobj.close()
os.unlink(name)
class TestDialectExcel(TestCsvBase): class TestDialectExcel(TestCsvBase):
dialect = 'excel' dialect = 'excel'
...@@ -574,91 +513,59 @@ class TestDictFields(unittest.TestCase): ...@@ -574,91 +513,59 @@ class TestDictFields(unittest.TestCase):
### "long" means the row is longer than the number of fieldnames ### "long" means the row is longer than the number of fieldnames
### "short" means there are fewer elements in the row than fieldnames ### "short" means there are fewer elements in the row than fieldnames
def test_write_simple_dict(self): def test_write_simple_dict(self):
fd, name = tempfile.mkstemp() with TemporaryFile("w+b") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"]) writer = csv.DictWriter(fileobj, fieldnames = ["f1", "f2", "f3"])
writer.writerow({"f1": 10, "f3": "abc"}) writer.writerow({"f1": 10, "f3": "abc"})
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), "10,,abc\r\n") self.assertEqual(str(fileobj.read()), "10,,abc\r\n")
finally:
fileobj.close()
os.unlink(name)
def test_write_no_fields(self): def test_write_no_fields(self):
fileobj = StringIO() fileobj = StringIO()
self.assertRaises(TypeError, csv.DictWriter, fileobj) self.assertRaises(TypeError, csv.DictWriter, fileobj)
def test_read_dict_fields(self): def test_read_dict_fields(self):
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write("1,2,abc\r\n") fileobj.write("1,2,abc\r\n")
fileobj.seek(0) fileobj.seek(0)
reader = csv.DictReader(fileobj, reader = csv.DictReader(fileobj,
fieldnames=["f1", "f2", "f3"]) fieldnames=["f1", "f2", "f3"])
self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'})
finally:
fileobj.close()
os.unlink(name)
def test_read_dict_no_fieldnames(self): def test_read_dict_no_fieldnames(self):
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write("f1,f2,f3\r\n1,2,abc\r\n") fileobj.write("f1,f2,f3\r\n1,2,abc\r\n")
fileobj.seek(0) fileobj.seek(0)
reader = csv.DictReader(fileobj) reader = csv.DictReader(fileobj)
self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'}) self.assertEqual(next(reader), {"f1": '1', "f2": '2', "f3": 'abc'})
finally:
fileobj.close()
os.unlink(name)
def test_read_long(self): def test_read_long(self):
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write("1,2,abc,4,5,6\r\n") fileobj.write("1,2,abc,4,5,6\r\n")
fileobj.seek(0) fileobj.seek(0)
reader = csv.DictReader(fileobj, reader = csv.DictReader(fileobj,
fieldnames=["f1", "f2"]) fieldnames=["f1", "f2"])
self.assertEqual(next(reader), {"f1": '1', "f2": '2', self.assertEqual(next(reader), {"f1": '1', "f2": '2',
None: ["abc", "4", "5", "6"]}) None: ["abc", "4", "5", "6"]})
finally:
fileobj.close()
os.unlink(name)
def test_read_long_with_rest(self): def test_read_long_with_rest(self):
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write("1,2,abc,4,5,6\r\n") fileobj.write("1,2,abc,4,5,6\r\n")
fileobj.seek(0) fileobj.seek(0)
reader = csv.DictReader(fileobj, reader = csv.DictReader(fileobj,
fieldnames=["f1", "f2"], restkey="_rest") fieldnames=["f1", "f2"], restkey="_rest")
self.assertEqual(next(reader), {"f1": '1', "f2": '2', self.assertEqual(next(reader), {"f1": '1', "f2": '2',
"_rest": ["abc", "4", "5", "6"]}) "_rest": ["abc", "4", "5", "6"]})
finally:
fileobj.close()
os.unlink(name)
def test_read_long_with_rest_no_fieldnames(self): def test_read_long_with_rest_no_fieldnames(self):
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n") fileobj.write("f1,f2\r\n1,2,abc,4,5,6\r\n")
fileobj.seek(0) fileobj.seek(0)
reader = csv.DictReader(fileobj, restkey="_rest") reader = csv.DictReader(fileobj, restkey="_rest")
self.assertEqual(next(reader), {"f1": '1', "f2": '2', self.assertEqual(next(reader), {"f1": '1', "f2": '2',
"_rest": ["abc", "4", "5", "6"]}) "_rest": ["abc", "4", "5", "6"]})
finally:
fileobj.close()
os.unlink(name)
def test_read_short(self): def test_read_short(self):
fd, name = tempfile.mkstemp() with TemporaryFile("w+") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
fileobj.write("1,2,abc,4,5,6\r\n1,2,abc\r\n") fileobj.write("1,2,abc,4,5,6\r\n1,2,abc\r\n")
fileobj.seek(0) fileobj.seek(0)
reader = csv.DictReader(fileobj, reader = csv.DictReader(fileobj,
...@@ -669,9 +576,6 @@ class TestDictFields(unittest.TestCase): ...@@ -669,9 +576,6 @@ class TestDictFields(unittest.TestCase):
self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc', self.assertEqual(next(reader), {"1": '1', "2": '2', "3": 'abc',
"4": 'DEFAULT', "5": 'DEFAULT', "4": 'DEFAULT', "5": 'DEFAULT',
"6": 'DEFAULT'}) "6": 'DEFAULT'})
finally:
fileobj.close()
os.unlink(name)
def test_read_multi(self): def test_read_multi(self):
sample = [ sample = [
...@@ -710,64 +614,45 @@ class TestArrayWrites(unittest.TestCase): ...@@ -710,64 +614,45 @@ class TestArrayWrites(unittest.TestCase):
contents = [(20-i) for i in range(20)] contents = [(20-i) for i in range(20)]
a = array.array('i', contents) a = array.array('i', contents)
fd, name = tempfile.mkstemp() with TemporaryFile("w+b") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, dialect="excel") writer = csv.writer(fileobj, dialect="excel")
writer.writerow(a) writer.writerow(a)
expected = ",".join([str(i) for i in a])+"\r\n" expected = ",".join([str(i) for i in a])+"\r\n"
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), expected) self.assertEqual(str(fileobj.read()), expected)
finally:
fileobj.close()
os.unlink(name)
def test_double_write(self): def test_double_write(self):
import array import array
contents = [(20-i)*0.1 for i in range(20)] contents = [(20-i)*0.1 for i in range(20)]
a = array.array('d', contents) a = array.array('d', contents)
fd, name = tempfile.mkstemp() with TemporaryFile("w+b") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, dialect="excel") writer = csv.writer(fileobj, dialect="excel")
writer.writerow(a) writer.writerow(a)
expected = ",".join([str(i) for i in a])+"\r\n" expected = ",".join([str(i) for i in a])+"\r\n"
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), expected) self.assertEqual(str(fileobj.read()), expected)
finally:
fileobj.close()
os.unlink(name)
def test_float_write(self): def test_float_write(self):
import array import array
contents = [(20-i)*0.1 for i in range(20)] contents = [(20-i)*0.1 for i in range(20)]
a = array.array('f', contents) a = array.array('f', contents)
fd, name = tempfile.mkstemp() with TemporaryFile("w+b") as fileobj:
fileobj = os.fdopen(fd, "w+b")
try:
writer = csv.writer(fileobj, dialect="excel") writer = csv.writer(fileobj, dialect="excel")
writer.writerow(a) writer.writerow(a)
expected = ",".join([str(i) for i in a])+"\r\n" expected = ",".join([str(i) for i in a])+"\r\n"
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), expected) self.assertEqual(str(fileobj.read()), expected)
finally:
fileobj.close()
os.unlink(name)
def test_char_write(self): def test_char_write(self):
import array, string import array, string
a = array.array('c', string.letters) a = array.array('u', string.letters)
fd, name = tempfile.mkstemp()
fileobj = os.fdopen(fd, "w+b") with TemporaryFile("w+b") as fileobj:
try:
writer = csv.writer(fileobj, dialect="excel") writer = csv.writer(fileobj, dialect="excel")
writer.writerow(a) writer.writerow(a)
expected = ",".join(a)+"\r\n" expected = ",".join(a)+"\r\n"
fileobj.seek(0) fileobj.seek(0)
self.assertEqual(fileobj.read(), expected) self.assertEqual(str(fileobj.read()), expected)
finally:
fileobj.close()
os.unlink(name)
class TestDialectValidity(unittest.TestCase): class TestDialectValidity(unittest.TestCase):
def test_quoting(self): def test_quoting(self):
...@@ -970,20 +855,36 @@ else: ...@@ -970,20 +855,36 @@ else:
# if writer leaks during write, last delta should be 5 or more # if writer leaks during write, last delta should be 5 or more
self.assertEqual(delta < 5, True) self.assertEqual(delta < 5, True)
# commented out for now - csv module doesn't yet support Unicode class TestUnicode(unittest.TestCase):
## class TestUnicode(unittest.TestCase):
## def test_unicode_read(self): names = ["Martin von Löwis",
## import codecs "Marc André Lemburg",
## f = codecs.EncodedFile(StringIO("Martin von Löwis," "Guido van Rossum",
## "Marc André Lemburg," "François Pinard"]
## "Guido van Rossum,"
## "François Pinard\r\n"), def test_unicode_read(self):
## data_encoding='iso-8859-1') import io
## reader = csv.reader(f) fileobj = io.TextIOWrapper(TemporaryFile("w+b"), encoding="utf-16")
## self.assertEqual(list(reader), [["Martin von Löwis", with fileobj as fileobj:
## "Marc André Lemburg", fileobj.write(",".join(self.names) + "\r\n")
## "Guido van Rossum",
## "François Pinardn"]]) fileobj.seek(0)
reader = csv.reader(fileobj)
self.assertEqual(list(reader), [self.names])
def test_unicode_write(self):
import io
with TemporaryFile("w+b") as fileobj:
encwriter = io.TextIOWrapper(fileobj, encoding="utf-8")
writer = csv.writer(encwriter)
writer.writerow(self.names)
expected = ",".join(self.names)+"\r\n"
fileobj.seek(0)
self.assertEqual(str(fileobj.read()), expected)
def test_main(): def test_main():
mod = sys.modules[__name__] mod = sys.modules[__name__]
......
...@@ -93,11 +93,11 @@ static StyleDesc quote_styles[] = { ...@@ -93,11 +93,11 @@ static StyleDesc quote_styles[] = {
typedef struct { typedef struct {
PyObject_HEAD PyObject_HEAD
int doublequote; /* is " represented by ""? */ int doublequote; /* is " represented by ""? */
char delimiter; /* field separator */ Py_UNICODE delimiter; /* field separator */
char quotechar; /* quote character */ Py_UNICODE quotechar; /* quote character */
char escapechar; /* escape character */ Py_UNICODE escapechar; /* escape character */
int skipinitialspace; /* ignore spaces following delimiter? */ int skipinitialspace; /* ignore spaces following delimiter? */
PyObject *lineterminator; /* string to write between records */ PyObject *lineterminator; /* string to write between records */
int quoting; /* style of quoting to write */ int quoting; /* style of quoting to write */
...@@ -116,9 +116,9 @@ typedef struct { ...@@ -116,9 +116,9 @@ typedef struct {
PyObject *fields; /* field list for current record */ PyObject *fields; /* field list for current record */
ParserState state; /* current CSV parse state */ ParserState state; /* current CSV parse state */
char *field; /* build current field in here */ Py_UNICODE *field; /* build current field in here */
int field_size; /* size of allocated buffer */ int field_size; /* size of allocated buffer */
int field_len; /* length of current field */ Py_ssize_t field_len; /* length of current field */
int numeric_field; /* treat field as numeric */ int numeric_field; /* treat field as numeric */
unsigned long line_num; /* Source-file line number */ unsigned long line_num; /* Source-file line number */
} ReaderObj; } ReaderObj;
...@@ -134,11 +134,11 @@ typedef struct { ...@@ -134,11 +134,11 @@ typedef struct {
DialectObj *dialect; /* parsing dialect */ DialectObj *dialect; /* parsing dialect */
char *rec; /* buffer for parser.join */ Py_UNICODE *rec; /* buffer for parser.join */
int rec_size; /* size of allocated record */ int rec_size; /* size of allocated record */
int rec_len; /* length of record */ Py_ssize_t rec_len; /* length of record */
int num_fields; /* number of fields in record */ int num_fields; /* number of fields in record */
} WriterObj; } WriterObj;
static PyTypeObject Writer_Type; static PyTypeObject Writer_Type;
...@@ -176,7 +176,7 @@ get_nullchar_as_None(char c) ...@@ -176,7 +176,7 @@ get_nullchar_as_None(char c)
return Py_None; return Py_None;
} }
else else
return PyString_FromStringAndSize((char*)&c, 1); return PyUnicode_DecodeASCII((char*)&c, 1, NULL);
} }
static PyObject * static PyObject *
...@@ -230,20 +230,21 @@ _set_int(const char *name, int *target, PyObject *src, int dflt) ...@@ -230,20 +230,21 @@ _set_int(const char *name, int *target, PyObject *src, int dflt)
} }
static int static int
_set_char(const char *name, char *target, PyObject *src, char dflt) _set_char(const char *name, Py_UNICODE *target, PyObject *src, Py_UNICODE dflt)
{ {
if (src == NULL) if (src == NULL)
*target = dflt; *target = dflt;
else { else {
*target = '\0'; *target = '\0';
if (src != Py_None) { if (src != Py_None) {
const char *buf; Py_UNICODE *buf;
Py_ssize_t len; Py_ssize_t len;
if (PyObject_AsCharBuffer(src, &buf, &len) < 0 || buf = PyUnicode_AsUnicode(src);
len > 1) { len = PyUnicode_GetSize(src);
if (buf == NULL || len > 1) {
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"\"%s\" must be an 1-character string", "\"%s\" must be an 1-character string",
name); name);
return -1; return -1;
} }
if (len > 0) if (len > 0)
...@@ -257,7 +258,7 @@ static int ...@@ -257,7 +258,7 @@ static int
_set_str(const char *name, PyObject **target, PyObject *src, const char *dflt) _set_str(const char *name, PyObject **target, PyObject *src, const char *dflt)
{ {
if (src == NULL) if (src == NULL)
*target = PyString_FromString(dflt); *target = PyUnicode_DecodeASCII(dflt, strlen(dflt), NULL);
else { else {
if (src == Py_None) if (src == Py_None)
*target = NULL; *target = NULL;
...@@ -528,7 +529,7 @@ parse_save_field(ReaderObj *self) ...@@ -528,7 +529,7 @@ parse_save_field(ReaderObj *self)
{ {
PyObject *field; PyObject *field;
field = PyString_FromStringAndSize(self->field, self->field_len); field = PyUnicode_FromUnicode(self->field, self->field_len);
if (field == NULL) if (field == NULL)
return -1; return -1;
self->field_len = 0; self->field_len = 0;
...@@ -556,11 +557,12 @@ parse_grow_buff(ReaderObj *self) ...@@ -556,11 +557,12 @@ parse_grow_buff(ReaderObj *self)
self->field_size = 4096; self->field_size = 4096;
if (self->field != NULL) if (self->field != NULL)
PyMem_Free(self->field); PyMem_Free(self->field);
self->field = PyMem_Malloc(self->field_size); self->field = PyMem_New(Py_UNICODE, self->field_size);
} }
else { else {
self->field_size *= 2; self->field_size *= 2;
self->field = PyMem_Realloc(self->field, self->field_size); self->field = PyMem_Resize(self->field, Py_UNICODE,
self->field_size);
} }
if (self->field == NULL) { if (self->field == NULL) {
PyErr_NoMemory(); PyErr_NoMemory();
...@@ -570,7 +572,7 @@ parse_grow_buff(ReaderObj *self) ...@@ -570,7 +572,7 @@ parse_grow_buff(ReaderObj *self)
} }
static int static int
parse_add_char(ReaderObj *self, char c) parse_add_char(ReaderObj *self, Py_UNICODE c)
{ {
if (self->field_len >= field_limit) { if (self->field_len >= field_limit) {
PyErr_Format(error_obj, "field larger than field limit (%ld)", PyErr_Format(error_obj, "field larger than field limit (%ld)",
...@@ -584,7 +586,7 @@ parse_add_char(ReaderObj *self, char c) ...@@ -584,7 +586,7 @@ parse_add_char(ReaderObj *self, char c)
} }
static int static int
parse_process_char(ReaderObj *self, char c) parse_process_char(ReaderObj *self, Py_UNICODE c)
{ {
DialectObj *dialect = self->dialect; DialectObj *dialect = self->dialect;
...@@ -771,8 +773,8 @@ Reader_iternext(ReaderObj *self) ...@@ -771,8 +773,8 @@ Reader_iternext(ReaderObj *self)
{ {
PyObject *lineobj; PyObject *lineobj;
PyObject *fields = NULL; PyObject *fields = NULL;
char *line, c; Py_UNICODE *line, c;
int linelen; Py_ssize_t linelen;
if (parse_reset(self) < 0) if (parse_reset(self) < 0)
return NULL; return NULL;
...@@ -785,11 +787,9 @@ Reader_iternext(ReaderObj *self) ...@@ -785,11 +787,9 @@ Reader_iternext(ReaderObj *self)
"newline inside string"); "newline inside string");
return NULL; return NULL;
} }
++self->line_num; ++self->line_num;
line = PyUnicode_AsUnicode(lineobj);
line = PyString_AsString(lineobj); linelen = PyUnicode_GetSize(lineobj);
linelen = PyString_Size(lineobj);
if (line == NULL || linelen < 0) { if (line == NULL || linelen < 0) {
Py_DECREF(lineobj); Py_DECREF(lineobj);
return NULL; return NULL;
...@@ -962,12 +962,13 @@ join_reset(WriterObj *self) ...@@ -962,12 +962,13 @@ join_reset(WriterObj *self)
* record length. * record length.
*/ */
static int static int
join_append_data(WriterObj *self, char *field, int quote_empty, join_append_data(WriterObj *self, Py_UNICODE *field, int quote_empty,
int *quoted, int copy_phase) int *quoted, int copy_phase)
{ {
DialectObj *dialect = self->dialect; DialectObj *dialect = self->dialect;
int i, rec_len; int i;
char *lineterm; int rec_len;
Py_UNICODE *lineterm;
#define ADDCH(c) \ #define ADDCH(c) \
do {\ do {\
...@@ -976,7 +977,7 @@ join_append_data(WriterObj *self, char *field, int quote_empty, ...@@ -976,7 +977,7 @@ join_append_data(WriterObj *self, char *field, int quote_empty,
rec_len++;\ rec_len++;\
} while(0) } while(0)
lineterm = PyString_AsString(dialect->lineterminator); lineterm = PyUnicode_AsUnicode(dialect->lineterminator);
if (lineterm == NULL) if (lineterm == NULL)
return -1; return -1;
...@@ -991,8 +992,9 @@ join_append_data(WriterObj *self, char *field, int quote_empty, ...@@ -991,8 +992,9 @@ join_append_data(WriterObj *self, char *field, int quote_empty,
ADDCH(dialect->quotechar); ADDCH(dialect->quotechar);
/* Copy/count field data */ /* Copy/count field data */
for (i = 0;; i++) { /* If field is null just pass over */
char c = field[i]; for (i = 0; field; i++) {
Py_UNICODE c = field[i];
int want_escape = 0; int want_escape = 0;
if (c == '\0') if (c == '\0')
...@@ -1000,8 +1002,8 @@ join_append_data(WriterObj *self, char *field, int quote_empty, ...@@ -1000,8 +1002,8 @@ join_append_data(WriterObj *self, char *field, int quote_empty,
if (c == dialect->delimiter || if (c == dialect->delimiter ||
c == dialect->escapechar || c == dialect->escapechar ||
c == dialect->quotechar || c == dialect->quotechar ||
strchr(lineterm, c)) { Py_UNICODE_strchr(lineterm, c)) {
if (dialect->quoting == QUOTE_NONE) if (dialect->quoting == QUOTE_NONE)
want_escape = 1; want_escape = 1;
else { else {
...@@ -1033,7 +1035,7 @@ join_append_data(WriterObj *self, char *field, int quote_empty, ...@@ -1033,7 +1035,7 @@ join_append_data(WriterObj *self, char *field, int quote_empty,
if (i == 0 && quote_empty) { if (i == 0 && quote_empty) {
if (dialect->quoting == QUOTE_NONE) { if (dialect->quoting == QUOTE_NONE) {
PyErr_Format(error_obj, PyErr_Format(error_obj,
"single empty field record must be quoted"); "single empty field record must be quoted");
return -1; return -1;
} }
else else
...@@ -1058,13 +1060,14 @@ join_check_rec_size(WriterObj *self, int rec_len) ...@@ -1058,13 +1060,14 @@ join_check_rec_size(WriterObj *self, int rec_len)
self->rec_size = (rec_len / MEM_INCR + 1) * MEM_INCR; self->rec_size = (rec_len / MEM_INCR + 1) * MEM_INCR;
if (self->rec != NULL) if (self->rec != NULL)
PyMem_Free(self->rec); PyMem_Free(self->rec);
self->rec = PyMem_Malloc(self->rec_size); self->rec = PyMem_New(Py_UNICODE, self->rec_size);
} }
else { else {
char *old_rec = self->rec; Py_UNICODE* old_rec = self->rec;
self->rec_size = (rec_len / MEM_INCR + 1) * MEM_INCR; self->rec_size = (rec_len / MEM_INCR + 1) * MEM_INCR;
self->rec = PyMem_Realloc(self->rec, self->rec_size); self->rec = PyMem_Resize(self->rec, Py_UNICODE,
self->rec_size);
if (self->rec == NULL) if (self->rec == NULL)
PyMem_Free(old_rec); PyMem_Free(old_rec);
} }
...@@ -1077,7 +1080,7 @@ join_check_rec_size(WriterObj *self, int rec_len) ...@@ -1077,7 +1080,7 @@ join_check_rec_size(WriterObj *self, int rec_len)
} }
static int static int
join_append(WriterObj *self, char *field, int *quoted, int quote_empty) join_append(WriterObj *self, Py_UNICODE *field, int *quoted, int quote_empty)
{ {
int rec_len; int rec_len;
...@@ -1099,9 +1102,9 @@ static int ...@@ -1099,9 +1102,9 @@ static int
join_append_lineterminator(WriterObj *self) join_append_lineterminator(WriterObj *self)
{ {
int terminator_len; int terminator_len;
char *terminator; Py_UNICODE *terminator;
terminator_len = PyString_Size(self->dialect->lineterminator); terminator_len = PyUnicode_GetSize(self->dialect->lineterminator);
if (terminator_len == -1) if (terminator_len == -1)
return 0; return 0;
...@@ -1109,10 +1112,11 @@ join_append_lineterminator(WriterObj *self) ...@@ -1109,10 +1112,11 @@ join_append_lineterminator(WriterObj *self)
if (!join_check_rec_size(self, self->rec_len + terminator_len)) if (!join_check_rec_size(self, self->rec_len + terminator_len))
return 0; return 0;
terminator = PyString_AsString(self->dialect->lineterminator); terminator = PyUnicode_AsUnicode(self->dialect->lineterminator);
if (terminator == NULL) if (terminator == NULL)
return 0; return 0;
memmove(self->rec + self->rec_len, terminator, terminator_len); memmove(self->rec + self->rec_len, terminator,
sizeof(Py_UNICODE)*terminator_len);
self->rec_len += terminator_len; self->rec_len += terminator_len;
return 1; return 1;
...@@ -1161,26 +1165,27 @@ csv_writerow(WriterObj *self, PyObject *seq) ...@@ -1161,26 +1165,27 @@ csv_writerow(WriterObj *self, PyObject *seq)
break; break;
} }
if (PyString_Check(field)) { if (PyUnicode_Check(field)) {
append_ok = join_append(self, append_ok = join_append(self,
PyString_AS_STRING(field), PyUnicode_AS_UNICODE(field),
&quoted, len == 1); &quoted, len == 1);
Py_DECREF(field); Py_DECREF(field);
} }
else if (field == Py_None) { else if (field == Py_None) {
append_ok = join_append(self, "", &quoted, len == 1); append_ok = join_append(self, NULL,
&quoted, len == 1);
Py_DECREF(field); Py_DECREF(field);
} }
else { else {
PyObject *str; PyObject *str;
str = PyObject_Str(field); str = PyObject_Unicode(field);
Py_DECREF(field); Py_DECREF(field);
if (str == NULL) if (str == NULL)
return NULL; return NULL;
append_ok = join_append(self,
append_ok = join_append(self, PyString_AS_STRING(str), PyUnicode_AS_UNICODE(str),
&quoted, len == 1); &quoted, len == 1);
Py_DECREF(str); Py_DECREF(str);
} }
if (!append_ok) if (!append_ok)
...@@ -1192,8 +1197,9 @@ csv_writerow(WriterObj *self, PyObject *seq) ...@@ -1192,8 +1197,9 @@ csv_writerow(WriterObj *self, PyObject *seq)
if (!join_append_lineterminator(self)) if (!join_append_lineterminator(self))
return 0; return 0;
return PyObject_CallFunction(self->writeline, return PyObject_CallFunction(self->writeline,
"(s#)", self->rec, self->rec_len); "(u#)", self->rec,
self->rec_len);
} }
PyDoc_STRVAR(csv_writerows_doc, PyDoc_STRVAR(csv_writerows_doc,
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment