Commit 3a4fbd82 authored by Eli Bendersky's avatar Eli Bendersky

_elementtree.XMLParser._setevents should support any sequence, not just tuples

Also clean up some code around this
parent 5b6616de
...@@ -979,6 +979,21 @@ class IncrementalParserTest(unittest.TestCase): ...@@ -979,6 +979,21 @@ class IncrementalParserTest(unittest.TestCase):
parser.eof_received() parser.eof_received()
self.assertEqual(parser.root.tag, '{namespace}root') self.assertEqual(parser.root.tag, '{namespace}root')
def test_ns_events(self):
parser = ET.IncrementalParser(events=('start-ns', 'end-ns'))
self._feed(parser, "<!-- comment -->\n")
self._feed(parser, "<root xmlns='namespace'>\n")
self.assertEqual(
list(parser.events()),
[('start-ns', ('', 'namespace'))])
self._feed(parser, "<element key='value'>text</element")
self._feed(parser, ">\n")
self._feed(parser, "<element>text</element>tail\n")
self._feed(parser, "<empty-element/>\n")
self._feed(parser, "</root>\n")
self.assertEqual(list(parser.events()), [('end-ns', None)])
parser.eof_received()
def test_events(self): def test_events(self):
parser = ET.IncrementalParser(events=()) parser = ET.IncrementalParser(events=())
self._feed(parser, "<root/>\n") self._feed(parser, "<root/>\n")
...@@ -1026,6 +1041,26 @@ class IncrementalParserTest(unittest.TestCase): ...@@ -1026,6 +1041,26 @@ class IncrementalParserTest(unittest.TestCase):
parser.eof_received() parser.eof_received()
self.assertEqual(parser.root.tag, 'root') self.assertEqual(parser.root.tag, 'root')
def test_events_sequence(self):
# Test that events can be some sequence that's not just a tuple or list
eventset = {'end', 'start'}
parser = ET.IncrementalParser(events=eventset)
self._feed(parser, "<foo>bar</foo>")
self.assert_event_tags(parser, [('start', 'foo'), ('end', 'foo')])
class DummyIter:
def __init__(self):
self.events = iter(['start', 'end', 'start-ns'])
def __iter__(self):
return self
def __next__(self):
return next(self.events)
parser = ET.IncrementalParser(events=DummyIter())
self._feed(parser, "<foo>bar</foo>")
self.assert_event_tags(parser, [('start', 'foo'), ('end', 'foo')])
def test_unknown_event(self): def test_unknown_event(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
ET.IncrementalParser(events=('start', 'end', 'bogus')) ET.IncrementalParser(events=('start', 'end', 'bogus'))
......
...@@ -1498,33 +1498,38 @@ class XMLParser: ...@@ -1498,33 +1498,38 @@ class XMLParser:
except AttributeError: except AttributeError:
pass # unknown pass # unknown
def _setevents(self, event_list, events): def _setevents(self, events_queue, events_to_report):
# Internal API for IncrementalParser # Internal API for IncrementalParser
# events_to_report: a list of events to report during parsing (same as
# the *events* of IncrementalParser's constructor.
# events_queue: a list of actual parsing events that will be populated
# by the underlying parser.
#
parser = self._parser parser = self._parser
append = event_list.append append = events_queue.append
for event in events: for event_name in events_to_report:
if event == "start": if event_name == "start":
parser.ordered_attributes = 1 parser.ordered_attributes = 1
parser.specified_attributes = 1 parser.specified_attributes = 1
def handler(tag, attrib_in, event=event, append=append, def handler(tag, attrib_in, event=event_name, append=append,
start=self._start_list): start=self._start_list):
append((event, start(tag, attrib_in))) append((event, start(tag, attrib_in)))
parser.StartElementHandler = handler parser.StartElementHandler = handler
elif event == "end": elif event_name == "end":
def handler(tag, event=event, append=append, def handler(tag, event=event_name, append=append,
end=self._end): end=self._end):
append((event, end(tag))) append((event, end(tag)))
parser.EndElementHandler = handler parser.EndElementHandler = handler
elif event == "start-ns": elif event_name == "start-ns":
def handler(prefix, uri, event=event, append=append): def handler(prefix, uri, event=event_name, append=append):
append((event, (prefix or "", uri or ""))) append((event, (prefix or "", uri or "")))
parser.StartNamespaceDeclHandler = handler parser.StartNamespaceDeclHandler = handler
elif event == "end-ns": elif event_name == "end-ns":
def handler(prefix, event=event, append=append): def handler(prefix, event=event_name, append=append):
append((event, None)) append((event, None))
parser.EndNamespaceDeclHandler = handler parser.EndNamespaceDeclHandler = handler
else: else:
raise ValueError("unknown event %r" % event) raise ValueError("unknown event %r" % event_name)
def _raiseerror(self, value): def _raiseerror(self, value):
err = ParseError(value) err = ParseError(value)
......
...@@ -3431,14 +3431,14 @@ static PyObject* ...@@ -3431,14 +3431,14 @@ static PyObject*
xmlparser_setevents(XMLParserObject *self, PyObject* args) xmlparser_setevents(XMLParserObject *self, PyObject* args)
{ {
/* activate element event reporting */ /* activate element event reporting */
Py_ssize_t i, seqlen;
TreeBuilderObject *target;
Py_ssize_t i; PyObject *events_queue;
TreeBuilderObject* target; PyObject *events_to_report = Py_None;
PyObject *events_seq;
PyObject* events; /* event collector */ if (!PyArg_ParseTuple(args, "O!|O:_setevents", &PyList_Type, &events_queue,
PyObject* event_set = Py_None; &events_to_report))
if (!PyArg_ParseTuple(args, "O!|O:_setevents", &PyList_Type, &events,
&event_set))
return NULL; return NULL;
if (!TreeBuilder_CheckExact(self->target)) { if (!TreeBuilder_CheckExact(self->target)) {
...@@ -3452,9 +3452,9 @@ xmlparser_setevents(XMLParserObject *self, PyObject* args) ...@@ -3452,9 +3452,9 @@ xmlparser_setevents(XMLParserObject *self, PyObject* args)
target = (TreeBuilderObject*) self->target; target = (TreeBuilderObject*) self->target;
Py_INCREF(events); Py_INCREF(events_queue);
Py_XDECREF(target->events); Py_XDECREF(target->events);
target->events = events; target->events = events_queue;
/* clear out existing events */ /* clear out existing events */
Py_CLEAR(target->start_event_obj); Py_CLEAR(target->start_event_obj);
...@@ -3462,69 +3462,65 @@ xmlparser_setevents(XMLParserObject *self, PyObject* args) ...@@ -3462,69 +3462,65 @@ xmlparser_setevents(XMLParserObject *self, PyObject* args)
Py_CLEAR(target->start_ns_event_obj); Py_CLEAR(target->start_ns_event_obj);
Py_CLEAR(target->end_ns_event_obj); Py_CLEAR(target->end_ns_event_obj);
if (event_set == Py_None) { if (events_to_report == Py_None) {
/* default is "end" only */ /* default is "end" only */
target->end_event_obj = PyUnicode_FromString("end"); target->end_event_obj = PyUnicode_FromString("end");
Py_RETURN_NONE; Py_RETURN_NONE;
} }
if (!PyTuple_Check(event_set)) /* FIXME: handle arbitrary sequences */ if (!(events_seq = PySequence_Fast(events_to_report,
goto error; "events must be a sequence"))) {
return NULL;
}
for (i = 0; i < PyTuple_GET_SIZE(event_set); i++) { seqlen = PySequence_Size(events_seq);
PyObject* item = PyTuple_GET_ITEM(event_set, i); for (i = 0; i < seqlen; ++i) {
char* event; PyObject *event_name_obj = PySequence_Fast_GET_ITEM(events_seq, i);
if (PyUnicode_Check(item)) { char *event_name = NULL;
event = _PyUnicode_AsString(item); if (PyUnicode_Check(event_name_obj)) {
if (event == NULL) event_name = _PyUnicode_AsString(event_name_obj);
goto error; } else if (PyBytes_Check(event_name_obj)) {
} else if (PyBytes_Check(item)) event_name = PyBytes_AS_STRING(event_name_obj);
event = PyBytes_AS_STRING(item);
else {
goto error;
} }
if (strcmp(event, "start") == 0) {
Py_INCREF(item); if (event_name == NULL) {
target->start_event_obj = item; Py_DECREF(events_seq);
} else if (strcmp(event, "end") == 0) { PyErr_Format(PyExc_ValueError, "invalid events sequence");
Py_INCREF(item); return NULL;
} else if (strcmp(event_name, "start") == 0) {
Py_INCREF(event_name_obj);
target->start_event_obj = event_name_obj;
} else if (strcmp(event_name, "end") == 0) {
Py_INCREF(event_name_obj);
Py_XDECREF(target->end_event_obj); Py_XDECREF(target->end_event_obj);
target->end_event_obj = item; target->end_event_obj = event_name_obj;
} else if (strcmp(event, "start-ns") == 0) { } else if (strcmp(event_name, "start-ns") == 0) {
Py_INCREF(item); Py_INCREF(event_name_obj);
Py_XDECREF(target->start_ns_event_obj); Py_XDECREF(target->start_ns_event_obj);
target->start_ns_event_obj = item; target->start_ns_event_obj = event_name_obj;
EXPAT(SetNamespaceDeclHandler)( EXPAT(SetNamespaceDeclHandler)(
self->parser, self->parser,
(XML_StartNamespaceDeclHandler) expat_start_ns_handler, (XML_StartNamespaceDeclHandler) expat_start_ns_handler,
(XML_EndNamespaceDeclHandler) expat_end_ns_handler (XML_EndNamespaceDeclHandler) expat_end_ns_handler
); );
} else if (strcmp(event, "end-ns") == 0) { } else if (strcmp(event_name, "end-ns") == 0) {
Py_INCREF(item); Py_INCREF(event_name_obj);
Py_XDECREF(target->end_ns_event_obj); Py_XDECREF(target->end_ns_event_obj);
target->end_ns_event_obj = item; target->end_ns_event_obj = event_name_obj;
EXPAT(SetNamespaceDeclHandler)( EXPAT(SetNamespaceDeclHandler)(
self->parser, self->parser,
(XML_StartNamespaceDeclHandler) expat_start_ns_handler, (XML_StartNamespaceDeclHandler) expat_start_ns_handler,
(XML_EndNamespaceDeclHandler) expat_end_ns_handler (XML_EndNamespaceDeclHandler) expat_end_ns_handler
); );
} else { } else {
PyErr_Format( Py_DECREF(events_seq);
PyExc_ValueError, PyErr_Format(PyExc_ValueError, "unknown event '%s'", event_name);
"unknown event '%s'", event
);
return NULL; return NULL;
} }
} }
Py_DECREF(events_seq);
Py_RETURN_NONE; Py_RETURN_NONE;
error:
PyErr_SetString(
PyExc_TypeError,
"invalid event tuple"
);
return NULL;
} }
static PyMethodDef xmlparser_methods[] = { static PyMethodDef xmlparser_methods[] = {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment