Commit 767fce81 authored by Robert Bradshaw's avatar Robert Bradshaw

Allow default template parameters for C++ classes.

This is done with the T=* syntax, similar to default args
of cdef methods.

This does, however, expose us to types that we can't
explicitly declare.
parent a57ff372
...@@ -28,6 +28,8 @@ Features added ...@@ -28,6 +28,8 @@ Features added
* Exception type tests have slightly lower overhead. * Exception type tests have slightly lower overhead.
This fixes ticket 868. This fixes ticket 868.
* C++ classes can now be declared with default template parameters.
Bugs fixed Bugs fixed
---------- ----------
......
...@@ -888,7 +888,8 @@ class ExprNode(Node): ...@@ -888,7 +888,8 @@ class ExprNode(Node):
# Added the string comparison, since for c types that # Added the string comparison, since for c types that
# is enough, but Cython gets confused when the types are # is enough, but Cython gets confused when the types are
# in different pxi files. # in different pxi files.
if not (str(src.type) == str(dst_type) or dst_type.assignable_from(src_type)): # TODO: Remove this hack and require shared declarations.
if not (src.type == dst_type or str(src.type) == str(dst_type) or dst_type.assignable_from(src_type)):
self.fail_assignment(dst_type) self.fail_assignment(dst_type)
return src return src
......
...@@ -1403,7 +1403,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1403,7 +1403,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
# attributes [CVarDefNode] or None # attributes [CVarDefNode] or None
# entry Entry # entry Entry
# base_classes [CBaseTypeNode] # base_classes [CBaseTypeNode]
# templates [string] or None # templates [(string, bool)] or None
# decorators [DecoratorNode] or None # decorators [DecoratorNode] or None
decorators = None decorators = None
...@@ -1412,25 +1412,31 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1412,25 +1412,31 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
if self.templates is None: if self.templates is None:
template_types = None template_types = None
else: else:
template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates] template_types = [PyrexTypes.TemplatePlaceholderType(template_name, not required)
for template_name, required in self.templates]
num_optional_templates = sum(not required for _, required in self.templates)
if num_optional_templates and not all(required for _, required in self.templates[:-num_optional_templates]):
error(self.pos, "Required template parameters must precede optional template parameters.")
self.entry = env.declare_cpp_class( self.entry = env.declare_cpp_class(
self.name, None, self.pos, self.name, None, self.pos,
self.cname, base_classes = [], visibility = self.visibility, templates = template_types) self.cname, base_classes = [], visibility = self.visibility, templates = template_types)
def analyse_declarations(self, env): def analyse_declarations(self, env):
if self.templates is None:
template_types = template_names = None
else:
template_names = [template_name for template_name, _ in self.templates]
template_types = [PyrexTypes.TemplatePlaceholderType(template_name, not required)
for template_name, required in self.templates]
scope = None scope = None
if self.attributes is not None: if self.attributes is not None:
scope = CppClassScope(self.name, env, templates = self.templates) scope = CppClassScope(self.name, env, templates = template_names)
def base_ok(base_class): def base_ok(base_class):
if base_class.is_cpp_class or base_class.is_struct: if base_class.is_cpp_class or base_class.is_struct:
return True return True
else: else:
error(self.pos, "Base class '%s' not a struct or class." % base_class) error(self.pos, "Base class '%s' not a struct or class." % base_class)
base_class_types = filter(base_ok, [b.analyse(scope or env) for b in self.base_classes]) base_class_types = filter(base_ok, [b.analyse(scope or env) for b in self.base_classes])
if self.templates is None:
template_types = None
else:
template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates]
self.entry = env.declare_cpp_class( self.entry = env.declare_cpp_class(
self.name, scope, self.pos, self.name, scope, self.pos,
self.cname, base_class_types, visibility = self.visibility, templates = template_types) self.cname, base_class_types, visibility = self.visibility, templates = template_types)
...@@ -1455,7 +1461,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1455,7 +1461,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
for func in func_attributes(self.attributes): for func in func_attributes(self.attributes):
defined_funcs.append(func) defined_funcs.append(func)
if self.templates is not None: if self.templates is not None:
func.template_declaration = "template <typename %s>" % ", typename ".join(self.templates) func.template_declaration = "template <typename %s>" % ", typename ".join(template_names)
self.body = StatListNode(self.pos, stats=defined_funcs) self.body = StatListNode(self.pos, stats=defined_funcs)
self.scope = scope self.scope = scope
......
...@@ -3363,6 +3363,16 @@ def p_module(s, pxd, full_module_name, ctx=Ctx): ...@@ -3363,6 +3363,16 @@ def p_module(s, pxd, full_module_name, ctx=Ctx):
full_module_name = full_module_name, full_module_name = full_module_name,
directive_comments = directive_comments) directive_comments = directive_comments)
def p_template_definition(s):
name = p_ident(s)
if s.sy == '=':
s.expect('=')
s.expect('*')
required = False
else:
required = True
return name, required
def p_cpp_class_definition(s, pos, ctx): def p_cpp_class_definition(s, pos, ctx):
# s.sy == 'cppclass' # s.sy == 'cppclass'
s.next() s.next()
...@@ -3375,19 +3385,21 @@ def p_cpp_class_definition(s, pos, ctx): ...@@ -3375,19 +3385,21 @@ def p_cpp_class_definition(s, pos, ctx):
error(pos, "Qualified class name not allowed C++ class") error(pos, "Qualified class name not allowed C++ class")
if s.sy == '[': if s.sy == '[':
s.next() s.next()
templates = [p_ident(s)] templates = [p_template_definition(s)]
while s.sy == ',': while s.sy == ',':
s.next() s.next()
templates.append(p_ident(s)) templates.append(p_template_definition(s))
s.expect(']') s.expect(']')
template_names = [name for name, required in templates]
else: else:
templates = None templates = None
template_names = None
if s.sy == '(': if s.sy == '(':
s.next() s.next()
base_classes = [p_c_base_type(s, templates = templates)] base_classes = [p_c_base_type(s, templates = template_names)]
while s.sy == ',': while s.sy == ',':
s.next() s.next()
base_classes.append(p_c_base_type(s, templates = templates)) base_classes.append(p_c_base_type(s, templates = template_names))
s.expect(')') s.expect(')')
else: else:
base_classes = [] base_classes = []
...@@ -3400,7 +3412,7 @@ def p_cpp_class_definition(s, pos, ctx): ...@@ -3400,7 +3412,7 @@ def p_cpp_class_definition(s, pos, ctx):
s.expect_indent() s.expect_indent()
attributes = [] attributes = []
body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil) body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil)
body_ctx.templates = templates body_ctx.templates = template_names
while s.sy != 'DEDENT': while s.sy != 'DEDENT':
if s.sy != 'pass': if s.sy != 'pass':
attributes.append(p_cpp_class_attribute(s, body_ctx)) attributes.append(p_cpp_class_attribute(s, body_ctx))
......
...@@ -3398,7 +3398,6 @@ builtin_cpp_conversions = ("std::pair", ...@@ -3398,7 +3398,6 @@ builtin_cpp_conversions = ("std::pair",
"std::set", "std::unordered_set", "std::set", "std::unordered_set",
"std::map", "std::unordered_map") "std::map", "std::unordered_map")
class CppClassType(CType): class CppClassType(CType):
# name string # name string
# cname string # cname string
...@@ -3425,6 +3424,7 @@ class CppClassType(CType): ...@@ -3425,6 +3424,7 @@ class CppClassType(CType):
self.operators = [] self.operators = []
self.templates = templates self.templates = templates
self.template_type = template_type self.template_type = template_type
self.num_optional_templates = sum(is_optional_template_param(T) for T in templates or ())
self.specializations = {} self.specializations = {}
self.is_cpp_string = cname in cpp_string_conversions self.is_cpp_string = cname in cpp_string_conversions
...@@ -3554,6 +3554,13 @@ class CppClassType(CType): ...@@ -3554,6 +3554,13 @@ class CppClassType(CType):
if not self.is_template_type(): if not self.is_template_type():
error(pos, "'%s' type is not a template" % self) error(pos, "'%s' type is not a template" % self)
return error_type return error_type
if len(self.templates) - self.num_optional_templates <= len(template_values) < len(self.templates):
partial_specialization = self.declaration_code('', template_params=template_values)
template_values = template_values + [
TemplatePlaceholderType("%s %s::%s" % (
TemplatePlaceholderType.UNDECLARABLE_DEFAULT, partial_specialization, param.name),
True)
for param in self.templates[-self.num_optional_templates:]]
if len(self.templates) != len(template_values): if len(self.templates) != len(template_values):
error(pos, "%s templated type receives %d arguments, got %d" % error(pos, "%s templated type receives %d arguments, got %d" %
(self.name, len(self.templates), len(template_values))) (self.name, len(self.templates), len(template_values)))
...@@ -3601,10 +3608,14 @@ class CppClassType(CType): ...@@ -3601,10 +3608,14 @@ class CppClassType(CType):
return None return None
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0,
template_params = None):
if template_params is None:
template_params = self.templates
if self.templates: if self.templates:
template_strings = [param.declaration_code('', for_display, None, pyrex) template_strings = [param.declaration_code('', for_display, None, pyrex)
for param in self.templates] for param in template_params
if not is_optional_template_param(param)]
if for_display: if for_display:
brackets = "[%s]" brackets = "[%s]"
else: else:
...@@ -3673,11 +3684,17 @@ class CppClassType(CType): ...@@ -3673,11 +3684,17 @@ class CppClassType(CType):
class TemplatePlaceholderType(CType): class TemplatePlaceholderType(CType):
def __init__(self, name): UNDECLARABLE_DEFAULT = "undeclarable default "
def __init__(self, name, optional=False):
self.name = name self.name = name
self.optional = optional
def declaration_code(self, entity_code, def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0): for_display = 0, dll_linkage = None, pyrex = 0):
if self.name.startswith(self.UNDECLARABLE_DEFAULT) and not for_display:
error(None, "Can't declare variable of type '%s'"
% self.name[len(self.UNDECLARABLE_DEFAULT) + 1:])
if entity_code: if entity_code:
return self.name + " " + entity_code return self.name + " " + entity_code
else: else:
...@@ -3713,6 +3730,9 @@ class TemplatePlaceholderType(CType): ...@@ -3713,6 +3730,9 @@ class TemplatePlaceholderType(CType):
else: else:
return False return False
def is_optional_template_param(type):
return isinstance(type, TemplatePlaceholderType) and type.optional
class CEnumType(CType): class CEnumType(CType):
# name string # name string
......
...@@ -398,7 +398,7 @@ class SimpleAssignmentTypeInferer(object): ...@@ -398,7 +398,7 @@ class SimpleAssignmentTypeInferer(object):
else: else:
entry = node.entry entry = node.entry
node_type = spanning_type( node_type = spanning_type(
types, entry.might_overflow, entry.pos) types, entry.might_overflow, entry.pos, scope)
node.inferred_type = node_type node.inferred_type = node_type
def infer_name_node_type_partial(node): def infer_name_node_type_partial(node):
...@@ -407,7 +407,7 @@ class SimpleAssignmentTypeInferer(object): ...@@ -407,7 +407,7 @@ class SimpleAssignmentTypeInferer(object):
if not types: if not types:
return return
entry = node.entry entry = node.entry
return spanning_type(types, entry.might_overflow, entry.pos) return spanning_type(types, entry.might_overflow, entry.pos, scope)
def resolve_assignments(assignments): def resolve_assignments(assignments):
resolved = set() resolved = set()
...@@ -464,7 +464,7 @@ class SimpleAssignmentTypeInferer(object): ...@@ -464,7 +464,7 @@ class SimpleAssignmentTypeInferer(object):
types = [assmt.inferred_type for assmt in entry.cf_assignments] types = [assmt.inferred_type for assmt in entry.cf_assignments]
if types and all(types): if types and all(types):
entry_type = spanning_type( entry_type = spanning_type(
types, entry.might_overflow, entry.pos) types, entry.might_overflow, entry.pos, scope)
inferred.add(entry) inferred.add(entry)
self.set_entry_type(entry, entry_type) self.set_entry_type(entry, entry_type)
...@@ -473,7 +473,7 @@ class SimpleAssignmentTypeInferer(object): ...@@ -473,7 +473,7 @@ class SimpleAssignmentTypeInferer(object):
for entry in inferred: for entry in inferred:
types = [assmt.infer_type() types = [assmt.infer_type()
for assmt in entry.cf_assignments] for assmt in entry.cf_assignments]
new_type = spanning_type(types, entry.might_overflow, entry.pos) new_type = spanning_type(types, entry.might_overflow, entry.pos, scope)
if new_type != entry.type: if new_type != entry.type:
self.set_entry_type(entry, new_type) self.set_entry_type(entry, new_type)
dirty = True dirty = True
...@@ -516,10 +516,10 @@ def simply_type(result_type, pos): ...@@ -516,10 +516,10 @@ def simply_type(result_type, pos):
result_type = PyrexTypes.c_ptr_type(result_type.base_type) result_type = PyrexTypes.c_ptr_type(result_type.base_type)
return result_type return result_type
def aggressive_spanning_type(types, might_overflow, pos): def aggressive_spanning_type(types, might_overflow, pos, scope):
return simply_type(reduce(find_spanning_type, types), pos) return simply_type(reduce(find_spanning_type, types), pos)
def safe_spanning_type(types, might_overflow, pos): def safe_spanning_type(types, might_overflow, pos, scope):
result_type = simply_type(reduce(find_spanning_type, types), pos) result_type = simply_type(reduce(find_spanning_type, types), pos)
if result_type.is_pyobject: if result_type.is_pyobject:
# In theory, any specific Python type is always safe to # In theory, any specific Python type is always safe to
...@@ -554,6 +554,8 @@ def safe_spanning_type(types, might_overflow, pos): ...@@ -554,6 +554,8 @@ def safe_spanning_type(types, might_overflow, pos):
# to make sure everything is supported. # to make sure everything is supported.
elif (result_type.is_int or result_type.is_enum) and not might_overflow: elif (result_type.is_int or result_type.is_enum) and not might_overflow:
return result_type return result_type
elif not result_type.can_coerce_to_pyobject(scope):
return result_type
return py_object_type return py_object_type
......
...@@ -3,12 +3,15 @@ ...@@ -3,12 +3,15 @@
from cython.operator import dereference as deref from cython.operator import dereference as deref
cdef extern from "cpp_templates_helper.h": cdef extern from "cpp_templates_helper.h":
cdef cppclass Wrap[T]: cdef cppclass Wrap[T, S=*]:
Wrap(T) Wrap(T)
void set(T) void set(T)
T get() T get()
bint operator==(Wrap[T]) bint operator==(Wrap[T])
S get_alt_type()
void set_alt_type(S)
cdef cppclass Pair[T1,T2]: cdef cppclass Pair[T1,T2]:
Pair(T1,T2) Pair(T1,T2)
T1 first() T1 first()
...@@ -57,6 +60,29 @@ def test_double(double x, double y): ...@@ -57,6 +60,29 @@ def test_double(double x, double y):
finally: finally:
del a, b del a, b
def test_default_template_arguments(double x):
"""
>>> test_default_template_arguments(3.5)
(3.5, 3.0)
"""
try:
a = new Wrap[double](x)
b = new Wrap[double, int](x)
# ax = a.get_alt_type()
# a.set_alt_type(ax)
a.set_alt_type(a.get_alt_type())
# bx = b.get_alt_type()
# b.set_alt_type(bx)
b.set_alt_type(b.get_alt_type())
return a.get(), b.get()
finally:
del a
def test_pair(int i, double x): def test_pair(int i, double x):
""" """
>>> test_pair(1, 1.5) >>> test_pair(1, 1.5)
......
template <class T> template <typename T, typename S=T>
class Wrap { class Wrap {
T value; T value;
public: public:
...@@ -6,6 +6,9 @@ public: ...@@ -6,6 +6,9 @@ public:
void set(T v) { value = v; } void set(T v) { value = v; }
T get(void) { return value; } T get(void) { return value; }
bool operator==(Wrap<T> other) { return value == other.value; } bool operator==(Wrap<T> other) { return value == other.value; }
S get_alt_type(void) { return (S) value; }
void set_alt_type(S v) { value = (T) v; }
}; };
template <class T1, class T2> template <class T1, class T2>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment