Commit 767fce81 authored by Robert Bradshaw's avatar Robert Bradshaw

Allow default template parameters for C++ classes.

This is done with the T=* syntax, similar to default args
of cdef methods.

This does, however, expose us to types that we can't
explicitly declare.
parent a57ff372
......@@ -28,6 +28,8 @@ Features added
* Exception type tests have slightly lower overhead.
This fixes ticket 868.
* C++ classes can now be declared with default template parameters.
Bugs fixed
----------
......
......@@ -888,7 +888,8 @@ class ExprNode(Node):
# Added the string comparison, since for c types that
# is enough, but Cython gets confused when the types are
# in different pxi files.
if not (str(src.type) == str(dst_type) or dst_type.assignable_from(src_type)):
# TODO: Remove this hack and require shared declarations.
if not (src.type == dst_type or str(src.type) == str(dst_type) or dst_type.assignable_from(src_type)):
self.fail_assignment(dst_type)
return src
......
......@@ -1403,7 +1403,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
# attributes [CVarDefNode] or None
# entry Entry
# base_classes [CBaseTypeNode]
# templates [string] or None
# templates [(string, bool)] or None
# decorators [DecoratorNode] or None
decorators = None
......@@ -1412,25 +1412,31 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
if self.templates is None:
template_types = None
else:
template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates]
template_types = [PyrexTypes.TemplatePlaceholderType(template_name, not required)
for template_name, required in self.templates]
num_optional_templates = sum(not required for _, required in self.templates)
if num_optional_templates and not all(required for _, required in self.templates[:-num_optional_templates]):
error(self.pos, "Required template parameters must precede optional template parameters.")
self.entry = env.declare_cpp_class(
self.name, None, self.pos,
self.cname, base_classes = [], visibility = self.visibility, templates = template_types)
def analyse_declarations(self, env):
if self.templates is None:
template_types = template_names = None
else:
template_names = [template_name for template_name, _ in self.templates]
template_types = [PyrexTypes.TemplatePlaceholderType(template_name, not required)
for template_name, required in self.templates]
scope = None
if self.attributes is not None:
scope = CppClassScope(self.name, env, templates = self.templates)
scope = CppClassScope(self.name, env, templates = template_names)
def base_ok(base_class):
if base_class.is_cpp_class or base_class.is_struct:
return True
else:
error(self.pos, "Base class '%s' not a struct or class." % base_class)
base_class_types = filter(base_ok, [b.analyse(scope or env) for b in self.base_classes])
if self.templates is None:
template_types = None
else:
template_types = [PyrexTypes.TemplatePlaceholderType(template_name) for template_name in self.templates]
self.entry = env.declare_cpp_class(
self.name, scope, self.pos,
self.cname, base_class_types, visibility = self.visibility, templates = template_types)
......@@ -1455,7 +1461,7 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
for func in func_attributes(self.attributes):
defined_funcs.append(func)
if self.templates is not None:
func.template_declaration = "template <typename %s>" % ", typename ".join(self.templates)
func.template_declaration = "template <typename %s>" % ", typename ".join(template_names)
self.body = StatListNode(self.pos, stats=defined_funcs)
self.scope = scope
......
......@@ -3363,6 +3363,16 @@ def p_module(s, pxd, full_module_name, ctx=Ctx):
full_module_name = full_module_name,
directive_comments = directive_comments)
def p_template_definition(s):
name = p_ident(s)
if s.sy == '=':
s.expect('=')
s.expect('*')
required = False
else:
required = True
return name, required
def p_cpp_class_definition(s, pos, ctx):
# s.sy == 'cppclass'
s.next()
......@@ -3375,19 +3385,21 @@ def p_cpp_class_definition(s, pos, ctx):
error(pos, "Qualified class name not allowed C++ class")
if s.sy == '[':
s.next()
templates = [p_ident(s)]
templates = [p_template_definition(s)]
while s.sy == ',':
s.next()
templates.append(p_ident(s))
templates.append(p_template_definition(s))
s.expect(']')
template_names = [name for name, required in templates]
else:
templates = None
template_names = None
if s.sy == '(':
s.next()
base_classes = [p_c_base_type(s, templates = templates)]
base_classes = [p_c_base_type(s, templates = template_names)]
while s.sy == ',':
s.next()
base_classes.append(p_c_base_type(s, templates = templates))
base_classes.append(p_c_base_type(s, templates = template_names))
s.expect(')')
else:
base_classes = []
......@@ -3400,7 +3412,7 @@ def p_cpp_class_definition(s, pos, ctx):
s.expect_indent()
attributes = []
body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil)
body_ctx.templates = templates
body_ctx.templates = template_names
while s.sy != 'DEDENT':
if s.sy != 'pass':
attributes.append(p_cpp_class_attribute(s, body_ctx))
......
......@@ -3398,7 +3398,6 @@ builtin_cpp_conversions = ("std::pair",
"std::set", "std::unordered_set",
"std::map", "std::unordered_map")
class CppClassType(CType):
# name string
# cname string
......@@ -3425,6 +3424,7 @@ class CppClassType(CType):
self.operators = []
self.templates = templates
self.template_type = template_type
self.num_optional_templates = sum(is_optional_template_param(T) for T in templates or ())
self.specializations = {}
self.is_cpp_string = cname in cpp_string_conversions
......@@ -3554,6 +3554,13 @@ class CppClassType(CType):
if not self.is_template_type():
error(pos, "'%s' type is not a template" % self)
return error_type
if len(self.templates) - self.num_optional_templates <= len(template_values) < len(self.templates):
partial_specialization = self.declaration_code('', template_params=template_values)
template_values = template_values + [
TemplatePlaceholderType("%s %s::%s" % (
TemplatePlaceholderType.UNDECLARABLE_DEFAULT, partial_specialization, param.name),
True)
for param in self.templates[-self.num_optional_templates:]]
if len(self.templates) != len(template_values):
error(pos, "%s templated type receives %d arguments, got %d" %
(self.name, len(self.templates), len(template_values)))
......@@ -3601,10 +3608,14 @@ class CppClassType(CType):
return None
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
for_display = 0, dll_linkage = None, pyrex = 0,
template_params = None):
if template_params is None:
template_params = self.templates
if self.templates:
template_strings = [param.declaration_code('', for_display, None, pyrex)
for param in self.templates]
for param in template_params
if not is_optional_template_param(param)]
if for_display:
brackets = "[%s]"
else:
......@@ -3673,11 +3684,17 @@ class CppClassType(CType):
class TemplatePlaceholderType(CType):
def __init__(self, name):
UNDECLARABLE_DEFAULT = "undeclarable default "
def __init__(self, name, optional=False):
self.name = name
self.optional = optional
def declaration_code(self, entity_code,
for_display = 0, dll_linkage = None, pyrex = 0):
if self.name.startswith(self.UNDECLARABLE_DEFAULT) and not for_display:
error(None, "Can't declare variable of type '%s'"
% self.name[len(self.UNDECLARABLE_DEFAULT) + 1:])
if entity_code:
return self.name + " " + entity_code
else:
......@@ -3713,6 +3730,9 @@ class TemplatePlaceholderType(CType):
else:
return False
def is_optional_template_param(type):
return isinstance(type, TemplatePlaceholderType) and type.optional
class CEnumType(CType):
# name string
......
......@@ -398,7 +398,7 @@ class SimpleAssignmentTypeInferer(object):
else:
entry = node.entry
node_type = spanning_type(
types, entry.might_overflow, entry.pos)
types, entry.might_overflow, entry.pos, scope)
node.inferred_type = node_type
def infer_name_node_type_partial(node):
......@@ -407,7 +407,7 @@ class SimpleAssignmentTypeInferer(object):
if not types:
return
entry = node.entry
return spanning_type(types, entry.might_overflow, entry.pos)
return spanning_type(types, entry.might_overflow, entry.pos, scope)
def resolve_assignments(assignments):
resolved = set()
......@@ -464,7 +464,7 @@ class SimpleAssignmentTypeInferer(object):
types = [assmt.inferred_type for assmt in entry.cf_assignments]
if types and all(types):
entry_type = spanning_type(
types, entry.might_overflow, entry.pos)
types, entry.might_overflow, entry.pos, scope)
inferred.add(entry)
self.set_entry_type(entry, entry_type)
......@@ -473,7 +473,7 @@ class SimpleAssignmentTypeInferer(object):
for entry in inferred:
types = [assmt.infer_type()
for assmt in entry.cf_assignments]
new_type = spanning_type(types, entry.might_overflow, entry.pos)
new_type = spanning_type(types, entry.might_overflow, entry.pos, scope)
if new_type != entry.type:
self.set_entry_type(entry, new_type)
dirty = True
......@@ -516,10 +516,10 @@ def simply_type(result_type, pos):
result_type = PyrexTypes.c_ptr_type(result_type.base_type)
return result_type
def aggressive_spanning_type(types, might_overflow, pos):
def aggressive_spanning_type(types, might_overflow, pos, scope):
return simply_type(reduce(find_spanning_type, types), pos)
def safe_spanning_type(types, might_overflow, pos):
def safe_spanning_type(types, might_overflow, pos, scope):
result_type = simply_type(reduce(find_spanning_type, types), pos)
if result_type.is_pyobject:
# In theory, any specific Python type is always safe to
......@@ -554,6 +554,8 @@ def safe_spanning_type(types, might_overflow, pos):
# to make sure everything is supported.
elif (result_type.is_int or result_type.is_enum) and not might_overflow:
return result_type
elif not result_type.can_coerce_to_pyobject(scope):
return result_type
return py_object_type
......
......@@ -3,12 +3,15 @@
from cython.operator import dereference as deref
cdef extern from "cpp_templates_helper.h":
cdef cppclass Wrap[T]:
cdef cppclass Wrap[T, S=*]:
Wrap(T)
void set(T)
T get()
bint operator==(Wrap[T])
S get_alt_type()
void set_alt_type(S)
cdef cppclass Pair[T1,T2]:
Pair(T1,T2)
T1 first()
......@@ -57,6 +60,29 @@ def test_double(double x, double y):
finally:
del a, b
def test_default_template_arguments(double x):
"""
>>> test_default_template_arguments(3.5)
(3.5, 3.0)
"""
try:
a = new Wrap[double](x)
b = new Wrap[double, int](x)
# ax = a.get_alt_type()
# a.set_alt_type(ax)
a.set_alt_type(a.get_alt_type())
# bx = b.get_alt_type()
# b.set_alt_type(bx)
b.set_alt_type(b.get_alt_type())
return a.get(), b.get()
finally:
del a
def test_pair(int i, double x):
"""
>>> test_pair(1, 1.5)
......
template <class T>
template <typename T, typename S=T>
class Wrap {
T value;
public:
......@@ -6,6 +6,9 @@ public:
void set(T v) { value = v; }
T get(void) { return value; }
bool operator==(Wrap<T> other) { return value == other.value; }
S get_alt_type(void) { return (S) value; }
void set_alt_type(S v) { value = (T) v; }
};
template <class T1, class T2>
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment