Commit 2f20794c authored by Mark Florisson's avatar Mark Florisson

Support fused cpdef functions with Python indexing

parent 1472e87b
......@@ -16,6 +16,13 @@ class CythonScope(ModuleScope):
defining = 1,
cname='<error>')
for fused_type in (cy_integral_type, cy_floating_type, cy_numeric_type):
entry = self.declare_typedef(fused_type.name,
fused_type,
None,
cname='<error>')
entry.in_cinclude = True
def lookup_type(self, name):
# This function should go away when types are all first-level objects.
type = parse_basic_type(name)
......
This diff is collapsed.
......@@ -1823,8 +1823,12 @@ class CFuncDefNode(FuncDefNode):
# An error will be produced in the cdef function
self.overridable = False
self.declare_cpdef_wrapper(env)
self.create_local_scope(env)
def declare_cpdef_wrapper(self, env):
if self.overridable:
import ExprNodes
name = self.entry.name
py_func_body = self.call_self_node(is_module_scope = env.is_module_scope)
self.py_func = DefNode(pos = self.pos,
name = self.entry.name,
......@@ -1842,7 +1846,6 @@ class CFuncDefNode(FuncDefNode):
if not env.is_module_scope or Options.lookup_module_cpdef:
self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
self.body = StatListNode(self.pos, stats=[self.override, self.body])
self.create_local_scope(env)
def _validate_type_visibility(self, type, pos, env):
"""
......@@ -2027,14 +2030,13 @@ class FusedCFuncDefNode(StatListNode):
node FuncDefNode the original function
nodes [FuncDefNode] list of copies of node with different specific types
py_func DefNode the original python function (in case of a cpdef)
"""
child_attrs = ['nodes']
def __init__(self, node, env):
super(FusedCFuncDefNode, self).__init__(node.pos)
self.nodes = self.stats = []
self.nodes = []
self.node = node
self.copy_cdefs(env)
......@@ -2055,6 +2057,10 @@ class FusedCFuncDefNode(StatListNode):
node.entry.fused_cfunction = self
self.stats = self.nodes[:]
if self.py_func:
self.stats.append(self.py_func)
def copy_cdefs(self, env):
"""
Gives a list of fused types and the parent environment, make copies
......@@ -2067,8 +2073,17 @@ class FusedCFuncDefNode(StatListNode):
# len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations])
if self.node.entry in env.cfunc_entries:
env.cfunc_entries.remove(self.node.entry)
# Prevent copying of the python function
self.py_func = self.node.py_func
self.node.py_func = None
if self.py_func:
env.pyfunc_entries.remove(self.py_func.entry)
fused_types = self.node.type.get_fused_types()
for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node)
......@@ -2104,6 +2119,25 @@ class FusedCFuncDefNode(StatListNode):
type.specialize_entry(entry, cname)
env.cfunc_entries.append(entry)
# If a cpdef, declare all specialized cpdefs
copied_node.declare_cpdef_wrapper(env)
if copied_node.py_func:
env.pyfunc_entries.remove(copied_node.py_func.entry)
type_strings = [str(fused_to_specific[fused_type])
for fused_type in fused_types]
if len(type_strings) == 1:
sigstring = type_strings[0]
else:
sigstring = '(%s)' % ', '.join(type_strings)
copied_node.py_func.specialized_signature_string = sigstring
copied_node.py_func.fused_py_func = self.py_func
e = copied_node.py_func.entry
e.pymethdef_cname = PyrexTypes.get_fused_cname(
cname, e.pymethdef_cname)
num_errors = Errors.num_errors
transform = ParseTreeTransforms.ReplaceFusedTypeChecks(
copied_node.local_scope)
......@@ -2112,6 +2146,18 @@ class FusedCFuncDefNode(StatListNode):
if Errors.num_errors > num_errors:
break
if self.py_func:
self.py_func.specialized_cpdefs = [n.py_func for n in self.nodes]
self.py_func.fused_args_positions = [
i for i, arg in enumerate(self.node.type.args)
if arg.is_fused]
from Cython.Compiler import TreeFragment
fragment = TreeFragment.TreeFragment(u"""
raise ValueError("Index the function to get a specialized version")
""", level='function')
self.py_func.body = fragment.substitute()
def generate_function_definitions(self, env, code):
for stat in self.stats:
# print stat.entry, stat.entry.used
......@@ -2167,6 +2213,12 @@ class DefNode(FuncDefNode):
# when the def statement is inside a Python class definition.
#
# assmt AssignmentNode Function construction/assignment
#
# fused_py_func DefNode The original fused cpdef DefNode
# (in case this is a specialization)
# specialized_cpdefs [DefNode] list of specialized cpdef DefNodes
# fused_args_positions [int] list of the positions of the
# arguments with fused types
child_attrs = ["args", "star_arg", "starstar_arg", "body", "decorators"]
......@@ -2186,6 +2238,10 @@ class DefNode(FuncDefNode):
starstar_arg = None
doc = None
fused_py_func = False
specialized_cpdefs = None
fused_args_positions = None
def __init__(self, pos, **kwds):
FuncDefNode.__init__(self, pos, **kwds)
k = rk = r = 0
......@@ -2506,11 +2562,22 @@ class DefNode(FuncDefNode):
def analyse_expressions(self, env):
self.local_scope.directives = env.directives
self.analyse_default_values(env)
if self.specialized_cpdefs:
for arg in self.args + self.local_scope.arg_entries:
arg.needs_conversion = False
arg.type = py_object_type
self.local_scope.entries.clear()
del self.local_scope.var_entries[:]
if self.needs_assignment_synthesis(env):
# Shouldn't we be doing this at the module level too?
self.synthesize_assignment_node(env)
def needs_assignment_synthesis(self, env, code=None):
if self.specialized_cpdefs:
return True
if self.no_assignment_synthesis:
return False
# Should enable for module level as well, that will require more testing...
......@@ -2534,7 +2601,11 @@ class DefNode(FuncDefNode):
self.pos, pymethdef_cname = self.entry.pymethdef_cname)
else:
rhs = ExprNodes.PyCFunctionNode(
self.pos, pymethdef_cname = self.entry.pymethdef_cname, binding = env.directives['binding'])
self.pos,
pymethdef_cname = self.entry.pymethdef_cname,
binding = env.directives['binding'],
specialized_cpdefs = self.specialized_cpdefs,
fused_args_positions = self.fused_args_positions)
if env.is_py_class_scope:
if not self.is_staticmethod and not self.is_classmethod:
......@@ -2573,8 +2644,15 @@ class DefNode(FuncDefNode):
if mf: mf += " "
header = "static %s%s(%s)" % (mf, dc, arg_code)
code.putln("%s; /*proto*/" % header)
if proto_only:
if self.fused_py_func:
# If we are the specialized version of the cpdef, we still
# want the prototype for the "fused cpdef", in case we're
# checking to see if our method was overridden in Python
self.fused_py_func.generate_function_header(code, with_pymethdef, proto_only=True)
return
if (Options.docstrings and self.entry.doc and
not self.entry.scope.is_property_scope and
(not self.entry.is_special or self.entry.wrapperbase_cname)):
......@@ -2588,7 +2666,7 @@ class DefNode(FuncDefNode):
if self.entry.is_special:
code.putln(
"struct wrapperbase %s;" % self.entry.wrapperbase_cname)
if with_pymethdef:
if with_pymethdef or self.fused_py_func:
code.put(
"static PyMethodDef %s = " %
self.entry.pymethdef_cname)
......
......@@ -1318,20 +1318,21 @@ if VALUE is not None:
# ---------------------------------------
return property
class AnalyseExpressionsTransform(CythonTransform):
nested_index_node = False
class AnalyseExpressionsTransform(EnvTransform):
def visit_ModuleNode(self, node):
self.env_stack = [node.scope]
node.scope.infer_types()
node.body.analyse_expressions(node.scope)
self.visitchildren(node)
return node
def visit_FuncDefNode(self, node):
self.env_stack.append(node.local_scope)
node.local_scope.infer_types()
node.body.analyse_expressions(node.local_scope)
self.visitchildren(node)
self.env_stack.pop()
return node
def visit_ScopedExprNode(self, node):
......@@ -1347,15 +1348,24 @@ class AnalyseExpressionsTransform(CythonTransform):
argument types with the Attribute- or NameNode referring to the
function. We then need to copy over the specialization properties to
the attribute or name node.
Because the indexing might be a Python indexing operation on a fused
function, or (usually) a Cython indexing operation, we need to
re-analyse the types.
"""
self.visit_Node(node)
type = node.type
if type.is_cfunction and node.base.type.is_fused:
if node.is_fused_index:
if node.type is PyrexTypes.error_type:
node.type = PyrexTypes.error_type
else:
node.base.type = node.type
node.base.entry = node.type.entry
node.base.entry = getattr(node, 'entry', None) or node.type.entry
node = node.base
node.analyse_types(self.env_stack[-1])
return node
......
......@@ -666,15 +666,15 @@ class FusedType(PyrexType):
See http://wiki.cython.org/enhancements/fusedtypes
types [CSimpleBaseTypeNode] is the list of types to be fused
types [PyrexType] is the list of types to be fused
name str the name of the ctypedef
"""
is_fused = 1
name = None
def __init__(self, types):
def __init__(self, types, name=None):
self.types = types
self.name = name
def declaration_code(self, entity_code, for_display = 0,
dll_linkage = None, pyrex = 0):
......@@ -2079,7 +2079,8 @@ class CFuncType(CType):
if entry.is_cmethod:
entry.cname = entry.name
if entry.is_inherited:
entry.cname = "%s.%s" % (Naming.obj_base_cname, entry.cname)
entry.cname = StringEncoding.EncodedString(
"%s.%s" % (Naming.obj_base_cname, entry.cname))
else:
entry.cname = get_fused_cname(cname, entry.cname)
......@@ -2092,7 +2093,8 @@ def get_fused_cname(fused_cname, orig_cname):
Given the fused cname id and an original cname, return a specialized cname
"""
assert fused_cname and orig_cname
return '%s%s%s' % (Naming.fused_func_prefix, fused_cname, orig_cname)
return StringEncoding.EncodedString('%s%s%s' % (Naming.fused_func_prefix,
fused_cname, orig_cname))
def get_all_specific_permutations(fused_types, id="", f2s=()):
fused_type = fused_types[0]
......@@ -2631,6 +2633,17 @@ c_size_t_ptr_type = CPtrType(c_size_t_type)
c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
# Not sure whether the unsigned versions and 'long long' should be in there
# long long requires C99 and might be slow, and would always get preferred
# when specialization happens through calling and not indexing
cy_integral_type = FusedType([c_int_type, c_long_type], name="integral")
# Omitting long double as it might be slow
cy_floating_type = FusedType([c_float_type, c_double_type], name="floating")
cy_numeric_type = FusedType([c_long_type,
c_double_type,
c_double_complex_type], name="numeric")
error_type = ErrorType()
unspecified_type = UnspecifiedType()
......
......@@ -135,6 +135,9 @@ class PointerType(CythonType):
else:
return not self._items and not value._items
def __repr__(self):
return "%s *" % (self._basetype,)
class ArrayType(PointerType):
def __init__(self):
......@@ -218,13 +221,17 @@ def union(**members):
class typedef(CythonType):
def __init__(self, type):
def __init__(self, type, name=None):
self._basetype = type
self.name = name
def __call__(self, *arg):
value = cast(self._basetype, *arg)
return value
def __repr__(self):
return self.name or str(self._basetype)
class _FusedType(CythonType):
def __call__(self, type, value):
......@@ -235,6 +242,7 @@ def fused_type(*args):
if not args:
raise TypeError("Expected at least one type as argument")
# Find the numeric type with biggest rank if all types are numeric
rank = -1
for type in args:
if type not in (py_int, py_long, py_float, py_complex):
......@@ -251,13 +259,18 @@ def fused_type(*args):
return _FusedType()
py_int = int
def _specialized_from_args(signatures, args, kwargs):
"Perhaps this should be implemented in a TreeFragment in Cython code"
raise Exception("yet to be implemented")
py_int = typedef(int, "int")
try:
py_long = long
py_long = typedef(long, "long")
except NameError: # Py3
py_long = int
py_float = float
py_complex = complex
py_long = typedef(int, "long")
py_float = typedef(float, "float")
py_complex = typedef(complex, "complex")
try:
......@@ -278,28 +291,39 @@ float_types = ['longdouble', 'double', 'float']
complex_types = ['longdoublecomplex', 'doublecomplex', 'floatcomplex', 'complex']
other_types = ['bint', 'void']
to_repr = {
'longlong': 'long long',
'longdouble': 'long double',
'longdoublecomplex': 'long double complex',
'doublecomplex': 'double complex',
'floatcomplex': 'float complex',
}.get
gs = globals()
for name in int_types:
gs[name] = typedef(py_int)
reprname = to_repr(name, name)
gs[name] = typedef(py_int, reprname)
if name != 'Py_UNICODE' and not name.endswith('size_t'):
gs['u'+name] = typedef(py_int)
gs['s'+name] = typedef(py_int)
gs['u'+name] = typedef(py_int, "unsigned " + reprname)
gs['s'+name] = typedef(py_int, "signed " + reprname)
for name in float_types:
gs[name] = typedef(py_float)
gs[name] = typedef(py_float, to_repr(name, name))
for name in complex_types:
gs[name] = typedef(py_complex)
gs[name] = typedef(py_complex, to_repr(name, name))
bint = typedef(bool)
void = typedef(int)
bint = typedef(bool, "bint")
void = typedef(int, "void")
for t in int_types + float_types + complex_types + other_types:
for i in range(1, 4):
gs["%s_%s" % ('p'*i, t)] = globals()[t]._pointer(i)
void = typedef(None)
void = typedef(None, "void")
NULL = p_void(0)
integral = floating = numeric = _FusedType()
type_ordering = [py_int, py_long, py_float, py_complex]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment