Commit 47ce63c9 authored by Mark Florisson's avatar Mark Florisson

Support for fused types in cdef functions

parent b386862c
......@@ -2949,6 +2949,7 @@ class SimpleCallNode(CallNode):
else:
for arg in self.args:
arg.analyse_types(env)
if self.self and func_type.args:
# Coerce 'self' to the type expected by the method.
self_arg = func_type.args[0]
......@@ -2965,10 +2966,13 @@ class SimpleCallNode(CallNode):
def function_type(self):
# Return the type of the function being called, coercing a function
# pointer to a function if necessary.
# pointer to a function if necessary. If the function has fused
# arguments, return the specific type.
func_type = self.function.type
if func_type.is_ptr:
func_type = func_type.base_type
return func_type
def is_simple(self):
......@@ -2982,6 +2986,7 @@ class SimpleCallNode(CallNode):
if self.function.type is error_type:
self.type = error_type
return
if self.function.type.is_cpp_class:
overloaded_entry = self.function.type.scope.lookup("operator()")
if overloaded_entry is None:
......@@ -2992,8 +2997,16 @@ class SimpleCallNode(CallNode):
overloaded_entry = self.function.entry
else:
overloaded_entry = None
if overloaded_entry:
entry = PyrexTypes.best_match(self.args, overloaded_entry.all_alternatives(), self.pos)
if overloaded_entry.fused_cfunction:
specific_cdef_funcs = overloaded_entry.fused_cfunction.nodes
alternatives = [n.entry for n in specific_cdef_funcs]
else:
alternatives = overloaded_entry.all_alternatives()
entry = PyrexTypes.best_match(self.args, alternatives, self.pos, env)
if not entry:
self.type = PyrexTypes.error_type
self.result_code = "<error>"
......@@ -3130,8 +3143,8 @@ class SimpleCallNode(CallNode):
for actual_arg in self.args[len(formal_args):]:
arg_list_code.append(actual_arg.result())
result = "%s(%s)" % (self.function.result(),
', '.join(arg_list_code))
result = "%s(%s)" % (self.function.result(), ', '.join(arg_list_code))
return result
def generate_result_code(self, code):
......
......@@ -156,13 +156,22 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
f.close()
def generate_public_declaration(self, entry, h_code, i_code):
if entry.fused_cfunction:
for cfunction in entry.fused_cfunction.nodes:
self._generate_public_declaration(cfunction.entry,
cfunction.entry.cname, h_code, i_code)
else:
self._generate_public_declaration(entry, entry.cname,
h_code, i_code)
def _generate_public_declaration(self, entry, cname, h_code, i_code):
h_code.putln("%s %s;" % (
Naming.extern_c_macro,
entry.type.declaration_code(
entry.cname, dll_linkage = "DL_IMPORT")))
cname, dll_linkage = "DL_IMPORT")))
if i_code:
i_code.putln("cdef extern %s" %
entry.type.declaration_code(entry.cname, pyrex = 1))
entry.type.declaration_code(cname, pyrex = 1))
def api_name(self, env):
return env.qualified_name.replace(".", "__")
......@@ -987,6 +996,15 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
def generate_cfunction_predeclarations(self, env, code, definition):
for entry in env.cfunc_entries:
if entry.fused_cfunction:
for node in entry.fused_cfunction.nodes:
self._generate_cfunction_predeclaration(
code, definition, node.entry)
else:
self._generate_cfunction_predeclaration(code, definition, entry)
def _generate_cfunction_predeclaration(self, code, definition, entry):
if entry.inline_func_in_pxd or (not entry.in_cinclude and (definition
or entry.defined_in_pxd or entry.visibility == 'extern')):
if entry.visibility == 'public':
......
......@@ -93,6 +93,7 @@ enc_scope_cname = pyrex_prefix + "enc_scope"
frame_cname = pyrex_prefix + "frame"
frame_code_cname = pyrex_prefix + "frame_code"
binding_cfunc = pyrex_prefix + "binding_PyCFunctionType"
fused_func_prefix = pyrex_prefix + 'fuse_'
genexpr_id_ref = 'genexpr'
......
This diff is collapsed.
......@@ -610,8 +610,9 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
'operator.comma' : ExprNodes.c_binop_constructor(','),
}
special_methods = cython.set(['declare', 'union', 'struct', 'typedef', 'sizeof',
'cast', 'pointer', 'compiled', 'NULL'])
special_methods = cython.set(['declare', 'union', 'struct', 'typedef',
'sizeof', 'cast', 'pointer', 'compiled',
'NULL', 'fused_type'])
special_methods.update(unop_method_nodes.keys())
def __init__(self, context, compilation_directive_defaults):
......@@ -896,6 +897,36 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
return self.visit_with_directives(node.body, directive_dict)
return self.visit_Node(node)
def visit_CTypeDefNode(self, node):
"Don't skip ctypedefs"
self.visitchildren(node)
return node
def visit_FusedTypeNode(self, node):
"""
See if a function call expression in a ctypedef is actually
cython.fused_type()
"""
def err():
error(node.pos, "Can only fuse types with cython.fused_type()")
if len(node.funcname) == 1:
fused_type, = node.funcname
else:
cython_module, fused_type = node.funcname
wrong_module = cython_module not in self.cython_module_names
if wrong_module or fused_type != u'fused_type':
err()
return node
if not self.directive_names.get(fused_type):
err()
return node
class WithTransform(CythonTransform, SkipDeclarations):
# EXCINFO is manually set to a variable that contains
......@@ -1115,6 +1146,14 @@ if VALUE is not None:
return node
def visit_FuncDefNode(self, node):
"""
Analyse a function and its body, as that hasn't happend yet. Also
analyse the directive_locals set by @cython.locals(). Then, if we are
a function with fused arguments, replace the function (after it has
declared itself in the symbol table!) with a FusedCFuncDefNode, and
analyse its children (which are in turn normal functions). If we're a
normal function, just analyse the body of the function.
"""
self.seen_vars_stack.append(cython.set())
lenv = node.local_scope
node.body.analyse_control_flow(lenv) # this will be totally refactored
......@@ -1126,10 +1165,16 @@ if VALUE is not None:
lenv.declare_var(var, type, type_node.pos)
else:
error(type_node.pos, "Not a type")
if node.has_fused_arguments:
node = Nodes.FusedCFuncDefNode(node, self.env_stack[-1])
self.visitchildren(node)
else:
node.body.analyse_declarations(lenv)
self.env_stack.append(lenv)
self.visitchildren(node)
self.env_stack.pop()
self.seen_vars_stack.pop()
return node
......
......@@ -2572,6 +2572,24 @@ def p_c_func_or_var_declaration(s, pos, ctx):
overridable = ctx.overridable)
return result
def p_typelist(s):
"""
parse a list of basic c types as part of a function call, like
cython.fused_type(int, long, double)
"""
types = []
pos = s.position()
while s.sy == 'IDENT':
types.append(p_c_base_type(s))
if s.sy != ',':
if s.sy != ')':
s.expect(',')
break
s.next()
return Nodes.FusedTypeNode(pos, types=types)
def p_ctypedef_statement(s, ctx):
# s.sy == 'ctypedef'
pos = s.position()
......@@ -2588,10 +2606,30 @@ def p_ctypedef_statement(s, ctx):
return p_c_enum_definition(s, pos, ctx)
else:
return p_c_struct_or_union_definition(s, pos, ctx)
elif looking_at_expr(s):
# ctypedef cython.fused_types(int, long) integral
if s.sy == 'IDENT':
funcname = [s.systring]
s.next()
if s.systring == u'.':
s.next()
funcname.append(s.systring)
s.expect('IDENT')
s.expect('(')
base_type = p_typelist(s)
s.expect(')')
# Check if funcname equals cython.fused_types in
# InterpretCompilerDirectives
base_type.funcname = funcname
else:
s.error("Syntax error in ctypedef statement")
else:
base_type = p_c_base_type(s, nonempty = 1)
if base_type.name is None:
s.error("Syntax error in ctypedef statement")
declarator = p_c_declarator(s, ctx, is_type = 1, nonempty = 1)
s.expect_newline("Syntax error in ctypedef statement")
return Nodes.CTypeDefNode(
......
......@@ -2,6 +2,8 @@
# Pyrex - Types
#
import cython
from Code import UtilityCode
import StringEncoding
import Naming
......@@ -12,6 +14,9 @@ class BaseType(object):
#
# Base class for all Pyrex types including pseudo-types.
# List of attribute names of any subtypes
subtypes = []
def can_coerce_to_pyobject(self, env):
return False
......@@ -27,6 +32,42 @@ class BaseType(object):
else:
return base_code
def __deepcopy__(self, memo):
"""
Types never need to be copied, if we do copy, Unfortunate Things
Will Happen!
"""
return self
def get_fused_types(self, result=None, seen=None):
if self.subtypes:
def add_fused_types(types):
for type in types or ():
if type not in seen:
seen.add(type)
result.append(type)
if result is None:
result = []
seen = cython.set()
for attr in self.subtypes:
list_or_subtype = getattr(self, attr)
if isinstance(list_or_subtype, BaseType):
list_or_subtype.get_fused_types(result, seen)
else:
for subtype in list_or_subtype:
subtype.get_fused_types(result, seen)
return result
return None
is_fused = property(get_fused_types, doc="Whether this type or any of its "
"subtypes is a fused type")
class PyrexType(BaseType):
#
# Base class for all Pyrex types.
......@@ -196,6 +237,7 @@ class CTypedefType(BaseType):
to_py_utility_code = None
from_py_utility_code = None
subtypes = ['typedef_base_type']
def __init__(self, name, base_type, cname, is_external=0):
assert not base_type.is_complex
......@@ -314,6 +356,9 @@ class BufferType(BaseType):
is_buffer = 1
writable = True
subtypes = ['dtype']
def __init__(self, base, dtype, ndim, mode, negative_indices, cast):
self.base = base
self.dtype = dtype
......@@ -618,6 +663,45 @@ class CType(PyrexType):
return 0
class FusedType(CType):
"""
Represents a Fused Type. All it needs to do is keep track of the types
it aggregates, as it will be replaced with its specific version wherever
needed.
See http://wiki.cython.org/enhancements/fusedtypes
types [CSimpleBaseTypeNode] is the list of types to be fused
name str the name of the ctypedef
"""
is_fused = 1
def __init__(self, types):
self.types = types
def declaration_code(self, entity_code, for_display = 0,
dll_linkage = None, pyrex = 0):
if pyrex or for_display:
return self.name
raise Exception("This may never happen, please report a bug")
def __repr__(self):
return 'FusedType(name=%r)' % self.name
def specialize(self, values):
return values[self]
def get_fused_types(self, result=None, seen=None):
if result is None:
return [self]
if self not in seen:
result.append(self)
seen.add(self)
class CVoidType(CType):
#
# C "void" type
......@@ -1532,6 +1616,8 @@ class CArrayType(CType):
is_array = 1
subtypes = ['base_type']
def __init__(self, base_type, size):
self.base_type = base_type
self.size = size
......@@ -1578,6 +1664,8 @@ class CPtrType(CType):
is_ptr = 1
default_value = "0"
subtypes = ['base_type']
def __init__(self, base_type):
self.base_type = base_type
......@@ -1676,6 +1764,8 @@ class CFuncType(CType):
is_cfunction = 1
original_sig = None
subtypes = ['return_type', 'args']
def __init__(self, return_type, args, has_varargs = 0,
exception_value = None, exception_check = 0, calling_convention = "",
nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0,
......@@ -1915,7 +2005,7 @@ class CFuncType(CType):
return self.op_arg_struct.base_type.scope.lookup(arg_name).cname
class CFuncTypeArg(object):
class CFuncTypeArg(BaseType):
# name string
# cname string
# type PyrexType
......@@ -1926,6 +2016,8 @@ class CFuncTypeArg(object):
or_none = False
accept_none = True
subtypes = ['type']
def __init__(self, name, type, pos, cname=None):
self.name = name
if cname is not None:
......@@ -2478,7 +2570,7 @@ def is_promotion(src_type, dst_type):
return src_type.is_float and src_type.rank <= dst_type.rank
return False
def best_match(args, functions, pos=None):
def best_match(args, functions, pos=None, env=None):
"""
Given a list args of arguments and a list of functions, choose one
to call which seems to be the "best" fit for this list of arguments.
......@@ -2546,12 +2638,33 @@ def best_match(args, functions, pos=None):
possibilities = []
bad_types = []
needed_coercions = {}
for func, func_type in candidates:
score = [0,0,0]
for i in range(min(len(args), len(func_type.args))):
src_type = args[i].type
dst_type = func_type.args[i].type
if dst_type.assignable_from(src_type):
assignable = dst_type.assignable_from(src_type)
# Now take care of normal string literals. So when you call a cdef
# function that takes a char *, the coercion will mean that the
# type will simply become bytes. We need to do this coercion
# manually for overloaded and fused functions
if not assignable and src_type.is_pyobject:
if (src_type.is_builtin_type and src_type.name == 'str' and
dst_type.resolve() is c_char_ptr_type):
c_src_type = c_char_ptr_type
else:
c_src_type = src_type.default_coerced_ctype()
if c_src_type:
assignable = dst_type.assignable_from(c_src_type)
if assignable:
src_type = c_src_type
needed_coercions[func] = i, dst_type
if assignable:
if src_type == dst_type or dst_type.same_as(src_type):
pass # score 0
elif is_promotion(src_type, dst_type):
......@@ -2567,18 +2680,28 @@ def best_match(args, functions, pos=None):
break
else:
possibilities.append((score, func)) # so we can sort it
if possibilities:
possibilities.sort()
if len(possibilities) > 1 and possibilities[0][0] == possibilities[1][0]:
if pos is not None:
error(pos, "ambiguous overloaded method")
return None
return possibilities[0][1]
function = possibilities[0][1]
if function in needed_coercions and env:
arg_i, coerce_to_type = needed_coercions[function]
args[arg_i] = args[arg_i].coerce_to(coerce_to_type, env)
return function
if pos is not None:
if len(bad_types) == 1:
error(pos, bad_types[0][1])
else:
error(pos, "no suitable method found")
return None
def widest_numeric_type(type1, type2):
......
......@@ -176,6 +176,7 @@ class Entry(object):
buffer_aux = None
prev_entry = None
might_overflow = 0
fused_cfunction = None
def __init__(self, name, cname, type, pos = None, init = None):
self.name = name
......@@ -241,6 +242,7 @@ class Scope(object):
scope_prefix = ""
in_cinclude = 0
nogil = 0
fused_to_specific = None
def __init__(self, name, outer_scope, parent_scope):
# The outer_scope is the next scope in the lookup chain.
......@@ -279,6 +281,9 @@ class Scope(object):
self.return_type = None
self.id_counters = {}
def __deepcopy__(self, memo):
return self
def start_branching(self, pos):
self.control_flow = self.control_flow.start_branch(pos)
......@@ -677,6 +682,8 @@ class Scope(object):
def lookup_type(self, name):
entry = self.lookup(name)
if entry and entry.is_type:
if entry.type.is_fused and self.fused_to_specific:
return entry.type.specialize(self.fused_to_specific)
return entry.type
def lookup_operator(self, operator, operands):
......
......@@ -225,6 +225,30 @@ class typedef(CythonType):
value = cast(self._basetype, *arg)
return value
class _FusedType(CythonType):
def __call__(self, type, value):
return value
def fused_type(*args):
if not args:
raise TypeError("Expected at least one type as argument")
rank = -1
for type in args:
if type not in (py_int, py_long, py_float, py_complex):
break
if type_ordering.index(type) > rank:
result_type = type
else:
return result_type
# Not a simple numeric type, return a fused type instance. The result
# isn't really meant to be used, as we can't keep track of the context in
# pure-mode. Casting won't do anything in this case.
return _FusedType()
py_int = int
......@@ -277,3 +301,5 @@ for t in int_types + float_types + complex_types + other_types:
void = typedef(None)
NULL = p_void(0)
type_ordering = [py_int, py_long, py_float, py_complex]
\ No newline at end of file
# mode: error
cimport cython
from cython import fused_type
# This is all invalid
ctypedef foo(int) dtype1
ctypedef foo.bar(float) dtype2
ctypedef fused_type(foo) dtype3
dtype4 = cython.typedef(cython.fused_type(int, long, kw=None))
# This is all valid
ctypedef fused_type(int, long, float) dtype5
ctypedef cython.fused_type(int, long) dtype6
_ERRORS = u"""
fused_types.pyx:7:13: Can only fuse types with cython.fused_type()
fused_types.pyx:8:17: Can only fuse types with cython.fused_type()
fused_types.pyx:9:20: 'foo' is not a type identifier
fused_types.pyx:10:23: fused_type does not take keyword arguments
"""
# mode: run
cimport cython
from cpython cimport Py_INCREF
from Cython import Shadow as pure_cython
ctypedef char * string_t
ctypedef cython.fused_type(int, long, float, double, string_t) fused_type1
ctypedef cython.fused_type(string_t) fused_type2
def test_pure():
"""
>>> test_pure()
(10+0j)
"""
mytype = pure_cython.typedef(pure_cython.fused_type(int, long, complex))
print mytype(10)
cdef cdef_func_with_fused_args(fused_type1 x, fused_type1 y, fused_type2 z):
print x, y, z
return x + y
def test_cdef_func_with_fused_args():
"""
>>> test_cdef_func_with_fused_args()
spam ham eggs
spamham
10 20 butter
30
4.2 8.6 bunny
12.8
"""
print cdef_func_with_fused_args('spam', 'ham', 'eggs')
print cdef_func_with_fused_args(10, 20, 'butter')
print cdef_func_with_fused_args(4.2, 8.6, 'bunny')
cdef fused_type1 fused_with_pointer(fused_type1 *array):
for i in range(5):
print array[i]
obj = array[0] + array[1] + array[2] + array[3] + array[4]
# if cython.typeof(fused_type1) is string_t:
Py_INCREF(obj)
return obj
def test_fused_with_pointer():
"""
>>> test_fused_with_pointer()
0
1
2
3
4
10
<BLANKLINE>
0
1
2
3
4
10
<BLANKLINE>
0.0
1.0
2.0
3.0
4.0
10.0
<BLANKLINE>
humpty
dumpty
fall
splatch
breakfast
humptydumptyfallsplatchbreakfast
"""
cdef int int_array[5]
cdef long long_array[5]
cdef float float_array[5]
cdef string_t string_array[5]
cdef char *s1 = "humpty", *s2 = "dumpty", *s3 = "fall", *s4 = "splatch", *s5 = "breakfast"
strings = ["humpty", "dumpty", "fall", "splatch", "breakfast"]
for i in range(5):
int_array[i] = i
long_array[i] = i
float_array[i] = i
s = strings[i]
string_array[i] = s
print fused_with_pointer(int_array)
print
print fused_with_pointer(long_array)
print
print fused_with_pointer(float_array)
print
print fused_with_pointer(string_array)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment