Commit 15833ae4 authored by Robert Bradshaw's avatar Robert Bradshaw

(Python) override modifier for cdef methods

parent b9539756
...@@ -1033,6 +1033,9 @@ class TempNode(AtomicExprNode): ...@@ -1033,6 +1033,9 @@ class TempNode(AtomicExprNode):
if type.is_pyobject: if type.is_pyobject:
self.result_ctype = py_object_type self.result_ctype = py_object_type
self.is_temp = 1 self.is_temp = 1
def analyse_types(self, env):
return self.type
def generate_result_code(self, code): def generate_result_code(self, code):
pass pass
...@@ -1686,11 +1689,10 @@ class AttributeNode(ExprNode): ...@@ -1686,11 +1689,10 @@ class AttributeNode(ExprNode):
error(self.pos, "Illegal use of special attribute __weakref__") error(self.pos, "Illegal use of special attribute __weakref__")
# methods need the normal attribute lookup # methods need the normal attribute lookup
# because they do not have struct entries # because they do not have struct entries
if not entry.is_method: if entry.is_variable or entry.is_cmethod:
if entry.is_variable or entry.is_cmethod: self.type = entry.type
self.type = entry.type self.member = entry.cname
self.member = entry.cname return
return
else: else:
# If it's not a variable or C method, it must be a Python # If it's not a variable or C method, it must be a Python
# method of an extension type, so we treat it like a Python # method of an extension type, so we treat it like a Python
......
...@@ -534,6 +534,8 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -534,6 +534,8 @@ class FuncDefNode(StatNode, BlockNode):
# #filename string C name of filename string const # #filename string C name of filename string const
# entry Symtab.Entry # entry Symtab.Entry
py_func = None
def analyse_expressions(self, env): def analyse_expressions(self, env):
pass pass
...@@ -559,6 +561,10 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -559,6 +561,10 @@ class FuncDefNode(StatNode, BlockNode):
self.generate_const_definitions(lenv, code) self.generate_const_definitions(lenv, code)
# ----- Function header # ----- Function header
code.putln("") code.putln("")
if self.py_func:
self.py_func.generate_function_header(code,
with_pymethdef = env.is_py_class_scope,
proto_only=True)
self.generate_function_header(code, self.generate_function_header(code,
with_pymethdef = env.is_py_class_scope) with_pymethdef = env.is_py_class_scope)
# ----- Local variable declarations # ----- Local variable declarations
...@@ -639,7 +645,11 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -639,7 +645,11 @@ class FuncDefNode(StatNode, BlockNode):
# retval_code) # retval_code)
code.putln("return %s;" % retval_code) code.putln("return %s;" % retval_code)
code.putln("}") code.putln("}")
# ----- Python version
if self.py_func:
self.py_func.generate_function_definitions(env, code)
def put_stararg_decrefs(self, code): def put_stararg_decrefs(self, code):
pass pass
...@@ -671,7 +681,7 @@ class FuncDefNode(StatNode, BlockNode): ...@@ -671,7 +681,7 @@ class FuncDefNode(StatNode, BlockNode):
class CFuncDefNode(FuncDefNode): class CFuncDefNode(FuncDefNode):
# C function definition. # C function definition.
# #
# modifiers 'inline ' or '' # modifiers 'inline ' or 'visible' or 'overrideable'
# visibility 'private' or 'public' or 'extern' # visibility 'private' or 'public' or 'extern'
# base_type CBaseTypeNode # base_type CBaseTypeNode
# declarator CDeclaratorNode # declarator CDeclaratorNode
...@@ -700,7 +710,34 @@ class CFuncDefNode(FuncDefNode): ...@@ -700,7 +710,34 @@ class CFuncDefNode(FuncDefNode):
cname = cname, visibility = self.visibility, cname = cname, visibility = self.visibility,
defining = self.body is not None) defining = self.body is not None)
self.return_type = type.return_type self.return_type = type.return_type
if 'overrideable' in self.modifiers or 'visible' in self.modifiers:
if 'visible' in self.modifiers:
self.modifiers.remove('visible')
import ExprNodes
arg_names = [arg.name for arg in self.type.args]
self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]])
py_func_body = ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
self.py_func = DefNode(pos = self.pos,
name = self.declarator.base.name,
args = self.declarator.args,
star_arg = None,
starstar_arg = None,
doc = None, # self.doc,
body = py_func_body)
self.py_func.analyse_declarations(env)
# Reset scope entry the above cfunction
env.entries[name] = self.entry
if Options.intern_names:
self.py_func.interned_attr_cname = env.intern(self.py_func.entry.name)
if 'overrideable' in self.modifiers:
self.modifiers.remove('overrideable')
self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
self.body.stats.insert(0, self.override)
def declare_arguments(self, env): def declare_arguments(self, env):
for arg in self.type.args: for arg in self.type.args:
if not arg.name: if not arg.name:
...@@ -732,7 +769,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -732,7 +769,7 @@ class CFuncDefNode(FuncDefNode):
storage_class = "" storage_class = ""
code.putln("%s%s%s {" % ( code.putln("%s%s%s {" % (
storage_class, storage_class,
self.modifiers, ' '.join(self.modifiers).upper(), # macro forms
header)) header))
def generate_argument_declarations(self, env, code): def generate_argument_declarations(self, env, code):
...@@ -792,7 +829,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -792,7 +829,7 @@ class CFuncDefNode(FuncDefNode):
def caller_will_check_exceptions(self): def caller_will_check_exceptions(self):
return self.entry.type.exception_check return self.entry.type.exception_check
class PyArgDeclNode(Node): class PyArgDeclNode(Node):
# Argument which must be a Python object (used # Argument which must be a Python object (used
...@@ -923,7 +960,7 @@ class DefNode(FuncDefNode): ...@@ -923,7 +960,7 @@ class DefNode(FuncDefNode):
else: else:
self.entry.doc = self.doc self.entry.doc = self.doc
self.entry.func_cname = \ self.entry.func_cname = \
Naming.func_prefix + env.scope_prefix + self.name Naming.func_prefix + "py_" + env.scope_prefix + self.name
self.entry.doc_cname = \ self.entry.doc_cname = \
Naming.funcdoc_prefix + env.scope_prefix + self.name Naming.funcdoc_prefix + env.scope_prefix + self.name
self.entry.pymethdef_cname = \ self.entry.pymethdef_cname = \
...@@ -989,7 +1026,7 @@ class DefNode(FuncDefNode): ...@@ -989,7 +1026,7 @@ class DefNode(FuncDefNode):
self.assmt.analyse_declarations(env) self.assmt.analyse_declarations(env)
self.assmt.analyse_expressions(env) self.assmt.analyse_expressions(env)
def generate_function_header(self, code, with_pymethdef): def generate_function_header(self, code, with_pymethdef, proto_only=0):
arg_code_list = [] arg_code_list = []
sig = self.entry.signature sig = self.entry.signature
if sig.has_dummy_arg: if sig.has_dummy_arg:
...@@ -1012,6 +1049,8 @@ class DefNode(FuncDefNode): ...@@ -1012,6 +1049,8 @@ class DefNode(FuncDefNode):
dc = self.return_type.declaration_code(self.entry.func_cname) dc = self.return_type.declaration_code(self.entry.func_cname)
header = "static %s(%s)" % (dc, arg_code) header = "static %s(%s)" % (dc, arg_code)
code.putln("%s; /*proto*/" % header) code.putln("%s; /*proto*/" % header)
if proto_only:
return
if self.entry.doc: if self.entry.doc:
code.putln( code.putln(
'static char %s[] = "%s";' % ( 'static char %s[] = "%s";' % (
...@@ -1272,6 +1311,50 @@ class DefNode(FuncDefNode): ...@@ -1272,6 +1311,50 @@ class DefNode(FuncDefNode):
def caller_will_check_exceptions(self): def caller_will_check_exceptions(self):
return 1 return 1
class OverrideCheckNode(StatNode):
# A Node for dispatching to the def method if it
# is overriden.
#
# py_func
#
# args
# func_temp
# body
def analyse_expressions(self, env):
self.args = env.arg_entries
import ExprNodes
self.func_node = ExprNodes.PyTempNode(self.pos, env)
call_tuple = ExprNodes.TupleNode(self.pos, args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[1:]])
call_node = ExprNodes.SimpleCallNode(self.pos,
function=self.func_node,
args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[1:]])
self.body = ReturnStatNode(self.pos, value=call_node)
# self.func_temp = env.allocate_temp_pyobject()
self.body.analyse_expressions(env)
# env.release_temp(self.func_temp)
def generate_execution_code(self, code):
# Check to see if we are an extension type
self_arg = "((PyObject *)%s)" % self.args[0].cname
code.putln("/* Check if overriden in Python */")
code.putln("if (unlikely(%s->ob_type->tp_dictoffset != 0)) {" % self_arg)
err = code.error_goto_if_null(self_arg, self.pos)
# need to get attribute manually--scope would return cdef method
if Options.intern_names:
code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (self.func_node.result_code, self_arg, self.py_func.interned_attr_cname, err))
else:
code.putln('%s = PyObject_GetAttrString(%s, "%s"); %s' % (self.func_node.result_code, self_arg, self.py_func.entry.name, err))
# It appears that this type is not anywhere exposed in the Python/C API
is_builtin_function_or_method = '(strcmp(%s->ob_type->tp_name, "builtin_function_or_method") == 0)' % self.func_node.result_code
is_overridden = '(PyCFunction_GET_FUNCTION(%s) != &%s)' % (self.func_node.result_code, self.py_func.entry.func_cname)
code.putln('if (!%s || %s) {' % (is_builtin_function_or_method, is_overridden))
self.body.generate_execution_code(code)
code.putln('}')
# code.put_decref(self.func_temp, PyrexTypes.py_object_type)
code.putln("}")
class PyClassDefNode(StatNode, BlockNode): class PyClassDefNode(StatNode, BlockNode):
# A Python class definition. # A Python class definition.
......
...@@ -1691,11 +1691,11 @@ def p_visibility(s, prev_visibility): ...@@ -1691,11 +1691,11 @@ def p_visibility(s, prev_visibility):
return visibility return visibility
def p_c_modifiers(s): def p_c_modifiers(s):
if s.systring in ('inline', ): if s.sy == 'IDENT' and s.systring in ('inline', 'visible', 'overrideable'):
modifier = s.systring.upper() # uppercase is macro defined for various compilers modifier = s.systring
s.next() s.next()
return modifier + ' ' + p_c_modifiers(s) return [modifier] + p_c_modifiers(s)
return "" return []
def p_c_func_or_var_declaration(s, level, pos, visibility = 'private'): def p_c_func_or_var_declaration(s, level, pos, visibility = 'private'):
cmethod_flag = level in ('c_class', 'c_class_pxd') cmethod_flag = level in ('c_class', 'c_class_pxd')
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment