Commit 04f8cc68 authored by Stefan Behnel's avatar Stefan Behnel

split BytesNode, UnicodeNode and StringNode

parent 0ebf976f
......@@ -132,6 +132,7 @@ class CodeWriter(TreeVisitor):
def visit_IntNode(self, node):
self.put(node.value)
# FIXME: represent string nodes correctly
def visit_StringNode(self, node):
value = node.value
if value.encoding is not None:
......
......@@ -767,7 +767,7 @@ class FloatNode(ConstNode):
return strval
class StringNode(ConstNode):
class BytesNode(ConstNode):
type = PyrexTypes.c_char_ptr_type
def compile_time_value(self, denv):
......@@ -794,13 +794,11 @@ class StringNode(ConstNode):
return CastNode(self, PyrexTypes.c_uchar_ptr_type)
if dst_type.is_int:
if not self.type.is_pyobject and len(self.value) == 1:
return CharNode(self.pos, value=self.value)
else:
error(self.pos, "Only single-character byte strings can be coerced into ints.")
if len(self.value) > 1:
error(self.pos, "Only single-character strings can be coerced into ints.")
return self
# Arrange for a Python version of the string to be pre-allocated
# when coercing to a Python type.
return CharNode(self.pos, value=self.value)
if dst_type.is_pyobject and not self.type.is_pyobject:
node = self.as_py_string_node(env)
else:
......@@ -811,13 +809,9 @@ class StringNode(ConstNode):
return ConstNode.coerce_to(node, dst_type, env)
def as_py_string_node(self, env):
# Return a new StringNode with the same value as this node
# Return a new BytesNode with the same value as this node
# but whose type is a Python type instead of a C type.
if self.value.encoding is None:
py_type = Builtin.unicode_type
else:
py_type = Builtin.bytes_type
return StringNode(self.pos, value = self.value, type = py_type)
return BytesNode(self.pos, value = self.value, type = Builtin.bytes_type)
def generate_evaluation_code(self, code):
if self.type.is_pyobject:
......@@ -831,8 +825,11 @@ class StringNode(ConstNode):
def calculate_result_code(self):
return self.result_code
class UnicodeNode(PyConstNode):
# entry Symtab.Entry
# A Python unicode object
#
# value EncodedString
type = unicode_type
......@@ -844,10 +841,7 @@ class UnicodeNode(PyConstNode):
return self
def generate_evaluation_code(self, code):
if self.type.is_pyobject:
self.result_code = code.get_py_string_const(self.value)
else:
self.result_code = code.get_string_const(self.value)
self.result_code = code.get_py_string_const(self.value)
def calculate_result_code(self):
return self.result_code
......@@ -856,16 +850,30 @@ class UnicodeNode(PyConstNode):
return self.value
class IdentifierStringNode(ConstNode):
# A Python string that behaves like an identifier, e.g. for
# keyword arguments in a call, or for imported names
class StringNode(PyConstNode):
# A Python str object, i.e. a byte string in Python 2.x and a
# unicode string in Python 3.x
#
# Can be coerced to a BytesNode (and thus to C types), but not to
# a UnicodeNode.
#
# value BytesLiteral
type = PyrexTypes.py_object_type
def generate_evaluation_code(self, code):
if self.type.is_pyobject:
self.result_code = code.get_py_string_const(self.value, True)
def coerce_to(self, dst_type, env):
if dst_type is Builtin.unicode_type:
error(self.pos, "str objects do not support coercion to unicode, use a unicode string literal instead (u'')")
return self
if dst_type is Builtin.bytes_type:
return BytesNode(self.pos, value=self.value)
elif dst_type.is_pyobject:
return self
else:
self.result_code = code.get_string_const(self.value)
return BytesNode(self.pos, value=self.value).coerce_to(dst_type, env)
def generate_evaluation_code(self, code):
self.result_code = code.get_py_string_const(self.value, True)
def get_constant_c_result_code(self):
return None
......@@ -1370,8 +1378,8 @@ class ImportNode(ExprNode):
# Implements result =
# __import__(module_name, globals(), None, name_list)
#
# module_name IdentifierStringNode dotted name of module
# name_list ListNode or None list of names to be imported
# module_name StringNode dotted name of module
# name_list ListNode or None list of names to be imported
type = py_object_type
......@@ -1650,7 +1658,7 @@ class IndexNode(ExprNode):
return self.base.type_dependencies(env)
def infer_type(self, env):
if isinstance(self.base, StringNode):
if isinstance(self.base, (StringNode, UnicodeNode)): # FIXME: BytesNode?
return py_object_type
base_type = self.base.infer_type(env)
if base_type.is_ptr or base_type.is_array:
......@@ -1677,7 +1685,7 @@ class IndexNode(ExprNode):
self.base.analyse_types(env)
# Handle the case where base is a literal char* (and we expect a string, not an int)
if isinstance(self.base, StringNode):
if isinstance(self.base, BytesNode):
self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False
......@@ -2223,7 +2231,7 @@ class CallNode(ExprNode):
args, kwds = self.explicit_args_kwds()
items = []
for arg, member in zip(args, type.scope.var_entries):
items.append(DictItemNode(pos=arg.pos, key=IdentifierStringNode(pos=arg.pos, value=member.name), value=arg))
items.append(DictItemNode(pos=arg.pos, key=StringNode(pos=arg.pos, value=member.name), value=arg))
if kwds:
items += kwds.key_value_pairs
self.key_value_pairs = items
......@@ -3663,9 +3671,9 @@ class DictNode(ExprNode):
for item in self.key_value_pairs:
if isinstance(item.key, CoerceToPyTypeNode):
item.key = item.key.arg
if not isinstance(item.key, (StringNode, IdentifierStringNode)):
if not isinstance(item.key, (UnicodeNode, StringNode, BytesNode)):
error(item.key.pos, "Invalid struct field identifier")
item.key = IdentifierStringNode(item.key.pos, value="<error>")
item.key = StringNode(item.key.pos, value="<error>")
else:
key = str(item.key.value) # converts string literals to unicode in Py3
member = dst_type.scope.lookup_here(key)
......@@ -4262,8 +4270,8 @@ class TypeofNode(ExprNode):
def analyse_types(self, env):
self.operand.analyse_types(env)
from StringEncoding import EncodedString
self.literal = StringNode(self.pos, value=EncodedString(str(self.operand.type)))
self.literal = StringNode(
self.pos, value=StringEncoding.EncodedString(str(self.operand.type)))
self.literal.analyse_types(env)
self.literal = self.literal.coerce_to_pyobject(env)
......@@ -5190,9 +5198,9 @@ class PrimaryCmpNode(ExprNode, CmpNode):
def coerce_chars_to_ints(self, env):
# coerce literal single-char strings to c chars
if self.operand1.type.is_string and isinstance(self.operand1, StringNode):
if self.operand1.type.is_string and isinstance(self.operand1, BytesNode):
self.operand1 = self.operand1.coerce_to(PyrexTypes.c_uchar_type, env)
if self.operand2.type.is_string and isinstance(self.operand2, StringNode):
if self.operand2.type.is_string and isinstance(self.operand2, BytesNode):
self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env)
if self.cascade:
self.cascade.coerce_chars_to_ints(env)
......@@ -5299,7 +5307,7 @@ class CascadedCmpNode(Node, CmpNode):
return self.operand2.type.is_int
def coerce_chars_to_ints(self, env):
if self.operand2.type.is_string and isinstance(self.operand2, StringNode):
if self.operand2.type.is_string and isinstance(self.operand2, BytesNode):
self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env)
def coerce_cascaded_operands_to_temp(self, env):
......
......@@ -2528,6 +2528,7 @@ class PyClassDefNode(ClassDefNode):
self.dict = ExprNodes.DictNode(pos, key_value_pairs = [])
if self.doc and Options.docstrings:
doc = embed_position(self.pos, self.doc)
# FIXME: correct string node?
doc_node = ExprNodes.StringNode(pos, value = doc)
else:
doc_node = None
......
......@@ -224,7 +224,7 @@ class IterationTransform(Visitor.VisitorTransform):
bound2 = args[1].coerce_to_integer(self.current_scope)
step = step.coerce_to_integer(self.current_scope)
if not isinstance(bound2, ExprNodes.ConstNode):
if not bound2.is_literal:
# stop bound must be immutable => keep it in a temp var
bound2_is_temp = True
bound2 = UtilNodes.LetRefNode(bound2)
......@@ -416,12 +416,12 @@ class SwitchTransform(Visitor.VisitorTransform):
and cond.operator == '=='
and not cond.is_python_comparison()):
if is_common_value(cond.operand1, cond.operand1):
if isinstance(cond.operand2, ExprNodes.ConstNode):
if cond.operand2.is_literal:
return cond.operand1, [cond.operand2]
elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const:
return cond.operand1, [cond.operand2]
if is_common_value(cond.operand2, cond.operand2):
if isinstance(cond.operand1, ExprNodes.ConstNode):
if cond.operand1.is_literal:
return cond.operand2, [cond.operand1]
elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const:
return cond.operand2, [cond.operand1]
......@@ -853,10 +853,11 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
encoding_node = args[1]
if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
encoding_node = encoding_node.arg
if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode)):
if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return node
encoding = encoding_node.value
encoding_node = ExprNodes.StringNode(encoding_node.pos, value=encoding,
encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
type=PyrexTypes.c_char_ptr_type)
if len(args) == 3:
......@@ -864,13 +865,14 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
error_handling_node = error_handling_node.arg
if not isinstance(error_handling_node,
(ExprNodes.UnicodeNode, ExprNodes.StringNode)):
(ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return node
error_handling = error_handling_node.value
if error_handling == 'strict':
error_handling_node = null_node
else:
error_handling_node = ExprNodes.StringNode(
error_handling_node = ExprNodes.BytesNode(
error_handling_node.pos, value=error_handling,
type=PyrexTypes.c_char_ptr_type)
else:
......@@ -887,7 +889,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
else:
value = BytesLiteral(value)
value.encoding = encoding
return ExprNodes.StringNode(
return ExprNodes.BytesNode(
string_node.pos, value=value, type=Builtin.bytes_type)
if error_handling == 'strict':
......@@ -1030,8 +1032,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
# the compiler, but we do not aggregate them into a
# constant node to prevent any loss of precision.
return node
if not isinstance(node.operand1, ExprNodes.ConstNode) or \
not isinstance(node.operand2, ExprNodes.ConstNode):
if not node.operand1.is_literal or not node.operand2.is_literal:
# We calculate other constants to make them available to
# the compiler, but we only aggregate constant nodes
# recursively, so non-const nodes are straight out.
......
......@@ -444,22 +444,22 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
args, kwds = node.explicit_args_kwds()
if optiontype is bool:
if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
raise PostParseError(dec.function.pos,
raise PostParseError(node.function.pos,
'The %s option takes one compile-time boolean argument' % optname)
return (optname, args[0].value)
elif optiontype is str:
if kwds is not None or len(args) != 1 or not isinstance(args[0], StringNode):
raise PostParseError(dec.function.pos,
if kwds is not None or len(args) != 1 or not isinstance(args[0], (StringNode, UnicodeNode)):
raise PostParseError(node.function.pos,
'The %s option takes one compile-time string argument' % optname)
return (optname, str(args[0].value))
elif optiontype is dict:
if len(args) != 0:
raise PostParseError(dec.function.pos,
raise PostParseError(node.function.pos,
'The %s option takes no prepositional arguments' % optname)
return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
elif optiontype is list:
if kwds and len(kwds) != 0:
raise PostParseError(dec.function.pos,
raise PostParseError(node.function.pos,
'The %s option takes no keyword arguments' % optname)
return optname, [ str(arg.value) for arg in args ]
else:
......@@ -984,7 +984,7 @@ class TransformBuiltinMethods(EnvTransform):
pos = node.pos
lenv = self.env_stack[-1]
items = [ExprNodes.DictItemNode(pos,
key=ExprNodes.IdentifierStringNode(pos, value=var),
key=ExprNodes.StringNode(pos, value=var),
value=ExprNodes.NameNode(pos, name=var)) for var in lenv.entries]
return ExprNodes.DictNode(pos, key_value_pairs=items)
......
......@@ -14,7 +14,7 @@ from Cython.Compiler.Scanning import PyrexScanner, FileSourceDescriptor
import Nodes
import ExprNodes
import StringEncoding
from StringEncoding import EncodedString, BytesLiteral, _str, _bytes
from StringEncoding import EncodedString, BytesLiteral, _unicode, _bytes
from ModuleNode import ModuleNode
from Errors import error, warning, InternalError
from Cython import Utils
......@@ -348,8 +348,7 @@ def p_call(s, function):
s.error("Expected an identifier before '='",
pos = arg.pos)
encoded_name = EncodedString(arg.name)
keyword = ExprNodes.IdentifierStringNode(arg.pos,
value = encoded_name)
keyword = ExprNodes.StringNode(arg.pos, value = encoded_name)
arg = p_simple_expr(s)
keyword_args.append((keyword, arg))
else:
......@@ -540,6 +539,8 @@ def p_atom(s):
return ExprNodes.CharNode(pos, value = value)
elif kind == 'u':
return ExprNodes.UnicodeNode(pos, value = value)
elif kind == 'b':
return ExprNodes.BytesNode(pos, value = value)
else:
return ExprNodes.StringNode(pos, value = value)
elif sy == 'IDENT':
......@@ -571,8 +572,10 @@ def p_name(s, name):
return ExprNodes.IntNode(pos, value = rep, longness = "L")
elif isinstance(value, float):
return ExprNodes.FloatNode(pos, value = rep)
elif isinstance(value, (_str, _bytes)):
return ExprNodes.StringNode(pos, value = value)
elif isinstance(value, _unicode):
return ExprNodes.UnicodeNode(pos, value = value)
elif isinstance(value, _bytes):
return ExprNodes.BytesNode(pos, value = value)
else:
error(pos, "Invalid type for compile-time constant: %s"
% value.__class__.__name__)
......@@ -580,24 +583,20 @@ def p_name(s, name):
def p_cat_string_literal(s):
# A sequence of one or more adjacent string literals.
# Returns (kind, value) where kind in ('b', 'c', 'u')
# Returns (kind, value) where kind in ('b', 'c', 'u', '')
kind, value = p_string_literal(s)
if s.sy != 'BEGIN_STRING':
return kind, value
if kind != 'c':
strings = [value]
while s.sy == 'BEGIN_STRING':
pos = s.position()
next_kind, next_value = p_string_literal(s)
if next_kind == 'c':
error(s.position(),
"Cannot concatenate char literal with another string or char literal")
error(pos, "Cannot concatenate char literal with another string or char literal")
elif next_kind != kind:
# we have to switch to unicode now
if kind == 'b':
# concatenating a unicode string to byte strings
strings = [u''.join([s.decode(s.encoding) for s in strings])]
elif kind == 'u':
# concatenating a byte string to unicode strings
strings.append(next_value.decode(next_value.encoding))
kind = 'u'
error(pos, "Cannot mix string literals of different types, expected %s'', got %s''" %
(kind, next_kind))
else:
strings.append(next_value)
if kind == 'u':
......@@ -630,8 +629,6 @@ def p_string_literal(s):
if Future.unicode_literals in s.context.future_directives:
if kind == '':
kind = 'u'
elif kind == '':
kind = 'b'
if kind == 'u':
chars = StringEncoding.UnicodeLiteralBuilder()
else:
......@@ -896,7 +893,7 @@ def p_expression_or_assignment(s):
rhs = p_expr(s)
return Nodes.InPlaceAssignmentNode(lhs.pos, operator = operator, lhs = lhs, rhs = rhs)
expr = expr_list[0]
if isinstance(expr, ExprNodes.StringNode):
if isinstance(expr, (ExprNodes.UnicodeNode, ExprNodes.StringNode, ExprNodes.BytesNode)):
return Nodes.PassStatNode(expr.pos)
else:
return Nodes.ExprStatNode(expr.pos, expr = expr)
......@@ -1131,15 +1128,14 @@ def p_import_statement(s):
else:
if as_name and "." in dotted_name:
name_list = ExprNodes.ListNode(pos, args = [
ExprNodes.IdentifierStringNode(
pos, value = EncodedString("*"))])
ExprNodes.StringNode(pos, value = EncodedString("*"))])
else:
name_list = None
stat = Nodes.SingleAssignmentNode(pos,
lhs = ExprNodes.NameNode(pos,
name = as_name or target_name),
rhs = ExprNodes.ImportNode(pos,
module_name = ExprNodes.IdentifierStringNode(
module_name = ExprNodes.StringNode(
pos, value = dotted_name),
name_list = name_list))
stats.append(stat)
......@@ -1197,7 +1193,7 @@ def p_from_import_statement(s, first_statement = 0):
for (name_pos, name, as_name, kind) in imported_names:
encoded_name = EncodedString(name)
imported_name_strings.append(
ExprNodes.IdentifierStringNode(name_pos, value = encoded_name))
ExprNodes.StringNode(name_pos, value = encoded_name))
items.append(
(name,
ExprNodes.NameNode(name_pos,
......@@ -1207,7 +1203,7 @@ def p_from_import_statement(s, first_statement = 0):
dotted_name = EncodedString(dotted_name)
return Nodes.FromImportStatNode(pos,
module = ExprNodes.ImportNode(dotted_name_pos,
module_name = ExprNodes.IdentifierStringNode(pos, value = dotted_name),
module_name = ExprNodes.StringNode(pos, value = dotted_name),
name_list = import_list),
items = items)
......@@ -1717,8 +1713,8 @@ def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keyword
parsed_type = True
else:
arg = p_simple_expr(s)
keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
value = EncodedString(ident))
keyword_node = ExprNodes.StringNode(
arg.pos, value = EncodedString(ident))
keyword_args.append((keyword_node, arg))
was_keyword = True
else:
......
......@@ -6,14 +6,14 @@ import re
import sys
if sys.version_info[0] >= 3:
_str, _bytes = str, bytes
_unicode, _str, _bytes = str, str, bytes
IS_PYTHON3 = True
else:
_str, _bytes = unicode, str
_unicode, _str, _bytes = unicode, str, str
IS_PYTHON3 = False
empty_bytes = _bytes()
empty_str = _str()
empty_unicode = _unicode()
join_bytes = empty_bytes.join
......@@ -27,7 +27,7 @@ class UnicodeLiteralBuilder(object):
if isinstance(characters, _bytes):
# this came from a Py2 string literal in the parser code
characters = characters.decode("ASCII")
assert isinstance(characters, _str), str(type(characters))
assert isinstance(characters, _unicode), str(type(characters))
self.chars.append(characters)
def append_charval(self, char_number):
......@@ -45,7 +45,7 @@ class BytesLiteralBuilder(object):
self.target_encoding = target_encoding
def append(self, characters):
if isinstance(characters, _str):
if isinstance(characters, _unicode):
characters = characters.encode(self.target_encoding)
assert isinstance(characters, _bytes), str(type(characters))
self.chars.append(characters)
......@@ -63,7 +63,7 @@ class BytesLiteralBuilder(object):
# this *must* return a byte string!
return self.getstring()
class EncodedString(_str):
class EncodedString(_unicode):
# unicode string subclass to keep track of the original encoding.
# 'encoding' is None for unicode strings and the source encoding
# otherwise
......@@ -82,7 +82,7 @@ class EncodedString(_str):
is_unicode = property(is_unicode)
class BytesLiteral(_bytes):
# str subclass that is compatible with EncodedString
# bytes subclass that is compatible with EncodedString
encoding = None
def byteencode(self):
......
import Cython.Compiler.Errors as Errors
from Cython.CodeWriter import CodeWriter
import unittest
from Cython.Compiler.ModuleNode import ModuleNode
import Cython.Compiler.Main as Main
from Cython.Compiler.TreeFragment import TreeFragment, strip_common_indent
from Cython.Compiler.Visitor import TreeVisitor, VisitorTransform
from Cython.Compiler import TreePath
import unittest
import sys
class NodeTypeWriter(TreeVisitor):
def __init__(self):
super(NodeTypeWriter, self).__init__()
......@@ -107,7 +109,7 @@ class CythonTest(unittest.TestCase):
try:
return func()
except:
self.fail()
self.fail(str(sys.exc_info()[1]))
class TransformTest(CythonTest):
"""
......
......@@ -3,11 +3,13 @@ cdef int c2 = "te" # fails
cdef int cx = "test" # fails
cdef int x1 = "\xFF" # works
cdef int x2 = u"\xFF" # fails
cdef int x2 = "\u0FFF" # fails
cdef int x3 = u"\xFF" # fails
_ERRORS = u"""
2:14: Only single-character byte strings can be coerced into ints.
3:14: Only single-character byte strings can be coerced into ints.
6:14: Unicode objects do not support coercion to C types.
2:14: Only single-character strings can be coerced into ints.
3:14: Only single-character strings can be coerced into ints.
6:15: Only single-character strings can be coerced into ints.
7:14: Unicode objects do not support coercion to C types.
"""
__doc__ = u"""
>>> y
(b'1', b'2', b'3')
('1', '2', '3')
>>> x
b'1foo2foo3'
'1foo2foo3'
"""
import sys
if sys.version_info[0] < 3:
__doc__ = __doc__.replace(u"b'", u"'")
y = ('1','2','3')
......
......@@ -23,7 +23,7 @@ def test_ints(int x):
return L[3], Li[3], Lii[1][0]
def test_chars(foo):
cdef char** ss = ["a", "bc", foo]
cdef char** ss = [b"a", b"bc", foo]
return ss[0], ss[1], ss[2]
cdef struct MyStruct:
......
......@@ -50,7 +50,8 @@ with ' and " quotes"""
q = "NameLikeString2"
r = "99_percent_un_namelike"
s = "Not an \escape"
t = b'this' b'parrot' b'is' b'resting'
u = u'this' u'parrot' u'is' u'resting'
def test_float(x):
......
__doc__ = u"""
>>> c = C()
>>> c.x
b'foo'
'foo'
"""
import sys
if sys.version_info[0] < 3:
__doc__ = __doc__.replace(u" b'", u" '")
class C:
x = "foo"
__doc__ = ur"""
>>> s1
b'abc\x11'
>>> s1 == b'abc\x11'
'abc\x11'
>>> s1 == 'abc\x11'
True
>>> len(s1)
4
>>> s2
b'abc\\x11'
>>> s2 == br'abc\x11'
'abc\\x11'
>>> s2 == r'abc\x11'
True
>>> len(s2)
7
>>> s3
b'abc\\x11'
>>> s3 == bR'abc\x11'
'abc\\x11'
>>> s3 == R'abc\x11'
True
>>> len(s3)
7
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment