Commit af9fce7d authored by Just van Rossum's avatar Just van Rossum

Fix and test for bug #764548:

Use isinstance() instead of comparing types directly, to enable
subclasses of str and unicode to be used as patterns.
Blessed by /F.
parent a8ac59aa
...@@ -219,9 +219,9 @@ def _compile(*key): ...@@ -219,9 +219,9 @@ def _compile(*key):
if p is not None: if p is not None:
return p return p
pattern, flags = key pattern, flags = key
if type(pattern) is _pattern_type: if isinstance(pattern, _pattern_type):
return pattern return pattern
if type(pattern) not in sre_compile.STRING_TYPES: if not isinstance(pattern, sre_compile.STRING_TYPES):
raise TypeError, "first argument must be string or compiled pattern" raise TypeError, "first argument must be string or compiled pattern"
try: try:
p = sre_compile.compile(pattern, flags) p = sre_compile.compile(pattern, flags)
......
...@@ -428,12 +428,12 @@ def _compile_info(code, pattern, flags): ...@@ -428,12 +428,12 @@ def _compile_info(code, pattern, flags):
_compile_charset(charset, flags, code) _compile_charset(charset, flags, code)
code[skip] = len(code) - skip code[skip] = len(code) - skip
STRING_TYPES = [type("")]
try: try:
STRING_TYPES.append(type(unicode(""))) unicode
except NameError: except NameError:
pass STRING_TYPES = type("")
else:
STRING_TYPES = (type(""), type(unicode("")))
def _code(p, flags): def _code(p, flags):
...@@ -453,7 +453,7 @@ def _code(p, flags): ...@@ -453,7 +453,7 @@ def _code(p, flags):
def compile(p, flags=0): def compile(p, flags=0):
# internal: convert pattern list to internal format # internal: convert pattern list to internal format
if type(p) in STRING_TYPES: if isinstance(p, STRING_TYPES):
import sre_parse import sre_parse
pattern = p pattern = p
p = sre_parse.parse(p, flags) p = sre_parse.parse(p, flags)
......
...@@ -474,6 +474,16 @@ class ReTests(unittest.TestCase): ...@@ -474,6 +474,16 @@ class ReTests(unittest.TestCase):
self.assertEqual(re.match('(a)((?!(b)*))*', 'abb').groups(), self.assertEqual(re.match('(a)((?!(b)*))*', 'abb').groups(),
('a', None, None)) ('a', None, None))
def test_bug_764548(self):
# bug 764548, re.compile() barfs on str/unicode subclasses
try:
unicode
except NameError:
return # no problem if we have no unicode
class my_unicode(unicode): pass
pat = re.compile(my_unicode("abc"))
self.assertEqual(pat.match("xyz"), None)
def test_finditer(self): def test_finditer(self):
iter = re.finditer(r":+", "a:b::c:::d") iter = re.finditer(r":+", "a:b::c:::d")
self.assertEqual([item.group(0) for item in iter], self.assertEqual([item.group(0) for item in iter],
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment