Commit f4e4484f authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn Committed by Mark Florisson

Support utility code written in Cython

parent 4855fccb
...@@ -57,6 +57,9 @@ class UtilityCode(object): ...@@ -57,6 +57,9 @@ class UtilityCode(object):
self.specialize_list = [] self.specialize_list = []
self.proto_block = proto_block self.proto_block = proto_block
def get_tree(self):
pass
def specialize(self, pyrex_type=None, **data): def specialize(self, pyrex_type=None, **data):
# Dicts aren't hashable... # Dicts aren't hashable...
if pyrex_type is not None: if pyrex_type is not None:
......
...@@ -69,16 +69,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -69,16 +69,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
selfscope = self.scope selfscope = self.scope
selfscope.utility_code_list.extend(scope.utility_code_list) selfscope.utility_code_list.extend(scope.utility_code_list)
if merge_scope: if merge_scope:
selfscope.entries.update(scope.entries) selfscope.merge_in(scope)
for x in ('const_entries',
'type_entries',
'sue_entries',
'arg_entries',
'var_entries',
'pyfunc_entries',
'cfunc_entries',
'c_class_entries'):
getattr(selfscope, x).extend(getattr(scope, x))
def analyse_declarations(self, env): def analyse_declarations(self, env):
if Options.embed_pos_in_docstring: if Options.embed_pos_in_docstring:
......
...@@ -60,11 +60,22 @@ def inject_pxd_code_stage_factory(context): ...@@ -60,11 +60,22 @@ def inject_pxd_code_stage_factory(context):
return module_node return module_node
return inject_pxd_code_stage return inject_pxd_code_stage
def inject_utility_code_stage(module_node):
added = []
# need to copy list as the list will be altered!
for utilcode in module_node.scope.utility_code_list[:]:
if utilcode in added: continue
added.append(utilcode)
tree = utilcode.get_tree()
if tree:
module_node.merge_in(tree.body, tree.scope, merge_scope=True)
return module_node
# #
# Pipeline factories # Pipeline factories
# #
def create_pipeline(context, mode): def create_pipeline(context, mode, exclude_classes=()):
assert mode in ('pyx', 'py', 'pxd') assert mode in ('pyx', 'py', 'pxd')
from Visitor import PrintTree from Visitor import PrintTree
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
...@@ -98,7 +109,7 @@ def create_pipeline(context, mode): ...@@ -98,7 +109,7 @@ def create_pipeline(context, mode):
else: else:
_align_function_definitions = None _align_function_definitions = None
return [ stages = [
NormalizeTree(context), NormalizeTree(context),
PostParse(context), PostParse(context),
_specific_post_parse, _specific_post_parse,
...@@ -134,9 +145,13 @@ def create_pipeline(context, mode): ...@@ -134,9 +145,13 @@ def create_pipeline(context, mode):
FinalOptimizePhase(context), FinalOptimizePhase(context),
GilCheck(), GilCheck(),
] ]
filtered_stages = []
for s in stages:
if s.__class__ not in exclude_classes:
filtered_stages.append(s)
return filtered_stages
def create_pyx_pipeline(context, options, result, py=False, exclude_classes=()):
def create_pyx_pipeline(context, options, result, py=False):
if py: if py:
mode = 'py' mode = 'py'
else: else:
...@@ -157,9 +172,10 @@ def create_pyx_pipeline(context, options, result, py=False): ...@@ -157,9 +172,10 @@ def create_pyx_pipeline(context, options, result, py=False):
return list(itertools.chain( return list(itertools.chain(
[parse_stage_factory(context)], [parse_stage_factory(context)],
create_pipeline(context, mode), create_pipeline(context, mode, exclude_classes=exclude_classes),
test_support, test_support,
[inject_pxd_code_stage_factory(context), [inject_pxd_code_stage_factory(context),
inject_utility_code_stage,
abort_on_errors], abort_on_errors],
debug_transform, debug_transform,
[generate_pyx_code_stage_factory(options, result)])) [generate_pyx_code_stage_factory(options, result)]))
...@@ -184,17 +200,15 @@ def create_pyx_as_pxd_pipeline(context, result): ...@@ -184,17 +200,15 @@ def create_pyx_as_pxd_pipeline(context, result):
from Optimize import ConstantFolding, FlattenInListTransform from Optimize import ConstantFolding, FlattenInListTransform
from Nodes import StatListNode from Nodes import StatListNode
pipeline = [] pipeline = []
pyx_pipeline = create_pyx_pipeline(context, context.options, result) pyx_pipeline = create_pyx_pipeline(context, context.options, result,
exclude_classes=[
AlignFunctionDefinitions,
MarkClosureVisitor,
ConstantFolding,
FlattenInListTransform,
WithTransform
])
for stage in pyx_pipeline: for stage in pyx_pipeline:
if stage.__class__ in [
AlignFunctionDefinitions,
MarkClosureVisitor,
ConstantFolding,
FlattenInListTransform,
WithTransform,
]:
# Skip these unnecessary stages.
continue
pipeline.append(stage) pipeline.append(stage)
if isinstance(stage, AnalyseDeclarationsTransform): if isinstance(stage, AnalyseDeclarationsTransform):
# This is the last stage we need. # This is the last stage we need.
......
...@@ -1968,6 +1968,9 @@ class StructUtilityCode(object): ...@@ -1968,6 +1968,9 @@ class StructUtilityCode(object):
def __hash__(self): def __hash__(self):
return hash(self.header) return hash(self.header)
def get_tree(self):
pass
def put_code(self, output): def put_code(self, output):
code = output['utility_code_def'] code = output['utility_code_def']
proto = output['utility_code_proto'] proto = output['utility_code_proto']
...@@ -1995,6 +1998,9 @@ class StructUtilityCode(object): ...@@ -1995,6 +1998,9 @@ class StructUtilityCode(object):
proto.putln(self.type.declaration_code('') + ';') proto.putln(self.type.declaration_code('') + ';')
proto.putln(self.header + ";") proto.putln(self.header + ";")
def inject_tree_and_scope_into(self, module_node):
pass
class CStructOrUnionType(CType): class CStructOrUnionType(CType):
# name string # name string
......
...@@ -122,6 +122,9 @@ class Entry(object): ...@@ -122,6 +122,9 @@ class Entry(object):
# assignments [ExprNode] List of expressions that get assigned to this entry. # assignments [ExprNode] List of expressions that get assigned to this entry.
# might_overflow boolean In an arithmetic expression that could cause # might_overflow boolean In an arithmetic expression that could cause
# overflow (used for type inference). # overflow (used for type inference).
# utility_code_definition For some Cython builtins, the utility code
# which contains the definition of the entry.
# Currently only supported for CythonScope entries.
inline_func_in_pxd = False inline_func_in_pxd = False
borrowed = 0 borrowed = 0
...@@ -173,6 +176,7 @@ class Entry(object): ...@@ -173,6 +176,7 @@ class Entry(object):
buffer_aux = None buffer_aux = None
prev_entry = None prev_entry = None
might_overflow = 0 might_overflow = 0
utility_code_definition = None
in_with_gil_block = 0 in_with_gil_block = 0
def __init__(self, name, cname, type, pos = None, init = None): def __init__(self, name, cname, type, pos = None, init = None):
...@@ -274,6 +278,20 @@ class Scope(object): ...@@ -274,6 +278,20 @@ class Scope(object):
self.return_type = None self.return_type = None
self.id_counters = {} self.id_counters = {}
def merge_in(self, other):
# Use with care...
self.entries.update(other.entries)
for x in ('const_entries',
'type_entries',
'sue_entries',
'arg_entries',
'var_entries',
'pyfunc_entries',
'cfunc_entries',
'c_class_entries'):
getattr(self, x).extend(getattr(other, x))
def __str__(self): def __str__(self):
return "<%s %s>" % (self.__class__.__name__, self.qualified_name) return "<%s %s>" % (self.__class__.__name__, self.qualified_name)
...@@ -879,8 +897,9 @@ class ModuleScope(Scope): ...@@ -879,8 +897,9 @@ class ModuleScope(Scope):
has_import_star = 0 has_import_star = 0
def __init__(self, name, parent_module, context): def __init__(self, name, parent_module, context):
import Builtin
self.parent_module = parent_module self.parent_module = parent_module
outer_scope = context.find_submodule("__builtin__") outer_scope = Builtin.builtin_scope
Scope.__init__(self, name, outer_scope, parent_module) Scope.__init__(self, name, outer_scope, parent_module)
if name != "__init__": if name != "__init__":
self.module_name = name self.module_name = name
...@@ -921,7 +940,8 @@ class ModuleScope(Scope): ...@@ -921,7 +940,8 @@ class ModuleScope(Scope):
entry = self.lookup_here(name) entry = self.lookup_here(name)
if entry is not None: if entry is not None:
return entry return entry
return self.outer_scope.lookup(name, language_level = self.context.language_level) language_level = self.context.language_level if self.context is not None else 3
return self.outer_scope.lookup(name, language_level=language_level)
def declare_builtin(self, name, pos): def declare_builtin(self, name, pos):
if not hasattr(builtins, name) \ if not hasattr(builtins, name) \
......
...@@ -20,7 +20,8 @@ Support for parsing strings into code trees. ...@@ -20,7 +20,8 @@ Support for parsing strings into code trees.
""" """
class StringParseContext(Main.Context): class StringParseContext(Main.Context):
def __init__(self, include_directories, name): def __init__(self, name, include_directories=None):
if include_directories is None: include_directories = []
Main.Context.__init__(self, include_directories, {}) Main.Context.__init__(self, include_directories, {})
self.module_name = name self.module_name = name
...@@ -29,7 +30,8 @@ class StringParseContext(Main.Context): ...@@ -29,7 +30,8 @@ class StringParseContext(Main.Context):
raise AssertionError("Not yet supporting any cimports/includes from string code snippets") raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
return ModuleScope(module_name, parent_module = None, context = self) return ModuleScope(module_name, parent_module = None, context = self)
def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None): def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None,
context=None):
""" """
Utility method to parse a (unicode) string of code. This is mostly Utility method to parse a (unicode) string of code. This is mostly
used for internal Cython compiler purposes (creating code snippets used for internal Cython compiler purposes (creating code snippets
...@@ -37,8 +39,14 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None): ...@@ -37,8 +39,14 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
code - a unicode string containing Cython (module-level) code code - a unicode string containing Cython (module-level) code
name - a descriptive name for the code source (to use in error messages etc.) name - a descriptive name for the code source (to use in error messages etc.)
"""
RETURNS
The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
set to the scope used when parsing.
"""
if context is None:
context = StringParseContext(name)
# Since source files carry an encoding, it makes sense in this context # Since source files carry an encoding, it makes sense in this context
# to use a unicode string so that code fragments don't have to bother # to use a unicode string so that code fragments don't have to bother
# with encoding. This means that test code passed in should not have an # with encoding. This means that test code passed in should not have an
...@@ -51,7 +59,6 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None): ...@@ -51,7 +59,6 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
initial_pos = (name, 1, 0) initial_pos = (name, 1, 0)
code_source = StringSourceDescriptor(name, code) code_source = StringSourceDescriptor(name, code)
context = StringParseContext([], name)
scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0) scope = context.find_module(module_name, pos = initial_pos, need_pxd = 0)
buf = StringIO(code) buf = StringIO(code)
...@@ -61,8 +68,10 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None): ...@@ -61,8 +68,10 @@ def parse_from_strings(name, code, pxds={}, level=None, initial_pos=None):
if level is None: if level is None:
tree = Parsing.p_module(scanner, 0, module_name) tree = Parsing.p_module(scanner, 0, module_name)
tree.scope = scope tree.scope = scope
tree.is_pxd = False
else: else:
tree = Parsing.p_code(scanner, level=level) tree = Parsing.p_code(scanner, level=level)
tree.scope = scope
return tree return tree
class TreeCopier(VisitorTransform): class TreeCopier(VisitorTransform):
......
from TreeFragment import parse_from_strings, StringParseContext
from Scanning import StringSourceDescriptor
import Symtab
import Naming
class NonManglingModuleScope(Symtab.ModuleScope):
def mangle(self, prefix, name=None):
if name:
if prefix in (Naming.typeobj_prefix, Naming.func_prefix):
# Functions, classes etc. gets a manually defined prefix easily
# manually callable instead (the one passed to CythonUtilityCode)
prefix = self.prefix
return "%s%s" % (prefix, name)
else:
return self.base.name
class CythonUtilityCodeContext(StringParseContext):
scope = None
def find_module(self, module_name, relative_to = None, pos = None, need_pxd = 1):
if module_name != self.module_name:
raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
if self.scope is None:
self.scope = NonManglingModuleScope(module_name, parent_module = None, context = self)
self.scope.prefix = self.prefix
return self.scope
class CythonUtilityCode:
"""
Utility code written in the Cython language itself.
"""
def __init__(self, pyx, name="<utility code>", prefix=""):
# 1) We need to delay the parsing/processing, so that all modules can be
# imported without import loops
# 2) The same utility code object can be used for multiple source files;
# while the generated node trees can be altered in the compilation of a
# single file.
# Hence, delay any processing until later.
self.pyx = pyx
self.name = name
self.prefix = prefix
def get_tree(self):
from AnalysedTreeTransforms import AutoTestDictTransform
excludes = [AutoTestDictTransform]
import Pipeline
context = CythonUtilityCodeContext(self.name)
context.prefix = self.prefix
tree = parse_from_strings(self.name, self.pyx, context=context)
pipeline = Pipeline.create_pipeline(context, 'pyx', exclude_classes=excludes)
(err, tree) = Pipeline.run_pipeline(pipeline, tree)
assert not err
return tree
def put_code(self, output):
pass
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment