Commit 6353844b authored by Xavier Thompson's avatar Xavier Thompson

Fix unsafe cast of the underlying cypclass between unrelated bases

parent d8a312b2
...@@ -378,6 +378,23 @@ builtin_structs_table = [ ...@@ -378,6 +378,23 @@ builtin_structs_table = [
]) ])
] ]
# inject cyobject
def inject_cy_object(self):
global cy_object_type
def init_scope(scope):
scope.is_cpp_class_scope = 1
scope.is_cyp_class_scope = 1
scope.inherited_var_entries = []
scope.inherited_type_entries = []
cy_object_scope = Scope("CyObject", self, None)
init_scope(cy_object_scope)
cy_object_type = PyrexTypes.cy_object_type
cy_object_scope.type = PyrexTypes.cy_object_type
cy_object_type.set_scope(cy_object_scope)
cy_object_entry = self.declare("CyObject", "CyObject", cy_object_type, None, "extern")
cy_object_entry.is_type = 1
# inject acthon interfaces # inject acthon interfaces
def inject_acthon_interfaces(self): def inject_acthon_interfaces(self):
global acthon_result_type, acthon_message_type, acthon_sync_type, acthon_queue_type, acthon_activable_type global acthon_result_type, acthon_message_type, acthon_sync_type, acthon_queue_type, acthon_activable_type
...@@ -677,6 +694,7 @@ def init_builtins(): ...@@ -677,6 +694,7 @@ def init_builtins():
inject_cypclass_refcount_macros() inject_cypclass_refcount_macros()
inject_cypclass_lock_macros() inject_cypclass_lock_macros()
inject_acthon_interfaces(builtin_scope) inject_acthon_interfaces(builtin_scope)
inject_cy_object(builtin_scope)
init_builtins() init_builtins()
...@@ -220,9 +220,6 @@ class CypclassWrapperInjection(CythonTransform): ...@@ -220,9 +220,6 @@ class CypclassWrapperInjection(CythonTransform):
qualified_name, cclass_name, pyclass_name = self.type_to_names[node_type] qualified_name, cclass_name, pyclass_name = self.type_to_names[node_type]
# determine the oldest wrapped base type once and for all
node_type.find_wrapped_base_type()
cclass = self.synthesize_wrapper_cclass(node, cclass_name, qualified_name) cclass = self.synthesize_wrapper_cclass(node, cclass_name, qualified_name)
# mark this cypclass as having synthesized wrappers # mark this cypclass as having synthesized wrappers
...@@ -263,24 +260,25 @@ class CypclassWrapperInjection(CythonTransform): ...@@ -263,24 +260,25 @@ class CypclassWrapperInjection(CythonTransform):
bases_args = [] bases_args = []
first_wrapped_base = node_type.first_wrapped_base
wrapped_bases_iterator = node_type.iter_wrapped_base_types() wrapped_bases_iterator = node_type.iter_wrapped_base_types()
if first_wrapped_base: try:
# consume the first wrapped base from the iterator
first_wrapped_base = next(wrapped_bases_iterator)
first_base_cclass_name = first_wrapped_base.wrapper_type.name first_base_cclass_name = first_wrapped_base.wrapper_type.name
wrapped_first_base = NameNode(node.pos, name=first_base_cclass_name) wrapped_first_base = NameNode(node.pos, name=first_base_cclass_name)
bases_args.append(wrapped_first_base) bases_args.append(wrapped_first_base)
# consume the first wrapped base from the iterator
next(wrapped_bases_iterator)
# use the pyclass wrapper for the other bases # use the pyclass wrapper for the other bases
for other_base in wrapped_bases_iterator: for other_base in wrapped_bases_iterator:
_, __, other_base_pyclass_name = self.type_to_names[other_base] _, __, other_base_pyclass_name = self.type_to_names[other_base]
other_base_arg = NameNode(node.pos, name=other_base_pyclass_name) other_base_arg = NameNode(node.pos, name=other_base_pyclass_name)
bases_args.append(other_base_arg) bases_args.append(other_base_arg)
except StopIteration:
# no bases
pass
return TupleNode(node.pos, args=bases_args) return TupleNode(node.pos, args=bases_args)
def synthesize_wrapper_cclass(self, node, cclass_name, qualified_name): def synthesize_wrapper_cclass(self, node, cclass_name, qualified_name):
...@@ -326,19 +324,12 @@ class CypclassWrapperInjection(CythonTransform): ...@@ -326,19 +324,12 @@ class CypclassWrapperInjection(CythonTransform):
return wrapper return wrapper
def synthesize_underlying_cyobject_attribute(self, node): def synthesize_underlying_cyobject_attribute(self, node):
base_type = node.entry.type.wrapped_base_type base_type = cy_object_type
nesting_path = []
outer_scope = base_type.scope.outer_scope
while outer_scope and not outer_scope.is_module_scope:
nesting_path.append(outer_scope.name)
outer_scope = outer_scope.outer_scope
nesting_path.reverse()
base_type_node = Nodes.CSimpleBaseTypeNode( base_type_node = Nodes.CSimpleBaseTypeNode(
node.pos, node.pos,
name = base_type.name, name = base_type.name,
module_path = nesting_path, module_path = [],
is_basic_c_type = 0, is_basic_c_type = 0,
signed = 1, signed = 1,
complex = 0, complex = 0,
...@@ -1064,7 +1055,7 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry ...@@ -1064,7 +1055,7 @@ def generate_cyp_class_wrapper_definition(type, wrapper_entry, constructor_entry
# initialise PyObject fields # initialise PyObject fields
if is_new_return_type and type.wrapper_type: if is_new_return_type and type.wrapper_type:
objstruct_cname = type.wrapper_type.objstruct_cname objstruct_cname = type.wrapper_type.objstruct_cname
cclass_wrapper_base = type.wrapped_base_type.wrapper_type cclass_wrapper_base = type.wrapped_base_type().wrapper_type
code.putln("if(self) {") code.putln("if(self) {")
code.putln("%s * wrapper = new %s();" % (objstruct_cname, objstruct_cname)) code.putln("%s * wrapper = new %s();" % (objstruct_cname, objstruct_cname))
code.putln("((%s *)wrapper)->nogil_cyobject = self;" % cclass_wrapper_base.objstruct_cname) code.putln("((%s *)wrapper)->nogil_cyobject = self;" % cclass_wrapper_base.objstruct_cname)
......
...@@ -1686,12 +1686,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -1686,12 +1686,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
# for cyp wrappers, just decrement the atomic counter of the underlying type # for cyp wrappers, just decrement the atomic counter of the underlying type
parent_type = scope.parent_type parent_type = scope.parent_type
if parent_type.is_cyp_wrapper: if parent_type.is_cyp_wrapper:
underlying_type_decl = parent_type.wrapped_decl
underlying_attribute_name = CypclassWrapper.underlying_name underlying_attribute_name = CypclassWrapper.underlying_name
self.generate_self_cast(scope, code) self.generate_self_cast(scope, code)
code.putln( code.putln(
"%s * p_nogil_cyobject = (%s *) (p->%s);" "CyObject * p_nogil_cyobject = p->%s;"
% (underlying_type_decl, underlying_type_decl, underlying_attribute_name) % underlying_attribute_name
) )
code.putln("Cy_DECREF(p_nogil_cyobject);") code.putln("Cy_DECREF(p_nogil_cyobject);")
code.putln("}") code.putln("}")
......
...@@ -4179,8 +4179,7 @@ class CypClassType(CppClassType): ...@@ -4179,8 +4179,7 @@ class CypClassType(CppClassType):
# _mro [CppClassType] or None the Method Resolution Order of this cypclass according to Python # _mro [CppClassType] or None the Method Resolution Order of this cypclass according to Python
# support_wrapper boolean whether this cypclass will be wrapped # support_wrapper boolean whether this cypclass will be wrapped
# wrapper_type PyExtensionType or None the type of the cclass wrapper # wrapper_type PyExtensionType or None the type of the cclass wrapper
# first_wrapped_base CypClassType or None the first cypclass base that has a wrapper if there is one # _wrapped_base_type CypClassType or None the type of the oldest wrapped cypclass base
# wrapped_base_type CypClassType or None the type of the oldest wrapped cypclass base
is_cyp_class = 1 is_cyp_class = 1
to_py_function = None to_py_function = None
...@@ -4192,26 +4191,23 @@ class CypClassType(CppClassType): ...@@ -4192,26 +4191,23 @@ class CypClassType(CppClassType):
self._mro = None self._mro = None
self.support_wrapper = False self.support_wrapper = False
self.wrapper_type = None self.wrapper_type = None
self.wrapped_base_type = None self._wrapped_base_type = None
# find the first base that has a wrapper, if there is one # return the oldest left-path superclass such that all intervening classes have a wrapper
# find the oldest superclass such that all intervening classes have a wrapper def wrapped_base_type(self):
def find_wrapped_base_type(self): # if the result has already been computed, return it
# default: the oldest superclass is self and there are no bases if self._wrapped_base_type is not None:
self.wrapped_base_type = self return self._wrapped_base_type
self.first_wrapped_base = None # find the first wrapped base (if there is one) and take the same oldest superclass
# if there are no bases, no need to look further
if not self.base_classes:
return
# otherwise, find the first wrapped base (if there is one) and take the same oldest superclass
for base_type in self.base_classes: for base_type in self.base_classes:
if base_type.is_cyp_class and base_type.support_wrapper: if base_type.is_cyp_class and base_type.support_wrapper:
# this base type is the first wrapped base self._wrapped_base_type = base_type.wrapped_base_type()
self.first_wrapped_base = base_type return self._wrapped_base_type
self.wrapped_base_type = base_type.wrapped_base_type # if no wrapped base was found, this type is the oldest wrapped base
break self._wrapped_base_type = self
return self
# iterate over the bases that support wrapping # iterate over the direct bases that support wrapping
def iter_wrapped_base_types(self): def iter_wrapped_base_types(self):
for base_type in self.base_classes: for base_type in self.base_classes:
if base_type.is_cyp_class and base_type.support_wrapper: if base_type.is_cyp_class and base_type.support_wrapper:
...@@ -4237,20 +4233,6 @@ class CypClassType(CppClassType): ...@@ -4237,20 +4233,6 @@ class CypClassType(CppClassType):
self._mro = mro_C3_merge(inputs) self._mro = mro_C3_merge(inputs)
return self._mro return self._mro
# iterate over the chain of first wrapped bases until the oldest wrapped base is reached
def first_base_iter(self):
type_item = self
while type_item is not self.wrapped_base_type:
type_item = type_item.first_wrapped_base
yield type_item
# iterate down the chain of first wrapped bases until this type is reached
def first_base_rev_iter(self):
if self is not self.wrapped_base_type:
for t in self.first_wrapped_base.first_base_rev_iter():
yield t
yield self
# allow conversion to Python only when there is a wrapper type # allow conversion to Python only when there is a wrapper type
def can_coerce_to_pyobject(self, env): def can_coerce_to_pyobject(self, env):
return self.wrapper_type is not None return self.wrapper_type is not None
...@@ -4264,7 +4246,7 @@ class CypClassType(CppClassType): ...@@ -4264,7 +4246,7 @@ class CypClassType(CppClassType):
def create_from_py_utility_code(self, env): def create_from_py_utility_code(self, env):
if not self.wrapper_type: if not self.wrapper_type:
return False return False
wrapper_objstruct = self.wrapped_base_type.wrapper_type.objstruct_cname wrapper_objstruct = self.wrapped_base_type().wrapper_type.objstruct_cname
underlying_type_name = self.cname underlying_type_name = self.cname
self.from_py_function = "__Pyx_PyObject_AsCyObject<%s, %s>" % (wrapper_objstruct, underlying_type_name) self.from_py_function = "__Pyx_PyObject_AsCyObject<%s, %s>" % (wrapper_objstruct, underlying_type_name)
return True return True
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment