Commit 32b1a50f authored by Stefan Behnel's avatar Stefan Behnel

rewrite of type hierarchy sorting patch

parent 65e017e1
...@@ -6,6 +6,11 @@ import os, time ...@@ -6,6 +6,11 @@ import os, time
from cStringIO import StringIO from cStringIO import StringIO
from PyrexTypes import CPtrType from PyrexTypes import CPtrType
try:
set
except NameError: # Python 2.3
from sets import Set as set
import Annotate import Annotate
import Code import Code
import Naming import Naming
...@@ -19,32 +24,6 @@ from Errors import error ...@@ -19,32 +24,6 @@ from Errors import error
from PyrexTypes import py_object_type from PyrexTypes import py_object_type
from Cython.Utils import open_new_file, replace_suffix from Cython.Utils import open_new_file, replace_suffix
def recurse_vtab_check_inheritance(entry, b, dict):
base = entry
while base is not None:
if base.type.base_type is None or base.type.base_type.vtabstruct_cname is None:
return False
if base.type.base_type.vtabstruct_cname == b.type.vtabstruct_cname:
return True
try:
base = dict[base.type.base_type.vtabstruct_cname]
except KeyError:
return False
return False
def recurse_vtabslot_check_inheritance(entry, b, dict):
base = entry
while base is not None:
if base.type.base_type is None:
return False
if base.type.base_type.objstruct_cname == b.type.objstruct_cname:
return True
try:
base = dict[base.type.base_type.objstruct_cname]
except KeyError:
return False
return False
class ModuleNode(Nodes.Node, Nodes.BlockNode): class ModuleNode(Nodes.Node, Nodes.BlockNode):
# doc string or None # doc string or None
...@@ -278,68 +257,69 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -278,68 +257,69 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.find_referenced_modules(imported_module, module_list, modules_seen) self.find_referenced_modules(imported_module, module_list, modules_seen)
module_list.append(env) module_list.append(env)
def generate_vtab_dict(self, module_list): def collect_inheritance_hierarchies(self, type_dict, getkey):
base_dict = {}
for key, entry in type_dict.items():
hierarchy = set()
base = entry
while base:
base_type = base.type.base_type
if not base_type:
break
base_key = getkey(base_type)
hierarchy.add(base_key)
base = type_dict.get(base_key)
entry.base_keys = hierarchy
base_dict[key] = entry
return base_dict
def sort_types_by_inheritance(self, base_dict):
type_items = base_dict.items()
type_list = []
for i, item in enumerate(type_items):
type_key, new_entry = item
for j in range(i):
entry = type_list[j]
if type_key in entry.base_keys:
type_list.insert(j, new_entry)
break
else:
type_list.append(new_entry)
return type_list
def sort_type_hierarchy(self, module_list, env):
vtab_dict = {} vtab_dict = {}
vtabslot_dict = {}
for module in module_list: for module in module_list:
for entry in module.c_class_entries: for entry in module.c_class_entries:
if not entry.in_cinclude: if not entry.in_cinclude:
type = entry.type type = entry.type
scope = type.scope
if type.vtabstruct_cname: if type.vtabstruct_cname:
vtab_dict[type.vtabstruct_cname]=entry vtab_dict[type.vtabstruct_cname] = entry
return vtab_dict all_defined_here = module is env
def generate_vtab_list(self, vtab_dict):
vtab_list = list()
for entry in vtab_dict.itervalues():
vtab_list.append(entry)
for i in range(0,len(vtab_list)):
for j in range(0,len(vtab_list)):
if(recurse_vtab_check_inheritance(vtab_list[j],vtab_list[i], vtab_dict)==1):
if i > j:
vtab_list.insert(j,vtab_list[i])
if i > j:
vtab_list.pop(i+1)
else:
vtab_list.pop(i)
return vtab_list
def generate_vtabslot_dict(self, module_list, env):
vtab_dict={}
type_entries=[]
for module in module_list:
definition = module is env
if definition:
type_entries.extend( env.type_entries)
else:
for entry in module.type_entries: for entry in module.type_entries:
if entry.defined_in_pxd: if all_defined_here or entry.defined_in_pxd:
type_entries.append(entry)
for entry in type_entries:
type = entry.type type = entry.type
if type.is_extension_type: if type.is_extension_type and not entry.in_cinclude:
if not entry.in_cinclude:
type = entry.type type = entry.type
scope = type.scope vtabslot_dict[type.objstruct_cname] = entry
vtab_dict[type.objstruct_cname]=entry
return vtab_dict def vtabstruct_cname(entry_type):
return entry_type.vtabstruct_cname
def generate_vtabslot_list(self, vtab_dict): vtab_hierarchies = self.sort_types_by_inheritance(
vtab_list = list() self.collect_inheritance_hierarchies(
for entry in vtab_dict.itervalues(): vtab_dict, vtabstruct_cname))
vtab_list.append(entry)
for i in range(0,len(vtab_list)):
for j in range(0,len(vtab_list)):
if(recurse_vtabslot_check_inheritance(vtab_list[j],vtab_list[i], vtab_dict)==1):
if i > j:
vtab_list.insert(j,vtab_list[i])
if i > j:
vtab_list.pop(i+1)
else:
vtab_list.pop(i)
return vtab_list
def objstruct_cname(entry_type):
return entry_type.objstruct_cname
vtabslot_hierarchies = self.sort_types_by_inheritance(
self.collect_inheritance_hierarchies(
vtabslot_dict, objstruct_cname))
return (vtab_hierarchies, vtabslot_hierarchies)
def generate_type_definitions(self, env, modules, vtab_list, vtabslot_list, code): def generate_type_definitions(self, env, modules, vtab_list, vtabslot_list, code):
vtabslot_entries = set(vtabslot_list)
for module in modules: for module in modules:
definition = module is env definition = module is env
if definition: if definition:
...@@ -359,7 +339,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -359,7 +339,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_struct_union_definition(entry, code) self.generate_struct_union_definition(entry, code)
elif type.is_enum: elif type.is_enum:
self.generate_enum_definition(entry, code) self.generate_enum_definition(entry, code)
elif type.is_extension_type and (not (entry in vtabslot_list)): elif type.is_extension_type and entry not in vtabslot_entries:
self.generate_obj_struct_definition(type, code) self.generate_obj_struct_definition(type, code)
for entry in vtabslot_list: for entry in vtabslot_list:
self.generate_obj_struct_definition(entry.type, code) self.generate_obj_struct_definition(entry.type, code)
...@@ -368,19 +348,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -368,19 +348,16 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_exttype_vtable_struct(entry, code) self.generate_exttype_vtable_struct(entry, code)
self.generate_exttype_vtabptr_declaration(entry, code) self.generate_exttype_vtabptr_declaration(entry, code)
def generate_declarations_for_modules(self, env, modules, code): def generate_declarations_for_modules(self, env, modules, code):
code.putln("") code.putln("")
code.putln("/* Declarations */") code.putln("/* Declarations */")
vtab_dict = self.generate_vtab_dict(modules) vtab_list, vtabslot_list = self.sort_type_hierarchy(modules, env)
vtab_list = self.generate_vtab_list(vtab_dict) self.generate_type_definitions(
vtabslot_dict = self.generate_vtabslot_dict(modules,env) env, modules, vtab_list, vtabslot_list, code)
vtabslot_list = self.generate_vtabslot_list(vtabslot_dict)
self.generate_type_definitions(env, modules, vtab_list, vtabslot_list, code)
for module in modules: for module in modules:
definition = module is env defined_here = module is env
self.generate_global_declarations(module, code, definition) self.generate_global_declarations(module, code, defined_here)
self.generate_cfunction_predeclarations(module, code, definition) self.generate_cfunction_predeclarations(module, code, defined_here)
def generate_module_preamble(self, env, cimported_modules, code): def generate_module_preamble(self, env, cimported_modules, code):
code.putln('/* Generated by Cython %s on %s */' % ( code.putln('/* Generated by Cython %s on %s */' % (
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment