Commit fcc3461e authored by Stefan Behnel's avatar Stefan Behnel

Make vtable order of extension types with fused methods only dependant on the...

Make vtable order of extension types with fused methods only dependant on the original declaration order (e.g. in the .pxd file).
Previously, fused methods were specialised and expanded on first use, which lead to an arbitrary order in the vtable.
Also fixes compile failures when inheriting from base types with fused cdef methods.
Fixes #1873.
parent 625562c9
...@@ -63,6 +63,14 @@ Features added ...@@ -63,6 +63,14 @@ Features added
Bugs fixed Bugs fixed
---------- ----------
* Extension types that were cimported from other Cython modules could disagree
about the order of fused cdef methods in their call table. This could lead
to wrong methods being called and potentially also crashes. The fix required
changes to the ordering of fused methods in the call table, which may break
existing compiled modules that call fused cdef methods across module boundaries,
if these methods were implemented in a different order than they were declared
in the corresponding .pxd file. (Github issue #1873)
* The exception state handling in generators and coroutines could lead to * The exception state handling in generators and coroutines could lead to
exceptions in the caller being lost if an exception was raised and handled exceptions in the caller being lost if an exception was raised and handled
inside of the coroutine when yielding. (Github issue #1731) inside of the coroutine when yielding. (Github issue #1731)
......
...@@ -127,9 +127,6 @@ class FusedCFuncDefNode(StatListNode): ...@@ -127,9 +127,6 @@ class FusedCFuncDefNode(StatListNode):
# len(permutations)) # len(permutations))
# import pprint; pprint.pprint([d for cname, d in permutations]) # import pprint; pprint.pprint([d for cname, d in permutations])
if self.node.entry in env.cfunc_entries:
env.cfunc_entries.remove(self.node.entry)
# Prevent copying of the python function # Prevent copying of the python function
self.orig_py_func = orig_py_func = self.node.py_func self.orig_py_func = orig_py_func = self.node.py_func
self.node.py_func = None self.node.py_func = None
...@@ -139,12 +136,26 @@ class FusedCFuncDefNode(StatListNode): ...@@ -139,12 +136,26 @@ class FusedCFuncDefNode(StatListNode):
fused_types = self.node.type.get_fused_types() fused_types = self.node.type.get_fused_types()
self.fused_compound_types = fused_types self.fused_compound_types = fused_types
new_cfunc_entries = []
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
copied_node = copy.deepcopy(self.node) copied_node = copy.deepcopy(self.node)
# Make the types in our CFuncType specific # Make the types in our CFuncType specific.
type = copied_node.type.specialize(fused_to_specific) type = copied_node.type.specialize(fused_to_specific)
entry = copied_node.entry entry = copied_node.entry
type.specialize_entry(entry, cname)
# Reuse existing Entries (e.g. from .pxd files).
for i, orig_entry in enumerate(env.cfunc_entries):
if entry.cname == orig_entry.cname and type.same_as_resolved_type(orig_entry.type):
copied_node.entry = env.cfunc_entries[i]
if not copied_node.entry.func_cname:
copied_node.entry.func_cname = entry.func_cname
entry = copied_node.entry
type = entry.type
break
else:
new_cfunc_entries.append(entry)
copied_node.type = type copied_node.type = type
entry.type, type.entry = type, entry entry.type, type.entry = type, entry
...@@ -165,9 +176,6 @@ class FusedCFuncDefNode(StatListNode): ...@@ -165,9 +176,6 @@ class FusedCFuncDefNode(StatListNode):
self._specialize_function_args(copied_node.cfunc_declarator.args, self._specialize_function_args(copied_node.cfunc_declarator.args,
fused_to_specific) fused_to_specific)
type.specialize_entry(entry, cname)
env.cfunc_entries.append(entry)
# If a cpdef, declare all specialized cpdefs (this # If a cpdef, declare all specialized cpdefs (this
# also calls analyse_declarations) # also calls analyse_declarations)
copied_node.declare_cpdef_wrapper(env) copied_node.declare_cpdef_wrapper(env)
...@@ -181,6 +189,14 @@ class FusedCFuncDefNode(StatListNode): ...@@ -181,6 +189,14 @@ class FusedCFuncDefNode(StatListNode):
if not self.replace_fused_typechecks(copied_node): if not self.replace_fused_typechecks(copied_node):
break break
# replace old entry with new entries
try:
cindex = env.cfunc_entries.index(self.node.entry)
except ValueError:
env.cfunc_entries.extend(new_cfunc_entries)
else:
env.cfunc_entries[cindex:cindex+1] = new_cfunc_entries
if orig_py_func: if orig_py_func:
self.py_func = self.make_fused_cpdef(orig_py_func, env, self.py_func = self.make_fused_cpdef(orig_py_func, env,
is_def=False) is_def=False)
......
...@@ -1535,6 +1535,13 @@ class ForwardDeclareTypes(CythonTransform): ...@@ -1535,6 +1535,13 @@ class ForwardDeclareTypes(CythonTransform):
def visit_CClassDefNode(self, node): def visit_CClassDefNode(self, node):
if node.class_name not in self.module_scope.entries: if node.class_name not in self.module_scope.entries:
node.declare(self.module_scope) node.declare(self.module_scope)
# Expand fused methods of .pxd declared types to construct the final vtable order.
type = self.module_scope.entries[node.class_name].type
if type is not None and type.is_extension_type and not type.is_builtin_type and type.scope:
scope = type.scope
for entry in scope.cfunc_entries:
if entry.type and entry.type.is_fused:
entry.type.get_all_specialized_function_types()
return node return node
......
...@@ -2884,12 +2884,10 @@ class CFuncType(CType): ...@@ -2884,12 +2884,10 @@ class CFuncType(CType):
elif self.cached_specialized_types is not None: elif self.cached_specialized_types is not None:
return self.cached_specialized_types return self.cached_specialized_types
cfunc_entries = self.entry.scope.cfunc_entries
cfunc_entries.remove(self.entry)
result = [] result = []
permutations = self.get_all_specialized_permutations() permutations = self.get_all_specialized_permutations()
new_cfunc_entries = []
for cname, fused_to_specific in permutations: for cname, fused_to_specific in permutations:
new_func_type = self.entry.type.specialize(fused_to_specific) new_func_type = self.entry.type.specialize(fused_to_specific)
...@@ -2904,7 +2902,15 @@ class CFuncType(CType): ...@@ -2904,7 +2902,15 @@ class CFuncType(CType):
new_func_type.entry = new_entry new_func_type.entry = new_entry
result.append(new_func_type) result.append(new_func_type)
cfunc_entries.append(new_entry) new_cfunc_entries.append(new_entry)
cfunc_entries = self.entry.scope.cfunc_entries
try:
cindex = cfunc_entries.index(self.entry)
except ValueError:
cfunc_entries.extend(new_cfunc_entries)
else:
cfunc_entries[cindex:cindex+1] = new_cfunc_entries
self.cached_specialized_types = result self.cached_specialized_types = result
......
...@@ -800,13 +800,23 @@ class Scope(object): ...@@ -800,13 +800,23 @@ class Scope(object):
type.entry = entry type.entry = entry
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers, inherited=False):
# Add a C function entry without giving it a func_cname. # Add a C function entry without giving it a func_cname.
entry = self.declare(name, cname, type, pos, visibility) entry = self.declare(name, cname, type, pos, visibility)
entry.is_cfunction = 1 entry.is_cfunction = 1
if modifiers: if modifiers:
entry.func_modifiers = modifiers entry.func_modifiers = modifiers
self.cfunc_entries.append(entry) if inherited or type.is_fused:
self.cfunc_entries.append(entry)
else:
# For backwards compatibility reasons, we must keep all non-fused methods
# before all fused methods, but separately for each type.
i = len(self.cfunc_entries)
for cfunc_entry in reversed(self.cfunc_entries):
if cfunc_entry.is_inherited or not cfunc_entry.type.is_fused:
break
i -= 1
self.cfunc_entries.insert(i, entry)
return entry return entry
def find(self, name, pos): def find(self, name, pos):
...@@ -2166,11 +2176,11 @@ class CClassScope(ClassScope): ...@@ -2166,11 +2176,11 @@ class CClassScope(ClassScope):
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers, inherited=False):
# Add a cfunction entry without giving it a func_cname. # Add a cfunction entry without giving it a func_cname.
prev_entry = self.lookup_here(name) prev_entry = self.lookup_here(name)
entry = ClassScope.add_cfunction(self, name, type, pos, cname, entry = ClassScope.add_cfunction(self, name, type, pos, cname,
visibility, modifiers) visibility, modifiers, inherited=inherited)
entry.is_cmethod = 1 entry.is_cmethod = 1
entry.prev_entry = prev_entry entry.prev_entry = prev_entry
return entry return entry
...@@ -2231,7 +2241,7 @@ class CClassScope(ClassScope): ...@@ -2231,7 +2241,7 @@ class CClassScope(ClassScope):
cname = adapt(cname) cname = adapt(cname)
entry = self.add_cfunction(base_entry.name, base_entry.type, entry = self.add_cfunction(base_entry.name, base_entry.type,
base_entry.pos, cname, base_entry.pos, cname,
base_entry.visibility, base_entry.func_modifiers) base_entry.visibility, base_entry.func_modifiers, inherited=True)
entry.is_inherited = 1 entry.is_inherited = 1
if base_entry.is_final_cmethod: if base_entry.is_final_cmethod:
entry.is_final_cmethod = True entry.is_final_cmethod = True
......
"""
PYTHON setup.py build_ext -i
PYTHON main.py
"""
######## main.py ########
from __future__ import absolute_import
from pkg.user import UseRegisters
def test():
from pkg import called
assert called == [], called
ureg = UseRegisters()
assert called == [
'Before setFullFlags',
'setFullFlags was called',
'After setFullFlags',
], called
del called[:]
ureg.call_write()
assert called == [
'Before regWriteWithOpWords',
'regWriteWithOpWords was called',
'regWriteWithOpWords leave function',
'After regWriteWithOpWords',
], called
del called[:]
ureg.call_non_fused()
assert called == [
'Before nonFusedMiddle',
'nonFusedMiddle was called',
'After nonFusedMiddle',
'Before nonFusedBottom',
'nonFusedBottom was called',
'After nonFusedBottom',
'Before nonFusedTop',
'nonFusedTop was called',
'After nonFusedTop',
], called
def test_sub():
from pkg import called
from pkg.registers import SubRegisters
ureg = UseRegisters(reg_type=SubRegisters)
del called[:]
ureg.call_sub()
assert called == [
'Before nonFusedSub',
'nonFusedSub was called',
'After nonFusedSub',
'Before fusedSub',
'fusedSub was called',
'After fusedSub',
], called
test()
test_sub()
######## setup.py ########
from distutils.core import setup
from Cython.Build import cythonize
setup(ext_modules = cythonize('pkg/*.pyx'))
######## pkg/__init__.py ########
called = []
######## pkg/user.pxd ########
from libc.stdint cimport *
from .registers cimport Registers, SubRegisters
cdef class UseRegisters:
cdef Registers registers
######## pkg/user.pyx ########
from . import called
cdef class UseRegisters:
def __init__(self, reg_type=Registers):
self.registers = reg_type()
called.append("Before setFullFlags")
self.registers.setFullFlags(<uint32_t>0xffffffff, <uint32_t>0)
called.append("After setFullFlags")
def call_write(self):
called.append("Before regWriteWithOpWords")
self.registers.regWriteWithOpWords(0, <uint32_t>0)
called.append("After regWriteWithOpWords")
def call_non_fused(self):
called.append("Before nonFusedMiddle")
self.registers.nonFusedMiddle(0, <uint32_t>0)
called.append("After nonFusedMiddle")
called.append("Before nonFusedBottom")
self.registers.nonFusedBottom(0, <uint32_t>0)
called.append("After nonFusedBottom")
called.append("Before nonFusedTop")
self.registers.nonFusedTop(0, <uint32_t>0)
called.append("After nonFusedTop")
def call_sub(self):
assert isinstance(self.registers, SubRegisters), type(self.registers)
called.append("Before nonFusedSub")
(<SubRegisters>self.registers).nonFusedSub(0, <uint32_t>0)
called.append("After nonFusedSub")
called.append("Before fusedSub")
(<SubRegisters>self.registers).fusedSub(0, <uint32_t>0)
called.append("After fusedSub")
######## pkg/registers.pxd ########
from libc.stdint cimport *
cdef:
ctypedef fused uint8_t_uint16_t_uint32_t:
uint8_t
uint16_t
uint32_t
ctypedef fused uint16_t_uint32_t_uint64_t:
uint16_t
uint32_t
uint64_t
cdef class Registers:
cdef uint64_t regs[1]
cdef void nonFusedTop(self, uint16_t regId, uint32_t value)
cdef void regWriteWithOpWords(self, uint16_t regId, uint16_t_uint32_t_uint64_t value)
cdef void nonFusedMiddle(self, uint16_t regId, uint32_t value)
cdef void setFullFlags(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1)
cdef void nonFusedBottom(self, uint16_t regId, uint32_t value)
cdef void lastFusedImplFirst(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1)
cdef class SubRegisters(Registers):
cdef void fusedSub(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1)
cdef void nonFusedSub(self, uint16_t regId, uint32_t value)
######## pkg/registers.pyx ########
from . import called
cdef class Registers:
def __init__(self):
pass
cdef void lastFusedImplFirst(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1):
called.append("lastFusedImplFirst was called")
cdef void nonFusedTop(self, uint16_t regId, uint32_t value):
called.append("nonFusedTop was called")
cdef void regWriteWithOpWords(self, uint16_t regId, uint16_t_uint32_t_uint64_t value):
called.append("regWriteWithOpWords was called")
self.regs[regId] = value
called.append("regWriteWithOpWords leave function")
cdef void nonFusedMiddle(self, uint16_t regId, uint32_t value):
called.append("nonFusedMiddle was called")
cdef void setFullFlags(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1):
called.append("setFullFlags was called")
cdef void nonFusedBottom(self, uint16_t regId, uint32_t value):
called.append("nonFusedBottom was called")
cdef class SubRegisters(Registers):
cdef void fusedSub(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1):
called.append("fusedSub was called")
cdef void nonFusedSub(self, uint16_t regId, uint32_t value):
called.append("nonFusedSub was called")
######## pkg/sub.pxd ########
from .registers cimport *
cdef class SubSubRegisters(SubRegisters):
cdef void fusedSubSubFirst(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1)
cdef void nonFusedSubSub(self, uint16_t regId, uint32_t value)
cdef void fusedSubSubLast(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1)
######## pkg/sub.pyx ########
from . import called
cdef class SubSubRegisters(SubRegisters):
cdef void fusedSubSubFirst(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1):
called.append("fusedSubSubFirst was called")
cdef void nonFusedSubSub(self, uint16_t regId, uint32_t value):
called.append("nonFusedSubSub was called")
cdef void fusedSubSubLast(self, uint8_t_uint16_t_uint32_t reg0, uint32_t reg1):
called.append("fusedSubSubLast was called")
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment