Commit 8b642621 authored by Robert Bradshaw's avatar Robert Bradshaw

Infer common parent of C++ classes for spanning type of pointers.

parent e98f32aa
......@@ -4,6 +4,7 @@
from __future__ import absolute_import
import collections
import copy
import re
......@@ -12,6 +13,7 @@ try:
except NameError:
from functools import reduce
from Cython.Utils import cached_function
from .Code import UtilityCode, LazyUtilityCode, TempitaUtilityCode
from . import StringEncoding
from . import Naming
......@@ -4219,6 +4221,10 @@ def _spanning_type(type1, type2):
return py_object_type
return type2
elif type1.is_ptr and type2.is_ptr:
if type1.base_type.is_cpp_class and type2.base_type.is_cpp_class:
common_base = widest_cpp_type(type1.base_type, type2.base_type)
if common_base:
return CPtrType(common_base)
# incompatible pointers, void* will do as a result
return c_void_ptr_type
else:
......@@ -4236,6 +4242,24 @@ def widest_extension_type(type1, type2):
if type1 is None or type2 is None:
return py_object_type
def widest_cpp_type(type1, type2):
@cached_function
def bases(type):
all = set()
for base in type.base_classes:
all.add(base)
all.update(bases(base))
return all
common_bases = bases(type1).intersection(bases(type2))
common_bases_bases = reduce(set.union, [bases(b) for b in common_bases], set())
candidates = [b for b in common_bases if b not in common_bases_bases]
if len(candidates) == 1:
return candidates[0]
else:
# Fall back to void* for now.
return None
def simple_c_type(signed, longness, name):
# Find type descriptor for simple type given name and modifiers.
# Returns None if arguments don't make sense.
......
# mode: run
# tag: cpp, werror
cdef extern from "shapes.h" namespace "shapes":
cdef cppclass Shape:
float area()
cdef cppclass Circle(Shape):
int radius
Circle(int)
cdef cppclass Square(Shape):
Square(int)
from cython cimport typeof
from cython.operator cimport dereference as d
......@@ -23,3 +34,17 @@ def test_reversed_vector_iteration(L):
incr(it)
print('%s: %s' % (typeof(a), a))
print(typeof(a))
def test_derived_types(int size, bint round):
"""
>>> test_derived_types(5, True)
Shape *
>>> test_derived_types(5, False)
Shape *
"""
if round:
ptr = new Circle(size)
else:
ptr = new Square(size)
print typeof(ptr)
del ptr
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment