Commit 96eea730 authored by Robert Bradshaw's avatar Robert Bradshaw

Make cpdef enums into first-class types.

For example

    cpdef enum Eggs
        SOFT
        HARD
        SCRAMBLED

produces three constants with int values that print as strings,
and a type Eggs with attributes Eggs.SOFT, etc. and list(Eggs)
giving the set of all enum values.  Instantiating Eggs with a
numeric or string value will return the appropriate constant.
parent 7da49602
...@@ -14,6 +14,8 @@ Features added ...@@ -14,6 +14,8 @@ Features added
* The embedded C code comments that show the original source code * The embedded C code comments that show the original source code
can be discarded with the new directive ``emit_code_comments=False``. can be discarded with the new directive ``emit_code_comments=False``.
* Cpdef enums are now first-class iterable, callable types in Python.
Bugs fixed Bugs fixed
---------- ----------
......
...@@ -1488,7 +1488,7 @@ class CEnumDefNode(StatNode): ...@@ -1488,7 +1488,7 @@ class CEnumDefNode(StatNode):
self.entry = env.declare_enum(self.name, self.pos, self.entry = env.declare_enum(self.name, self.pos,
cname = self.cname, typedef_flag = self.typedef_flag, cname = self.cname, typedef_flag = self.typedef_flag,
visibility = self.visibility, api = self.api, visibility = self.visibility, api = self.api,
create_wrapper = self.create_wrapper) create_wrapper = self.create_wrapper and self.name is None)
def analyse_declarations(self, env): def analyse_declarations(self, env):
if self.items is not None: if self.items is not None:
...@@ -1496,6 +1496,12 @@ class CEnumDefNode(StatNode): ...@@ -1496,6 +1496,12 @@ class CEnumDefNode(StatNode):
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
for item in self.items: for item in self.items:
item.analyse_declarations(env, self.entry) item.analyse_declarations(env, self.entry)
if self.create_wrapper and self.name is not None:
from .UtilityCode import CythonUtilityCode
env.use_utility_code(CythonUtilityCode.load(
"EnumType", "CpdefEnums.pyx",
context={"name": self.name, "items": tuple(item.name for item in self.items)},
outer_module_scope=env.global_scope()))
def analyse_expressions(self, env): def analyse_expressions(self, env):
return self return self
...@@ -1535,7 +1541,7 @@ class CEnumDefItemNode(StatNode): ...@@ -1535,7 +1541,7 @@ class CEnumDefItemNode(StatNode):
entry = env.declare_const(self.name, enum_entry.type, entry = env.declare_const(self.name, enum_entry.type,
self.value, self.pos, cname = self.cname, self.value, self.pos, cname = self.cname,
visibility = enum_entry.visibility, api = enum_entry.api, visibility = enum_entry.visibility, api = enum_entry.api,
create_wrapper = enum_entry.create_wrapper) create_wrapper = enum_entry.create_wrapper and enum_entry.name is None)
enum_entry.enum_values.append(entry) enum_entry.enum_values.append(entry)
if enum_entry.name: if enum_entry.name:
enum_entry.type.values.append(entry.cname) enum_entry.type.values.append(entry.cname)
......
...@@ -80,11 +80,36 @@ def use_utility_code_definitions(scope, target, seen=None): ...@@ -80,11 +80,36 @@ def use_utility_code_definitions(scope, target, seen=None):
elif entry.as_module: elif entry.as_module:
use_utility_code_definitions(entry.as_module, target, seen) use_utility_code_definitions(entry.as_module, target, seen)
def sort_utility_codes(utilcodes):
ranks = {}
def get_rank(utilcode):
if utilcode not in ranks:
ranks[utilcode] = 0
ranks[utilcode] = 1 + min([get_rank(dep) for dep in utilcode.requires or ()] or [-1])
return ranks[utilcode]
for utilcode in utilcodes:
get_rank(utilcode)
return [utilcode for utilcode, _ in sorted(ranks.items(), key=lambda kv: kv[1])]
def normalize_deps(utilcodes):
deps = {}
for utilcode in utilcodes:
deps[utilcode] = utilcode
def unify_dep(dep):
if dep in deps:
return deps[dep]
else:
deps[dep] = dep
return dep
for utilcode in utilcodes:
utilcode.requires = [unify_dep(dep) for dep in utilcode.requires or ()]
def inject_utility_code_stage_factory(context): def inject_utility_code_stage_factory(context):
def inject_utility_code_stage(module_node): def inject_utility_code_stage(module_node):
module_node.prepare_utility_code() module_node.prepare_utility_code()
use_utility_code_definitions(context.cython_scope, module_node.scope) use_utility_code_definitions(context.cython_scope, module_node.scope)
module_node.scope.utility_code_list = sort_utility_codes(module_node.scope.utility_code_list)
normalize_deps(module_node.scope.utility_code_list)
added = [] added = []
# Note: the list might be extended inside the loop (if some utility code # Note: the list might be extended inside the loop (if some utility code
# pulls in other utility code, explicitly or implicitly) # pulls in other utility code, explicitly or implicitly)
......
...@@ -146,6 +146,16 @@ class CythonUtilityCode(Code.UtilityCodeBase): ...@@ -146,6 +146,16 @@ class CythonUtilityCode(Code.UtilityCodeBase):
pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform, pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform,
before=transform) before=transform)
for dep in self.requires:
if isinstance(dep, CythonUtilityCode):
def scope_transform(module_node):
module_node.scope.merge_in(dep.tree.scope)
return module_node
transform = ParseTreeTransforms.AnalyseDeclarationsTransform
pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform,
before=transform)
if self.outer_module_scope: if self.outer_module_scope:
# inject outer module between utility code module and builtin module # inject outer module between utility code module and builtin module
def scope_transform(module_node): def scope_transform(module_node):
...@@ -158,6 +168,7 @@ class CythonUtilityCode(Code.UtilityCodeBase): ...@@ -158,6 +168,7 @@ class CythonUtilityCode(Code.UtilityCodeBase):
(err, tree) = Pipeline.run_pipeline(pipeline, tree, printtree=False) (err, tree) = Pipeline.run_pipeline(pipeline, tree, printtree=False)
assert not err, err assert not err, err
self.tree = tree
return tree return tree
def put_code(self, output): def put_code(self, output):
......
#################### EnumBase ####################
cimport cython
@cython.internal
cdef class __Pyx_EnumMeta(type):
def __init__(cls, name, parents, dct):
type.__init__(cls, name, parents, dct)
cls.__values__ = []
def __iter__(cls):
return iter(getattr(cls, '__values__', ()))
# @cython.internal
cdef type __Pyx_EnumBase
class __Pyx_EnumBase(int):
__metaclass__ = __Pyx_EnumMeta
def __new__(cls, value, name=None):
for v in cls.__values__:
if v == value or v.name == value:
return v
if name is None:
raise ValueError("Unknown enum value: '%s'" % value)
res = int.__new__(cls, value)
res.name = name
setattr(cls, name, res)
cls.__values__.append(res)
return res
def __repr__(self):
return self.name
def __str__(self):
return self.name
#################### EnumType ####################
#@requires: EnumBase
class {{name}}(__Pyx_EnumBase):
pass
cdef dict __Pyx_globals = globals()
{{for item in items}}
__Pyx_globals['{{item}}'] = {{name}}({{item}}, '{{item}}')
{{endfor}}
...@@ -32,6 +32,13 @@ True ...@@ -32,6 +32,13 @@ True
>>> RANK_3 # doctest: +ELLIPSIS >>> RANK_3 # doctest: +ELLIPSIS
Traceback (most recent call last): Traceback (most recent call last):
NameError: ...name 'RANK_3' is not defined NameError: ...name 'RANK_3' is not defined
>>> list(PyxEnum)
[TWO, THREE, FIVE]
>>> PyxEnum.TWO + PyxEnum.THREE == PyxEnum.FIVE
True
>>> PyxEnum(2) is PyxEnum("TWO") is PyxEnum.TWO
True
""" """
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment