Commit 9fef188c authored by Serhiy Storchaka's avatar Serhiy Storchaka

Issue #1470548: XMLGenerator now works with binary output streams.

parents d83c8244 88efc52d
...@@ -13,7 +13,7 @@ from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \ ...@@ -13,7 +13,7 @@ from xml.sax.saxutils import XMLGenerator, escape, unescape, quoteattr, \
from xml.sax.expatreader import create_parser from xml.sax.expatreader import create_parser
from xml.sax.handler import feature_namespaces from xml.sax.handler import feature_namespaces
from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl from xml.sax.xmlreader import InputSource, AttributesImpl, AttributesNSImpl
from io import StringIO from io import BytesIO, StringIO
import os.path import os.path
import shutil import shutil
from test import support from test import support
...@@ -173,31 +173,29 @@ class SaxutilsTest(unittest.TestCase): ...@@ -173,31 +173,29 @@ class SaxutilsTest(unittest.TestCase):
# ===== XMLGenerator # ===== XMLGenerator
start = '<?xml version="1.0" encoding="iso-8859-1"?>\n' class XmlgenTest:
class XmlgenTest(unittest.TestCase):
def test_xmlgen_basic(self): def test_xmlgen_basic(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
gen.startElement("doc", {}) gen.startElement("doc", {})
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc></doc>") self.assertEqual(result.getvalue(), self.xml("<doc></doc>"))
def test_xmlgen_basic_empty(self): def test_xmlgen_basic_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
gen.startElement("doc", {}) gen.startElement("doc", {})
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc/>") self.assertEqual(result.getvalue(), self.xml("<doc/>"))
def test_xmlgen_content(self): def test_xmlgen_content(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -206,10 +204,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -206,10 +204,10 @@ class XmlgenTest(unittest.TestCase):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>") self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
def test_xmlgen_content_empty(self): def test_xmlgen_content_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
...@@ -218,10 +216,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -218,10 +216,10 @@ class XmlgenTest(unittest.TestCase):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc>huhei</doc>") self.assertEqual(result.getvalue(), self.xml("<doc>huhei</doc>"))
def test_xmlgen_pi(self): def test_xmlgen_pi(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -230,10 +228,11 @@ class XmlgenTest(unittest.TestCase): ...@@ -230,10 +228,11 @@ class XmlgenTest(unittest.TestCase):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<?test data?><doc></doc>") self.assertEqual(result.getvalue(),
self.xml("<?test data?><doc></doc>"))
def test_xmlgen_content_escape(self): def test_xmlgen_content_escape(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -243,10 +242,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -243,10 +242,10 @@ class XmlgenTest(unittest.TestCase):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start + "<doc>&lt;huhei&amp;</doc>") self.xml("<doc>&lt;huhei&amp;</doc>"))
def test_xmlgen_attr_escape(self): def test_xmlgen_attr_escape(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -260,13 +259,43 @@ class XmlgenTest(unittest.TestCase): ...@@ -260,13 +259,43 @@ class XmlgenTest(unittest.TestCase):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + self.assertEqual(result.getvalue(), self.xml(
("<doc a='\"'><e a=\"'\"></e>" "<doc a='\"'><e a=\"'\"></e>"
"<e a=\"'&quot;\"></e>" "<e a=\"'&quot;\"></e>"
"<e a=\"&#10;&#13;&#9;\"></e></doc>")) "<e a=\"&#10;&#13;&#9;\"></e></doc>"))
def test_xmlgen_encoding(self):
encodings = ('iso-8859-15', 'utf-8', 'utf-8-sig',
'utf-16', 'utf-16be', 'utf-16le',
'utf-32', 'utf-32be', 'utf-32le')
for encoding in encodings:
result = self.ioclass()
gen = XMLGenerator(result, encoding=encoding)
gen.startDocument()
gen.startElement("doc", {"a": '\u20ac'})
gen.characters("\u20ac")
gen.endElement("doc")
gen.endDocument()
self.assertEqual(result.getvalue(),
self.xml('<doc a="\u20ac">\u20ac</doc>', encoding=encoding))
def test_xmlgen_unencodable(self):
result = self.ioclass()
gen = XMLGenerator(result, encoding='ascii')
gen.startDocument()
gen.startElement("doc", {"a": '\u20ac'})
gen.characters("\u20ac")
gen.endElement("doc")
gen.endDocument()
self.assertEqual(result.getvalue(),
self.xml('<doc a="&#8364;">&#8364;</doc>', encoding='ascii'))
def test_xmlgen_ignorable(self): def test_xmlgen_ignorable(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -275,10 +304,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -275,10 +304,10 @@ class XmlgenTest(unittest.TestCase):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc> </doc>") self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
def test_xmlgen_ignorable_empty(self): def test_xmlgen_ignorable_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
...@@ -287,10 +316,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -287,10 +316,10 @@ class XmlgenTest(unittest.TestCase):
gen.endElement("doc") gen.endElement("doc")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + "<doc> </doc>") self.assertEqual(result.getvalue(), self.xml("<doc> </doc>"))
def test_xmlgen_ns(self): def test_xmlgen_ns(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -303,12 +332,12 @@ class XmlgenTest(unittest.TestCase): ...@@ -303,12 +332,12 @@ class XmlgenTest(unittest.TestCase):
gen.endPrefixMapping("ns1") gen.endPrefixMapping("ns1")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + \ self.assertEqual(result.getvalue(), self.xml(
('<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' % '<ns1:doc xmlns:ns1="%s"><udoc></udoc></ns1:doc>' %
ns_uri)) ns_uri))
def test_xmlgen_ns_empty(self): def test_xmlgen_ns_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
...@@ -321,12 +350,12 @@ class XmlgenTest(unittest.TestCase): ...@@ -321,12 +350,12 @@ class XmlgenTest(unittest.TestCase):
gen.endPrefixMapping("ns1") gen.endPrefixMapping("ns1")
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start + \ self.assertEqual(result.getvalue(), self.xml(
('<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' % '<ns1:doc xmlns:ns1="%s"><udoc/></ns1:doc>' %
ns_uri)) ns_uri))
def test_1463026_1(self): def test_1463026_1(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -334,10 +363,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -334,10 +363,10 @@ class XmlgenTest(unittest.TestCase):
gen.endElementNS((None, 'a'), 'a') gen.endElementNS((None, 'a'), 'a')
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a b="c"></a>') self.assertEqual(result.getvalue(), self.xml('<a b="c"></a>'))
def test_1463026_1_empty(self): def test_1463026_1_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
...@@ -345,10 +374,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -345,10 +374,10 @@ class XmlgenTest(unittest.TestCase):
gen.endElementNS((None, 'a'), 'a') gen.endElementNS((None, 'a'), 'a')
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a b="c"/>') self.assertEqual(result.getvalue(), self.xml('<a b="c"/>'))
def test_1463026_2(self): def test_1463026_2(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -358,10 +387,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -358,10 +387,10 @@ class XmlgenTest(unittest.TestCase):
gen.endPrefixMapping(None) gen.endPrefixMapping(None)
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a xmlns="qux"></a>') self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"></a>'))
def test_1463026_2_empty(self): def test_1463026_2_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
...@@ -371,10 +400,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -371,10 +400,10 @@ class XmlgenTest(unittest.TestCase):
gen.endPrefixMapping(None) gen.endPrefixMapping(None)
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), start+'<a xmlns="qux"/>') self.assertEqual(result.getvalue(), self.xml('<a xmlns="qux"/>'))
def test_1463026_3(self): def test_1463026_3(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -385,10 +414,10 @@ class XmlgenTest(unittest.TestCase): ...@@ -385,10 +414,10 @@ class XmlgenTest(unittest.TestCase):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start+'<my:a xmlns:my="qux" b="c"></my:a>') self.xml('<my:a xmlns:my="qux" b="c"></my:a>'))
def test_1463026_3_empty(self): def test_1463026_3_empty(self):
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result, short_empty_elements=True) gen = XMLGenerator(result, short_empty_elements=True)
gen.startDocument() gen.startDocument()
...@@ -399,7 +428,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -399,7 +428,7 @@ class XmlgenTest(unittest.TestCase):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start+'<my:a xmlns:my="qux" b="c"/>') self.xml('<my:a xmlns:my="qux" b="c"/>'))
def test_5027_1(self): def test_5027_1(self):
# The xml prefix (as in xml:lang below) is reserved and bound by # The xml prefix (as in xml:lang below) is reserved and bound by
...@@ -416,13 +445,13 @@ class XmlgenTest(unittest.TestCase): ...@@ -416,13 +445,13 @@ class XmlgenTest(unittest.TestCase):
parser = make_parser() parser = make_parser()
parser.setFeature(feature_namespaces, True) parser.setFeature(feature_namespaces, True)
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
parser.setContentHandler(gen) parser.setContentHandler(gen)
parser.parse(test_xml) parser.parse(test_xml)
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start + ( self.xml(
'<a:g1 xmlns:a="http://example.com/ns">' '<a:g1 xmlns:a="http://example.com/ns">'
'<a:g2 xml:lang="en">Hello</a:g2>' '<a:g2 xml:lang="en">Hello</a:g2>'
'</a:g1>')) '</a:g1>'))
...@@ -435,7 +464,7 @@ class XmlgenTest(unittest.TestCase): ...@@ -435,7 +464,7 @@ class XmlgenTest(unittest.TestCase):
# #
# This test demonstrates the bug by direct manipulation of the # This test demonstrates the bug by direct manipulation of the
# XMLGenerator. # XMLGenerator.
result = StringIO() result = self.ioclass()
gen = XMLGenerator(result) gen = XMLGenerator(result)
gen.startDocument() gen.startDocument()
...@@ -450,15 +479,57 @@ class XmlgenTest(unittest.TestCase): ...@@ -450,15 +479,57 @@ class XmlgenTest(unittest.TestCase):
gen.endDocument() gen.endDocument()
self.assertEqual(result.getvalue(), self.assertEqual(result.getvalue(),
start + ( self.xml(
'<a:g1 xmlns:a="http://example.com/ns">' '<a:g1 xmlns:a="http://example.com/ns">'
'<a:g2 xml:lang="en">Hello</a:g2>' '<a:g2 xml:lang="en">Hello</a:g2>'
'</a:g1>')) '</a:g1>'))
def test_no_close_file(self):
result = self.ioclass()
def func(out):
gen = XMLGenerator(out)
gen.startDocument()
gen.startElement("doc", {})
func(result)
self.assertFalse(result.closed)
class StringXmlgenTest(XmlgenTest, unittest.TestCase):
ioclass = StringIO
def xml(self, doc, encoding='iso-8859-1'):
return '<?xml version="1.0" encoding="%s"?>\n%s' % (encoding, doc)
test_xmlgen_unencodable = None
class BytesXmlgenTest(XmlgenTest, unittest.TestCase):
ioclass = BytesIO
def xml(self, doc, encoding='iso-8859-1'):
return ('<?xml version="1.0" encoding="%s"?>\n%s' %
(encoding, doc)).encode(encoding, 'xmlcharrefreplace')
class WriterXmlgenTest(BytesXmlgenTest):
class ioclass(list):
write = list.append
closed = False
def seekable(self):
return True
def tell(self):
# return 0 at start and not 0 after start
return len(self)
def getvalue(self):
return b''.join(self)
start = b'<?xml version="1.0" encoding="iso-8859-1"?>\n'
class XMLFilterBaseTest(unittest.TestCase): class XMLFilterBaseTest(unittest.TestCase):
def test_filter_basic(self): def test_filter_basic(self):
result = StringIO() result = BytesIO()
gen = XMLGenerator(result) gen = XMLGenerator(result)
filter = XMLFilterBase() filter = XMLFilterBase()
filter.setContentHandler(gen) filter.setContentHandler(gen)
...@@ -470,7 +541,7 @@ class XMLFilterBaseTest(unittest.TestCase): ...@@ -470,7 +541,7 @@ class XMLFilterBaseTest(unittest.TestCase):
filter.endElement("doc") filter.endElement("doc")
filter.endDocument() filter.endDocument()
self.assertEqual(result.getvalue(), start + "<doc>content </doc>") self.assertEqual(result.getvalue(), start + b"<doc>content </doc>")
# =========================================================================== # ===========================================================================
# #
...@@ -478,7 +549,7 @@ class XMLFilterBaseTest(unittest.TestCase): ...@@ -478,7 +549,7 @@ class XMLFilterBaseTest(unittest.TestCase):
# #
# =========================================================================== # ===========================================================================
with open(TEST_XMLFILE_OUT) as f: with open(TEST_XMLFILE_OUT, 'rb') as f:
xml_test_out = f.read() xml_test_out = f.read()
class ExpatReaderTest(XmlTestBase): class ExpatReaderTest(XmlTestBase):
...@@ -487,11 +558,11 @@ class ExpatReaderTest(XmlTestBase): ...@@ -487,11 +558,11 @@ class ExpatReaderTest(XmlTestBase):
def test_expat_file(self): def test_expat_file(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
with open(TEST_XMLFILE) as f: with open(TEST_XMLFILE, 'rb') as f:
parser.parse(f) parser.parse(f)
self.assertEqual(result.getvalue(), xml_test_out) self.assertEqual(result.getvalue(), xml_test_out)
...@@ -503,7 +574,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -503,7 +574,7 @@ class ExpatReaderTest(XmlTestBase):
self.addCleanup(support.unlink, fname) self.addCleanup(support.unlink, fname)
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -547,13 +618,13 @@ class ExpatReaderTest(XmlTestBase): ...@@ -547,13 +618,13 @@ class ExpatReaderTest(XmlTestBase):
def resolveEntity(self, publicId, systemId): def resolveEntity(self, publicId, systemId):
inpsrc = InputSource() inpsrc = InputSource()
inpsrc.setByteStream(StringIO("<entity/>")) inpsrc.setByteStream(BytesIO(b"<entity/>"))
return inpsrc return inpsrc
def test_expat_entityresolver(self): def test_expat_entityresolver(self):
parser = create_parser() parser = create_parser()
parser.setEntityResolver(self.TestEntityResolver()) parser.setEntityResolver(self.TestEntityResolver())
result = StringIO() result = BytesIO()
parser.setContentHandler(XMLGenerator(result)) parser.setContentHandler(XMLGenerator(result))
parser.feed('<!DOCTYPE doc [\n') parser.feed('<!DOCTYPE doc [\n')
...@@ -563,7 +634,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -563,7 +634,7 @@ class ExpatReaderTest(XmlTestBase):
parser.close() parser.close()
self.assertEqual(result.getvalue(), start + self.assertEqual(result.getvalue(), start +
"<doc><entity></entity></doc>") b"<doc><entity></entity></doc>")
# ===== Attributes support # ===== Attributes support
...@@ -632,7 +703,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -632,7 +703,7 @@ class ExpatReaderTest(XmlTestBase):
def test_expat_inpsource_filename(self): def test_expat_inpsource_filename(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -642,7 +713,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -642,7 +713,7 @@ class ExpatReaderTest(XmlTestBase):
def test_expat_inpsource_sysid(self): def test_expat_inpsource_sysid(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -657,7 +728,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -657,7 +728,7 @@ class ExpatReaderTest(XmlTestBase):
self.addCleanup(support.unlink, fname) self.addCleanup(support.unlink, fname)
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -667,12 +738,12 @@ class ExpatReaderTest(XmlTestBase): ...@@ -667,12 +738,12 @@ class ExpatReaderTest(XmlTestBase):
def test_expat_inpsource_stream(self): def test_expat_inpsource_stream(self):
parser = create_parser() parser = create_parser()
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
inpsrc = InputSource() inpsrc = InputSource()
with open(TEST_XMLFILE) as f: with open(TEST_XMLFILE, 'rb') as f:
inpsrc.setByteStream(f) inpsrc.setByteStream(f)
parser.parse(inpsrc) parser.parse(inpsrc)
...@@ -681,7 +752,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -681,7 +752,7 @@ class ExpatReaderTest(XmlTestBase):
# ===== IncrementalParser support # ===== IncrementalParser support
def test_expat_incremental(self): def test_expat_incremental(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -690,10 +761,10 @@ class ExpatReaderTest(XmlTestBase): ...@@ -690,10 +761,10 @@ class ExpatReaderTest(XmlTestBase):
parser.feed("</doc>") parser.feed("</doc>")
parser.close() parser.close()
self.assertEqual(result.getvalue(), start + "<doc></doc>") self.assertEqual(result.getvalue(), start + b"<doc></doc>")
def test_expat_incremental_reset(self): def test_expat_incremental_reset(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -701,7 +772,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -701,7 +772,7 @@ class ExpatReaderTest(XmlTestBase):
parser.feed("<doc>") parser.feed("<doc>")
parser.feed("text") parser.feed("text")
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
parser.reset() parser.reset()
...@@ -711,12 +782,12 @@ class ExpatReaderTest(XmlTestBase): ...@@ -711,12 +782,12 @@ class ExpatReaderTest(XmlTestBase):
parser.feed("</doc>") parser.feed("</doc>")
parser.close() parser.close()
self.assertEqual(result.getvalue(), start + "<doc>text</doc>") self.assertEqual(result.getvalue(), start + b"<doc>text</doc>")
# ===== Locator support # ===== Locator support
def test_expat_locator_noinfo(self): def test_expat_locator_noinfo(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -730,7 +801,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -730,7 +801,7 @@ class ExpatReaderTest(XmlTestBase):
self.assertEqual(parser.getLineNumber(), 1) self.assertEqual(parser.getLineNumber(), 1)
def test_expat_locator_withinfo(self): def test_expat_locator_withinfo(self):
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -745,7 +816,7 @@ class ExpatReaderTest(XmlTestBase): ...@@ -745,7 +816,7 @@ class ExpatReaderTest(XmlTestBase):
shutil.copyfile(TEST_XMLFILE, fname) shutil.copyfile(TEST_XMLFILE, fname)
self.addCleanup(support.unlink, fname) self.addCleanup(support.unlink, fname)
result = StringIO() result = BytesIO()
xmlgen = XMLGenerator(result) xmlgen = XMLGenerator(result)
parser = create_parser() parser = create_parser()
parser.setContentHandler(xmlgen) parser.setContentHandler(xmlgen)
...@@ -766,7 +837,7 @@ class ErrorReportingTest(unittest.TestCase): ...@@ -766,7 +837,7 @@ class ErrorReportingTest(unittest.TestCase):
parser = create_parser() parser = create_parser()
parser.setContentHandler(ContentHandler()) # do nothing parser.setContentHandler(ContentHandler()) # do nothing
source = InputSource() source = InputSource()
source.setByteStream(StringIO("<foo bar foobar>")) #ill-formed source.setByteStream(BytesIO(b"<foo bar foobar>")) #ill-formed
name = "a file name" name = "a file name"
source.setSystemId(name) source.setSystemId(name)
try: try:
...@@ -857,7 +928,9 @@ class XmlReaderTest(XmlTestBase): ...@@ -857,7 +928,9 @@ class XmlReaderTest(XmlTestBase):
def test_main(): def test_main():
run_unittest(MakeParserTest, run_unittest(MakeParserTest,
SaxutilsTest, SaxutilsTest,
XmlgenTest, StringXmlgenTest,
BytesXmlgenTest,
WriterXmlgenTest,
ExpatReaderTest, ExpatReaderTest,
ErrorReportingTest, ErrorReportingTest,
XmlReaderTest) XmlReaderTest)
......
...@@ -4,18 +4,10 @@ convenience of application and driver writers. ...@@ -4,18 +4,10 @@ convenience of application and driver writers.
""" """
import os, urllib.parse, urllib.request import os, urllib.parse, urllib.request
import io
from . import handler from . import handler
from . import xmlreader from . import xmlreader
# See whether the xmlcharrefreplace error handler is
# supported
try:
from codecs import xmlcharrefreplace_errors
_error_handling = "xmlcharrefreplace"
del xmlcharrefreplace_errors
except ImportError:
_error_handling = "strict"
def __dict_replace(s, d): def __dict_replace(s, d):
"""Replace substrings of a string using a dictionary.""" """Replace substrings of a string using a dictionary."""
for key, value in d.items(): for key, value in d.items():
...@@ -76,14 +68,50 @@ def quoteattr(data, entities={}): ...@@ -76,14 +68,50 @@ def quoteattr(data, entities={}):
return data return data
def _gettextwriter(out, encoding):
if out is None:
import sys
return sys.stdout
if isinstance(out, io.TextIOBase):
# use a text writer as is
return out
# wrap a binary writer with TextIOWrapper
if isinstance(out, io.RawIOBase):
# Keep the original file open when the TextIOWrapper is
# destroyed
class _wrapper:
__class__ = out.__class__
def __getattr__(self, name):
return getattr(out, name)
buffer = _wrapper()
buffer.close = lambda: None
else:
# This is to handle passed objects that aren't in the
# IOBase hierarchy, but just have a write method
buffer = io.BufferedIOBase()
buffer.writable = lambda: True
buffer.write = out.write
try:
# TextIOWrapper uses this methods to determine
# if BOM (for UTF-16, etc) should be added
buffer.seekable = out.seekable
buffer.tell = out.tell
except AttributeError:
pass
return io.TextIOWrapper(buffer, encoding=encoding,
errors='xmlcharrefreplace',
newline='\n',
write_through=True)
class XMLGenerator(handler.ContentHandler): class XMLGenerator(handler.ContentHandler):
def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False): def __init__(self, out=None, encoding="iso-8859-1", short_empty_elements=False):
if out is None:
import sys
out = sys.stdout
handler.ContentHandler.__init__(self) handler.ContentHandler.__init__(self)
self._out = out out = _gettextwriter(out, encoding)
self._write = out.write
self._flush = out.flush
self._ns_contexts = [{}] # contains uri -> prefix dicts self._ns_contexts = [{}] # contains uri -> prefix dicts
self._current_context = self._ns_contexts[-1] self._current_context = self._ns_contexts[-1]
self._undeclared_ns_maps = [] self._undeclared_ns_maps = []
...@@ -91,12 +119,6 @@ class XMLGenerator(handler.ContentHandler): ...@@ -91,12 +119,6 @@ class XMLGenerator(handler.ContentHandler):
self._short_empty_elements = short_empty_elements self._short_empty_elements = short_empty_elements
self._pending_start_element = False self._pending_start_element = False
def _write(self, text):
if isinstance(text, str):
self._out.write(text)
else:
self._out.write(text.encode(self._encoding, _error_handling))
def _qname(self, name): def _qname(self, name):
"""Builds a qualified name from a (ns_url, localname) pair""" """Builds a qualified name from a (ns_url, localname) pair"""
if name[0]: if name[0]:
...@@ -125,6 +147,9 @@ class XMLGenerator(handler.ContentHandler): ...@@ -125,6 +147,9 @@ class XMLGenerator(handler.ContentHandler):
self._write('<?xml version="1.0" encoding="%s"?>\n' % self._write('<?xml version="1.0" encoding="%s"?>\n' %
self._encoding) self._encoding)
def endDocument(self):
self._flush()
def startPrefixMapping(self, prefix, uri): def startPrefixMapping(self, prefix, uri):
self._ns_contexts.append(self._current_context.copy()) self._ns_contexts.append(self._current_context.copy())
self._current_context[uri] = prefix self._current_context[uri] = prefix
...@@ -157,9 +182,9 @@ class XMLGenerator(handler.ContentHandler): ...@@ -157,9 +182,9 @@ class XMLGenerator(handler.ContentHandler):
for prefix, uri in self._undeclared_ns_maps: for prefix, uri in self._undeclared_ns_maps:
if prefix: if prefix:
self._out.write(' xmlns:%s="%s"' % (prefix, uri)) self._write(' xmlns:%s="%s"' % (prefix, uri))
else: else:
self._out.write(' xmlns="%s"' % uri) self._write(' xmlns="%s"' % uri)
self._undeclared_ns_maps = [] self._undeclared_ns_maps = []
for (name, value) in attrs.items(): for (name, value) in attrs.items():
......
...@@ -172,6 +172,8 @@ Core and Builtins ...@@ -172,6 +172,8 @@ Core and Builtins
Library Library
------- -------
- Issue #1470548: XMLGenerator now works with binary output streams.
- Issue #6975: os.path.realpath() now correctly resolves multiple nested - Issue #6975: os.path.realpath() now correctly resolves multiple nested
symlinks on POSIX platforms. symlinks on POSIX platforms.
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment