Commit 061b132e authored by Stefan Behnel's avatar Stefan Behnel

fix bytes literal creation from compile-time DEF expressions (used to become...

fix bytes literal creation from compile-time DEF expressions (used to become Unicode strings due to missing encoding)
parent 067c6cca
...@@ -15,6 +15,9 @@ Bugs fixed ...@@ -15,6 +15,9 @@ Bugs fixed
* Incorrect handling of large integers in compile-time evaluated DEF * Incorrect handling of large integers in compile-time evaluated DEF
expressions under Python 2.x. expressions under Python 2.x.
* Byte string constants could end up as Unicode strings when originating
from compile-time evaluated DEF expressions.
* Invalid C code when caching known builtin methods. * Invalid C code when caching known builtin methods.
This fixes ticket 860. This fixes ticket 860.
......
...@@ -1268,10 +1268,8 @@ class BytesNode(ConstNode): ...@@ -1268,10 +1268,8 @@ class BytesNode(ConstNode):
self.constant_result = self.value self.constant_result = self.value
def as_sliced_node(self, start, stop, step=None): def as_sliced_node(self, start, stop, step=None):
value = StringEncoding.BytesLiteral(self.value[start:stop:step]) value = StringEncoding.bytes_literal(self.value[start:stop:step], self.value.encoding)
value.encoding = self.value.encoding return BytesNode(self.pos, value=value, constant_result=value)
return BytesNode(
self.pos, value=value, constant_result=value)
def compile_time_value(self, denv): def compile_time_value(self, denv):
return self.value return self.value
...@@ -1362,9 +1360,8 @@ class UnicodeNode(ConstNode): ...@@ -1362,9 +1360,8 @@ class UnicodeNode(ConstNode):
value = StringEncoding.EncodedString(self.value[start:stop:step]) value = StringEncoding.EncodedString(self.value[start:stop:step])
value.encoding = self.value.encoding value.encoding = self.value.encoding
if self.bytes_value is not None: if self.bytes_value is not None:
bytes_value = StringEncoding.BytesLiteral( bytes_value = StringEncoding.bytes_literal(
self.bytes_value[start:stop:step]) self.bytes_value[start:stop:step], self.bytes_value.encoding)
bytes_value.encoding = self.bytes_value.encoding
else: else:
bytes_value = None bytes_value = None
return UnicodeNode( return UnicodeNode(
...@@ -8503,7 +8500,7 @@ class CodeObjectNode(ExprNode): ...@@ -8503,7 +8500,7 @@ class CodeObjectNode(ExprNode):
func_name = code.get_py_string_const( func_name = code.get_py_string_const(
func.name, identifier=True, is_str=False, unicode_value=func.name) func.name, identifier=True, is_str=False, unicode_value=func.name)
# FIXME: better way to get the module file path at module init time? Encoding to use? # FIXME: better way to get the module file path at module init time? Encoding to use?
file_path = StringEncoding.BytesLiteral(func.pos[0].get_filenametable_entry().encode('utf8')) file_path = StringEncoding.bytes_literal(func.pos[0].get_filenametable_entry().encode('utf8'), 'utf8')
file_path_const = code.get_py_string_const(file_path, identifier=False, is_str=True) file_path_const = code.get_py_string_const(file_path, identifier=False, is_str=True)
flags = [] flags = []
......
...@@ -7,7 +7,7 @@ import codecs ...@@ -7,7 +7,7 @@ import codecs
from . import TypeSlots from . import TypeSlots
from .ExprNodes import not_a_constant from .ExprNodes import not_a_constant
import cython import cython
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object, cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object,
Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object, Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
UtilNodes=object, _py_int_types=object) UtilNodes=object, _py_int_types=object)
...@@ -25,7 +25,7 @@ from . import UtilNodes ...@@ -25,7 +25,7 @@ from . import UtilNodes
from . import Options from . import Options
from .Code import UtilityCode, TempitaUtilityCode from .Code import UtilityCode, TempitaUtilityCode
from .StringEncoding import EncodedString, BytesLiteral from .StringEncoding import EncodedString, bytes_literal
from .Errors import error from .Errors import error
from .ParseTreeTransforms import SkipDeclarations from .ParseTreeTransforms import SkipDeclarations
...@@ -342,7 +342,7 @@ class IterationTransform(Visitor.EnvTransform): ...@@ -342,7 +342,7 @@ class IterationTransform(Visitor.EnvTransform):
if slice_node.is_literal: if slice_node.is_literal:
# try to reduce to byte iteration for plain Latin-1 strings # try to reduce to byte iteration for plain Latin-1 strings
try: try:
bytes_value = BytesLiteral(slice_node.value.encode('latin1')) bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
except UnicodeEncodeError: except UnicodeEncodeError:
pass pass
else: else:
...@@ -3298,10 +3298,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3298,10 +3298,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
# well, looks like we can't # well, looks like we can't
pass pass
else: else:
value = BytesLiteral(value) value = bytes_literal(value, encoding)
value.encoding = encoding return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
return ExprNodes.BytesNode(
string_node.pos, value=value, type=Builtin.bytes_type)
if encoding and error_handling == 'strict': if encoding and error_handling == 'strict':
# try to find a specific encoder function # try to find a specific encoder function
...@@ -3508,8 +3506,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -3508,8 +3506,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
if isinstance(node, ExprNodes.UnicodeNode): if isinstance(node, ExprNodes.UnicodeNode):
encoding = node.value encoding = node.value
node = ExprNodes.BytesNode( node = ExprNodes.BytesNode(
node.pos, value=BytesLiteral(encoding.utf8encode()), node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_char_ptr_type)
type=PyrexTypes.c_char_ptr_type)
elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)): elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
encoding = node.value.decode('ISO-8859-1') encoding = node.value.decode('ISO-8859-1')
node = ExprNodes.BytesNode( node = ExprNodes.BytesNode(
...@@ -3833,15 +3830,15 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3833,15 +3830,15 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
bytes_value = None bytes_value = None
if str1.bytes_value is not None and str2.bytes_value is not None: if str1.bytes_value is not None and str2.bytes_value is not None:
if str1.bytes_value.encoding == str2.bytes_value.encoding: if str1.bytes_value.encoding == str2.bytes_value.encoding:
bytes_value = BytesLiteral(str1.bytes_value + str2.bytes_value) bytes_value = bytes_literal(
bytes_value.encoding = str1.bytes_value.encoding str1.bytes_value + str2.bytes_value,
str1.bytes_value.encoding)
string_value = EncodedString(node.constant_result) string_value = EncodedString(node.constant_result)
return ExprNodes.UnicodeNode( return ExprNodes.UnicodeNode(
str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value) str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value)
elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode): elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
if str1.value.encoding == str2.value.encoding: if str1.value.encoding == str2.value.encoding:
bytes_value = BytesLiteral(node.constant_result) bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
bytes_value.encoding = str1.value.encoding
return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result) return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
# all other combinations are rather complicated # all other combinations are rather complicated
# to get right in Py2/3: encodings, unicode escapes, ... # to get right in Py2/3: encodings, unicode escapes, ...
......
...@@ -8,7 +8,7 @@ from __future__ import absolute_import ...@@ -8,7 +8,7 @@ from __future__ import absolute_import
# This should be done automatically # This should be done automatically
import cython import cython
cython.declare(Nodes=object, ExprNodes=object, EncodedString=object, cython.declare(Nodes=object, ExprNodes=object, EncodedString=object,
BytesLiteral=object, StringEncoding=object, bytes_literal=object, StringEncoding=object,
FileSourceDescriptor=object, lookup_unicodechar=object, FileSourceDescriptor=object, lookup_unicodechar=object,
Future=object, Options=object, error=object, warning=object, Future=object, Options=object, error=object, warning=object,
Builtin=object, ModuleNode=object, Utils=object, Builtin=object, ModuleNode=object, Utils=object,
...@@ -25,7 +25,7 @@ from . import Nodes ...@@ -25,7 +25,7 @@ from . import Nodes
from . import ExprNodes from . import ExprNodes
from . import Builtin from . import Builtin
from . import StringEncoding from . import StringEncoding
from .StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes from .StringEncoding import EncodedString, bytes_literal, _unicode, _bytes
from .ModuleNode import ModuleNode from .ModuleNode import ModuleNode
from .Errors import error, warning from .Errors import error, warning
from .. import Utils from .. import Utils
...@@ -768,7 +768,8 @@ def wrap_compile_time_constant(pos, value): ...@@ -768,7 +768,8 @@ def wrap_compile_time_constant(pos, value):
elif isinstance(value, _unicode): elif isinstance(value, _unicode):
return ExprNodes.UnicodeNode(pos, value=EncodedString(value)) return ExprNodes.UnicodeNode(pos, value=EncodedString(value))
elif isinstance(value, _bytes): elif isinstance(value, _bytes):
return ExprNodes.BytesNode(pos, value=BytesLiteral(value)) bvalue = bytes_literal(value, 'ascii') # actually: unknown encoding, but BytesLiteral requires one
return ExprNodes.BytesNode(pos, value=bvalue, constant_result=value)
elif isinstance(value, tuple): elif isinstance(value, tuple):
args = [wrap_compile_time_constant(pos, arg) args = [wrap_compile_time_constant(pos, arg)
for arg in value] for arg in value]
...@@ -809,8 +810,7 @@ def p_cat_string_literal(s): ...@@ -809,8 +810,7 @@ def p_cat_string_literal(s):
# join and rewrap the partial literals # join and rewrap the partial literals
if kind in ('b', 'c', '') or kind == 'u' and None not in bstrings: if kind in ('b', 'c', '') or kind == 'u' and None not in bstrings:
# Py3 enforced unicode literals are parsed as bytes/unicode combination # Py3 enforced unicode literals are parsed as bytes/unicode combination
bytes_value = BytesLiteral( StringEncoding.join_bytes(bstrings) ) bytes_value = bytes_literal(StringEncoding.join_bytes(bstrings), s.source_encoding)
bytes_value.encoding = s.source_encoding
if kind in ('u', ''): if kind in ('u', ''):
unicode_value = EncodedString( u''.join([ u for u in ustrings if u is not None ]) ) unicode_value = EncodedString( u''.join([ u for u in ustrings if u is not None ]) )
return kind, bytes_value, unicode_value return kind, bytes_value, unicode_value
......
...@@ -78,9 +78,7 @@ class BytesLiteralBuilder(object): ...@@ -78,9 +78,7 @@ class BytesLiteralBuilder(object):
def getstring(self): def getstring(self):
# this *must* return a byte string! # this *must* return a byte string!
s = BytesLiteral(join_bytes(self.chars)) return bytes_literal(join_bytes(self.chars), self.target_encoding)
s.encoding = self.target_encoding
return s
def getchar(self): def getchar(self):
# this *must* return a byte string! # this *must* return a byte string!
...@@ -136,6 +134,9 @@ class EncodedString(_unicode): ...@@ -136,6 +134,9 @@ class EncodedString(_unicode):
def contains_surrogates(self): def contains_surrogates(self):
return string_contains_surrogates(self) return string_contains_surrogates(self)
def as_utf8_string(self):
return bytes_literal(self.utf8encode(), 'utf8')
def string_contains_surrogates(ustring): def string_contains_surrogates(ustring):
""" """
...@@ -178,6 +179,13 @@ class BytesLiteral(_bytes): ...@@ -178,6 +179,13 @@ class BytesLiteral(_bytes):
is_unicode = False is_unicode = False
def bytes_literal(s, encoding):
assert isinstance(s, bytes)
s = BytesLiteral(s)
s.encoding = encoding
return s
char_from_escape_sequence = { char_from_escape_sequence = {
r'\a' : u'\a', r'\a' : u'\a',
r'\b' : u'\b', r'\b' : u'\b',
......
...@@ -6,8 +6,12 @@ __doc__ = u""" ...@@ -6,8 +6,12 @@ __doc__ = u"""
b'spam' b'spam'
""" """
_unicode = unicode
import sys import sys
if sys.version_info[0] < 3: IS_PY3 = sys.version_info[0] >= 3
if not IS_PY3:
__doc__ = __doc__.replace(u" b'", u" '") __doc__ = __doc__.replace(u" b'", u" '")
...@@ -28,7 +32,8 @@ DEF LONG = 666L ...@@ -28,7 +32,8 @@ DEF LONG = 666L
DEF LARGE_NUM32 = (1 << 32) - 1 DEF LARGE_NUM32 = (1 << 32) - 1
DEF LARGE_NUM64 = (1 << 64) - 1 DEF LARGE_NUM64 = (1 << 64) - 1
DEF FLOAT = 12.5 DEF FLOAT = 12.5
DEF STR = b"spam" DEF BYTES = b"spam"
DEF UNICODE = u"spam-u"
DEF TWO = TUPLE[1] DEF TWO = TUPLE[1]
DEF FIVE = TWO + 3 DEF FIVE = TWO + 3
DEF TRUE = TRUE_FALSE[0] DEF TRUE = TRUE_FALSE[0]
...@@ -111,9 +116,29 @@ def s(): ...@@ -111,9 +116,29 @@ def s():
""" """
see module docstring above see module docstring above
""" """
cdef char* s = STR cdef char* s = BYTES
return s return s
def type_of_bytes():
"""
>>> t, s = type_of_bytes()
>>> assert t is bytes, t
>>> assert type(s) is bytes, type(s)
"""
t = type(BYTES)
s = BYTES
return t, s
def type_of_unicode():
"""
>>> t, s = type_of_unicode()
>>> assert t is _unicode, t
>>> assert type(s) is _unicode, type(s)
"""
t = type(UNICODE)
s = UNICODE
return t, s
@cython.test_assert_path_exists('//TupleNode') @cython.test_assert_path_exists('//TupleNode')
def constant_tuple(): def constant_tuple():
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment