Commit bd019533 authored by Stefan Behnel's avatar Stefan Behnel

Py3 fix

parent b8c14a8d
...@@ -3,6 +3,17 @@ ...@@ -3,6 +3,17 @@
# #
import re import re
import sys
if sys.version_info[0] >= 3:
_str, _bytes = str, bytes
else:
_str, _bytes = unicode, str
empty_bytes = _bytes()
empty_str = _str()
join_bytes = empty_bytes.join
class UnicodeLiteralBuilder(object): class UnicodeLiteralBuilder(object):
"""Assemble a unicode string. """Assemble a unicode string.
...@@ -11,10 +22,10 @@ class UnicodeLiteralBuilder(object): ...@@ -11,10 +22,10 @@ class UnicodeLiteralBuilder(object):
self.chars = [] self.chars = []
def append(self, characters): def append(self, characters):
if isinstance(characters, str): if isinstance(characters, _bytes):
# this came from a Py2 string literal in the parser code # this came from a Py2 string literal in the parser code
characters = characters.decode("ASCII") characters = characters.decode("ASCII")
assert isinstance(characters, unicode), str(type(characters)) assert isinstance(characters, _str), str(type(characters))
self.chars.append(characters) self.chars.append(characters)
def append_charval(self, char_number): def append_charval(self, char_number):
...@@ -32,9 +43,9 @@ class BytesLiteralBuilder(object): ...@@ -32,9 +43,9 @@ class BytesLiteralBuilder(object):
self.target_encoding = target_encoding self.target_encoding = target_encoding
def append(self, characters): def append(self, characters):
if isinstance(characters, unicode): if isinstance(characters, _str):
characters = characters.encode(self.target_encoding) characters = characters.encode(self.target_encoding)
assert isinstance(characters, str), str(type(characters)) assert isinstance(characters, _bytes), str(type(characters))
self.chars.append(characters) self.chars.append(characters)
def append_charval(self, char_number): def append_charval(self, char_number):
...@@ -42,7 +53,7 @@ class BytesLiteralBuilder(object): ...@@ -42,7 +53,7 @@ class BytesLiteralBuilder(object):
def getstring(self): def getstring(self):
# this *must* return a byte string! => fix it in Py3k!! # this *must* return a byte string! => fix it in Py3k!!
s = BytesLiteral(''.join(self.chars)) s = BytesLiteral(join_bytes(self.chars))
s.encoding = self.target_encoding s.encoding = self.target_encoding
return s return s
...@@ -50,7 +61,7 @@ class BytesLiteralBuilder(object): ...@@ -50,7 +61,7 @@ class BytesLiteralBuilder(object):
# this *must* return a byte string! => fix it in Py3k!! # this *must* return a byte string! => fix it in Py3k!!
return self.getstring() return self.getstring()
class EncodedString(unicode): class EncodedString(_str):
# unicode string subclass to keep track of the original encoding. # unicode string subclass to keep track of the original encoding.
# 'encoding' is None for unicode strings and the source encoding # 'encoding' is None for unicode strings and the source encoding
# otherwise # otherwise
...@@ -68,7 +79,7 @@ class EncodedString(unicode): ...@@ -68,7 +79,7 @@ class EncodedString(unicode):
return self.encoding is None return self.encoding is None
is_unicode = property(is_unicode) is_unicode = property(is_unicode)
class BytesLiteral(str): class BytesLiteral(_bytes):
# str subclass that is compatible with EncodedString # str subclass that is compatible with EncodedString
encoding = None encoding = None
...@@ -95,19 +106,23 @@ def _to_escape_sequence(s): ...@@ -95,19 +106,23 @@ def _to_escape_sequence(s):
return repr(s)[1:-1] return repr(s)[1:-1]
elif s == '"': elif s == '"':
return r'\"' return r'\"'
elif s == '\\':
return r'\\'
else: else:
# within a character sequence, oct passes much better than hex # within a character sequence, oct passes much better than hex
return ''.join(['\\%03o' % ord(c) for c in s]) return ''.join(['\\%03o' % ord(c) for c in s])
_c_special = ('\0', '\n', '\r', '\t', '??', '"') _c_special = ('\\', '\0', '\n', '\r', '\t', '??', '"')
_c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special)) _c_special_replacements = [(orig.encode('ASCII'),
_to_escape_sequence(orig).encode('ASCII'))
for orig in _c_special ]
def _build_specials_test(): def _build_specials_test():
subexps = [] subexps = []
for special in _c_special: for special in _c_special:
regexp = ''.join(['[%s]' % c for c in special]) regexp = ''.join(['[%s]' % c for c in special])
subexps.append(regexp) subexps.append(regexp)
return re.compile('|'.join(subexps)).search return re.compile('|'.join(subexps).encode('ASCII')).search
_has_specials = _build_specials_test() _has_specials = _build_specials_test()
...@@ -124,24 +139,33 @@ def escape_character(c): ...@@ -124,24 +139,33 @@ def escape_character(c):
return c return c
def escape_byte_string(s): def escape_byte_string(s):
s = s.replace('\\', '\\\\')
if _has_specials(s): if _has_specials(s):
for special, replacement in _c_special_replacements: for special, replacement in _c_special_replacements:
s = s.replace(special, replacement) s = s.replace(special, replacement)
try: try:
s.decode("ASCII") s.decode("ASCII") # trial decoding: plain ASCII => done
return s return s
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
l = [] if sys.version_info[0] >= 3:
append = l.append s_new = bytearray()
for c in s: append, extend = s_new.append, s_new.extend
o = ord(c) for b in s:
if o >= 128: if b >= 128:
append('\\%3o' % o) extend(('\\%3o' % b).encode('ASCII'))
else: else:
append(c) append(b)
return ''.join(l) return bytes(s_new)
else:
l = []
append = l.append
for c in s:
o = ord(c)
if o >= 128:
append('\\%3o' % o)
else:
append(c)
return join_bytes(l)
def split_docstring(s): def split_docstring(s):
if len(s) < 2047: if len(s) < 2047:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment