Commit a4d4b46c authored by Stefan Behnel's avatar Stefan Behnel

implement 'char_val in bytes_string' and 'pyunicode_val in unicode_string'

optimise literal string case using a switch statement
enable switch transform for regular PrimaryCmpNode
parent 36b2a2af
...@@ -5535,9 +5535,10 @@ class CmpNode(object): ...@@ -5535,9 +5535,10 @@ class CmpNode(object):
(op, operand1.type, operand2.type)) (op, operand1.type, operand2.type))
def is_python_comparison(self): def is_python_comparison(self):
return (self.has_python_operands() return not self.is_c_string_contains() and (
or (self.cascade and self.cascade.is_python_comparison()) self.has_python_operands()
or self.operator in ('in', 'not_in')) or (self.cascade and self.cascade.is_python_comparison())
or self.operator in ('in', 'not_in'))
def coerce_operands_to(self, dst_type, env): def coerce_operands_to(self, dst_type, env):
operand2 = self.operand2 operand2 = self.operand2
...@@ -5548,9 +5549,19 @@ class CmpNode(object): ...@@ -5548,9 +5549,19 @@ class CmpNode(object):
def is_python_result(self): def is_python_result(self):
return ((self.has_python_operands() and return ((self.has_python_operands() and
self.operator not in ('is', 'is_not', 'in', 'not_in')) self.operator not in ('is', 'is_not', 'in', 'not_in') and
not self.is_c_string_contains())
or (self.cascade and self.cascade.is_python_result())) or (self.cascade and self.cascade.is_python_result()))
def is_c_string_contains(self):
return self.operator in ('in', 'not_in') and \
((self.operand1.type in (PyrexTypes.c_char_type, PyrexTypes.c_uchar_type)
and self.operand2.type in (PyrexTypes.c_char_ptr_type,
PyrexTypes.c_uchar_ptr_type,
bytes_type)) or
(self.operand1.type is PyrexTypes.c_py_unicode_type
and self.operand2.type is unicode_type))
def generate_operation_code(self, code, result_code, def generate_operation_code(self, code, result_code,
operand1, op , operand2): operand1, op , operand2):
if self.type.is_pyobject: if self.type.is_pyobject:
...@@ -5652,6 +5663,38 @@ static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) { ...@@ -5652,6 +5663,38 @@ static CYTHON_INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
} }
""") """)
char_in_bytes_utility_code = UtilityCode(
proto="""
static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character); /*proto*/
""",
impl="""
static CYTHON_INLINE int __Pyx_BytesContains(PyObject* bytes, char character) {
const Py_ssize_t length = PyBytes_GET_SIZE(bytes);
char* char_start = PyBytes_AS_STRING(bytes);
char* pos;
for (pos=char_start; pos < char_start+length; pos++) {
if (character == pos[0]) return 1;
}
return 0;
}
""")
pyunicode_in_unicode_utility_code = UtilityCode(
proto="""
static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character); /*proto*/
""",
impl="""
static CYTHON_INLINE int __Pyx_UnicodeContains(PyObject* unicode, Py_UNICODE character) {
const Py_ssize_t length = PyUnicode_GET_SIZE(unicode);
Py_UNICODE* char_start = PyUnicode_AS_UNICODE(unicode);
Py_UNICODE* pos;
for (pos=char_start; pos < char_start+length; pos++) {
if (character == pos[0]) return 1;
}
return 0;
}
""")
class PrimaryCmpNode(ExprNode, CmpNode): class PrimaryCmpNode(ExprNode, CmpNode):
# Non-cascaded comparison or first comparison of # Non-cascaded comparison or first comparison of
...@@ -5698,13 +5741,32 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5698,13 +5741,32 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.cascade.analyse_types(env) self.cascade.analyse_types(env)
if self.operator in ('in', 'not_in'): if self.operator in ('in', 'not_in'):
common_type = py_object_type if self.is_c_string_contains():
self.is_pycmp = True self.is_pycmp = False
common_type = None
if self.cascade:
error(self.pos, "Cascading comparison not yet supported for 'int_val in string'.")
return
if self.operand2.type is unicode_type:
env.use_utility_code(pyunicode_in_unicode_utility_code)
else:
if self.operand1.type is PyrexTypes.c_uchar_type:
self.operand1 = self.operand1.coerce_to(PyrexTypes.c_char_type, env)
if self.operand2.type is not bytes_type:
self.operand2 = self.operand2.coerce_to(bytes_type, env)
env.use_utility_code(char_in_bytes_utility_code)
if not isinstance(self.operand2, (UnicodeNode, BytesNode)):
self.operand2 = NoneCheckNode(
self.operand2, "PyExc_TypeError",
"argument of type 'NoneType' is not iterable")
else:
common_type = py_object_type
self.is_pycmp = True
else: else:
common_type = self.find_common_type(env, self.operator, self.operand1) common_type = self.find_common_type(env, self.operator, self.operand1)
self.is_pycmp = common_type.is_pyobject self.is_pycmp = common_type.is_pyobject
if not common_type.is_error: if common_type is not None and not common_type.is_error:
if self.operand1.type != common_type: if self.operand1.type != common_type:
self.operand1 = self.operand1.coerce_to(common_type, env) self.operand1 = self.operand1.coerce_to(common_type, env)
self.coerce_operands_to(common_type, env) self.coerce_operands_to(common_type, env)
...@@ -5765,6 +5827,20 @@ class PrimaryCmpNode(ExprNode, CmpNode): ...@@ -5765,6 +5827,20 @@ class PrimaryCmpNode(ExprNode, CmpNode):
self.operand1.type.binary_op('=='), self.operand1.type.binary_op('=='),
self.operand1.result(), self.operand1.result(),
self.operand2.result()) self.operand2.result())
elif self.is_c_string_contains():
if self.operand2.type is bytes_type:
method = "__Pyx_BytesContains"
else:
method = "__Pyx_UnicodeContains"
if self.operator == "not_in":
negation = "!"
else:
negation = ""
return "(%s%s(%s, %s))" % (
negation,
method,
self.operand2.result(),
self.operand1.result())
else: else:
return "(%s %s %s)" % ( return "(%s %s %s)" % (
self.operand1.result(), self.operand1.result(),
......
...@@ -596,6 +596,17 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -596,6 +596,17 @@ class SwitchTransform(Visitor.VisitorTransform):
not_in = False not_in = False
elif allow_not_in and cond.operator == '!=': elif allow_not_in and cond.operator == '!=':
not_in = True not_in = True
elif cond.is_c_string_contains() and \
isinstance(cond.operand2, (ExprNodes.UnicodeNode, ExprNodes.BytesNode)):
not_in = cond.operator == 'not_in'
if not_in and not allow_not_in:
return self.NO_MATCH
# this looks somewhat silly, but it does the right
# checks for NameNode and AttributeNode
if is_common_value(cond.operand1, cond.operand1):
return not_in, cond.operand1, self.extract_in_string_conditions(cond.operand2)
else:
return self.NO_MATCH
else: else:
return self.NO_MATCH return self.NO_MATCH
# this looks somewhat silly, but it does the right # this looks somewhat silly, but it does the right
...@@ -622,6 +633,23 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -622,6 +633,23 @@ class SwitchTransform(Visitor.VisitorTransform):
return not_in_1, t1, c1+c2 return not_in_1, t1, c1+c2
return self.NO_MATCH return self.NO_MATCH
def extract_in_string_conditions(self, string_literal):
if isinstance(string_literal, ExprNodes.UnicodeNode):
charvals = map(ord, set(string_literal.value))
charvals.sort()
return [ ExprNodes.IntNode(string_literal.pos, value=str(charval),
constant_result=charval)
for charval in charvals ]
else:
# this is a bit tricky as Py3's bytes type returns
# integers on iteration, whereas Py2 returns 1-char byte
# strings
characters = string_literal.value
characters = set([ characters[i:i+1] for i in range(len(characters)) ])
return [ ExprNodes.CharNode(string_literal.pos, value=charval,
constant_result=charval)
for charval in characters ]
def extract_common_conditions(self, common_var, condition, allow_not_in): def extract_common_conditions(self, common_var, condition, allow_not_in):
not_in, var, conditions = self.extract_conditions(condition, allow_not_in) not_in, var, conditions = self.extract_conditions(condition, allow_not_in)
if var is None: if var is None:
...@@ -696,8 +724,22 @@ class SwitchTransform(Visitor.VisitorTransform): ...@@ -696,8 +724,22 @@ class SwitchTransform(Visitor.VisitorTransform):
return self.build_simple_switch_statement( return self.build_simple_switch_statement(
node, common_var, conditions, not_in, node, common_var, conditions, not_in,
ExprNodes.BoolNode(node.pos, value=True), ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
ExprNodes.BoolNode(node.pos, value=False)) ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
def visit_PrimaryCmpNode(self, node):
not_in, common_var, conditions = self.extract_common_conditions(
None, node, True)
if common_var is None \
or len(conditions) < 2 \
or self.has_duplicate_values(conditions):
self.visitchildren(node)
return node
return self.build_simple_switch_statement(
node, common_var, conditions, not_in,
ExprNodes.BoolNode(node.pos, value=True, constant_result=True),
ExprNodes.BoolNode(node.pos, value=False, constant_result=False))
def build_simple_switch_statement(self, node, common_var, conditions, def build_simple_switch_statement(self, node, common_var, conditions,
not_in, true_val, false_val): not_in, true_val, false_val):
......
...@@ -92,6 +92,72 @@ def m_set(int a): ...@@ -92,6 +92,72 @@ def m_set(int a):
cdef int result = a in {1,2,3,4} cdef int result = a in {1,2,3,4}
return result return result
cdef bytes bytes_string = b'abcdefg'
py_bytes_string = bytes_string
@cython.test_assert_path_exists("//PrimaryCmpNode")
@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
def m_bytes(char a, bytes bytes_string):
"""
>>> m_bytes(ord('f'), py_bytes_string)
1
>>> m_bytes(ord('X'), py_bytes_string)
0
>>> 'f'.encode('ASCII') in None
Traceback (most recent call last):
TypeError: argument of type 'NoneType' is not iterable
>>> m_bytes(ord('f'), None)
Traceback (most recent call last):
TypeError: argument of type 'NoneType' is not iterable
"""
cdef int result = a in bytes_string
return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_bytes_literal(char a):
"""
>>> m_bytes_literal(ord('f'))
1
>>> m_bytes_literal(ord('X'))
0
"""
cdef int result = a in b'abcdefg'
return result
cdef unicode unicode_string = u'abcdefg\u1234\uF8D2'
py_unicode_string = unicode_string
@cython.test_assert_path_exists("//PrimaryCmpNode")
@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
def m_unicode(Py_UNICODE a, unicode unicode_string):
"""
>>> m_unicode(ord('f'), py_unicode_string)
1
>>> m_unicode(ord('X'), py_unicode_string)
0
>>> 'f' in None
Traceback (most recent call last):
TypeError: argument of type 'NoneType' is not iterable
>>> m_unicode(ord('f'), None)
Traceback (most recent call last):
TypeError: argument of type 'NoneType' is not iterable
"""
cdef int result = a in unicode_string
return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_unicode_literal(Py_UNICODE a):
"""
>>> m_unicode_literal(ord('f'))
1
>>> m_unicode_literal(ord('X'))
0
"""
cdef int result = a in u'abcdefg\u1234\uF8D2'
return result
@cython.test_assert_path_exists("//SwitchStatNode") @cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") @cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_int(int a): def conditional_int(int a):
......
...@@ -82,6 +82,70 @@ def m_tuple(int a): ...@@ -82,6 +82,70 @@ def m_tuple(int a):
cdef int result = a not in (1,2,3,4) cdef int result = a not in (1,2,3,4)
return result return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_set(int a):
"""
>>> m_set(2)
0
>>> m_set(5)
1
"""
cdef int result = a not in {1,2,3,4}
return result
cdef bytes bytes_string = b'abcdefg'
@cython.test_assert_path_exists("//PrimaryCmpNode")
@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
def m_bytes(char a):
"""
>>> m_bytes(ord('f'))
0
>>> m_bytes(ord('X'))
1
"""
cdef int result = a not in bytes_string
return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_bytes_literal(char a):
"""
>>> m_bytes_literal(ord('f'))
0
>>> m_bytes_literal(ord('X'))
1
"""
cdef int result = a not in b'abcdefg'
return result
cdef unicode unicode_string = u'abcdefg\u1234\uF8D2'
@cython.test_assert_path_exists("//PrimaryCmpNode")
@cython.test_fail_if_path_exists("//SwitchStatNode", "//BoolBinopNode")
def m_unicode(Py_UNICODE a):
"""
>>> m_unicode(ord('f'))
0
>>> m_unicode(ord('X'))
1
"""
cdef int result = a not in unicode_string
return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_unicode_literal(Py_UNICODE a):
"""
>>> m_unicode_literal(ord('f'))
0
>>> m_unicode_literal(ord('X'))
1
"""
cdef int result = a not in u'abcdefg\u1234\uF8D2'
return result
@cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode") @cython.test_assert_path_exists("//SwitchStatNode", "//BoolBinopNode")
@cython.test_fail_if_path_exists("//PrimaryCmpNode") @cython.test_fail_if_path_exists("//PrimaryCmpNode")
def m_tuple_in_or_notin(int a): def m_tuple_in_or_notin(int a):
...@@ -138,6 +202,43 @@ def m_tuple_notin_and_notin_overlap(int a): ...@@ -138,6 +202,43 @@ def m_tuple_notin_and_notin_overlap(int a):
cdef int result = a not in (1,2,3,4) and a not in (3,4) cdef int result = a not in (1,2,3,4) and a not in (3,4)
return result return result
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_int(int a):
"""
>>> conditional_int(1)
2
>>> conditional_int(0)
1
>>> conditional_int(5)
1
"""
return 1 if a not in (1,2,3,4) else 2
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_object(int a):
"""
>>> conditional_object(1)
'2'
>>> conditional_object(0)
1
>>> conditional_object(5)
1
"""
return 1 if a not in (1,2,3,4) else '2'
@cython.test_assert_path_exists("//SwitchStatNode")
@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def conditional_none(int a):
"""
>>> conditional_none(1)
1
>>> conditional_none(0)
>>> conditional_none(5)
"""
return None if a not in {1,2,3,4} else 1
def n(a): def n(a):
""" """
>>> n('d *') >>> n('d *')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment