Commit 404c8968 authored by Robert Bradshaw's avatar Robert Bradshaw

Add support for external C++ template functions.

The syntax follows that of template classes, namely

    cdef T foo[T](T, ...)
parent 93e0ec35
...@@ -2721,6 +2721,7 @@ class IndexNode(ExprNode): ...@@ -2721,6 +2721,7 @@ class IndexNode(ExprNode):
# base ExprNode # base ExprNode
# index ExprNode # index ExprNode
# indices [ExprNode] # indices [ExprNode]
# type_indices [PyrexType]
# is_buffer_access boolean Whether this is a buffer access. # is_buffer_access boolean Whether this is a buffer access.
# #
# indices is used on buffer access, index on non-buffer access. # indices is used on buffer access, index on non-buffer access.
...@@ -2732,6 +2733,7 @@ class IndexNode(ExprNode): ...@@ -2732,6 +2733,7 @@ class IndexNode(ExprNode):
subexprs = ['base', 'index', 'indices'] subexprs = ['base', 'index', 'indices']
indices = None indices = None
type_indices = None
is_subscript = True is_subscript = True
is_fused_index = False is_fused_index = False
...@@ -3103,8 +3105,7 @@ class IndexNode(ExprNode): ...@@ -3103,8 +3105,7 @@ class IndexNode(ExprNode):
else: else:
base_type = self.base.type base_type = self.base.type
fused_index_operation = base_type.is_cfunction and base_type.is_fused if not base_type.is_cfunction:
if not fused_index_operation:
if isinstance(self.index, TupleNode): if isinstance(self.index, TupleNode):
self.index = self.index.analyse_types( self.index = self.index.analyse_types(
env, skip_children=skip_child_analysis) env, skip_children=skip_child_analysis)
...@@ -3188,8 +3189,17 @@ class IndexNode(ExprNode): ...@@ -3188,8 +3189,17 @@ class IndexNode(ExprNode):
self.type = func_type.return_type self.type = func_type.return_type
if setting and not func_type.return_type.is_reference: if setting and not func_type.return_type.is_reference:
error(self.pos, "Can't set non-reference result '%s'" % self.type) error(self.pos, "Can't set non-reference result '%s'" % self.type)
elif fused_index_operation: elif base_type.is_cfunction:
if base_type.is_fused:
self.parse_indexed_fused_cdef(env) self.parse_indexed_fused_cdef(env)
else:
self.type_indices = self.parse_index_as_types(env)
if base_type.templates is None:
error(self.pos, "Can only parameterize template functions.")
elif len(base_type.templates) != len(self.type_indices):
error(self.pos, "Wrong number of template arguments: expected %s, got %s" % (
(len(base_type.templates), len(self.type_indices))))
self.type = base_type.specialize(dict(zip(base_type.templates, self.type_indices)))
else: else:
error(self.pos, error(self.pos,
"Attempting to index non-array type '%s'" % "Attempting to index non-array type '%s'" %
...@@ -3215,6 +3225,20 @@ class IndexNode(ExprNode): ...@@ -3215,6 +3225,20 @@ class IndexNode(ExprNode):
self.base = self.base.as_none_safe_node(msg) self.base = self.base.as_none_safe_node(msg)
def parse_index_as_types(self, env, required=True):
if isinstance(self.index, TupleNode):
indices = self.index.args
else:
indices = [self.index]
type_indices = []
for index in indices:
type_indices.append(index.analyse_as_type(env))
if type_indices[-1] is None:
if required:
error(index.pos, "not parsable as a type")
return None
return type_indices
def parse_indexed_fused_cdef(self, env): def parse_indexed_fused_cdef(self, env):
""" """
Interpret fused_cdef_func[specific_type1, ...] Interpret fused_cdef_func[specific_type1, ...]
...@@ -3234,16 +3258,12 @@ class IndexNode(ExprNode): ...@@ -3234,16 +3258,12 @@ class IndexNode(ExprNode):
if self.index.is_name or self.index.is_attribute: if self.index.is_name or self.index.is_attribute:
positions.append(self.index.pos) positions.append(self.index.pos)
specific_types.append(self.index.analyse_as_type(env))
elif isinstance(self.index, TupleNode): elif isinstance(self.index, TupleNode):
for arg in self.index.args: for arg in self.index.args:
positions.append(arg.pos) positions.append(arg.pos)
specific_type = arg.analyse_as_type(env) specific_types = self.parse_index_as_types(env, required=False)
specific_types.append(specific_type)
else:
specific_types = [False]
if not Utils.all(specific_types): if specific_types is None:
self.index = self.index.analyse_types(env) self.index = self.index.analyse_types(env)
if not self.base.entry.as_variable: if not self.base.entry.as_variable:
...@@ -3362,6 +3382,10 @@ class IndexNode(ExprNode): ...@@ -3362,6 +3382,10 @@ class IndexNode(ExprNode):
index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))" index_code = "((unsigned char)(PyByteArray_AS_STRING(%s)[%s]))"
else: else:
assert False, "unexpected base type in indexing: %s" % self.base.type assert False, "unexpected base type in indexing: %s" % self.base.type
elif self.base.type.is_cfunction:
return "%s<%s>" % (
self.base.result(),
",".join([param.declaration_code("") for param in self.type_indices]))
else: else:
if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type: if (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
error(self.pos, "Invalid use of pointer slice") error(self.pos, "Invalid use of pointer slice")
...@@ -3388,7 +3412,9 @@ class IndexNode(ExprNode): ...@@ -3388,7 +3412,9 @@ class IndexNode(ExprNode):
def generate_subexpr_evaluation_code(self, code): def generate_subexpr_evaluation_code(self, code):
self.base.generate_evaluation_code(code) self.base.generate_evaluation_code(code)
if self.indices is None: if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_evaluation_code(code) self.index.generate_evaluation_code(code)
else: else:
for i in self.indices: for i in self.indices:
...@@ -3396,7 +3422,9 @@ class IndexNode(ExprNode): ...@@ -3396,7 +3422,9 @@ class IndexNode(ExprNode):
def generate_subexpr_disposal_code(self, code): def generate_subexpr_disposal_code(self, code):
self.base.generate_disposal_code(code) self.base.generate_disposal_code(code)
if self.indices is None: if self.type_indices is not None:
pass
elif self.indices is None:
self.index.generate_disposal_code(code) self.index.generate_disposal_code(code)
else: else:
for i in self.indices: for i in self.indices:
......
...@@ -19,8 +19,8 @@ import Naming ...@@ -19,8 +19,8 @@ import Naming
import PyrexTypes import PyrexTypes
import TypeSlots import TypeSlots
from PyrexTypes import py_object_type, error_type from PyrexTypes import py_object_type, error_type
from Symtab import ModuleScope, LocalScope, ClosureScope, \ from Symtab import (ModuleScope, LocalScope, ClosureScope,
StructOrUnionScope, PyClassScope, CppClassScope StructOrUnionScope, PyClassScope, CppClassScope, TemplateScope)
from Code import UtilityCode from Code import UtilityCode
from StringEncoding import EncodedString, escape_byte_string, split_string_literal from StringEncoding import EncodedString, escape_byte_string, split_string_literal
import Options import Options
...@@ -465,6 +465,9 @@ class CDeclaratorNode(Node): ...@@ -465,6 +465,9 @@ class CDeclaratorNode(Node):
calling_convention = "" calling_convention = ""
def analyse_templates(self):
# Only C++ functions have templates.
return None
class CNameDeclaratorNode(CDeclaratorNode): class CNameDeclaratorNode(CDeclaratorNode):
# name string The Cython name being declared # name string The Cython name being declared
...@@ -523,7 +526,7 @@ class CArrayDeclaratorNode(CDeclaratorNode): ...@@ -523,7 +526,7 @@ class CArrayDeclaratorNode(CDeclaratorNode):
child_attrs = ["base", "dimension"] child_attrs = ["base", "dimension"]
def analyse(self, base_type, env, nonempty = 0): def analyse(self, base_type, env, nonempty = 0):
if base_type.is_cpp_class: if base_type.is_cpp_class or base_type.is_cfunction:
from ExprNodes import TupleNode from ExprNodes import TupleNode
if isinstance(self.dimension, TupleNode): if isinstance(self.dimension, TupleNode):
args = self.dimension.args args = self.dimension.args
...@@ -565,6 +568,7 @@ class CArrayDeclaratorNode(CDeclaratorNode): ...@@ -565,6 +568,7 @@ class CArrayDeclaratorNode(CDeclaratorNode):
class CFuncDeclaratorNode(CDeclaratorNode): class CFuncDeclaratorNode(CDeclaratorNode):
# base CDeclaratorNode # base CDeclaratorNode
# args [CArgDeclNode] # args [CArgDeclNode]
# templates [TemplatePlaceholderType]
# has_varargs boolean # has_varargs boolean
# exception_value ConstNode # exception_value ConstNode
# exception_check boolean True if PyErr_Occurred check needed # exception_check boolean True if PyErr_Occurred check needed
...@@ -575,6 +579,28 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -575,6 +579,28 @@ class CFuncDeclaratorNode(CDeclaratorNode):
overridable = 0 overridable = 0
optional_arg_count = 0 optional_arg_count = 0
templates = None
def analyse_templates(self):
if isinstance(self.base, CArrayDeclaratorNode):
from ExprNodes import TupleNode, NameNode
template_node = self.base.dimension
if isinstance(template_node, TupleNode):
template_nodes = template_node.args
elif isinstance(template_node, NameNode):
template_nodes = [template_node]
else:
error(template_node.pos, "Template arguments must be a list of names")
self.templates = []
for template in template_nodes:
if isinstance(template, NameNode):
self.templates.append(PyrexTypes.TemplatePlaceholderType(template.name))
else:
error(template.pos, "Template arguments must be a list of names")
self.base = self.base.base
return self.templates
else:
return None
def analyse(self, return_type, env, nonempty = 0, directive_locals = {}): def analyse(self, return_type, env, nonempty = 0, directive_locals = {}):
if nonempty: if nonempty:
...@@ -659,7 +685,8 @@ class CFuncDeclaratorNode(CDeclaratorNode): ...@@ -659,7 +685,8 @@ class CFuncDeclaratorNode(CDeclaratorNode):
optional_arg_count = self.optional_arg_count, optional_arg_count = self.optional_arg_count,
exception_value = exc_val, exception_check = exc_check, exception_value = exc_val, exception_check = exc_check,
calling_convention = self.base.calling_convention, calling_convention = self.base.calling_convention,
nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable,
templates = self.templates)
if self.optional_arg_count: if self.optional_arg_count:
if func_type.is_fused: if func_type.is_fused:
...@@ -1164,6 +1191,21 @@ class CVarDefNode(StatNode): ...@@ -1164,6 +1191,21 @@ class CVarDefNode(StatNode):
if not dest_scope: if not dest_scope:
dest_scope = env dest_scope = env
self.dest_scope = dest_scope self.dest_scope = dest_scope
if self.declarators:
templates = self.declarators[0].analyse_templates()
else:
templates = None
if templates is not None:
if self.visibility != 'extern':
error(self.pos, "Only extern functions allowed")
if len(self.declarators) > 1:
error(self.declarators[1].pos, "Can't multiply declare template types")
env = TemplateScope('func_template', env)
env.directives = env.outer_scope.directives
for template_param in templates:
env.declare_type(template_param.name, template_param, self.pos)
base_type = self.base_type.analyse(env) base_type = self.base_type.analyse(env)
if base_type.is_fused and not self.in_pxd and (env.is_c_class_scope or if base_type.is_fused and not self.in_pxd and (env.is_c_class_scope or
......
...@@ -2525,11 +2525,6 @@ class CFuncType(CType): ...@@ -2525,11 +2525,6 @@ class CFuncType(CType):
return '(%s)' % s return '(%s)' % s
def specialize(self, values): def specialize(self, values):
if self.templates is None:
new_templates = None
else:
new_templates = [v.specialize(values) for v in self.templates]
result = CFuncType(self.return_type.specialize(values), result = CFuncType(self.return_type.specialize(values),
[arg.specialize(values) for arg in self.args], [arg.specialize(values) for arg in self.args],
has_varargs = self.has_varargs, has_varargs = self.has_varargs,
...@@ -2540,7 +2535,7 @@ class CFuncType(CType): ...@@ -2540,7 +2535,7 @@ class CFuncType(CType):
with_gil = self.with_gil, with_gil = self.with_gil,
is_overridable = self.is_overridable, is_overridable = self.is_overridable,
optional_arg_count = self.optional_arg_count, optional_arg_count = self.optional_arg_count,
templates = new_templates) templates = self.templates)
result.from_fused = self.is_fused result.from_fused = self.is_fused
return result return result
......
...@@ -2237,3 +2237,8 @@ class CConstScope(Scope): ...@@ -2237,3 +2237,8 @@ class CConstScope(Scope):
entry = copy.copy(entry) entry = copy.copy(entry)
entry.type = PyrexTypes.c_const_type(entry.type) entry.type = PyrexTypes.c_const_type(entry.type)
return entry return entry
class TemplateScope(Scope):
def __init__(self, name, outer_scope):
Scope.__init__(self, name, outer_scope, None)
self.directives = outer_scope.directives
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment