Commit 4ac77f29 authored by Serhiy Storchaka's avatar Serhiy Storchaka

Issue #1470548: XMLGenerator now works with UTF-16 and UTF-32 encodings.

parent 554f0081
...@@ -14,6 +14,7 @@ from xml.sax.expatreader import create_parser ...@@ -14,6 +14,7 @@ from xml.sax.expatreader import create_parser
from xml.sax.handler import feature_namespaces from xml.sax.handler import feature_namespaces
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
from cStringIO import StringIO from cStringIO import StringIO
import io
import os.path import os.path
import shutil import shutil
import test.test_support as support import test.test_support as support
...@@ -170,9 +171,9 @@ class SaxutilsTest(unittest.TestCase): ...@@ -170,9 +171,9 @@ class SaxutilsTest(unittest.TestCase):
start = '<?xml version="1.0" encoding="iso-8859-1"?>\n' start = '<?xml version="1.0" encoding="iso-8859-1"?>\n'
class XmlgenTest(unittest.TestCase): class XmlgenTest:
def test_xmlgen_basic(self): def test_xmlgen_basic(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
gen.startElement("doc", {}) gen.startElement("doc", {})
...@@ -182,7 +183,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -182,7 +183,7 @@ class XmlgenTest(unittest.TestCase):
self.assertEqual(result.getvalue(), start + "<doc></doc>") self.assertEqual(result.getvalue(), start + "<doc></doc>")
def test_xmlgen_content(self): def test_xmlgen_content(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -194,7 +195,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -194,7 +195,7 @@ class XmlgenTest(unittest.TestCase):
self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>") self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>")
def test_xmlgen_pi(self): def test_xmlgen_pi(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -206,7 +207,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -206,7 +207,7 @@ class XmlgenTest(unittest.TestCase):
self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>") self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>")
def test_xmlgen_content_escape(self): def test_xmlgen_content_escape(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -219,7 +220,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -219,7 +220,7 @@ class XmlgenTest(unittest.TestCase):
start + "<doc>&lt;huhei&amp;</doc>") start + "<doc>&lt;huhei&amp;</doc>")
def test_xmlgen_attr_escape(self): def test_xmlgen_attr_escape(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -238,8 +239,41 @@ class XmlgenTest(unittest.TestCase): ...@@ -238,8 +239,41 @@ class XmlgenTest(unittest.TestCase):
"<e a=\"'&quot;\"></e>" "<e a=\"'&quot;\"></e>"
"<e a=\"&#10;&#13;&#9;\"></e></doc>")) "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
def test_xmlgen_encoding(self):
encodings = ('iso-8859-15', 'utf-8',
'utf-16be', 'utf-16le',
'utf-32be', 'utf-32le')
for encoding in encodings:
result = self.ioclass()
gen = XMLGenerator(result, encoding=encoding)
gen.startDocument()
gen.startElement("doc", {"a": u'\u20ac'})
gen.characters(u"\u20ac")
gen.endElement("doc")
gen.endDocument()
self.assertEqual(result.getvalue(), (
u'<?xml version="1.0" encoding="%s"?>\n'
u'<doc a="\u20ac">\u20ac</doc>' % encoding
).encode(encoding, 'xmlcharrefreplace'))
def test_xmlgen_unencodable(self):
result = self.ioclass()
gen = XMLGenerator(result, encoding='ascii')
gen.startDocument()
gen.startElement("doc", {"a": u'\u20ac'})
gen.characters(u"\u20ac")
gen.endElement("doc")
gen.endDocument()
self.assertEqual(result.getvalue(),
'<?xml version="1.0" encoding="ascii"?>\n'
'<doc a="&#8364;">&#8364;</doc>')
def test_xmlgen_ignorable(self): def test_xmlgen_ignorable(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -251,7 +285,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -251,7 +285,7 @@ class XmlgenTest(unittest.TestCase):
self.assertEqual(result.getvalue(), start + "<doc> </doc>") self.assertEqual(result.getvalue(), start + "<doc> </doc>")
def test_xmlgen_ns(self): def test_xmlgen_ns(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -269,7 +303,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -269,7 +303,7 @@ class XmlgenTest(unittest.TestCase):
ns_uri)) ns_uri))
def test_1463026_1(self): def test_1463026_1(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -280,7 +314,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -280,7 +314,7 @@ class XmlgenTest(unittest.TestCase):
self.assertEqual(result.getvalue(), start+'<a b="c"></a>') self.assertEqual(result.getvalue(), start+'<a b="c"></a>')
def test_1463026_2(self): def test_1463026_2(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -293,7 +327,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -293,7 +327,7 @@ class XmlgenTest(unittest.TestCase):
self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>') self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>')
def test_1463026_3(self): def test_1463026_3(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -321,7 +355,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -321,7 +355,7 @@ class XmlgenTest(unittest.TestCase):
parser = make_parser() parser = make_parser()
parser.setFeature(feature_namespaces, True) parser.setFeature(feature_namespaces, True)
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
parser.setContentHandler(gen) parser.setContentHandler(gen)
parser.parse(test_xml) parser.parse(test_xml)
...@@ -340,7 +374,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -340,7 +374,7 @@ class XmlgenTest(unittest.TestCase):
# #
# This test demonstrates the bug by direct manipulation of the # This test demonstrates the bug by direct manipulation of the
# XMLGenerator. # XMLGenerator.
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -360,6 +394,29 @@ class XmlgenTest(unittest.TestCase): ...@@ -360,6 +394,29 @@ class XmlgenTest(unittest.TestCase):
'<a:g2 xml:lang="en">Hello</a:g2>' '<a:g2 xml:lang="en">Hello</a:g2>'
'</a:g1>')) '</a:g1>'))
def test_no_close_file(self):
result = self.ioclass()
def func(out):
gen = XMLGenerator(out)
gen.startDocument()
gen.startElement("doc", {})
func(result)
self.assertFalse(result.closed)
class StringXmlgenTest(XmlgenTest, unittest.TestCase):
ioclass = StringIO
class BytesIOXmlgenTest(XmlgenTest, unittest.TestCase):
ioclass = io.BytesIO
class WriterXmlgenTest(XmlgenTest, unittest.TestCase):
class ioclass(list):
write = list.append
closed = False
def getvalue(self):
return b''.join(self)
class XMLFilterBaseTest(unittest.TestCase): class XMLFilterBaseTest(unittest.TestCase):
def test_filter_basic(self): def test_filter_basic(self):
...@@ -804,7 +861,9 @@ class XmlReaderTest(XmlTestBase): ...@@ -804,7 +861,9 @@ class XmlReaderTest(XmlTestBase):
def test_main(): def test_main():
run_unittest(MakeParserTest, run_unittest(MakeParserTest,
SaxutilsTest, SaxutilsTest,
XmlgenTest, StringXmlgenTest,
BytesIOXmlgenTest,
WriterXmlgenTest,
ExpatReaderTest, ExpatReaderTest,
ErrorReportingTest, ErrorReportingTest,
XmlReaderTest) XmlReaderTest)
......
...@@ -4,6 +4,7 @@ convenience of application and driver writers. ...@@ -4,6 +4,7 @@ convenience of application and driver writers.
""" """
import os, urlparse, urllib, types import os, urlparse, urllib, types
import io
import sys import sys
import handler import handler
import xmlreader import xmlreader
...@@ -13,15 +14,6 @@ try: ...@@ -13,15 +14,6 @@ try:
except AttributeError: except AttributeError:
_StringTypes = [types.StringType] _StringTypes = [types.StringType]
# See whether the xmlcharrefreplace error handler is
# supported
try:
from codecs import xmlcharrefreplace_errors
_error_handling = "xmlcharrefreplace"
del xmlcharrefreplace_errors
except ImportError:
_error_handling = "strict"
def __dict_replace(s, d): def __dict_replace(s, d):
"""Replace substrings of a string using a dictionary.""" """Replace substrings of a string using a dictionary."""
for key, value in d.items(): for key, value in d.items():
...@@ -82,25 +74,46 @@ def quoteattr(data, entities={}): ...@@ -82,25 +74,46 @@ def quoteattr(data, entities={}):
return data return data
def _gettextwriter(out, encoding):
if out is None:
import sys
out = sys.stdout
if isinstance(out, io.RawIOBase):
buffer = io.BufferedIOBase(out)
# Keep the original file open when the TextIOWrapper is
# destroyed
buffer.close = lambda: None
else:
# This is to handle passed objects that aren't in the
# IOBase hierarchy, but just have a write method
buffer = io.BufferedIOBase()
buffer.writable = lambda: True
buffer.write = out.write
try:
# TextIOWrapper uses this methods to determine
# if BOM (for UTF-16, etc) should be added
buffer.seekable = out.seekable
buffer.tell = out.tell
except AttributeError:
pass
# wrap a binary writer with TextIOWrapper
return io.TextIOWrapper(buffer, encoding=encoding,
errors='xmlcharrefreplace',
newline='\n')
class XMLGenerator(handler.ContentHandler): class XMLGenerator(handler.ContentHandler):
def __init__(self, out=None, encoding="iso-8859-1"): def __init__(self, out=None, encoding="iso-8859-1"):
if out is None:
import sys
out = sys.stdout
handler.ContentHandler.__init__(self) handler.ContentHandler.__init__(self)
self._out = out out = _gettextwriter(out, encoding)
self._write = out.write
self._flush = out.flush
self._ns_contexts = [{}] # contains uri -> prefix dicts self._ns_contexts = [{}] # contains uri -> prefix dicts
self._current_context = self._ns_contexts[-1] self._current_context = self._ns_contexts[-1]
self._undeclared_ns_maps = [] self._undeclared_ns_maps = []
self._encoding = encoding self._encoding = encoding
def _write(self, text):
if isinstance(text, str):
self._out.write(text)
else:
self._out.write(text.encode(self._encoding, _error_handling))
def _qname(self, name): def _qname(self, name):
"""Builds a qualified name from a (ns_url, localname) pair""" """Builds a qualified name from a (ns_url, localname) pair"""
if name[0]: if name[0]:
...@@ -121,9 +134,12 @@ class XMLGenerator(handler.ContentHandler): ...@@ -121,9 +134,12 @@ class XMLGenerator(handler.ContentHandler):
# ContentHandler methods # ContentHandler methods
def startDocument(self): def startDocument(self):
self._write('<?xml version="1.0" encoding="%s"?>\n' % self._write(u'<?xml version="1.0" encoding="%s"?>\n' %
self._encoding) self._encoding)
def endDocument(self):
self._flush()
def startPrefixMapping(self, prefix, uri): def startPrefixMapping(self, prefix, uri):
self._ns_contexts.append(self._current_context.copy()) self._ns_contexts.append(self._current_context.copy())
self._current_context[uri] = prefix self._current_context[uri] = prefix
...@@ -134,39 +150,39 @@ class XMLGenerator(handler.ContentHandler): ...@@ -134,39 +150,39 @@ class XMLGenerator(handler.ContentHandler):
del self._ns_contexts[-1] del self._ns_contexts[-1]
def startElement(self, name, attrs): def startElement(self, name, attrs):
self._write('<' + name) self._write(u'<' + name)
for (name, value) in attrs.items(): for (name, value) in attrs.items():
self._write(' %s=%s' % (name, quoteattr(value))) self._write(u' %s=%s' % (name, quoteattr(value)))
self._write('>') self._write(u'>')
def endElement(self, name): def endElement(self, name):
self._write('</%s>' % name) self._write(u'</%s>' % name)
def startElementNS(self, name, qname, attrs): def startElementNS(self, name, qname, attrs):
self._write('<' + self._qname(name)) self._write(u'<' + self._qname(name))
for prefix, uri in self._undeclared_ns_maps: for prefix, uri in self._undeclared_ns_maps:
if prefix: if prefix:
self._out.write(' xmlns:%s="%s"' % (prefix, uri)) self._write(u' xmlns:%s="%s"' % (prefix, uri))
else: else:
self._out.write(' xmlns="%s"' % uri) self._write(u' xmlns="%s"' % uri)
self._undeclared_ns_maps = [] self._undeclared_ns_maps = []
for (name, value) in attrs.items(): for (name, value) in attrs.items():
self._write(' %s=%s' % (self._qname(name), quoteattr(value))) self._write(u' %s=%s' % (self._qname(name), quoteattr(value)))
self._write('>') self._write(u'>')
def endElementNS(self, name, qname): def endElementNS(self, name, qname):
self._write('</%s>' % self._qname(name)) self._write(u'</%s>' % self._qname(name))
def characters(self, content): def characters(self, content):
self._write(escape(content)) self._write(escape(unicode(content)))
def ignorableWhitespace(self, content): def ignorableWhitespace(self, content):
self._write(content) self._write(unicode(content))
def processingInstruction(self, target, data): def processingInstruction(self, target, data):
self._write('<?%s %s?>' % (target, data)) self._write(u'<?%s %s?>' % (target, data))
class XMLFilterBase(xmlreader.XMLReader): class XMLFilterBase(xmlreader.XMLReader):
......
...@@ -202,6 +202,8 @@ Core and Builtins ...@@ -202,6 +202,8 @@ Core and Builtins
Library Library
------- -------
- Issue #1470548: XMLGenerator now works with UTF-16 and UTF-32 encodings.
- Issue #6975: os.path.realpath() now correctly resolves multiple nested - Issue #6975: os.path.realpath() now correctly resolves multiple nested
symlinks on POSIX platforms. symlinks on POSIX platforms.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment