Commit 2e8a0084 authored by Stefan Behnel's avatar Stefan Behnel

Rewrite of the string literal handling code

String literals pass through the compiler as follows:
- unicode string literals are stored as unicode strings and encoded to UTF-8 on the way out
- byte string literals are stored as correctly encoded byte strings by unescaping the source string literal into the corresponding byte sequence. No further encoding is done later on!
- char literals are stored as byte strings of length 1. This can be verified by the parser now, e.g. a non-ASCII char literal in UTF-8 source code will result in an error, as it would end up as two or more bytes in the C code, which can no longer be represented as a C char.

Storing byte strings is necessary as we otherwise loose the ability to encode byte string literals on the way out. They do not necessarily contain only bytes that fit into the source code encoding as the source can use escape sequences to represent them. Previously, ASCII encoded source code could not contain byte string literals with properly escaped non-ASCII bytes.

Another bug that was fixed: in Python, escape sequences behave different in unicode strings (where they represent the character code) and byte strings (where they represent a byte value). Previously, they resulted in the same byte value in Cython code. This is only a problem for non-ASCII escapes, since the character code and the byte value of ASCII characters are identical.
parent 7bc8549a
...@@ -3,7 +3,7 @@ from Cython.Compiler.ModuleNode import ModuleNode ...@@ -3,7 +3,7 @@ from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError from Cython.Compiler.Errors import CompileError
import Interpreter import Interpreter
import PyrexTypes import PyrexTypes
......
...@@ -6,6 +6,7 @@ import operator ...@@ -6,6 +6,7 @@ import operator
from string import join from string import join
from Errors import error, warning, InternalError from Errors import error, warning, InternalError
import StringEncoding
import Naming import Naming
from Nodes import Node from Nodes import Node
import PyrexTypes import PyrexTypes
...@@ -14,7 +15,6 @@ from Builtin import list_type, tuple_type, dict_type, unicode_type ...@@ -14,7 +15,6 @@ from Builtin import list_type, tuple_type, dict_type, unicode_type
import Symtab import Symtab
import Options import Options
from Annotate import AnnotationItem from Annotate import AnnotationItem
from Cython import Utils
from Cython.Debugging import print_call_chain from Cython.Debugging import print_call_chain
from DebugFlags import debug_disposal_code, debug_temp_alloc, \ from DebugFlags import debug_disposal_code, debug_temp_alloc, \
...@@ -640,10 +640,10 @@ class CharNode(ConstNode): ...@@ -640,10 +640,10 @@ class CharNode(ConstNode):
type = PyrexTypes.c_char_type type = PyrexTypes.c_char_type
def compile_time_value(self, denv): def compile_time_value(self, denv):
return ord(self.value.byteencode()) return ord(self.value)
def calculate_result_code(self): def calculate_result_code(self):
return "'%s'" % Utils.escape_character(self.value.byteencode()) return "'%s'" % StringEncoding.escape_character(self.value)
class IntNode(ConstNode): class IntNode(ConstNode):
......
...@@ -397,6 +397,8 @@ class Context: ...@@ -397,6 +397,8 @@ class Context:
finally: finally:
f.close() f.close()
except UnicodeDecodeError, msg: except UnicodeDecodeError, msg:
import traceback
traceback.print_exc()
error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg) error((source_desc, 0, 0), "Decoding error, missing or incorrect coding=<encoding-name> at top of source (%s)" % msg)
if Errors.num_errors > 0: if Errors.num_errors > 0:
raise CompileError raise CompileError
......
...@@ -23,7 +23,8 @@ import Version ...@@ -23,7 +23,8 @@ import Version
from Errors import error, warning from Errors import error, warning
from PyrexTypes import py_object_type from PyrexTypes import py_object_type
from Cython.Utils import open_new_file, replace_suffix, escape_byte_string, EncodedString from Cython.Utils import open_new_file, replace_suffix
from StringEncoding import escape_byte_string, EncodedString
def check_c_classes(module_node): def check_c_classes(module_node):
...@@ -514,9 +515,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -514,9 +515,12 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln('static const char *%s;' % Naming.filename_cname) code.putln('static const char *%s;' % Naming.filename_cname)
code.putln('static const char **%s;' % Naming.filetable_cname) code.putln('static const char **%s;' % Naming.filetable_cname)
if env.doc: if env.doc:
docstr = env.doc
if not isinstance(docstr, str):
docstr = docstr.utf8encode()
code.putln('') code.putln('')
code.putln('static char %s[] = "%s";' % ( code.putln('static char %s[] = "%s";' % (
env.doc_cname, escape_byte_string(env.doc.utf8encode()))) env.doc_cname, escape_byte_string(docstr)))
def generate_extern_c_macro_definition(self, code): def generate_extern_c_macro_definition(self, code):
name = Naming.extern_c_macro name = Naming.extern_c_macro
......
...@@ -13,7 +13,7 @@ from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType ...@@ -13,7 +13,7 @@ from PyrexTypes import py_object_type, error_type, CTypedefType, CFuncType
from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \ from Symtab import ModuleScope, LocalScope, GeneratorLocalScope, \
StructOrUnionScope, PyClassScope, CClassScope StructOrUnionScope, PyClassScope, CClassScope
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
from Cython.Utils import EncodedString, escape_byte_string from StringEncoding import EncodedString, escape_byte_string
import Options import Options
import ControlFlow import ControlFlow
...@@ -1516,10 +1516,13 @@ class DefNode(FuncDefNode): ...@@ -1516,10 +1516,13 @@ class DefNode(FuncDefNode):
if proto_only: if proto_only:
return return
if self.entry.doc and Options.docstrings: if self.entry.doc and Options.docstrings:
docstr = self.entry.doc
if not isinstance(docstr, str):
docstr = docstr.utf8encode()
code.putln( code.putln(
'static char %s[] = "%s";' % ( 'static char %s[] = "%s";' % (
self.entry.doc_cname, self.entry.doc_cname,
escape_byte_string(self.entry.doc.utf8encode()))) escape_byte_string(docstr)))
if with_pymethdef: if with_pymethdef:
code.put( code.put(
"static PyMethodDef %s = " % "static PyMethodDef %s = " %
......
...@@ -3,7 +3,7 @@ from Cython.Compiler.ModuleNode import ModuleNode ...@@ -3,7 +3,7 @@ from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString from Cython.Compiler.StringEncoding import EncodedString
from Cython.Compiler.Errors import CompileError from Cython.Compiler.Errors import CompileError
try: try:
set set
......
This diff is collapsed.
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Pyrex - Types # Pyrex - Types
# #
from Cython import Utils import StringEncoding
import Naming import Naming
import copy import copy
...@@ -1000,7 +1000,7 @@ class CStringType: ...@@ -1000,7 +1000,7 @@ class CStringType:
def literal_code(self, value): def literal_code(self, value):
assert isinstance(value, str) assert isinstance(value, str)
return '"%s"' % Utils.escape_byte_string(value) return '"%s"' % StringEncoding.escape_byte_string(value)
class CUTF8CharArrayType(CStringType, CArrayType): class CUTF8CharArrayType(CStringType, CArrayType):
......
...@@ -17,7 +17,7 @@ from Cython.Plex.Errors import UnrecognizedInput ...@@ -17,7 +17,7 @@ from Cython.Plex.Errors import UnrecognizedInput
from Errors import CompileError, error from Errors import CompileError, error
from Lexicon import string_prefixes, raw_prefixes, make_lexicon from Lexicon import string_prefixes, raw_prefixes, make_lexicon
from Cython import Utils from StringEncoding import EncodedString
plex_version = getattr(Plex, '_version', None) plex_version = getattr(Plex, '_version', None)
#print "Plex version:", plex_version ### #print "Plex version:", plex_version ###
...@@ -413,7 +413,7 @@ class PyrexScanner(Scanner): ...@@ -413,7 +413,7 @@ class PyrexScanner(Scanner):
if systring in self.resword_dict: if systring in self.resword_dict:
sy = systring sy = systring
else: else:
systring = Utils.EncodedString(systring) systring = EncodedString(systring)
systring.encoding = self.source_encoding systring.encoding = self.source_encoding
self.sy = sy self.sy = sy
self.systring = systring self.systring = systring
......
#
# Cython -- encoding related tools
#
import re
class UnicodeLiteralBuilder(object):
"""Assemble a unicode string.
"""
def __init__(self):
self.chars = []
def append(self, characters):
if isinstance(characters, str):
# this came from a Py2 string literal in the parser code
characters = characters.decode("ASCII")
assert isinstance(characters, unicode), str(type(characters))
self.chars.append(characters)
def append_charval(self, char_number):
self.chars.append( unichr(char_number) )
def getstring(self):
return EncodedString(u''.join(self.chars))
class BytesLiteralBuilder(object):
"""Assemble a byte string or char value.
"""
def __init__(self, target_encoding):
self.chars = []
self.target_encoding = target_encoding
def append(self, characters):
if isinstance(characters, unicode):
characters = characters.encode(self.target_encoding)
assert isinstance(characters, str), str(type(characters))
self.chars.append(characters)
def append_charval(self, char_number):
self.chars.append( chr(char_number) )
def getstring(self):
# this *must* return a byte string! => fix it in Py3k!!
s = BytesLiteral(''.join(self.chars))
s.encoding = self.target_encoding
return s
def getchar(self):
# this *must* return a byte string! => fix it in Py3k!!
return self.getstring()
class EncodedString(unicode):
# unicode string subclass to keep track of the original encoding.
# 'encoding' is None for unicode strings and the source encoding
# otherwise
encoding = None
def byteencode(self):
assert self.encoding is not None
return self.encode(self.encoding)
def utf8encode(self):
assert self.encoding is None
return self.encode("UTF-8")
def is_unicode(self):
return self.encoding is None
is_unicode = property(is_unicode)
class BytesLiteral(str):
# str subclass that is compatible with EncodedString
encoding = None
def byteencode(self):
return str(self)
def utf8encode(self):
assert False, "this is not a unicode string: %r" % self
is_unicode = False
char_from_escape_sequence = {
r'\a' : u'\a',
r'\b' : u'\b',
r'\f' : u'\f',
r'\n' : u'\n',
r'\r' : u'\r',
r'\t' : u'\t',
r'\v' : u'\v',
}.get
def _to_escape_sequence(s):
if s in '\n\r\t':
return repr(s)[1:-1]
elif s == '"':
return r'\"'
else:
# within a character sequence, oct passes much better than hex
return ''.join(['\\%03o' % ord(c) for c in s])
_c_special = ('\0', '\n', '\r', '\t', '??', '"')
_c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special))
def _build_specials_test():
subexps = []
for special in _c_special:
regexp = ''.join(['[%s]' % c for c in special])
subexps.append(regexp)
return re.compile('|'.join(subexps)).search
_has_specials = _build_specials_test()
def escape_character(c):
if c in '\n\r\t\\':
return repr(c)[1:-1]
elif c == "'":
return "\\'"
n = ord(c)
if n < 32 or n > 127:
# hex works well for characters
return "\\x%02X" % n
else:
return c
def escape_byte_string(s):
s = s.replace('\\', '\\\\')
if _has_specials(s):
for special, replacement in _c_special_replacements:
s = s.replace(special, replacement)
try:
s.decode("ASCII")
return s
except UnicodeDecodeError:
pass
l = []
append = l.append
for c in s:
o = ord(c)
if o >= 128:
append('\\%3o' % o)
else:
append(c)
return ''.join(l)
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
import re import re
from Cython import Utils from Cython import Utils
from Errors import warning, error, InternalError from Errors import warning, error, InternalError
from StringEncoding import EncodedString
import Options import Options
import Naming import Naming
import PyrexTypes import PyrexTypes
...@@ -684,14 +685,14 @@ class BuiltinScope(Scope): ...@@ -684,14 +685,14 @@ class BuiltinScope(Scope):
utility_code = None): utility_code = None):
# If python_equiv == "*", the Python equivalent has the same name # If python_equiv == "*", the Python equivalent has the same name
# as the entry, otherwise it has the name specified by python_equiv. # as the entry, otherwise it has the name specified by python_equiv.
name = Utils.EncodedString(name) name = EncodedString(name)
entry = self.declare_cfunction(name, type, None, cname) entry = self.declare_cfunction(name, type, None, cname)
entry.utility_code = utility_code entry.utility_code = utility_code
if python_equiv: if python_equiv:
if python_equiv == "*": if python_equiv == "*":
python_equiv = name python_equiv = name
else: else:
python_equiv = Utils.EncodedString(python_equiv) python_equiv = EncodedString(python_equiv)
var_entry = Entry(python_equiv, python_equiv, py_object_type) var_entry = Entry(python_equiv, python_equiv, py_object_type)
var_entry.is_variable = 1 var_entry.is_variable = 1
var_entry.is_builtin = 1 var_entry.is_builtin = 1
...@@ -699,7 +700,7 @@ class BuiltinScope(Scope): ...@@ -699,7 +700,7 @@ class BuiltinScope(Scope):
return entry return entry
def declare_builtin_type(self, name, cname): def declare_builtin_type(self, name, cname):
name = Utils.EncodedString(name) name = EncodedString(name)
type = PyrexTypes.BuiltinObjectType(name, cname) type = PyrexTypes.BuiltinObjectType(name, cname)
type.set_scope(CClassScope(name, outer_scope=None, visibility='extern')) type.set_scope(CClassScope(name, outer_scope=None, visibility='extern'))
self.type_names[name] = 1 self.type_names[name] = 1
...@@ -1370,7 +1371,7 @@ class CClassScope(ClassScope): ...@@ -1370,7 +1371,7 @@ class CClassScope(ClassScope):
if name == "__new__": if name == "__new__":
warning(pos, "__new__ method of extension type will change semantics " warning(pos, "__new__ method of extension type will change semantics "
"in a future version of Pyrex and Cython. Use __cinit__ instead.") "in a future version of Pyrex and Cython. Use __cinit__ instead.")
name = Utils.EncodedString("__cinit__") name = EncodedString("__cinit__")
entry = self.declare_var(name, py_object_type, pos, visibility='extern') entry = self.declare_var(name, py_object_type, pos, visibility='extern')
special_sig = get_special_method_signature(name) special_sig = get_special_method_signature(name)
if special_sig: if special_sig:
...@@ -1387,7 +1388,7 @@ class CClassScope(ClassScope): ...@@ -1387,7 +1388,7 @@ class CClassScope(ClassScope):
def lookup_here(self, name): def lookup_here(self, name):
if name == "__new__": if name == "__new__":
name = Utils.EncodedString("__cinit__") name = EncodedString("__cinit__")
return ClassScope.lookup_here(self, name) return ClassScope.lookup_here(self, name)
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
......
...@@ -3,9 +3,9 @@ ...@@ -3,9 +3,9 @@
# and associated know-how. # and associated know-how.
# #
from Cython import Utils
import Naming import Naming
import PyrexTypes import PyrexTypes
import StringEncoding
import sys import sys
class Signature: class Signature:
...@@ -311,7 +311,7 @@ class DocStringSlot(SlotDescriptor): ...@@ -311,7 +311,7 @@ class DocStringSlot(SlotDescriptor):
doc = scope.doc.utf8encode() doc = scope.doc.utf8encode()
else: else:
doc = scope.doc.byteencode() doc = scope.doc.byteencode()
return '"%s"' % Utils.escape_byte_string(doc) return '"%s"' % StringEncoding.escape_byte_string(doc)
else: else:
return "0" return "0"
......
...@@ -5,7 +5,7 @@ import inspect ...@@ -5,7 +5,7 @@ import inspect
import Nodes import Nodes
import ExprNodes import ExprNodes
import Naming import Naming
from Cython.Utils import EncodedString from StringEncoding import EncodedString
class BasicVisitor(object): class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object.""" """A generic visitor base class which can be used for visiting any kind of object."""
......
...@@ -40,7 +40,7 @@ def file_newer_than(path, time): ...@@ -40,7 +40,7 @@ def file_newer_than(path, time):
ftime = modification_time(path) ftime = modification_time(path)
return ftime > time return ftime > time
# support for source file encoding detection and unicode decoding # support for source file encoding detection
def encode_filename(filename): def encode_filename(filename):
if isinstance(filename, unicode): if isinstance(filename, unicode):
...@@ -77,90 +77,6 @@ def open_source_file(source_filename, mode="rU"): ...@@ -77,90 +77,6 @@ def open_source_file(source_filename, mode="rU"):
encoding = detect_file_encoding(source_filename) encoding = detect_file_encoding(source_filename)
return codecs.open(source_filename, mode=mode, encoding=encoding) return codecs.open(source_filename, mode=mode, encoding=encoding)
class EncodedString(unicode):
# unicode string subclass to keep track of the original encoding.
# 'encoding' is None for unicode strings and the source encoding
# otherwise
encoding = None
def byteencode(self):
assert self.encoding is not None
return self.encode(self.encoding)
def utf8encode(self):
assert self.encoding is None
return self.encode("UTF-8")
def is_unicode(self):
return self.encoding is None
is_unicode = property(is_unicode)
# def __eq__(self, other):
# return unicode.__eq__(self, other) and \
# getattr(other, 'encoding', '') == self.encoding
char_from_escape_sequence = {
r'\a' : '\a',
r'\b' : '\b',
r'\f' : '\f',
r'\n' : '\n',
r'\r' : '\r',
r'\t' : '\t',
r'\v' : '\v',
}.get
def _to_escape_sequence(s):
if s in '\n\r\t':
return repr(s)[1:-1]
elif s == '"':
return r'\"'
else:
# within a character sequence, oct passes much better than hex
return ''.join(['\\%03o' % ord(c) for c in s])
_c_special = ('\0', '\n', '\r', '\t', '??', '"')
_c_special_replacements = zip(_c_special, map(_to_escape_sequence, _c_special))
def _build_specials_test():
subexps = []
for special in _c_special:
regexp = ''.join(['[%s]' % c for c in special])
subexps.append(regexp)
return re.compile('|'.join(subexps)).search
_has_specials = _build_specials_test()
def escape_character(c):
if c in '\n\r\t\\':
return repr(c)[1:-1]
elif c == "'":
return "\\'"
elif ord(c) < 32:
# hex works well for characters
return "\\x%02X" % ord(c)
else:
return c
def escape_byte_string(s):
s = s.replace('\\', '\\\\')
if _has_specials(s):
for special, replacement in _c_special_replacements:
s = s.replace(special, replacement)
try:
s.decode("ASCII")
return s
except UnicodeDecodeError:
pass
l = []
append = l.append
for c in s:
o = ord(c)
if o >= 128:
append('\\%3o' % o)
else:
append(c)
return ''.join(l)
def long_literal(value): def long_literal(value):
if isinstance(value, basestring): if isinstance(value, basestring):
if len(value) < 2: if len(value) < 2:
......
# coding: ASCII
__doc__ = u"""
>>> s = test()
>>> assert s == ''.join([chr(i) for i in range(0x10,0xFF,0x11)] + [chr(0xFF)]), repr(s)
"""
def test():
cdef char s[17]
s[ 0] = c'\x10'
s[ 1] = c'\x21'
s[ 2] = c'\x32'
s[ 3] = c'\x43'
s[ 4] = c'\x54'
s[ 5] = c'\x65'
s[ 6] = c'\x76'
s[ 7] = c'\x87'
s[ 8] = c'\x98'
s[ 9] = c'\xA9'
s[10] = c'\xBA'
s[11] = c'\xCB'
s[12] = c'\xDC'
s[13] = c'\xED'
s[14] = c'\xFE'
s[15] = c'\xFF'
s[16] = c'\x00'
return s
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment