Commit 13327ba0 authored by Stefan Behnel's avatar Stefan Behnel

fix type inference for sliced builtins

parent 93ade469
...@@ -1942,29 +1942,39 @@ class IndexNode(ExprNode): ...@@ -1942,29 +1942,39 @@ class IndexNode(ExprNode):
return self.base.type_dependencies(env) return self.base.type_dependencies(env)
def infer_type(self, env): def infer_type(self, env):
is_slice = isinstance(self.index, SliceNode) base_type = self.base.infer_type(env)
if isinstance(self.base, BytesNode): if isinstance(self.index, SliceNode):
if is_slice: # slicing!
if base_type.is_string:
return bytes_type return bytes_type
elif base_type in (unicode_type, bytes_type, str_type, list_type, tuple_type):
# slicing these returns the same type
return base_type
else: else:
return py_object_type # Py2/3 return different types # TODO: Handle buffers (hopefully without too much redundancy).
base_type = self.base.infer_type(env) return py_object_type
if base_type.is_ptr or base_type.is_array:
return base_type.base_type if isinstance(self.base, BytesNode):
elif base_type is unicode_type and self.index.infer_type(env).is_int: # Py2/3 return different types on indexing bytes objects
# Py_UNICODE will automatically coerce to a unicode string # and we can't be sure if we are slicing, so we can't do
# if required, so this is safe. We only infer Py_UNICODE # any better than this:
# when the index is a C integer type. Otherwise, we may return py_object_type
# need to use normal Python item access, in which case
# it's faster to return the one-char unicode string than if self.index.infer_type(env).is_int or isinstance(self.index, (IntNode, LongNode)):
# to receive it, throw it away, and potentially rebuild it # indexing!
# on a subsequent PyObject coercion. if base_type is unicode_type:
return PyrexTypes.c_py_unicode_type # Py_UNICODE will automatically coerce to a unicode string
elif base_type in (str_type, unicode_type): # if required, so this is safe. We only infer Py_UNICODE
# these types will always return their own type on Python indexing/slicing # when the index is a C integer type. Otherwise, we may
return base_type # need to use normal Python item access, in which case
elif is_slice and base_type in (bytes_type, list_type, tuple_type): # it's faster to return the one-char unicode string than
# slicing these returns the same type # to receive it, throw it away, and potentially rebuild it
# on a subsequent PyObject coercion.
return PyrexTypes.c_py_unicode_type
elif base_type.is_ptr or base_type.is_array:
return base_type.base_type
if base_type is unicode_type:
# this type always returns its own type on Python indexing/slicing
return base_type return base_type
else: else:
# TODO: Handle buffers (hopefully without too much redundancy). # TODO: Handle buffers (hopefully without too much redundancy).
...@@ -1993,11 +2003,12 @@ class IndexNode(ExprNode): ...@@ -1993,11 +2003,12 @@ class IndexNode(ExprNode):
self.type = PyrexTypes.error_type self.type = PyrexTypes.error_type
return return
if isinstance(self.index, IntNode) and Utils.long_literal(self.index.value): is_slice = isinstance(self.index, SliceNode)
if not is_slice and isinstance(self.index, IntNode) and Utils.long_literal(self.index.value):
self.index = self.index.coerce_to_pyobject(env) self.index = self.index.coerce_to_pyobject(env)
# Handle the case where base is a literal char* (and we expect a string, not an int) # Handle the case where base is a literal char* (and we expect a string, not an int)
if isinstance(self.base, BytesNode): if isinstance(self.base, BytesNode) or is_slice:
self.base = self.base.coerce_to_pyobject(env) self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False skip_child_analysis = False
...@@ -2069,6 +2080,8 @@ class IndexNode(ExprNode): ...@@ -2069,6 +2080,8 @@ class IndexNode(ExprNode):
# Py_UNICODE will automatically coerce to a unicode string # Py_UNICODE will automatically coerce to a unicode string
# if required, so this is fast and safe # if required, so this is fast and safe
self.type = PyrexTypes.c_py_unicode_type self.type = PyrexTypes.c_py_unicode_type
elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
self.type = base_type
else: else:
self.type = py_object_type self.type = py_object_type
else: else:
......
...@@ -60,10 +60,20 @@ def slicing(): ...@@ -60,10 +60,20 @@ def slicing():
assert typeof(b) == "char *", typeof(b) assert typeof(b) == "char *", typeof(b)
b1 = b[1:2] b1 = b[1:2]
assert typeof(b1) == "bytes object", typeof(b1) assert typeof(b1) == "bytes object", typeof(b1)
b2 = b[1:2:2]
assert typeof(b2) == "bytes object", typeof(b2)
u = u"xyz" u = u"xyz"
assert typeof(u) == "unicode object", typeof(u) assert typeof(u) == "unicode object", typeof(u)
u1 = u[1:2] u1 = u[1:2]
assert typeof(u1) == "unicode object", typeof(u1) assert typeof(u1) == "unicode object", typeof(u1)
u2 = u[1:2:2]
assert typeof(u2) == "unicode object", typeof(u2)
s = "xyz"
assert typeof(s) == "str object", typeof(s)
s1 = s[1:2]
assert typeof(s1) == "str object", typeof(s1)
s2 = s[1:2:2]
assert typeof(s2) == "str object", typeof(s2)
L = [1,2,3] L = [1,2,3]
assert typeof(L) == "list object", typeof(L) assert typeof(L) == "list object", typeof(L)
L1 = L[1:2] L1 = L[1:2]
...@@ -84,11 +94,15 @@ def indexing(): ...@@ -84,11 +94,15 @@ def indexing():
b = b"abc" b = b"abc"
assert typeof(b) == "char *", typeof(b) assert typeof(b) == "char *", typeof(b)
b1 = b[1] b1 = b[1]
assert typeof(b1) == "char", typeof(b1) # FIXME: bytes object ?? assert typeof(b1) == "char", typeof(b1) # FIXME: Python object ??
u = u"xyz" u = u"xyz"
assert typeof(u) == "unicode object", typeof(u) assert typeof(u) == "unicode object", typeof(u)
u1 = u[1] u1 = u[1]
assert typeof(u1) == "Py_UNICODE", typeof(u1) assert typeof(u1) == "Py_UNICODE", typeof(u1)
s = "xyz"
assert typeof(s) == "str object", typeof(s)
s1 = s[1]
assert typeof(s1) == "Python object", typeof(s1)
L = [1,2,3] L = [1,2,3]
assert typeof(L) == "list object", typeof(L) assert typeof(L) == "list object", typeof(L)
L1 = L[1] L1 = L[1]
...@@ -267,7 +281,7 @@ def loop_over_bytes(): ...@@ -267,7 +281,7 @@ def loop_over_bytes():
def loop_over_str(): def loop_over_str():
""" """
>>> print( loop_over_str() ) >>> print( loop_over_str() )
str object Python object
""" """
cdef str string = 'abcdefg' cdef str string = 'abcdefg'
for c in string: for c in string:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment