Commit 97503e02 authored by Stefan Behnel's avatar Stefan Behnel

fold sliced literal sequences (e.g. from DEFs) into constants

parent e559cf61
...@@ -1068,6 +1068,9 @@ class BytesNode(ConstNode): ...@@ -1068,6 +1068,9 @@ class BytesNode(ConstNode):
# start off as Python 'bytes' to support len() in O(1) # start off as Python 'bytes' to support len() in O(1)
type = bytes_type type = bytes_type
def calculate_constant_result(self):
self.constant_result = self.value
def compile_time_value(self, denv): def compile_time_value(self, denv):
return self.value return self.value
...@@ -1149,6 +1152,9 @@ class UnicodeNode(PyConstNode): ...@@ -1149,6 +1152,9 @@ class UnicodeNode(PyConstNode):
bytes_value = None bytes_value = None
type = unicode_type type = unicode_type
def calculate_constant_result(self):
self.constant_result = self.value
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type is self.type: if dst_type is self.type:
pass pass
...@@ -1214,6 +1220,9 @@ class StringNode(PyConstNode): ...@@ -1214,6 +1220,9 @@ class StringNode(PyConstNode):
is_identifier = None is_identifier = None
unicode_value = None unicode_value = None
def calculate_constant_result(self):
self.constant_result = self.value
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type is not py_object_type and not str_type.subtype_of(dst_type): if dst_type is not py_object_type and not str_type.subtype_of(dst_type):
# if dst_type is Builtin.bytes_type: # if dst_type is Builtin.bytes_type:
......
from Cython.Compiler.ExprNodes import not_a_constant
import cython import cython
cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object, cython.declare(UtilityCode=object, EncodedString=object, BytesLiteral=object,
Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object, Nodes=object, ExprNodes=object, PyrexTypes=object, Builtin=object,
...@@ -3190,10 +3191,33 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): ...@@ -3190,10 +3191,33 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
def visit_SliceIndexNode(self, node): def visit_SliceIndexNode(self, node):
self._calculate_const(node) self._calculate_const(node)
# normalise start/stop values # normalise start/stop values
if node.start and node.start.constant_result is None: if node.start is None or node.start.constant_result is None:
node.start = None start = node.start = None
if node.stop and node.stop.constant_result is None: else:
node.stop = None start = node.start.constant_result
if node.stop is None or node.stop.constant_result is None:
stop = node.stop = None
else:
stop = node.stop.constant_result
# cut down sliced constant sequences
if node.constant_result is not not_a_constant:
base = node.base
if base.is_sequence_constructor:
base.args = base.args[start:stop]
return base
elif base.is_string_literal:
value = type(base.value)(node.constant_result)
value.encoding = base.value.encoding
base.value = value
if isinstance(base, ExprNodes.StringNode):
if base.unicode_value is not None:
base.unicode_value = EncodedString(
base.unicode_value[start:stop])
elif isinstance(base, ExprNodes.UnicodeNode):
if base.bytes_value is not None:
base.bytes_value = BytesLiteral(
base.bytes_value[start:stop])
return base
return node return node
def visit_ForInStatNode(self, node): def visit_ForInStatNode(self, node):
......
# coding=utf8
# mode: run # mode: run
# tag: constant_folding # tag: constant_folding
...@@ -86,3 +87,47 @@ def binop_bool(): ...@@ -86,3 +87,47 @@ def binop_bool():
ormix3 = False | 0 | False | True ormix3 = False | 0 | False | True
xor3 = False ^ True ^ False ^ True xor3 = False ^ True ^ False ^ True
return plus1, pmix1, minus1, and1, or1, ormix1, xor1, plus3, pmix3, minus3, and3, or3, ormix3, xor3 return plus1, pmix1, minus1, and1, or1, ormix1, xor1, plus3, pmix3, minus3, and3, or3, ormix3, xor3
@cython.test_fail_if_path_exists(
"//SliceIndexNode",
)
def slicing2():
"""
>>> slicing2()
([1, 2, 3, 4], [3, 4], [1, 2, 3, 4], [3, 4], (1, 2, 3, 4), (3, 4), (1, 2, 3, 4), (3, 4))
"""
lst0 = [1, 2, 3, 4][:]
lst1 = [1, 2, 3, 4][2:]
lst2 = [1, 2, 3, 4][:4]
lst3 = [1, 2, 3, 4][2:4]
tpl0 = (1, 2, 3, 4)[:]
tpl1 = (1, 2, 3, 4)[2:]
tpl2 = (1, 2, 3, 4)[:4]
tpl3 = (1, 2, 3, 4)[2:4]
return lst0, lst1, lst2, lst3, tpl0, tpl1, tpl2, tpl3
@cython.test_fail_if_path_exists(
"//SliceIndexNode",
)
def str_slicing2():
"""
>>> a,b,c,d = str_slicing2()
>>> a == 'abc\\xE9def'[:]
True
>>> b == 'abc\\xE9def'[2:]
True
>>> c == 'abc\\xE9def'[:4]
True
>>> d == 'abc\\xE9def'[2:4]
True
"""
str0 = 'abc\xE9def'[:]
str1 = 'abc\xE9def'[2:]
str2 = 'abc\xE9def'[:4]
str3 = 'abc\xE9def'[2:4]
return str0, str1, str2, str3
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment