Commit fe73aba5 authored by Robert Bradshaw's avatar Robert Bradshaw

Support @staticmethod decorator for C++ classes.

parent 49d90c94
...@@ -5310,10 +5310,10 @@ class AttributeNode(ExprNode): ...@@ -5310,10 +5310,10 @@ class AttributeNode(ExprNode):
# C method of an extension type or builtin type. If successful, # C method of an extension type or builtin type. If successful,
# creates a corresponding NameNode and returns it, otherwise # creates a corresponding NameNode and returns it, otherwise
# returns None. # returns None.
type = self.obj.analyse_as_extension_type(env) type = self.obj.analyse_as_type(env)
if type: if type and (type.is_extension_type or type.is_builtin_type or type.is_cpp_class):
entry = type.scope.lookup_here(self.attribute) entry = type.scope.lookup_here(self.attribute)
if entry and entry.is_cmethod: if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction):
if type.is_builtin_type: if type.is_builtin_type:
if not self.is_called: if not self.is_called:
# must handle this as Python object # must handle this as Python object
...@@ -5326,6 +5326,9 @@ class AttributeNode(ExprNode): ...@@ -5326,6 +5326,9 @@ class AttributeNode(ExprNode):
cname = entry.func_cname cname = entry.func_cname
if entry.type.is_static_method: if entry.type.is_static_method:
ctype = entry.type ctype = entry.type
elif type.is_cpp_class:
error(self.pos, "%s not a static member of %s" % (entry.name, type))
ctype = PyrexTypes.error_type
else: else:
# Fix self type. # Fix self type.
ctype = copy.copy(entry.type) ctype = copy.copy(entry.type)
......
...@@ -826,7 +826,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -826,7 +826,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
has_virtual_methods = False has_virtual_methods = False
has_destructor = False has_destructor = False
for attr in scope.var_entries: for attr in scope.var_entries:
if attr.type.is_cfunction and attr.name != "<init>": if attr.type.is_cfunction and attr.type.is_static_method:
code.put("static ")
elif attr.type.is_cfunction and attr.name != "<init>":
code.put("virtual ") code.put("virtual ")
has_virtual_methods = True has_virtual_methods = True
if attr.cname[0] == '~': if attr.cname[0] == '~':
......
...@@ -1294,6 +1294,8 @@ class CVarDefNode(StatNode): ...@@ -1294,6 +1294,8 @@ class CVarDefNode(StatNode):
if self.entry is not None: if self.entry is not None:
self.entry.is_overridable = self.overridable self.entry.is_overridable = self.overridable
self.entry.directive_locals = copy.copy(self.directive_locals) self.entry.directive_locals = copy.copy(self.directive_locals)
if 'staticmethod' in env.directives:
type.is_static_method = True
else: else:
if self.directive_locals: if self.directive_locals:
error(self.pos, "Decorators can only be followed by functions") error(self.pos, "Decorators can only be followed by functions")
...@@ -1361,6 +1363,9 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1361,6 +1363,9 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
# entry Entry # entry Entry
# base_classes [CBaseTypeNode] # base_classes [CBaseTypeNode]
# templates [string] or None # templates [string] or None
# decorators [DecoratorNode] or None
decorators = None
def declare(self, env): def declare(self, env):
if self.templates is None: if self.templates is None:
...@@ -1394,15 +1399,22 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode): ...@@ -1394,15 +1399,22 @@ class CppClassNode(CStructOrUnionDefNode, BlockNode):
if scope is not None: if scope is not None:
scope.type = self.entry.type scope.type = self.entry.type
defined_funcs = [] defined_funcs = []
def func_attributes(attributes):
for attr in attributes:
if isinstance(attr, CFuncDefNode):
yield attr
elif isinstance(attr, CompilerDirectivesNode):
for sub_attr in func_attributes(attr.body.stats):
yield sub_attr
if self.attributes is not None: if self.attributes is not None:
if self.in_pxd and not env.in_cinclude: if self.in_pxd and not env.in_cinclude:
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
for attr in self.attributes: for attr in self.attributes:
attr.analyse_declarations(scope) attr.analyse_declarations(scope)
if isinstance(attr, CFuncDefNode): for func in func_attributes(self.attributes):
defined_funcs.append(attr) defined_funcs.append(func)
if self.templates is not None: if self.templates is not None:
attr.template_declaration = "template <typename %s>" % ", typename ".join(self.templates) func.template_declaration = "template <typename %s>" % ", typename ".join(self.templates)
self.body = StatListNode(self.pos, stats=defined_funcs) self.body = StatListNode(self.pos, stats=defined_funcs)
self.scope = scope self.scope = scope
......
...@@ -951,11 +951,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -951,11 +951,11 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
for name, value in directives.iteritems(): for name, value in directives.iteritems():
if name == 'locals': if name == 'locals':
node.directive_locals = value node.directive_locals = value
elif name != 'final': elif name not in ('final', 'staticmethod'):
self.context.nonfatal_error(PostParseError( self.context.nonfatal_error(PostParseError(
node.pos, node.pos,
"Cdef functions can only take cython.locals() " "Cdef functions can only take cython.locals(), "
"or final decorators, got %s." % name)) "staticmethod, or final decorators, got %s." % name))
body = Nodes.StatListNode(node.pos, stats=[node]) body = Nodes.StatListNode(node.pos, stats=[node])
return self.visit_with_directives(body, directives) return self.visit_with_directives(body, directives)
...@@ -966,6 +966,13 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): ...@@ -966,6 +966,13 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
body = Nodes.StatListNode(node.pos, stats=[node]) body = Nodes.StatListNode(node.pos, stats=[node])
return self.visit_with_directives(body, directives) return self.visit_with_directives(body, directives)
def visit_CppClassNode(self, node):
directives = self._extract_directives(node, 'cppclass')
if not directives:
return self.visit_Node(node)
body = Nodes.StatListNode(node.pos, stats=[node])
return self.visit_with_directives(body, directives)
def visit_PyClassDefNode(self, node): def visit_PyClassDefNode(self, node):
directives = self._extract_directives(node, 'class') directives = self._extract_directives(node, 'class')
if not directives: if not directives:
......
...@@ -187,3 +187,4 @@ cdef p_doc_string(PyrexScanner s) ...@@ -187,3 +187,4 @@ cdef p_doc_string(PyrexScanner s)
cdef p_ignorable_statement(PyrexScanner s) cdef p_ignorable_statement(PyrexScanner s)
cdef p_compiler_directive_comments(PyrexScanner s) cdef p_compiler_directive_comments(PyrexScanner s)
cdef p_cpp_class_definition(PyrexScanner s, pos, ctx) cdef p_cpp_class_definition(PyrexScanner s, pos, ctx)
def p_cpp_class_attribute(PyrexScanner s, ctx):
...@@ -3231,12 +3231,8 @@ def p_cpp_class_definition(s, pos, ctx): ...@@ -3231,12 +3231,8 @@ def p_cpp_class_definition(s, pos, ctx):
body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil) body_ctx = Ctx(visibility = ctx.visibility, level='cpp_class', nogil=nogil or ctx.nogil)
body_ctx.templates = templates body_ctx.templates = templates
while s.sy != 'DEDENT': while s.sy != 'DEDENT':
if s.systring == 'cppclass': if s.sy != 'pass':
attributes.append( attributes.append(p_cpp_class_attribute(s, body_ctx))
p_cpp_class_definition(s, s.position(), body_ctx))
elif s.sy != 'pass':
attributes.append(
p_c_func_or_var_declaration(s, s.position(), body_ctx))
else: else:
s.next() s.next()
s.expect_newline("Expected a newline") s.expect_newline("Expected a newline")
...@@ -3253,6 +3249,23 @@ def p_cpp_class_definition(s, pos, ctx): ...@@ -3253,6 +3249,23 @@ def p_cpp_class_definition(s, pos, ctx):
attributes = attributes, attributes = attributes,
templates = templates) templates = templates)
def p_cpp_class_attribute(s, ctx):
decorators = None
if s.sy == '@':
decorators = p_decorators(s)
if s.systring == 'cppclass':
return p_cpp_class_definition(s, s.position(), ctx)
else:
node = p_c_func_or_var_declaration(s, s.position(), ctx)
if decorators is not None:
tup = Nodes.CFuncDefNode, Nodes.CVarDefNode, Nodes.CClassDefNode
if ctx.allow_struct_enum_decorator:
tup += Nodes.CStructOrUnionDefNode, Nodes.CEnumDefNode
if not isinstance(node, tup):
s.error("Decorators can only be followed by functions or classes")
node.decorators = decorators
return node
#---------------------------------------------- #----------------------------------------------
# #
......
...@@ -3110,7 +3110,7 @@ class CppClassType(CType): ...@@ -3110,7 +3110,7 @@ class CppClassType(CType):
# Need to do these *after* self.specializations[key] is set # Need to do these *after* self.specializations[key] is set
# to avoid infinite recursion on circular references. # to avoid infinite recursion on circular references.
specialized.base_classes = [b.specialize(values) for b in self.base_classes] specialized.base_classes = [b.specialize(values) for b in self.base_classes]
specialized.scope = self.scope.specialize(values) specialized.scope = self.scope.specialize(values, specialized)
if self.namespace is not None: if self.namespace is not None:
specialized.namespace = self.namespace.specialize(values) specialized.namespace = self.namespace.specialize(values)
return specialized return specialized
......
...@@ -2190,8 +2190,9 @@ class CppClassScope(Scope): ...@@ -2190,8 +2190,9 @@ class CppClassScope(Scope):
utility_code = base_entry.utility_code) utility_code = base_entry.utility_code)
entry.is_inherited = 1 entry.is_inherited = 1
def specialize(self, values): def specialize(self, values, type_entry):
scope = CppClassScope(self.name, self.outer_scope) scope = CppClassScope(self.name, self.outer_scope)
scope.type = type_entry
for entry in self.entries.values(): for entry in self.entries.values():
if entry.is_type: if entry.is_type:
scope.declare_type(entry.name, scope.declare_type(entry.name,
......
...@@ -42,6 +42,21 @@ def test_Poly(int n, float radius=1): ...@@ -42,6 +42,21 @@ def test_Poly(int n, float radius=1):
del poly del poly
cdef cppclass WithStatic:
@staticmethod
double square(double x):
return x * x
def test_Static(x):
"""
>>> test_Static(2)
4.0
>>> test_Static(0.5)
0.25
"""
return WithStatic.square(x)
cdef cppclass InitDealloc: cdef cppclass InitDealloc:
__init__(): __init__():
print "Init" print "Init"
......
...@@ -22,6 +22,10 @@ cdef extern from "cpp_templates_helper.h": ...@@ -22,6 +22,10 @@ cdef extern from "cpp_templates_helper.h":
cdef cppclass SubClass[T2, T3](SuperClass[T2, T3]): cdef cppclass SubClass[T2, T3](SuperClass[T2, T3]):
pass pass
cdef cppclass Div[T]:
@staticmethod
T half(T value)
def test_int(int x, int y): def test_int(int x, int y):
""" """
>>> test_int(3, 4) >>> test_int(3, 4)
...@@ -104,3 +108,12 @@ def test_cast_template_pointer(): ...@@ -104,3 +108,12 @@ def test_cast_template_pointer():
sup = sub sup = sub
sup = <SubClass[int, float] *> sub sup = <SubClass[int, float] *> sub
def test_static(x):
"""
>>> test_static(2)
(1, 1.0)
>>> test_static(3)
(1, 1.5)
"""
return Div[int].half(x), Div[double].half(x)
...@@ -30,3 +30,9 @@ public: ...@@ -30,3 +30,9 @@ public:
template <class T2, class T3> template <class T2, class T3>
class SubClass : public SuperClass<T2, T3> { class SubClass : public SuperClass<T2, T3> {
}; };
template <class T>
class Div {
public:
static T half(T value) { return value / 2; }
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment