Commit 07e897e9 authored by Stefan Behnel's avatar Stefan Behnel

fix bug 412: str char comparison, refactoring to move comparison coercions closer in the code

parent 443e0f51
...@@ -13,7 +13,8 @@ import Nodes ...@@ -13,7 +13,8 @@ import Nodes
from Nodes import Node from Nodes import Node
import PyrexTypes import PyrexTypes
from PyrexTypes import py_object_type, c_long_type, typecast, error_type, unspecified_type from PyrexTypes import py_object_type, c_long_type, typecast, error_type, unspecified_type
from Builtin import list_type, tuple_type, set_type, dict_type, unicode_type, bytes_type, type_type from Builtin import list_type, tuple_type, set_type, dict_type, \
unicode_type, str_type, bytes_type, type_type
import Builtin import Builtin
import Symtab import Symtab
import Options import Options
...@@ -821,6 +822,9 @@ class BytesNode(ConstNode): ...@@ -821,6 +822,9 @@ class BytesNode(ConstNode):
if isinstance(sizeof_node, SizeofTypeNode): if isinstance(sizeof_node, SizeofTypeNode):
return sizeof_node.arg_type return sizeof_node.arg_type
def can_coerce_to_char_literal(self):
return len(self.value) == 1
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type == PyrexTypes.c_char_ptr_type: if dst_type == PyrexTypes.c_char_ptr_type:
self.type = PyrexTypes.c_char_ptr_type self.type = PyrexTypes.c_char_ptr_type
...@@ -830,7 +834,7 @@ class BytesNode(ConstNode): ...@@ -830,7 +834,7 @@ class BytesNode(ConstNode):
return CastNode(self, PyrexTypes.c_uchar_ptr_type) return CastNode(self, PyrexTypes.c_uchar_ptr_type)
if dst_type.is_int: if dst_type.is_int:
if len(self.value) > 1: if not self.can_coerce_to_char_literal():
error(self.pos, "Only single-character strings can be coerced into ints.") error(self.pos, "Only single-character strings can be coerced into ints.")
return self return self
return CharNode(self.pos, value=self.value) return CharNode(self.pos, value=self.value)
...@@ -905,11 +909,11 @@ class StringNode(PyConstNode): ...@@ -905,11 +909,11 @@ class StringNode(PyConstNode):
# value BytesLiteral or EncodedString # value BytesLiteral or EncodedString
# is_identifier boolean # is_identifier boolean
type = Builtin.str_type type = str_type
is_identifier = False is_identifier = False
def coerce_to(self, dst_type, env): def coerce_to(self, dst_type, env):
if dst_type is not py_object_type and dst_type is not Builtin.str_type: if dst_type is not py_object_type and dst_type is not str_type:
# if dst_type is Builtin.bytes_type: # if dst_type is Builtin.bytes_type:
# # special case: bytes = 'str literal' # # special case: bytes = 'str literal'
# return BytesNode(self.pos, value=self.value) # return BytesNode(self.pos, value=self.value)
...@@ -927,6 +931,9 @@ class StringNode(PyConstNode): ...@@ -927,6 +931,9 @@ class StringNode(PyConstNode):
return self return self
def can_coerce_to_char_literal(self):
return not self.is_identifier and len(self.value) == 1
def generate_evaluation_code(self, code): def generate_evaluation_code(self, code):
self.result_code = code.get_py_string_const( self.result_code = code.get_py_string_const(
self.value, identifier=self.is_identifier, is_str=True) self.value, identifier=self.is_identifier, is_str=True)
...@@ -5065,6 +5072,73 @@ class CmpNode(object): ...@@ -5065,6 +5072,73 @@ class CmpNode(object):
result = result and cascade.compile_time_value(operand2, denv) result = result and cascade.compile_time_value(operand2, denv)
return result return result
def try_coerce_to_int_cmp(self, env, op, operand1, operand2):
# type1 != type2 and at least one of the types is not a C int
type1 = operand1.type
type2 = operand2.type
type1_can_be_int = False
type2_can_be_int = False
if isinstance(operand1, (StringNode, BytesNode)) \
and operand1.can_coerce_to_char_literal():
type1_can_be_int = True
if isinstance(operand2, (StringNode, BytesNode)) \
and operand2.can_coerce_to_char_literal():
type2_can_be_int = True
if type1.is_int:
if type2_can_be_int:
operand2 = operand2.coerce_to(type1, env)
elif type2.is_int:
if type1_can_be_int:
operand1 = operand1.coerce_to(type2, env)
elif type1_can_be_int:
if type2_can_be_int:
operand1 = operand1.coerce_to(PyrexTypes.c_uchar_type, env)
operand2 = operand2.coerce_to(PyrexTypes.c_uchar_type, env)
return operand1, operand2
def coerce_operands(self, env, op, operand1, common_type=None):
operand2 = self.operand2
type1 = operand1.type
type2 = operand2.type
if type1 == str_type and (type2.is_string or type2 in (bytes_type, unicode_type)) or \
type2 == str_type and (type1.is_string or type1 in (bytes_type, unicode_type)):
error(self.pos, "Comparisons between bytes/unicode and str are not portable to Python 3")
elif operand1.type.is_complex or operand2.type.is_complex:
if op not in ('==', '!='):
error(self.pos, "complex types unordered")
if operand1.type.is_pyobject:
operand2 = operand2.coerce_to(operand2.type, env)
elif operand2.type.is_pyobject:
operand1 = operand1.coerce_to(operand2.type, env)
else:
common_type = PyrexTypes.widest_numeric_type(type1, type2)
operand1 = operand1.coerce_to(common_type, env)
operand2 = operand2.coerce_to(common_type, env)
elif common_type is None or not common_type.is_pyobject:
if not type1.is_int or not type2.is_int:
operand1, operand2 = self.try_coerce_to_int_cmp(env, op, operand1, operand2)
if operand1.type.is_pyobject or operand2.type.is_pyobject:
# we could do a lot better by splitting the comparison
# into a non-Python part and a Python part, but this is
# safer for now
if operand1.type == operand2.type:
common_type = operand1.type
else:
common_type = py_object_type
if self.cascade:
operand2 = self.cascade.coerce_operands(env, self.operator, operand2, common_type)
self.operand2 = operand2
return operand1
def is_python_comparison(self): def is_python_comparison(self):
return (self.has_python_operands() return (self.has_python_operands()
or (self.cascade and self.cascade.is_python_comparison()) or (self.cascade and self.cascade.is_python_comparison())
...@@ -5075,13 +5149,7 @@ class CmpNode(object): ...@@ -5075,13 +5149,7 @@ class CmpNode(object):
or (self.cascade and self.cascade.is_python_result())) or (self.cascade and self.cascade.is_python_result()))
def check_types(self, env, operand1, op, operand2): def check_types(self, env, operand1, op, operand2):
if operand1.type.is_complex or operand2.type.is_complex: if not self.types_okay(operand1, op, operand2):
if op not in ('==', '!='):
error(self.pos, "complex types unordered")
common_type = PyrexTypes.widest_numeric_type(operand1.type, operand2.type)
self.operand1 = operand1.coerce_to(common_type, env)
self.operand2 = operand2.coerce_to(common_type, env)
elif not self.types_okay(operand1, op, operand2):
error(self.pos, "Invalid types for '%s' (%s, %s)" % error(self.pos, "Invalid types for '%s' (%s, %s)" %
(self.operator, operand1.type, operand2.type)) (self.operator, operand1.type, operand2.type))
...@@ -5225,11 +5293,10 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5225,11 +5293,10 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.operand2.analyse_types(env) self.operand2.analyse_types(env)
if self.cascade: if self.cascade:
self.cascade.analyse_types(env, self.operand2) self.cascade.analyse_types(env, self.operand2)
self.operand1 = self.coerce_operands(env, self.operator, self.operand1)
self.is_pycmp = self.is_python_comparison() self.is_pycmp = self.is_python_comparison()
if self.is_pycmp: if self.is_pycmp:
self.coerce_operands_to_pyobjects(env) self.coerce_operands_to_pyobjects(env)
if self.has_int_operands():
self.coerce_chars_to_ints(env)
if self.cascade: if self.cascade:
self.operand2 = self.operand2.coerce_to_simple(env) self.operand2 = self.operand2.coerce_to_simple(env)
self.cascade.coerce_cascaded_operands_to_temp(env) self.cascade.coerce_cascaded_operands_to_temp(env)
...@@ -5261,19 +5328,6 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5261,19 +5328,6 @@ class PrimaryCmpNode(ExprNode, CmpNode):
if self.cascade: if self.cascade:
self.cascade.coerce_operands_to_pyobjects(env) self.cascade.coerce_operands_to_pyobjects(env)
def has_int_operands(self):
return (self.operand1.type.is_int or self.operand2.type.is_int) \
or (self.cascade and self.cascade.has_int_operands())
def coerce_chars_to_ints(self, env):
# coerce literal single-char strings to c chars
if self.operand1.type.is_string and isinstance(self.operand1, BytesNode):
self.operand1 = self.operand1.coerce_to(PyrexTypes.c_uchar_type, env)
if self.operand2.type.is_string and isinstance(self.operand2, BytesNode):
self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env)
if self.cascade:
self.cascade.coerce_chars_to_ints(env)
def check_const(self): def check_const(self):
self.operand1.check_const() self.operand1.check_const()
self.operand2.check_const() self.operand2.check_const()
...@@ -5372,13 +5426,6 @@ class CascadedCmpNode(Node, CmpNode): ...@@ -5372,13 +5426,6 @@ class CascadedCmpNode(Node, CmpNode):
if self.cascade: if self.cascade:
self.cascade.coerce_operands_to_pyobjects(env) self.cascade.coerce_operands_to_pyobjects(env)
def has_int_operands(self):
return self.operand2.type.is_int
def coerce_chars_to_ints(self, env):
if self.operand2.type.is_string and isinstance(self.operand2, BytesNode):
self.operand2 = self.operand2.coerce_to(PyrexTypes.c_uchar_type, env)
def coerce_cascaded_operands_to_temp(self, env): def coerce_cascaded_operands_to_temp(self, env):
if self.cascade: if self.cascade:
#self.operand2 = self.operand2.coerce_to_temp(env) #CTT #self.operand2 = self.operand2.coerce_to_temp(env) #CTT
......
__doc__ = u"""
>>> test_eq()
True
True
True
True
>>> test_cascaded_eq()
True
True
True
True
True
True
True
True
>>> test_cascaded_ineq()
True
True
True
True
True
True
True
True
>>> test_long_ineq()
True
>>> test_long_ineq_py()
True
True
"""
cdef int i = 'x'
cdef char c = 'x'
cdef char* s = 'x'
def test_eq():
print i == 'x'
print i == c'x'
print c == 'x'
print c == c'x'
# print s == 'x' # error
# print s == c'x' # error
def test_cascaded_eq():
print 'x' == i == 'x'
print 'x' == i == c'x'
print c'x' == i == 'x'
print c'x' == i == c'x'
print 'x' == c == 'x'
print 'x' == c == c'x'
print c'x' == c == 'x'
print c'x' == c == c'x'
def test_cascaded_ineq():
print 'a' <= i <= 'z'
print 'a' <= i <= c'z'
print c'a' <= i <= 'z'
print c'a' <= i <= c'z'
print 'a' <= c <= 'z'
print 'a' <= c <= c'z'
print c'a' <= c <= 'z'
print c'a' <= c <= c'z'
def test_long_ineq():
print 'a' < 'b' < 'c' < 'd' < c < 'y' < 'z'
def test_long_ineq_py():
print 'abcdef' < 'b' < 'c' < 'd' < 'y' < 'z'
print 'a' < 'b' < 'cde' < 'd' < 'y' < 'z'
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment