Commit b614aa32 authored by Stefan Behnel's avatar Stefan Behnel

clean up and simplify signature matching test

parent df57c172
...@@ -4,17 +4,17 @@ from Cython.Compiler import PyrexTypes as pt ...@@ -4,17 +4,17 @@ from Cython.Compiler import PyrexTypes as pt
from Cython.Compiler.ExprNodes import NameNode from Cython.Compiler.ExprNodes import NameNode
from Cython.Compiler.PyrexTypes import CFuncTypeArg from Cython.Compiler.PyrexTypes import CFuncTypeArg
def cfunctype(*arg_types):
return pt.CFuncType(pt.c_int_type,
[ CFuncTypeArg("name", arg_type, None) for arg_type in arg_types ])
def cppclasstype(name, base_classes):
return pt.CppClassType(name, None, 'CPP_'+name, base_classes)
class SignatureMatcherTest(unittest.TestCase): class SignatureMatcherTest(unittest.TestCase):
""" """
Test the signature matching algorithm for overloaded signatures. Test the signature matching algorithm for overloaded signatures.
""" """
def _cfunctype(self, return_type, *arg_types):
return pt.CFuncType(return_type,
[ CFuncTypeArg("name", arg_type, None) for arg_type in arg_types ])
def _cppclasstype(self, name, base_classes):
return pt.CppClassType(name, None, 'CPP_'+name, base_classes)
def assertMatches(self, expected_type, arg_types, functions): def assertMatches(self, expected_type, arg_types, functions):
args = [ NameNode(None, type=arg_type) for arg_type in arg_types ] args = [ NameNode(None, type=arg_type) for arg_type in arg_types ]
match = pt.best_match(args, functions) match = pt.best_match(args, functions)
...@@ -24,9 +24,9 @@ class SignatureMatcherTest(unittest.TestCase): ...@@ -24,9 +24,9 @@ class SignatureMatcherTest(unittest.TestCase):
def test_cpp_reference_single_arg(self): def test_cpp_reference_single_arg(self):
function_types = [ function_types = [
self._cfunctype(pt.c_int_type, pt.CReferenceType(pt.c_int_type)), cfunctype(pt.CReferenceType(pt.c_int_type)),
self._cfunctype(pt.c_long_type, pt.CReferenceType(pt.c_long_type)), cfunctype(pt.CReferenceType(pt.c_long_type)),
self._cfunctype(pt.c_double_type, pt.CReferenceType(pt.c_double_type)), cfunctype(pt.CReferenceType(pt.c_double_type)),
] ]
functions = [ NameNode(None, type=t) for t in function_types ] functions = [ NameNode(None, type=t) for t in function_types ]
...@@ -36,11 +36,9 @@ class SignatureMatcherTest(unittest.TestCase): ...@@ -36,11 +36,9 @@ class SignatureMatcherTest(unittest.TestCase):
def test_cpp_reference_two_args(self): def test_cpp_reference_two_args(self):
function_types = [ function_types = [
self._cfunctype( cfunctype(
pt.c_int_type,
pt.CReferenceType(pt.c_int_type), pt.CReferenceType(pt.c_long_type)), pt.CReferenceType(pt.c_int_type), pt.CReferenceType(pt.c_long_type)),
self._cfunctype( cfunctype(
pt.c_int_type,
pt.CReferenceType(pt.c_long_type), pt.CReferenceType(pt.c_long_type)), pt.CReferenceType(pt.c_long_type), pt.CReferenceType(pt.c_long_type)),
] ]
...@@ -50,10 +48,10 @@ class SignatureMatcherTest(unittest.TestCase): ...@@ -50,10 +48,10 @@ class SignatureMatcherTest(unittest.TestCase):
self.assertMatches(function_types[1], [pt.c_long_type, pt.c_int_type], functions) self.assertMatches(function_types[1], [pt.c_long_type, pt.c_int_type], functions)
def test_cpp_reference_cpp_class(self): def test_cpp_reference_cpp_class(self):
classes = [ self._cppclasstype("Test%d"%i, []) for i in range(2) ] classes = [ cppclasstype("Test%d"%i, []) for i in range(2) ]
function_types = [ function_types = [
self._cfunctype(pt.c_int_type, pt.CReferenceType(classes[0])), cfunctype(pt.CReferenceType(classes[0])),
self._cfunctype(pt.c_int_type, pt.CReferenceType(classes[1])), cfunctype(pt.CReferenceType(classes[1])),
] ]
functions = [ NameNode(None, type=t) for t in function_types ] functions = [ NameNode(None, type=t) for t in function_types ]
...@@ -61,14 +59,16 @@ class SignatureMatcherTest(unittest.TestCase): ...@@ -61,14 +59,16 @@ class SignatureMatcherTest(unittest.TestCase):
self.assertMatches(function_types[1], [classes[1]], functions) self.assertMatches(function_types[1], [classes[1]], functions)
def test_cpp_reference_cpp_class_and_int(self): def test_cpp_reference_cpp_class_and_int(self):
classes = [ self._cppclasstype("Test%d"%i, []) for i in range(2) ] classes = [ cppclasstype("Test%d"%i, []) for i in range(2) ]
function_types = [ function_types = [
self._cfunctype(pt.c_int_type, cfunctype(pt.CReferenceType(classes[0]), pt.c_int_type),
pt.CReferenceType(classes[0]), pt.c_int_type), cfunctype(pt.CReferenceType(classes[0]), pt.c_long_type),
self._cfunctype(pt.c_int_type, cfunctype(pt.CReferenceType(classes[1]), pt.c_int_type),
pt.CReferenceType(classes[1]), pt.c_long_type), cfunctype(pt.CReferenceType(classes[1]), pt.c_long_type),
] ]
functions = [ NameNode(None, type=t) for t in function_types ] functions = [ NameNode(None, type=t) for t in function_types ]
self.assertMatches(function_types[0], [classes[0], pt.c_int_type], functions) self.assertMatches(function_types[0], [classes[0], pt.c_int_type], functions)
self.assertMatches(function_types[1], [classes[1], pt.c_int_type], functions) self.assertMatches(function_types[1], [classes[0], pt.c_long_type], functions)
self.assertMatches(function_types[2], [classes[1], pt.c_int_type], functions)
self.assertMatches(function_types[3], [classes[1], pt.c_long_type], functions)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment