Commit 061b132e authored by Stefan Behnel's avatar Stefan Behnel

fix bytes literal creation from compile-time DEF expressions (used to become...

fix bytes literal creation from compile-time DEF expressions (used to become Unicode strings due to missing encoding)
parent 067c6cca
......@@ -15,6 +15,9 @@ Bugs fixed
* Incorrect handling of large integers in compile-time evaluated DEF
expressions under Python 2.x.
* Byte string constants could end up as Unicode strings when originating
from compile-time evaluated DEF expressions.
* Invalid C code when caching known builtin methods.
This fixes ticket 860.
......
......@@ -1268,10 +1268,8 @@ class BytesNode(ConstNode):
self.constant_result = self.value
def as_sliced_node(self, start, stop, step=None):
value = StringEncoding.BytesLiteral(self.value[start:stop:step])
value.encoding = self.value.encoding
return BytesNode(
self.pos, value=value, constant_result=value)
value = StringEncoding.bytes_literal(self.value[start:stop:step], self.value.encoding)
return BytesNode(self.pos, value=value, constant_result=value)
def compile_time_value(self, denv):
return self.value
......@@ -1362,9 +1360,8 @@ class UnicodeNode(ConstNode):
value = StringEncoding.EncodedString(self.value[start:stop:step])
value.encoding = self.value.encoding
if self.bytes_value is not None:
bytes_value = StringEncoding.BytesLiteral(
self.bytes_value[start:stop:step])
bytes_value.encoding = self.bytes_value.encoding
bytes_value = StringEncoding.bytes_literal(
self.bytes_value[start:stop:step], self.bytes_value.encoding)
else:
bytes_value = None
return UnicodeNode(
......@@ -8503,7 +8500,7 @@ class CodeObjectNode(ExprNode):
func_name = code.get_py_string_const(
func.name, identifier=True, is_str=False, unicode_value=func.name)
# FIXME: better way to get the module file path at module init time? Encoding to use?
file_path = StringEncoding.BytesLiteral(func.pos[0].get_filenametable_entry().encode('utf8'))
file_path = StringEncoding.bytes_literal(func.pos[0].get_filenametable_entry().encode('utf8'), 'utf8')
file_path_const = code.get_py_string_const(file_path, identifier=False, is_str=True)
flags = []
......
......@@ -7,7 +7,7 @@ import codecs
from . import TypeSlots
from .ExprNodes import not_a_constant
import cython
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
cython.declare(UtilityCode=object, EncodedString=object, bytes_literal=object,
Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
UtilNodes=object, _py_int_types=object)
......@@ -25,7 +25,7 @@ from . import UtilNodes
from . import Options
from .Code import UtilityCode, TempitaUtilityCode
from .StringEncoding import EncodedString, BytesLiteral
from .StringEncoding import EncodedString, bytes_literal
from .Errors import error
from .ParseTreeTransforms import SkipDeclarations
......@@ -342,7 +342,7 @@ class IterationTransform(Visitor.EnvTransform):
if slice_node.is_literal:
# try to reduce to byte iteration for plain Latin-1 strings
try:
bytes_value = BytesLiteral(slice_node.value.encode('latin1'))
bytes_value = bytes_literal(slice_node.value.encode('latin1'), 'iso8859-1')
except UnicodeEncodeError:
pass
else:
......@@ -3298,10 +3298,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
# well, looks like we can't
pass
else:
value = BytesLiteral(value)
value.encoding = encoding
return ExprNodes.BytesNode(
string_node.pos, value=value, type=Builtin.bytes_type)
value = bytes_literal(value, encoding)
return ExprNodes.BytesNode(string_node.pos, value=value, type=Builtin.bytes_type)
if encoding and error_handling == 'strict':
# try to find a specific encoder function
......@@ -3508,8 +3506,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
if isinstance(node, ExprNodes.UnicodeNode):
encoding = node.value
node = ExprNodes.BytesNode(
node.pos, value=BytesLiteral(encoding.utf8encode()),
type=PyrexTypes.c_char_ptr_type)
node.pos, value=encoding.as_utf8_string(), type=PyrexTypes.c_char_ptr_type)
elif isinstance(node, (ExprNodes.StringNode, ExprNodes.BytesNode)):
encoding = node.value.decode('ISO-8859-1')
node = ExprNodes.BytesNode(
......@@ -3833,15 +3830,15 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
bytes_value = None
if str1.bytes_value is not None and str2.bytes_value is not None:
if str1.bytes_value.encoding == str2.bytes_value.encoding:
bytes_value = BytesLiteral(str1.bytes_value + str2.bytes_value)
bytes_value.encoding = str1.bytes_value.encoding
bytes_value = bytes_literal(
str1.bytes_value + str2.bytes_value,
str1.bytes_value.encoding)
string_value = EncodedString(node.constant_result)
return ExprNodes.UnicodeNode(
str1.pos, value=string_value, constant_result=node.constant_result, bytes_value=bytes_value)
elif isinstance(str1, ExprNodes.BytesNode) and isinstance(str2, ExprNodes.BytesNode):
if str1.value.encoding == str2.value.encoding:
bytes_value = BytesLiteral(node.constant_result)
bytes_value.encoding = str1.value.encoding
bytes_value = bytes_literal(node.constant_result, str1.value.encoding)
return ExprNodes.BytesNode(str1.pos, value=bytes_value, constant_result=node.constant_result)
# all other combinations are rather complicated
# to get right in Py2/3: encodings, unicode escapes, ...
......
......@@ -8,7 +8,7 @@ from __future__ import absolute_import
# This should be done automatically
import cython
cython.declare(Nodes=object, ExprNodes=object, EncodedString=object,
BytesLiteral=object, StringEncoding=object,
bytes_literal=object, StringEncoding=object,
FileSourceDescriptor=object, lookup_unicodechar=object,
Future=object, Options=object, error=object, warning=object,
Builtin=object, ModuleNode=object, Utils=object,
......@@ -25,7 +25,7 @@ from . import Nodes
from . import ExprNodes
from . import Builtin
from . import StringEncoding
from .StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes
from .StringEncoding import EncodedString, bytes_literal, _unicode, _bytes
from .ModuleNode import ModuleNode
from .Errors import error, warning
from .. import Utils
......@@ -768,7 +768,8 @@ def wrap_compile_time_constant(pos, value):
elif isinstance(value, _unicode):
return ExprNodes.UnicodeNode(pos, value=EncodedString(value))
elif isinstance(value, _bytes):
return ExprNodes.BytesNode(pos, value=BytesLiteral(value))
bvalue = bytes_literal(value, 'ascii') # actually: unknown encoding, but BytesLiteral requires one
return ExprNodes.BytesNode(pos, value=bvalue, constant_result=value)
elif isinstance(value, tuple):
args = [wrap_compile_time_constant(pos, arg)
for arg in value]
......@@ -809,8 +810,7 @@ def p_cat_string_literal(s):
# join and rewrap the partial literals
if kind in ('b', 'c', '') or kind == 'u' and None not in bstrings:
# Py3 enforced unicode literals are parsed as bytes/unicode combination
bytes_value = BytesLiteral( StringEncoding.join_bytes(bstrings) )
bytes_value.encoding = s.source_encoding
bytes_value = bytes_literal(StringEncoding.join_bytes(bstrings), s.source_encoding)
if kind in ('u', ''):
unicode_value = EncodedString( u''.join([ u for u in ustrings if u is not None ]) )
return kind, bytes_value, unicode_value
......
......@@ -78,9 +78,7 @@ class BytesLiteralBuilder(object):
def getstring(self):
# this *must* return a byte string!
s = BytesLiteral(join_bytes(self.chars))
s.encoding = self.target_encoding
return s
return bytes_literal(join_bytes(self.chars), self.target_encoding)
def getchar(self):
# this *must* return a byte string!
......@@ -136,6 +134,9 @@ class EncodedString(_unicode):
def contains_surrogates(self):
return string_contains_surrogates(self)
def as_utf8_string(self):
return bytes_literal(self.utf8encode(), 'utf8')
def string_contains_surrogates(ustring):
"""
......@@ -178,6 +179,13 @@ class BytesLiteral(_bytes):
is_unicode = False
def bytes_literal(s, encoding):
assert isinstance(s, bytes)
s = BytesLiteral(s)
s.encoding = encoding
return s
char_from_escape_sequence = {
r'\a' : u'\a',
r'\b' : u'\b',
......
......@@ -6,8 +6,12 @@ __doc__ = u"""
b'spam'
"""
_unicode = unicode
import sys
if sys.version_info[0] < 3:
IS_PY3 = sys.version_info[0] >= 3
if not IS_PY3:
__doc__ = __doc__.replace(u" b'", u" '")
......@@ -28,7 +32,8 @@ DEF LONG = 666L
DEF LARGE_NUM32 = (1 << 32) - 1
DEF LARGE_NUM64 = (1 << 64) - 1
DEF FLOAT = 12.5
DEF STR = b"spam"
DEF BYTES = b"spam"
DEF UNICODE = u"spam-u"
DEF TWO = TUPLE[1]
DEF FIVE = TWO + 3
DEF TRUE = TRUE_FALSE[0]
......@@ -111,9 +116,29 @@ def s():
"""
see module docstring above
"""
cdef char* s = STR
cdef char* s = BYTES
return s
def type_of_bytes():
"""
>>> t, s = type_of_bytes()
>>> assert t is bytes, t
>>> assert type(s) is bytes, type(s)
"""
t = type(BYTES)
s = BYTES
return t, s
def type_of_unicode():
"""
>>> t, s = type_of_unicode()
>>> assert t is _unicode, t
>>> assert type(s) is _unicode, type(s)
"""
t = type(UNICODE)
s = UNICODE
return t, s
@cython.test_assert_path_exists('//TupleNode')
def constant_tuple():
"""
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment