Commit 1f8ffaa1 authored by Stefan Behnel's avatar Stefan Behnel

ticket 436: efficiently support char*.decode() through C-API calls

parent eed632a1
...@@ -136,7 +136,7 @@ class Context(object): ...@@ -136,7 +136,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
OptimizeBuiltinCalls(), OptimizeBuiltinCalls(self),
IterationTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
DropRefcountingTransform(), DropRefcountingTransform(),
......
...@@ -305,7 +305,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -305,7 +305,7 @@ class IterationTransform(Visitor.VisitorTransform):
if dest_type != obj_node.type: if dest_type != obj_node.type:
if dest_type.is_extension_type or dest_type.is_builtin_type: if dest_type.is_extension_type or dest_type.is_builtin_type:
obj_node = ExprNodes.PyTypeTestNode( obj_node = ExprNodes.PyTypeTestNode(
obj_node, dest_type, FakePythonEnv(), notnone=True) obj_node, dest_type, self.current_scope, notnone=True)
result = ExprNodes.TypecastNode( result = ExprNodes.TypecastNode(
obj_node.pos, obj_node.pos,
operand = obj_node, operand = obj_node,
...@@ -320,7 +320,7 @@ class IterationTransform(Visitor.VisitorTransform): ...@@ -320,7 +320,7 @@ class IterationTransform(Visitor.VisitorTransform):
return temp_result.result() return temp_result.result()
def generate_execution_code(self, code): def generate_execution_code(self, code):
self.generate_result_code(code) self.generate_result_code(code)
return (temp_result, CoercedTempNode(dest_type, obj_node, FakePythonEnv())) return (temp_result, CoercedTempNode(dest_type, obj_node, self.current_scope))
if isinstance(node.body, Nodes.StatListNode): if isinstance(node.body, Nodes.StatListNode):
body = node.body body = node.body
...@@ -633,7 +633,7 @@ class DropRefcountingTransform(Visitor.VisitorTransform): ...@@ -633,7 +633,7 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
return (base.name, index_val) return (base.name, index_val)
class OptimizeBuiltinCalls(Visitor.VisitorTransform): class OptimizeBuiltinCalls(Visitor.EnvTransform):
"""Optimize some common methods calls and instantiation patterns """Optimize some common methods calls and instantiation patterns
for builtin types. for builtin types.
""" """
...@@ -961,7 +961,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -961,7 +961,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
_special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII', _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII',
'unicode_escape', 'raw_unicode_escape'] 'unicode_escape', 'raw_unicode_escape']
_special_encoders = [ (name, codecs.getencoder(name)) _special_codecs = [ (name, codecs.getencoder(name))
for name in _special_encodings ] for name in _special_encodings ]
def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method): def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
...@@ -969,43 +969,19 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -969,43 +969,19 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
self._error_wrong_arg_count('unicode.encode', node, args, '1-3') self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
return node return node
null_node = ExprNodes.NullNode(node.pos)
string_node = args[0] string_node = args[0]
if len(args) == 1: if len(args) == 1:
null_node = ExprNodes.NullNode(node.pos)
return self._substitute_method_call( return self._substitute_method_call(
node, "PyUnicode_AsEncodedString", node, "PyUnicode_AsEncodedString",
self.PyUnicode_AsEncodedString_func_type, self.PyUnicode_AsEncodedString_func_type,
'encode', is_unbound_method, [string_node, null_node, null_node]) 'encode', is_unbound_method, [string_node, null_node, null_node])
encoding_node = args[1] parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode): if parameters is None:
encoding_node = encoding_node.arg
if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return node return node
encoding = encoding_node.value encoding, encoding_node, error_handling, error_handling_node = parameters
encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
type=PyrexTypes.c_char_ptr_type)
if len(args) == 3:
error_handling_node = args[2]
if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
error_handling_node = error_handling_node.arg
if not isinstance(error_handling_node,
(ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return node
error_handling = error_handling_node.value
if error_handling == 'strict':
error_handling_node = null_node
else:
error_handling_node = ExprNodes.BytesNode(
error_handling_node.pos, value=error_handling,
type=PyrexTypes.c_char_ptr_type)
else:
error_handling = 'strict'
error_handling_node = null_node
if isinstance(string_node, ExprNodes.UnicodeNode): if isinstance(string_node, ExprNodes.UnicodeNode):
# constant, so try to do the encoding at compile time # constant, so try to do the encoding at compile time
...@@ -1022,18 +998,9 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -1022,18 +998,9 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
if error_handling == 'strict': if error_handling == 'strict':
# try to find a specific encoder function # try to find a specific encoder function
try: requested_encoder = codecs.getencoder(encoding) codec_name = self._find_special_codec_name(encoding)
except: pass if codec_name is not None:
else: encode_function = "PyUnicode_As%sString" % codec_name
encode_function = None
for name, encoder in self._special_encoders:
if encoder == requested_encoder:
if '_' in name:
name = ''.join([ s.capitalize()
for s in name.split('_')])
encode_function = "PyUnicode_As%sString" % name
break
if encode_function is not None:
return self._substitute_method_call( return self._substitute_method_call(
node, encode_function, node, encode_function,
self.PyUnicode_AsXyzString_func_type, self.PyUnicode_AsXyzString_func_type,
...@@ -1045,6 +1012,128 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): ...@@ -1045,6 +1012,128 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
'encode', is_unbound_method, 'encode', is_unbound_method,
[string_node, encoding_node, error_handling_node]) [string_node, encoding_node, error_handling_node])
PyUnicode_DecodeXyz_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
],
exception_value = "NULL")
PyUnicode_Decode_func_type = PyrexTypes.CFuncType(
Builtin.unicode_type, [
PyrexTypes.CFuncTypeArg("string", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("size", PyrexTypes.c_py_ssize_t_type, None),
PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None),
PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None),
],
exception_value = "NULL")
def _handle_simple_method_bytes_decode(self, node, args, is_unbound_method):
if len(args) < 1 or len(args) > 3:
self._error_wrong_arg_count('bytes.decode', node, args, '1-3')
return node
if is_unbound_method:
return node
if not isinstance(args[0], ExprNodes.SliceIndexNode):
# we need the string length as a slice end index
return node
index_node = args[0]
string_node = index_node.base
if not string_node.type.is_string:
# nothing to optimise here
return node
start, stop = index_node.start, index_node.stop
if not stop:
# FIXME: could use strlen() - although Python will do that anyway ...
return node
if stop.type.is_pyobject:
stop = stop.coerce_to(PyrexTypes.c_py_ssize_t_type, self.env_stack[-1])
if start and start.constant_result != 0:
# FIXME: put start into a temp and do the math
return node
parameters = self._unpack_encoding_and_error_mode(node.pos, args)
if parameters is None:
return node
encoding, encoding_node, error_handling, error_handling_node = parameters
# try to find a specific encoder function
codec_name = self._find_special_codec_name(encoding)
if codec_name is not None:
decode_function = "PyUnicode_Decode%s" % codec_name
return ExprNodes.PythonCapiCallNode(
node.pos, decode_function,
self.PyUnicode_DecodeXyz_func_type,
args = [string_node, stop, error_handling_node],
is_temp = node.is_temp,
)
return self._substitute_method_call(
node, decode_function,
self.PyUnicode_DecodeXyz_func_type,
'decode', is_unbound_method,
[string_node, stop, error_handling_node])
return ExprNodes.PythonCapiCallNode(
node.pos, "PyUnicode_Decode",
self.PyUnicode_Decode_func_type,
args = [string_node, stop, encoding_node, error_handling_node],
is_temp = node.is_temp,
)
return self._substitute_method_call(
node, "PyUnicode_Decode",
self.PyUnicode_Decode_func_type,
'decode', is_unbound_method,
[string_node, stop, encoding_node, error_handling_node])
def _find_special_codec_name(self, encoding):
try:
requested_codec = codecs.getencoder(encoding)
except:
return None
for name, codec in self._special_codecs:
if codec == requested_codec:
if '_' in name:
name = ''.join([ s.capitalize()
for s in name.split('_')])
return name
return None
def _unpack_encoding_and_error_mode(self, pos, args):
encoding_node = args[1]
if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode):
encoding_node = encoding_node.arg
if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return None
encoding = encoding_node.value
encoding_node = ExprNodes.BytesNode(encoding_node.pos, value=encoding,
type=PyrexTypes.c_char_ptr_type)
null_node = ExprNodes.NullNode(pos)
if len(args) == 3:
error_handling_node = args[2]
if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode):
error_handling_node = error_handling_node.arg
if not isinstance(error_handling_node,
(ExprNodes.UnicodeNode, ExprNodes.StringNode,
ExprNodes.BytesNode)):
return None
error_handling = error_handling_node.value
if error_handling == 'strict':
error_handling_node = null_node
else:
error_handling_node = ExprNodes.BytesNode(
error_handling_node.pos, value=error_handling,
type=PyrexTypes.c_char_ptr_type)
else:
error_handling = 'strict'
error_handling_node = null_node
return (encoding, encoding_node, error_handling, error_handling_node)
def _substitute_method_call(self, node, name, func_type, def _substitute_method_call(self, node, name, func_type,
attr_name, is_unbound_method, args=()): attr_name, is_unbound_method, args=()):
args = list(args) args = list(args)
......
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
from Cython.Compiler.Visitor import CythonTransform, EnvTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
...@@ -938,21 +939,6 @@ class GilCheck(VisitorTransform): ...@@ -938,21 +939,6 @@ class GilCheck(VisitorTransform):
return node return node
class EnvTransform(CythonTransform):
"""
This transformation keeps a stack of the environments.
"""
def __call__(self, root):
self.env_stack = [root.scope]
return super(EnvTransform, self).__call__(root)
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
class TransformBuiltinMethods(EnvTransform): class TransformBuiltinMethods(EnvTransform):
def visit_SingleAssignmentNode(self, node): def visit_SingleAssignmentNode(self, node):
......
...@@ -306,6 +306,22 @@ class ScopeTrackingTransform(CythonTransform): ...@@ -306,6 +306,22 @@ class ScopeTrackingTransform(CythonTransform):
def visit_CStructOrUnionDefNode(self, node): def visit_CStructOrUnionDefNode(self, node):
return self.visit_scope(node, 'struct') return self.visit_scope(node, 'struct')
class EnvTransform(CythonTransform):
"""
This transformation keeps a stack of the environments.
"""
def __call__(self, root):
self.env_stack = [root.scope]
return super(EnvTransform, self).__call__(root)
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
class RecursiveNodeReplacer(VisitorTransform): class RecursiveNodeReplacer(VisitorTransform):
""" """
Recursively replace all occurrences of a node in a subtree by Recursively replace all occurrences of a node in a subtree by
......
cdef char* cstring = "abcABCqtp"
def slice_charptr_end():
"""
>>> print str(slice_charptr_end()).replace("b'", "'")
('a', 'abc', 'abcABCqtp')
"""
return cstring[:1], cstring[:3], cstring[:9]
def slice_charptr_decode():
"""
>>> print str(slice_charptr_decode()).replace("u'", "'")
('a', 'abc', 'abcABCqtp')
"""
return (cstring[:1].decode('UTF-8'),
cstring[:3].decode('UTF-8'),
cstring[:9].decode('UTF-8'))
def slice_charptr_decode_errormode():
"""
>>> print str(slice_charptr_decode_errormode()).replace("u'", "'")
('a', 'abc', 'abcABCqtp')
"""
return (cstring[:1].decode('UTF-8', 'strict'),
cstring[:3].decode('UTF-8', 'replace'),
cstring[:9].decode('UTF-8', 'unicode_escape'))
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment