Commit ad4b2066 authored by Robert Bradshaw's avatar Robert Bradshaw

Module-level cpdef functions

parent be3c7caf
...@@ -867,7 +867,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -867,7 +867,7 @@ class CFuncDefNode(FuncDefNode):
if self.overridable: if self.overridable:
import ExprNodes import ExprNodes
py_func_body = self.call_self_node() py_func_body = self.call_self_node(is_module_scope = env.is_module_scope)
self.py_func = DefNode(pos = self.pos, self.py_func = DefNode(pos = self.pos,
name = self.declarator.base.name, name = self.declarator.base.name,
args = self.declarator.args, args = self.declarator.args,
...@@ -875,23 +875,30 @@ class CFuncDefNode(FuncDefNode): ...@@ -875,23 +875,30 @@ class CFuncDefNode(FuncDefNode):
starstar_arg = None, starstar_arg = None,
doc = self.doc, doc = self.doc,
body = py_func_body) body = py_func_body)
self.py_func.is_module_scope = env.is_module_scope
self.py_func.analyse_declarations(env) self.py_func.analyse_declarations(env)
self.entry.as_variable = self.py_func.entry
# Reset scope entry the above cfunction # Reset scope entry the above cfunction
env.entries[name] = self.entry env.entries[name] = self.entry
if Options.intern_names: if Options.intern_names:
self.py_func.interned_attr_cname = env.intern(self.py_func.entry.name) self.py_func.interned_attr_cname = env.intern(self.py_func.entry.name)
self.override = OverrideCheckNode(self.pos, py_func = self.py_func) if not env.is_module_scope or Options.lookup_module_cpdef:
self.body = StatListNode(self.pos, stats=[self.override, self.body]) self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
self.body = StatListNode(self.pos, stats=[self.override, self.body])
def call_self_node(self, omit_optional_args=0): def call_self_node(self, omit_optional_args=0, is_module_scope=0):
import ExprNodes import ExprNodes
args = self.type.args args = self.type.args
if omit_optional_args: if omit_optional_args:
args = args[:len(args) - self.type.optional_arg_count] args = args[:len(args) - self.type.optional_arg_count]
arg_names = [arg.name for arg in args] arg_names = [arg.name for arg in args]
self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0]) if is_module_scope:
cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name) cfunc = ExprNodes.NameNode(self.pos, name=self.declarator.base.name)
c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True) else:
self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
skip_dispatch = not is_module_scope or Options.lookup_module_cpdef
c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1-is_module_scope:]], wrapper_call=skip_dispatch)
return ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call) return ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
def declare_arguments(self, env): def declare_arguments(self, env):
...@@ -1667,12 +1674,16 @@ class OverrideCheckNode(StatNode): ...@@ -1667,12 +1674,16 @@ class OverrideCheckNode(StatNode):
def analyse_expressions(self, env): def analyse_expressions(self, env):
self.args = env.arg_entries self.args = env.arg_entries
if self.py_func.is_module_scope:
first_arg = 0
else:
first_arg = 1
import ExprNodes import ExprNodes
self.func_node = ExprNodes.PyTempNode(self.pos, env) self.func_node = ExprNodes.PyTempNode(self.pos, env)
call_tuple = ExprNodes.TupleNode(self.pos, args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[1:]]) call_tuple = ExprNodes.TupleNode(self.pos, args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[first_arg:]])
call_node = ExprNodes.SimpleCallNode(self.pos, call_node = ExprNodes.SimpleCallNode(self.pos,
function=self.func_node, function=self.func_node,
args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[1:]]) args=[ExprNodes.NameNode(self.pos, name=arg.name) for arg in self.args[first_arg:]])
self.body = ReturnStatNode(self.pos, value=call_node) self.body = ReturnStatNode(self.pos, value=call_node)
# self.func_temp = env.allocate_temp_pyobject() # self.func_temp = env.allocate_temp_pyobject()
self.body.analyse_expressions(env) self.body.analyse_expressions(env)
...@@ -1680,11 +1691,17 @@ class OverrideCheckNode(StatNode): ...@@ -1680,11 +1691,17 @@ class OverrideCheckNode(StatNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
# Check to see if we are an extension type # Check to see if we are an extension type
self_arg = "((PyObject *)%s)" % self.args[0].cname if self.py_func.is_module_scope:
self_arg = "((PyObject *)%s)" % Naming.module_cname
else:
self_arg = "((PyObject *)%s)" % self.args[0].cname
code.putln("/* Check if called by wrapper */") code.putln("/* Check if called by wrapper */")
code.putln("if (unlikely(%s)) %s = 0;" % (Naming.skip_dispatch_cname, Naming.skip_dispatch_cname)) code.putln("if (unlikely(%s)) %s = 0;" % (Naming.skip_dispatch_cname, Naming.skip_dispatch_cname))
code.putln("/* Check if overriden in Python */") code.putln("/* Check if overriden in Python */")
code.putln("else if (unlikely(%s->ob_type->tp_dictoffset != 0)) {" % self_arg) if self.py_func.is_module_scope:
code.putln("else {")
else:
code.putln("else if (unlikely(%s->ob_type->tp_dictoffset != 0)) {" % self_arg)
err = code.error_goto_if_null(self_arg, self.pos) err = code.error_goto_if_null(self_arg, self.pos)
# need to get attribute manually--scope would return cdef method # need to get attribute manually--scope would return cdef method
if Options.intern_names: if Options.intern_names:
......
...@@ -33,3 +33,9 @@ annotate = 0 ...@@ -33,3 +33,9 @@ annotate = 0
# raised before the loop is entered, wheras without this option the loop # raised before the loop is entered, wheras without this option the loop
# will execute util a overflowing value is encountered. # will execute util a overflowing value is encountered.
convert_range = 0 convert_range = 0
# Enable this to allow one to write your_module.foo = ... to overwrite the
# definition if the cpdef function foo, at the cost of an extra dictionary
# lookup on every call.
# If this is 0 it simply creates a wrapper.
lookup_module_cpdef = 0
...@@ -1722,8 +1722,6 @@ def p_api(s): ...@@ -1722,8 +1722,6 @@ def p_api(s):
def p_cdef_statement(s, level, visibility = 'private', api = 0, def p_cdef_statement(s, level, visibility = 'private', api = 0,
overridable = False): overridable = False):
pos = s.position() pos = s.position()
if overridable and level not in ('c_class', 'c_class_pxd'):
error(pos, "Overridable cdef function not allowed here")
visibility = p_visibility(s, visibility) visibility = p_visibility(s, visibility)
api = api or p_api(s) api = api or p_api(s)
if api: if api:
......
...@@ -148,6 +148,7 @@ class Scope: ...@@ -148,6 +148,7 @@ class Scope:
is_py_class_scope = 0 is_py_class_scope = 0
is_c_class_scope = 0 is_c_class_scope = 0
is_module_scope = 0
scope_prefix = "" scope_prefix = ""
in_cinclude = 0 in_cinclude = 0
...@@ -673,6 +674,8 @@ class ModuleScope(Scope): ...@@ -673,6 +674,8 @@ class ModuleScope(Scope):
# interned_names [string] Interned names pending generation of declarations # interned_names [string] Interned names pending generation of declarations
# all_pystring_entries [Entry] Python string consts from all scopes # all_pystring_entries [Entry] Python string consts from all scopes
# types_imported {PyrexType : 1} Set of types for which import code generated # types_imported {PyrexType : 1} Set of types for which import code generated
is_module_scope = 1
def __init__(self, name, parent_module, context): def __init__(self, name, parent_module, context):
self.parent_module = parent_module self.parent_module = parent_module
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment