Commit 2e8a0084 authored by Stefan Behnel's avatar Stefan Behnel

Rewrite of the string literal handling code

String literals pass through the compiler as follows:
- unicode string literals are stored as unicode strings and encoded to UTF-8 on the way out
- byte string literals are stored as correctly encoded byte strings by unescaping the source string literal into the corresponding byte sequence. No further encoding is done later on!
- char literals are stored as byte strings of length 1. This can be verified by the parser now, e.g. a non-ASCII char literal in UTF-8 source code will result in an error, as it would end up as two or more bytes in the C code, which can no longer be represented as a C char.

Storing byte strings is necessary as we otherwise loose the ability to encode byte string literals on the way out. They do not necessarily contain only bytes that fit into the source code encoding as the source can use escape sequences to represent them. Previously, ASCII encoded source code could not contain byte string literals with properly escaped non-ASCII bytes.

Another bug that was fixed: in Python, escape sequences behave different in unicode strings (where they represent the character code) and byte strings (where they represent a byte value). Previously, they resulted in the same byte value in Cython code. This is only a problem for non-ASCII escapes, since the character code and the byte value of ASCII characters are identical.
parent 7bc8549a
......@@ -3,7 +3,7 @@ from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError
import Interpreter
import PyrexTypes
......
......@@ -6,6 +6,7 @@ import operator
from string import join
from Errors import error, warning, InternalError
import StringEncoding
import Naming
from Nodes import Node
import PyrexTypes
......@@ -14,7 +15,6 @@ from Builtin import list_type, tuple_type, dict_type, unicode_type
import Symtab
import Options
from Annotate import AnnotationItem
from Cython import Utils
from Cython.Debugging import print_call_chain
from DebugFlags import debug_disposal_code, debug_temp_alloc, \
......@@ -640,10 +640,10 @@ class CharNode(ConstNode):
type = PyrexTypes.c_char_type
def compile_time_value(self, denv):
return ord(self.value.byteencode())
return ord(self.value)
def calculate_result_code(self):
return "'%s'" % Utils.escape_character(self.value.byteencode())
return "'%s'" % StringEncoding.escape_character(self.value)
class IntNode(ConstNode):
......
......@@ -397,6 +397,8 @@ class Context:
finally:
f.close()
except UnicodeDecodeError, msg:
import traceback
traceback.print_exc()
error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
if Errors.num_errors > 0:
raise CompileError
......
......@@ -23,7 +23,8 @@ import Version
from Errors import error, warning
from PyrexTypes import py_object_type
from Cython.Utils import open_new_file, replace_suffix, escape_byte_string, EncodedString
from Cython.Utils import open_new_file, replace_suffix
from StringEncoding import escape_byte_string, EncodedString
def check_c_classes(module_node):
......@@ -514,9 +515,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln('static const char *%s;' % Naming.filename_cname)
code.putln('static const char **%s;' % Naming.filetable_cname)
if env.doc:
docstr = env.doc
if not isinstance(docstr, str):
docstr = docstr.utf8encode()
code.putln('')
code.putln('static char %s[] = "%s";' % (
env.doc_cname, escape_byte_string(env.doc.utf8encode())))
env.doc_cname, escape_byte_string(docstr)))
def generate_extern_c_macro_definition(self, code):
name = Naming.extern_c_macro
......
......@@ -13,7 +13,7 @@ from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType
from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \
StructOrUnionScope, PyClassScope, CClassScope
from Cython.Utils import open_new_file, replace_suffix
from Cython.Utils import EncodedString, escape_byte_string
from StringEncoding import EncodedString, escape_byte_string
import Options
import ControlFlow
......@@ -1516,10 +1516,13 @@ class DefNode(FuncDefNode):
if proto_only:
return
if self.entry.doc and Options.docstrings:
docstr = self.entry.doc
if not isinstance(docstr, str):
docstr = docstr.utf8encode()
code.putln(
'static char %s[] = "%s";' % (
self.entry.doc_cname,
escape_byte_string(self.entry.doc.utf8encode())))
escape_byte_string(docstr)))
if with_pymethdef:
code.put(
"static PyMethodDef %s = " %
......
......@@ -3,7 +3,7 @@ from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError
try:
set
......
......@@ -9,6 +9,8 @@ from types import ListType, TupleType
from Scanning import PyrexScanner, FileSourceDescriptor
import Nodes
import ExprNodes
import StringEncoding
from StringEncoding import EncodedString, BytesLiteral
from ModuleNode import ModuleNode
from Errors import error, warning, InternalError
from Cython import Utils
......@@ -280,7 +282,7 @@ def p_trailer(s, node1):
return p_index(s, node1)
else: # s.sy == '.'
s.next()
name = Utils.EncodedString( p_ident(s) )
name = EncodedString( p_ident(s) )
return ExprNodes.AttributeNode(pos,
obj = node1, attribute = name)
......@@ -302,7 +304,7 @@ def p_call(s, function):
if not arg.is_name:
s.error("Expected an identifier before '='",
pos = arg.pos)
encoded_name = Utils.EncodedString(arg.name)
encoded_name = EncodedString(arg.name)
keyword = ExprNodes.IdentifierStringNode(arg.pos,
value = encoded_name)
arg = p_simple_expr(s)
......@@ -498,7 +500,7 @@ def p_atom(s):
else:
return ExprNodes.StringNode(pos, value = value)
elif sy == 'IDENT':
name = Utils.EncodedString( s.systring )
name = EncodedString( s.systring )
s.next()
if name == "None":
return ExprNodes.NoneNode(pos)
......@@ -533,6 +535,8 @@ def p_name(s, name):
return ExprNodes.FloatNode(pos, value = rep)
elif isinstance(value, unicode):
return ExprNodes.StringNode(pos, value = value)
elif isinstance(value, str):
return ExprNodes.StringNode(pos, value = value)
else:
error(pos, "Invalid type for compile-time constant: %s"
% value.__class__.__name__)
......@@ -549,11 +553,21 @@ def p_cat_string_literal(s):
if next_kind == 'c':
error(s.position(),
"Cannot concatenate char literal with another string or char literal")
elif next_kind == 'u':
elif next_kind != kind:
# we have to switch to unicode now
if kind == 'b':
# concatenating a unicode string to byte strings
strings = [u''.join([s.decode(s.encoding) for s in strings])]
elif kind == 'u':
# concatenating a byte string to unicode strings
strings.append(next_value.decode(next_value.encoding))
kind = 'u'
else:
strings.append(next_value)
value = Utils.EncodedString( u''.join(strings) )
if kind != 'u':
if kind == 'u':
value = EncodedString( u''.join(strings) )
else:
value = BytesLiteral( ''.join(strings) )
value.encoding = s.source_encoding
return kind, value
......@@ -582,7 +596,10 @@ def p_string_literal(s):
kind = 'u'
elif kind == '':
kind = 'b'
chars = []
if kind == 'u':
chars = StringEncoding.UnicodeLiteralBuilder()
else:
chars = StringEncoding.BytesLiteralBuilder(s.source_encoding)
while 1:
s.next()
sy = s.sy
......@@ -590,41 +607,46 @@ def p_string_literal(s):
if sy == 'CHARS':
chars.append(s.systring)
elif sy == 'ESCAPE':
has_escape = True
systr = s.systring
if is_raw:
if systr == '\\\n':
chars.append('\\\n')
elif systr == '\\\"':
chars.append('"')
elif systr == '\\\'':
chars.append("'")
if systr == u'\\\n':
chars.append(u'\\\n')
elif systr == u'\\\"':
chars.append(u'"')
elif systr == u'\\\'':
chars.append(u"'")
else:
chars.append(systr)
else:
c = systr[1]
if c in "01234567":
chars.append(chr(int(systr[1:], 8)))
elif c in "'\"\\":
if c in u"01234567":
chars.append_charval( int(systr[1:], 8) )
elif c in u"'\"\\":
chars.append(c)
elif c in "abfnrtv":
chars.append(Utils.char_from_escape_sequence(systr))
elif c == '\n':
elif c in u"abfnrtv":
chars.append(
StringEncoding.char_from_escape_sequence(systr))
elif c == u'\n':
pass
elif c in 'Uux':
elif c in u'Uux':
if kind == 'u' or c == 'x':
chrval = int(systr[2:], 16)
if chrval > 1114111: # sys.maxunicode:
s.error("Invalid unicode escape '%s'" % systr,
pos = pos)
strval = unichr(chrval)
elif chrval > 65535:
warning(s.position(),
"Unicode characters above 65535 are not "
"necessarily portable across Python installations", 1)
chars.append_charval(chrval)
else:
# unicode escapes in plain byte strings are not unescaped
strval = systr
chars.append(strval)
chars.append(systr)
else:
chars.append('\\' + systr[1:])
chars.append(u'\\' + systr[1:])
elif sy == 'NEWLINE':
chars.append('\n')
chars.append(u'\n')
elif sy == 'END_STRING':
break
elif sy == 'EOF':
......@@ -633,13 +655,13 @@ def p_string_literal(s):
s.error(
"Unexpected token %r:%r in string literal" %
(sy, s.systring))
string = u''.join(chars)
if kind == 'c' and len(string) != 1:
error(pos, u"invalid character literal: %r" % string)
if kind == 'c':
value = chars.getchar()
if len(value) != 1:
error(pos, u"invalid character literal: %r" % value)
else:
value = chars.getstring()
s.next()
value = Utils.EncodedString(string)
if kind != 'u':
value.encoding = s.source_encoding
#print "p_string_literal: value =", repr(value) ###
return kind, value
......@@ -943,7 +965,7 @@ def p_import_statement(s):
items.append(p_dotted_name(s, as_allowed = 1))
stats = []
for pos, target_name, dotted_name, as_name in items:
dotted_name = Utils.EncodedString(dotted_name)
dotted_name = EncodedString(dotted_name)
if kind == 'cimport':
stat = Nodes.CImportStatNode(pos,
module_name = dotted_name,
......@@ -951,7 +973,7 @@ def p_import_statement(s):
else:
if as_name and "." in dotted_name:
name_list = ExprNodes.ListNode(pos, args = [
ExprNodes.StringNode(pos, value = Utils.EncodedString("*"))])
ExprNodes.StringNode(pos, value = EncodedString("*"))])
else:
name_list = None
stat = Nodes.SingleAssignmentNode(pos,
......@@ -984,7 +1006,7 @@ def p_from_import_statement(s, first_statement = 0):
while s.sy == ',':
s.next()
imported_names.append(p_imported_name(s, is_cimport))
dotted_name = Utils.EncodedString(dotted_name)
dotted_name = EncodedString(dotted_name)
if dotted_name == '__future__':
if not first_statement:
s.error("from __future__ imports must occur at the beginning of the file")
......@@ -1011,7 +1033,7 @@ def p_from_import_statement(s, first_statement = 0):
imported_name_strings = []
items = []
for (name_pos, name, as_name, kind) in imported_names:
encoded_name = Utils.EncodedString(name)
encoded_name = EncodedString(name)
imported_name_strings.append(
ExprNodes.IdentifierStringNode(name_pos, value = encoded_name))
items.append(
......@@ -1020,7 +1042,7 @@ def p_from_import_statement(s, first_statement = 0):
name = as_name or name)))
import_list = ExprNodes.ListNode(
imported_names[0][0], args = imported_name_strings)
dotted_name = Utils.EncodedString(dotted_name)
dotted_name = EncodedString(dotted_name)
return Nodes.FromImportStatNode(pos,
module = ExprNodes.ImportNode(dotted_name_pos,
module_name = ExprNodes.IdentifierStringNode(pos, value = dotted_name),
......@@ -1520,7 +1542,7 @@ def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keyword
else:
arg = p_simple_expr(s)
keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
value = Utils.EncodedString(ident))
value = EncodedString(ident))
keyword_args.append((keyword_node, arg))
was_keyword = True
else:
......@@ -2136,10 +2158,10 @@ def p_decorators(s):
s.next()
decstring = p_dotted_name(s, as_allowed=0)[2]
names = decstring.split('.')
decorator = ExprNodes.NameNode(pos, name=Utils.EncodedString(names[0]))
decorator = ExprNodes.NameNode(pos, name=EncodedString(names[0]))
for name in names[1:]:
decorator = ExprNodes.AttributeNode(pos,
attribute=Utils.EncodedString(name),
attribute=EncodedString(name),
obj=decorator)
if s.sy == '(':
decorator = p_call(s, decorator)
......@@ -2187,7 +2209,7 @@ def p_class_statement(s):
# s.sy == 'class'
pos = s.position()
s.next()
class_name = Utils.EncodedString( p_ident(s) )
class_name = EncodedString( p_ident(s) )
class_name.encoding = s.source_encoding
if s.sy == '(':
s.next()
......
......@@ -2,7 +2,7 @@
# Pyrex - Types
#
from Cython import Utils
import StringEncoding
import Naming
import copy
......@@ -1000,7 +1000,7 @@ class CStringType:
def literal_code(self, value):
assert isinstance(value, str)
return '"%s"' % Utils.escape_byte_string(value)
return '"%s"' % StringEncoding.escape_byte_string(value)
class CUTF8CharArrayType(CStringType, CArrayType):
......
......@@ -17,7 +17,7 @@ from Cython.Plex.Errors import UnrecognizedInput
from Errors import CompileError, error
from Lexicon import string_prefixes, raw_prefixes, make_lexicon
from Cython import Utils
from StringEncoding import EncodedString
plex_version = getattr(Plex, '_version', None)
#print "Plex version:", plex_version ###
......@@ -413,7 +413,7 @@ class PyrexScanner(Scanner):
if systring in self.resword_dict:
sy = systring
else:
systring = Utils.EncodedString(systring)
systring = EncodedString(systring)
systring.encoding = self.source_encoding
self.sy = sy
self.systring = systring
......
#
# Cython -- encoding related tools
#
import re
class UnicodeLiteralBuilder(object):
"""Assemble a unicode string.
"""
def __init__(self):
self.chars = []
def append(self, characters):
if isinstance(characters, str):
# this came from a Py2 string literal in the parser code
characters = characters.decode("ASCII")
assert isinstance(characters, unicode), str(type(characters))
self.chars.append(characters)
def append_charval(self, char_number):
self.chars.append( unichr(char_number) )
def getstring(self):
return EncodedString(u''.join(self.chars))
class BytesLiteralBuilder(object):
"""Assemble a byte string or char value.
"""
def __init__(self, target_encoding):
self.chars = []
self.target_encoding = target_encoding
def append(self, characters):
if isinstance(characters, unicode):
characters = characters.encode(self.target_encoding)
assert isinstance(characters, str), str(type(characters))
self.chars.append(characters)
def append_charval(self, char_number):
self.chars.append( chr(char_number) )
def getstring(self):
# this *must* return a byte string! => fix it in Py3k!!
s = BytesLiteral(''.join(self.chars))
s.encoding = self.target_encoding
return s
def getchar(self):
# this *must* return a byte string! => fix it in Py3k!!
return self.getstring()
class EncodedString(unicode):
# unicode string subclass to keep track of the original encoding.
# 'encoding' is None for unicode strings and the source encoding
# otherwise
encoding = None
def byteencode(self):
assert self.encoding is not None
return self.encode(self.encoding)
def utf8encode(self):
assert self.encoding is None
return self.encode("UTF-8")
def is_unicode(self):
return self.encoding is None
is_unicode = property(is_unicode)
class BytesLiteral(str):
# str subclass that is compatible with EncodedString
encoding = None
def byteencode(self):
return str(self)
def utf8encode(self):
assert False, "this is not a unicode string: %r" % self
is_unicode = False
char_from_escape_sequence = {
r'\a' : u'\a',
r'\b' : u'\b',
r'\f' : u'\f',
r'\n' : u'\n',
r'\r' : u'\r',
r'\t' : u'\t',
r'\v' : u'\v',
}.get
def _to_escape_sequence(s):
if s in '\n\r\t':
return repr(s)[1:-1]
elif s == '"':
return r'\"'
else:
# within a character sequence, oct passes much better than hex
return ''.join(['\\%03o' % ord(c) for c in s])
_c_special = ('\0', '\n', '\r', '\t', '??', '"')
_c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special))
def _build_specials_test():
subexps = []
for special in _c_special:
regexp = ''.join(['[%s]' % c for c in special])
subexps.append(regexp)
return re.compile('|'.join(subexps)).search
_has_specials = _build_specials_test()
def escape_character(c):
if c in '\n\r\t\\':
return repr(c)[1:-1]
elif c == "'":
return "\\'"
n = ord(c)
if n < 32 or n > 127:
# hex works well for characters
return "\\x%02X" % n
else:
return c
def escape_byte_string(s):
s = s.replace('\\', '\\\\')
if _has_specials(s):
for special, replacement in _c_special_replacements:
s = s.replace(special, replacement)
try:
s.decode("ASCII")
return s
except UnicodeDecodeError:
pass
l = []
append = l.append
for c in s:
o = ord(c)
if o >= 128:
append('\\%3o' % o)
else:
append(c)
return ''.join(l)
......@@ -5,6 +5,7 @@
import re
from Cython import Utils
from Errors import warning, error, InternalError
from StringEncoding import EncodedString
import Options
import Naming
import PyrexTypes
......@@ -684,14 +685,14 @@ class BuiltinScope(Scope):
utility_code = None):
# If python_equiv == "*", the Python equivalent has the same name
# as the entry, otherwise it has the name specified by python_equiv.
name = Utils.EncodedString(name)
name = EncodedString(name)
entry = self.declare_cfunction(name, type, None, cname)
entry.utility_code = utility_code
if python_equiv:
if python_equiv == "*":
python_equiv = name
else:
python_equiv = Utils.EncodedString(python_equiv)
python_equiv = EncodedString(python_equiv)
var_entry = Entry(python_equiv, python_equiv, py_object_type)
var_entry.is_variable = 1
var_entry.is_builtin = 1
......@@ -699,7 +700,7 @@ class BuiltinScope(Scope):
return entry
def declare_builtin_type(self, name, cname):
name = Utils.EncodedString(name)
name = EncodedString(name)
type = PyrexTypes.BuiltinObjectType(name, cname)
type.set_scope(CClassScope(name, outer_scope=None, visibility='extern'))
self.type_names[name] = 1
......@@ -1370,7 +1371,7 @@ class CClassScope(ClassScope):
if name == "__new__":
warning(pos, "__new__ method of extension type will change semantics "
"in a future version of Pyrex and Cython. Use __cinit__ instead.")
name = Utils.EncodedString("__cinit__")
name = EncodedString("__cinit__")
entry = self.declare_var(name, py_object_type, pos, visibility='extern')
special_sig = get_special_method_signature(name)
if special_sig:
......@@ -1387,7 +1388,7 @@ class CClassScope(ClassScope):
def lookup_here(self, name):
if name == "__new__":
name = Utils.EncodedString("__cinit__")
name = EncodedString("__cinit__")
return ClassScope.lookup_here(self, name)
def declare_cfunction(self, name, type, pos,
......
......@@ -3,9 +3,9 @@
# and associated know-how.
#
from Cython import Utils
import Naming
import PyrexTypes
import StringEncoding
import sys
class Signature:
......@@ -311,7 +311,7 @@ class DocStringSlot(SlotDescriptor):
doc = scope.doc.utf8encode()
else:
doc = scope.doc.byteencode()
return '"%s"' % Utils.escape_byte_string(doc)
return '"%s"' % StringEncoding.escape_byte_string(doc)
else:
return "0"
......
......@@ -5,7 +5,7 @@ import inspect
import Nodes
import ExprNodes
import Naming
from Cython.Utils import EncodedString
from StringEncoding import EncodedString
class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object."""
......
......@@ -40,7 +40,7 @@ def file_newer_than(path, time):
ftime = modification_time(path)
return ftime > time
# support for source file encoding detection and unicode decoding
# support for source file encoding detection
def encode_filename(filename):
if isinstance(filename, unicode):
......@@ -77,90 +77,6 @@ def open_source_file(source_filename, mode="rU"):
encoding = detect_file_encoding(source_filename)
return codecs.open(source_filename, mode=mode, encoding=encoding)
class EncodedString(unicode):
# unicode string subclass to keep track of the original encoding.
# 'encoding' is None for unicode strings and the source encoding
# otherwise
encoding = None
def byteencode(self):
assert self.encoding is not None
return self.encode(self.encoding)
def utf8encode(self):
assert self.encoding is None
return self.encode("UTF-8")
def is_unicode(self):
return self.encoding is None
is_unicode = property(is_unicode)
# def __eq__(self, other):
# return unicode.__eq__(self, other) and \
# getattr(other, 'encoding', '') == self.encoding
char_from_escape_sequence = {
r'\a' : '\a',
r'\b' : '\b',
r'\f' : '\f',
r'\n' : '\n',
r'\r' : '\r',
r'\t' : '\t',
r'\v' : '\v',
}.get
def _to_escape_sequence(s):
if s in '\n\r\t':
return repr(s)[1:-1]
elif s == '"':
return r'\"'
else:
# within a character sequence, oct passes much better than hex
return ''.join(['\\%03o' % ord(c) for c in s])
_c_special = ('\0', '\n', '\r', '\t', '??', '"')
_c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special))
def _build_specials_test():
subexps = []
for special in _c_special:
regexp = ''.join(['[%s]' % c for c in special])
subexps.append(regexp)
return re.compile('|'.join(subexps)).search
_has_specials = _build_specials_test()
def escape_character(c):
if c in '\n\r\t\\':
return repr(c)[1:-1]
elif c == "'":
return "\\'"
elif ord(c) < 32:
# hex works well for characters
return "\\x%02X" % ord(c)
else:
return c
def escape_byte_string(s):
s = s.replace('\\', '\\\\')
if _has_specials(s):
for special, replacement in _c_special_replacements:
s = s.replace(special, replacement)
try:
s.decode("ASCII")
return s
except UnicodeDecodeError:
pass
l = []
append = l.append
for c in s:
o = ord(c)
if o >= 128:
append('\\%3o' % o)
else:
append(c)
return ''.join(l)
def long_literal(value):
if isinstance(value, basestring):
if len(value) < 2:
......
# coding: ASCII
__doc__ = u"""
>>> s = test()
>>> assert s == ''.join([chr(i) for i in range(0x10,0xFF,0x11)] + [chr(0xFF)]), repr(s)
"""
def test():
cdef char s[17]
s[ 0] = c'\x10'
s[ 1] = c'\x21'
s[ 2] = c'\x32'
s[ 3] = c'\x43'
s[ 4] = c'\x54'
s[ 5] = c'\x65'
s[ 6] = c'\x76'
s[ 7] = c'\x87'
s[ 8] = c'\x98'
s[ 9] = c'\xA9'
s[10] = c'\xBA'
s[11] = c'\xCB'
s[12] = c'\xDC'
s[13] = c'\xED'
s[14] = c'\xFE'
s[15] = c'\xFF'
s[16] = c'\x00'
return s
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment