Commit fe73aba5 authored by Robert Bradshaw's avatar Robert Bradshaw

Support @staticmethod decorator for C++ classes.

parent 49d90c94
......@@ -5310,10 +5310,10 @@ class AttributeNode(ExprNode):
# C method of an extension type or builtin type. If successful,
# creates a corresponding NameNode and returns it, otherwise
# returns None.
type = self.obj.analyse_as_extension_type(env)
if type:
type = self.obj.analyse_as_type(env)
if type and (type.is_extension_type or type.is_builtin_type or type.is_cpp_class):
entry = type.scope.lookup_here(self.attribute)
if entry and entry.is_cmethod:
if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction):
if type.is_builtin_type:
if not self.is_called:
# must handle this as Python object
......@@ -5326,6 +5326,9 @@ class AttributeNode(ExprNode):
cname = entry.func_cname
if entry.type.is_static_method:
ctype = entry.type
elif type.is_cpp_class:
error(self.pos, "%s not a static member of %s" % (entry.name, type))
ctype = PyrexTypes.error_type
else:
# Fix self type.
ctype = copy.copy(entry.type)
......
......@@ -826,7 +826,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_virtual_methods = False
has_destructor = False
for attr in scope.var_entries:
if attr.type.is_cfunction and attr.name != "<init>":
if attr.type.is_cfunction and attr.type.is_static_method:
code.put("static ")
elif attr.type.is_cfunction and attr.name != "<init>":
code.put("virtual ")
has_virtual_methods = True
if attr.cname[0] == '~':
......
......@@ -1294,6 +1294,8 @@ class CVarDefNode(StatNode):
if self.entry is not None:
self.entry.is_overridable = self.overridable
self.entry.directive_locals = copy.copy(self.directive_locals)
if 'staticmethod' in env.directives:
type.is_static_method = True
else:
if self.directive_locals:
error(self.pos, "Decorators can only be followed by functions")
......@@ -1361,6 +1363,9 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
# entry Entry
# base_classes [CBaseTypeNode]
# templates [string] or None
# decorators [DecoratorNode] or None
decorators = None
def declare(self, env):
if self.templates is None:
......@@ -1394,15 +1399,22 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
if scope is not None:
scope.type = self.entry.type
defined_funcs = []
def func_attributes(attributes):
for attr in attributes:
if isinstance(attr, CFuncDefNode):
yield attr
elif isinstance(attr, CompilerDirectivesNode):
for sub_attr in func_attributes(attr.body.stats):
yield sub_attr
if self.attributes is not None:
if self.in_pxd and not env.in_cinclude:
self.entry.defined_in_pxd = 1
for attr in self.attributes:
attr.analyse_declarations(scope)
if isinstance(attr, CFuncDefNode):
defined_funcs.append(attr)
if self.templates is not None:
attr.template_declaration = "template <typename %s>" % ", typename ".join(self.templates)
for func in func_attributes(self.attributes):
defined_funcs.append(func)
if self.templates is not None:
func.template_declaration = "template <typename %s>" % ", typename ".join(self.templates)
self.body = StatListNode(self.pos, stats=defined_funcs)
self.scope = scope
......
......@@ -951,11 +951,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
for name, value in directives.iteritems():
if name == 'locals':
node.directive_locals = value
elif name != 'final':
elif name not in ('final', 'staticmethod'):
self.context.nonfatal_error(PostParseError(
node.pos,
"Cdef functions can only take cython.locals() "
"or final decorators, got %s." % name))
"Cdef functions can only take cython.locals(), "
"staticmethod, or final decorators, got %s." % name))
body = Nodes.StatListNode(node.pos, stats=[node])
return self.visit_with_directives(body, directives)
......@@ -966,6 +966,13 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
body = Nodes.StatListNode(node.pos, stats=[node])
return self.visit_with_directives(body, directives)
def visit_CppClassNode(self, node):
directives = self._extract_directives(node, 'cppclass')
if not directives:
return self.visit_Node(node)
body = Nodes.StatListNode(node.pos, stats=[node])
return self.visit_with_directives(body, directives)
def visit_PyClassDefNode(self, node):
directives = self._extract_directives(node, 'class')
if not directives:
......
......@@ -187,3 +187,4 @@ cdef p_doc_string(PyrexScanner s)
cdef p_ignorable_statement(PyrexScanner s)
cdef p_compiler_directive_comments(PyrexScanner s)
cdef p_cpp_class_definition(PyrexScanner s, pos, ctx)
def p_cpp_class_attribute(PyrexScanner s, ctx):
......@@ -3231,12 +3231,8 @@ def p_cpp_class_definition(s, pos, ctx):
body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil)
body_ctx.templates = templates
while s.sy != 'DEDENT':
if s.systring == 'cppclass':
attributes.append(
p_cpp_class_definition(s, s.position(), body_ctx))
elif s.sy != 'pass':
attributes.append(
p_c_func_or_var_declaration(s, s.position(), body_ctx))
if s.sy != 'pass':
attributes.append(p_cpp_class_attribute(s, body_ctx))
else:
s.next()
s.expect_newline("Expected a newline")
......@@ -3253,6 +3249,23 @@ def p_cpp_class_definition(s, pos, ctx):
attributes = attributes,
templates = templates)
def p_cpp_class_attribute(s, ctx):
decorators = None
if s.sy == '@':
decorators = p_decorators(s)
if s.systring == 'cppclass':
return p_cpp_class_definition(s, s.position(), ctx)
else:
node = p_c_func_or_var_declaration(s, s.position(), ctx)
if decorators is not None:
tup = Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode
if ctx.allow_struct_enum_decorator:
tup += Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode
if not isinstance(node, tup):
s.error("Decorators can only be followed by functions or classes")
node.decorators = decorators
return node
#----------------------------------------------
#
......
......@@ -3110,7 +3110,7 @@ class CppClassType(CType):
# Need to do these *after* self.specializations[key] is set
# to avoid infinite recursion on circular references.
specialized.base_classes = [b.specialize(values) for b in self.base_classes]
specialized.scope = self.scope.specialize(values)
specialized.scope = self.scope.specialize(values, specialized)
if self.namespace is not None:
specialized.namespace = self.namespace.specialize(values)
return specialized
......
......@@ -2190,8 +2190,9 @@ class CppClassScope(Scope):
utility_code = base_entry.utility_code)
entry.is_inherited = 1
def specialize(self, values):
def specialize(self, values, type_entry):
scope = CppClassScope(self.name, self.outer_scope)
scope.type = type_entry
for entry in self.entries.values():
if entry.is_type:
scope.declare_type(entry.name,
......
......@@ -42,6 +42,21 @@ def test_Poly(int n, float radius=1):
del poly
cdef cppclass WithStatic:
@staticmethod
double square(double x):
return x * x
def test_Static(x):
"""
>>> test_Static(2)
4.0
>>> test_Static(0.5)
0.25
"""
return WithStatic.square(x)
cdef cppclass InitDealloc:
__init__():
print "Init"
......
......@@ -22,6 +22,10 @@ cdef extern from "cpp_templates_helper.h":
cdef cppclass SubClass[T2, T3](SuperClass[T2, T3]):
pass
cdef cppclass Div[T]:
@staticmethod
T half(T value)
def test_int(int x, int y):
"""
>>> test_int(3, 4)
......@@ -104,3 +108,12 @@ def test_cast_template_pointer():
sup = sub
sup = <SubClass[int, float] *> sub
def test_static(x):
"""
>>> test_static(2)
(1, 1.0)
>>> test_static(3)
(1, 1.5)
"""
return Div[int].half(x), Div[double].half(x)
......@@ -30,3 +30,9 @@ public:
template <class T2, class T3>
class SubClass : public SuperClass<T2, T3> {
};
template <class T>
class Div {
public:
static T half(T value) { return value / 2; }
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment