Commit b9f31598 authored by Robert Bradshaw's avatar Robert Bradshaw

Merge pull request #437 from robertwb/cpdef-enums

More complete cpdef enums.
parents 0f339147 9a80574b
...@@ -14,6 +14,8 @@ Features added ...@@ -14,6 +14,8 @@ Features added
* The embedded C code comments that show the original source code * The embedded C code comments that show the original source code
can be discarded with the new directive ``emit_code_comments=False``. can be discarded with the new directive ``emit_code_comments=False``.
* Cpdef enums are now first-class iterable, callable types in Python.
* Posix declarations for DLL loading and stdio extensions were added. * Posix declarations for DLL loading and stdio extensions were added.
Patch by Lars Buitinck. Patch by Lars Buitinck.
......
...@@ -553,6 +553,7 @@ class LazyUtilityCode(UtilityCodeBase): ...@@ -553,6 +553,7 @@ class LazyUtilityCode(UtilityCodeBase):
Utility code that calls a callback with the root code writer when Utility code that calls a callback with the root code writer when
available. Useful when you only have 'env' but not 'code'. available. Useful when you only have 'env' but not 'code'.
""" """
__name__ = '<lazy>'
def __init__(self, callback): def __init__(self, callback):
self.callback = callback self.callback = callback
......
...@@ -1927,7 +1927,12 @@ class NameNode(AtomicExprNode): ...@@ -1927,7 +1927,12 @@ class NameNode(AtomicExprNode):
entry = self.entry entry = self.entry
if entry.is_type and entry.type.is_extension_type: if entry.is_type and entry.type.is_extension_type:
self.type_entry = entry self.type_entry = entry
if not (entry.is_const or entry.is_variable if entry.is_type and entry.type.is_enum:
py_entry = Symtab.Entry(self.name, None, py_object_type)
py_entry.is_pyglobal = True
py_entry.scope = self.entry.scope
self.entry = py_entry
elif not (entry.is_const or entry.is_variable
or entry.is_builtin or entry.is_cfunction or entry.is_builtin or entry.is_cfunction
or entry.is_cpp_class): or entry.is_cpp_class):
if self.entry.as_variable: if self.entry.as_variable:
...@@ -6060,7 +6065,7 @@ class AttributeNode(ExprNode): ...@@ -6060,7 +6065,7 @@ class AttributeNode(ExprNode):
node = self.analyse_as_cimported_attribute_node(env, target=False) node = self.analyse_as_cimported_attribute_node(env, target=False)
if node is not None: if node is not None:
return node.entry.type return node.entry.type
node = self.analyse_as_unbound_cmethod_node(env) node = self.analyse_as_type_attribute(env)
if node is not None: if node is not None:
return node.entry.type return node.entry.type
obj_type = self.obj.infer_type(env) obj_type = self.obj.infer_type(env)
...@@ -6087,7 +6092,7 @@ class AttributeNode(ExprNode): ...@@ -6087,7 +6092,7 @@ class AttributeNode(ExprNode):
self.initialized_check = env.directives['initializedcheck'] self.initialized_check = env.directives['initializedcheck']
node = self.analyse_as_cimported_attribute_node(env, target) node = self.analyse_as_cimported_attribute_node(env, target)
if node is None and not target: if node is None and not target:
node = self.analyse_as_unbound_cmethod_node(env) node = self.analyse_as_type_attribute(env)
if node is None: if node is None:
node = self.analyse_as_ordinary_attribute_node(env, target) node = self.analyse_as_ordinary_attribute_node(env, target)
assert node is not None assert node is not None
...@@ -6111,7 +6116,7 @@ class AttributeNode(ExprNode): ...@@ -6111,7 +6116,7 @@ class AttributeNode(ExprNode):
return self.as_name_node(env, entry, target) return self.as_name_node(env, entry, target)
return None return None
def analyse_as_unbound_cmethod_node(self, env): def analyse_as_type_attribute(self, env):
# Try to interpret this as a reference to an unbound # Try to interpret this as a reference to an unbound
# C method of an extension type or builtin type. If successful, # C method of an extension type or builtin type. If successful,
# creates a corresponding NameNode and returns it, otherwise # creates a corresponding NameNode and returns it, otherwise
...@@ -6119,37 +6124,43 @@ class AttributeNode(ExprNode): ...@@ -6119,37 +6124,43 @@ class AttributeNode(ExprNode):
if self.obj.is_string_literal: if self.obj.is_string_literal:
return return
type = self.obj.analyse_as_type(env) type = self.obj.analyse_as_type(env)
if type and (type.is_extension_type or type.is_builtin_type or type.is_cpp_class): if type:
entry = type.scope.lookup_here(self.attribute) if type.is_extension_type or type.is_builtin_type or type.is_cpp_class:
if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction): entry = type.scope.lookup_here(self.attribute)
if type.is_builtin_type: if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction):
if not self.is_called: if type.is_builtin_type:
# must handle this as Python object if not self.is_called:
return None # must handle this as Python object
ubcm_entry = entry return None
else: ubcm_entry = entry
# Create a temporary entry describing the C method
# as an ordinary function.
if entry.func_cname and not hasattr(entry.type, 'op_arg_struct'):
cname = entry.func_cname
if entry.type.is_static_method:
ctype = entry.type
elif type.is_cpp_class:
error(self.pos, "%s not a static member of %s" % (entry.name, type))
ctype = PyrexTypes.error_type
else:
# Fix self type.
ctype = copy.copy(entry.type)
ctype.args = ctype.args[:]
ctype.args[0] = PyrexTypes.CFuncTypeArg('self', type, 'self', None)
else: else:
cname = "%s->%s" % (type.vtabptr_cname, entry.cname) # Create a temporary entry describing the C method
ctype = entry.type # as an ordinary function.
ubcm_entry = Symtab.Entry(entry.name, cname, ctype) if entry.func_cname and not hasattr(entry.type, 'op_arg_struct'):
ubcm_entry.is_cfunction = 1 cname = entry.func_cname
ubcm_entry.func_cname = entry.func_cname if entry.type.is_static_method:
ubcm_entry.is_unbound_cmethod = 1 ctype = entry.type
return self.as_name_node(env, ubcm_entry, target=False) elif type.is_cpp_class:
error(self.pos, "%s not a static member of %s" % (entry.name, type))
ctype = PyrexTypes.error_type
else:
# Fix self type.
ctype = copy.copy(entry.type)
ctype.args = ctype.args[:]
ctype.args[0] = PyrexTypes.CFuncTypeArg('self', type, 'self', None)
else:
cname = "%s->%s" % (type.vtabptr_cname, entry.cname)
ctype = entry.type
ubcm_entry = Symtab.Entry(entry.name, cname, ctype)
ubcm_entry.is_cfunction = 1
ubcm_entry.func_cname = entry.func_cname
ubcm_entry.is_unbound_cmethod = 1
return self.as_name_node(env, ubcm_entry, target=False)
elif type.is_enum:
if self.attribute in type.values:
return self.as_name_node(env, env.lookup(self.attribute), target=False)
else:
error(self.pos, "%s not a known value of %s" % (self.attribute, type))
return None return None
def analyse_as_type(self, env): def analyse_as_type(self, env):
......
...@@ -1488,7 +1488,7 @@ class CEnumDefNode(StatNode): ...@@ -1488,7 +1488,7 @@ class CEnumDefNode(StatNode):
self.entry = env.declare_enum(self.name, self.pos, self.entry = env.declare_enum(self.name, self.pos,
cname = self.cname, typedef_flag = self.typedef_flag, cname = self.cname, typedef_flag = self.typedef_flag,
visibility = self.visibility, api = self.api, visibility = self.visibility, api = self.api,
create_wrapper = self.create_wrapper) create_wrapper = self.create_wrapper and self.name is None)
def analyse_declarations(self, env): def analyse_declarations(self, env):
if self.items is not None: if self.items is not None:
...@@ -1496,6 +1496,16 @@ class CEnumDefNode(StatNode): ...@@ -1496,6 +1496,16 @@ class CEnumDefNode(StatNode):
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
for item in self.items: for item in self.items:
item.analyse_declarations(env, self.entry) item.analyse_declarations(env, self.entry)
if self.name is not None:
self.entry.type.values = set(
(item.name) for item in self.items)
if self.create_wrapper and self.name is not None:
from .UtilityCode import CythonUtilityCode
env.use_utility_code(CythonUtilityCode.load(
"EnumType", "CpdefEnums.pyx",
context={"name": self.name,
"items": tuple(item.name for item in self.items)},
outer_module_scope=env.global_scope()))
def analyse_expressions(self, env): def analyse_expressions(self, env):
return self return self
...@@ -1535,7 +1545,7 @@ class CEnumDefItemNode(StatNode): ...@@ -1535,7 +1545,7 @@ class CEnumDefItemNode(StatNode):
entry = env.declare_const(self.name, enum_entry.type, entry = env.declare_const(self.name, enum_entry.type,
self.value, self.pos, cname = self.cname, self.value, self.pos, cname = self.cname,
visibility = enum_entry.visibility, api = enum_entry.api, visibility = enum_entry.visibility, api = enum_entry.api,
create_wrapper = enum_entry.create_wrapper) create_wrapper = enum_entry.create_wrapper and enum_entry.name is None)
enum_entry.enum_values.append(entry) enum_entry.enum_values.append(entry)
if enum_entry.name: if enum_entry.name:
enum_entry.type.values.append(entry.cname) enum_entry.type.values.append(entry.cname)
......
...@@ -2097,7 +2097,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -2097,7 +2097,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
entry=type_entry, entry=type_entry,
type=type_entry.type), type=type_entry.type),
attribute=attr_name, attribute=attr_name,
is_called=True).analyse_as_unbound_cmethod_node(self.current_env()) is_called=True).analyse_as_type_attribute(self.current_env())
if method is None: if method is None:
return node return node
args = node.args args = node.args
......
...@@ -80,11 +80,37 @@ def use_utility_code_definitions(scope, target, seen=None): ...@@ -80,11 +80,37 @@ def use_utility_code_definitions(scope, target, seen=None):
elif entry.as_module: elif entry.as_module:
use_utility_code_definitions(entry.as_module, target, seen) use_utility_code_definitions(entry.as_module, target, seen)
def sort_utility_codes(utilcodes):
ranks = {}
def get_rank(utilcode):
if utilcode not in ranks:
ranks[utilcode] = 0 # prevent infinite recursion on circular dependencies
original_order = len(ranks)
ranks[utilcode] = 1 + min([get_rank(dep) for dep in utilcode.requires or ()] or [-1]) + original_order * 1e-8
return ranks[utilcode]
for utilcode in utilcodes:
get_rank(utilcode)
return [utilcode for utilcode, _ in sorted(ranks.items(), key=lambda kv: kv[1])]
def normalize_deps(utilcodes):
deps = {}
for utilcode in utilcodes:
deps[utilcode] = utilcode
def unify_dep(dep):
if dep in deps:
return deps[dep]
else:
deps[dep] = dep
return dep
for utilcode in utilcodes:
utilcode.requires = [unify_dep(dep) for dep in utilcode.requires or ()]
def inject_utility_code_stage_factory(context): def inject_utility_code_stage_factory(context):
def inject_utility_code_stage(module_node): def inject_utility_code_stage(module_node):
module_node.prepare_utility_code() module_node.prepare_utility_code()
use_utility_code_definitions(context.cython_scope, module_node.scope) use_utility_code_definitions(context.cython_scope, module_node.scope)
module_node.scope.utility_code_list = sort_utility_codes(module_node.scope.utility_code_list)
normalize_deps(module_node.scope.utility_code_list)
added = [] added = []
# Note: the list might be extended inside the loop (if some utility code # Note: the list might be extended inside the loop (if some utility code
# pulls in other utility code, explicitly or implicitly) # pulls in other utility code, explicitly or implicitly)
......
...@@ -146,6 +146,18 @@ class CythonUtilityCode(Code.UtilityCodeBase): ...@@ -146,6 +146,18 @@ class CythonUtilityCode(Code.UtilityCodeBase):
pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform, pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform,
before=transform) before=transform)
for dep in self.requires:
if (isinstance(dep, CythonUtilityCode)
and hasattr(dep, 'tree')
and not cython_scope):
def scope_transform(module_node):
module_node.scope.merge_in(dep.tree.scope)
return module_node
transform = ParseTreeTransforms.AnalyseDeclarationsTransform
pipeline = Pipeline.insert_into_pipeline(pipeline, scope_transform,
before=transform)
if self.outer_module_scope: if self.outer_module_scope:
# inject outer module between utility code module and builtin module # inject outer module between utility code module and builtin module
def scope_transform(module_node): def scope_transform(module_node):
...@@ -158,6 +170,7 @@ class CythonUtilityCode(Code.UtilityCodeBase): ...@@ -158,6 +170,7 @@ class CythonUtilityCode(Code.UtilityCodeBase):
(err, tree) = Pipeline.run_pipeline(pipeline, tree, printtree=False) (err, tree) = Pipeline.run_pipeline(pipeline, tree, printtree=False)
assert not err, err assert not err, err
self.tree = tree
return tree return tree
def put_code(self, output): def put_code(self, output):
......
#################### EnumBase ####################
cimport cython
cdef extern from *:
int PY_VERSION_HEX
cdef object __Pyx_OrderedDict
if PY_VERSION_HEX >= 0x02070000:
from collections import OrderedDict as __Pyx_OrderedDict
else:
__Pyx_OrderedDict = dict
@cython.internal
cdef class __Pyx_EnumMeta(type):
def __init__(cls, name, parents, dct):
type.__init__(cls, name, parents, dct)
cls.__members__ = __Pyx_OrderedDict()
def __iter__(cls):
return iter(cls.__members__.values())
def __getitem__(cls, name):
return cls.__members__[name]
# @cython.internal
cdef type __Pyx_EnumBase
class __Pyx_EnumBase(int):
__metaclass__ = __Pyx_EnumMeta
def __new__(cls, value, name=None):
for v in cls:
if v == value:
return v
if name is None:
raise ValueError("Unknown enum value: '%s'" % value)
res = int.__new__(cls, value)
res.name = name
setattr(cls, name, res)
cls.__members__[name] = res
return res
def __repr__(self):
return "<%s.%s: %d>" % (self.__class__.__name__, self.name, self)
def __str__(self):
return "%s.%s" % (self.__class__.__name__, self.name)
#################### EnumType ####################
#@requires: EnumBase
cdef dict __Pyx_globals = globals()
if PY_VERSION_HEX >= 0x03040000:
from enum import IntEnum
{{name}} = IntEnum('{{name}}', __Pyx_OrderedDict([
{{for item in items}}
('{{item}}', {{item}}),
{{endfor}}
]))
{{for item in items}}
__Pyx_globals['{{item}}'] = {{name}}.{{item}}
{{endfor}}
else:
class {{name}}(__Pyx_EnumBase):
pass
{{for item in items}}
__Pyx_globals['{{item}}'] = {{name}}({{item}}, '{{item}}')
{{endfor}}
...@@ -153,7 +153,9 @@ the same effect as the C directive ``#pragma pack(1)``. ...@@ -153,7 +153,9 @@ the same effect as the C directive ``#pragma pack(1)``.
cheddar, edam, cheddar, edam,
camembert camembert
cdef enum CheeseState: Declaring an enum as ```cpdef`` will create a PEP 435-style Python wrapper::
cpdef enum CheeseState:
hard = 1 hard = 1
soft = 2 soft = 2
runny = 3 runny = 3
...@@ -803,16 +805,3 @@ Conditional Statements ...@@ -803,16 +805,3 @@ Conditional Statements
.. [#] The conversion is to/from str for Python 2.x, and bytes for Python 3.x. .. [#] The conversion is to/from str for Python 2.x, and bytes for Python 3.x.
...@@ -32,6 +32,15 @@ True ...@@ -32,6 +32,15 @@ True
>>> RANK_3 # doctest: +ELLIPSIS >>> RANK_3 # doctest: +ELLIPSIS
Traceback (most recent call last): Traceback (most recent call last):
NameError: ...name 'RANK_3' is not defined NameError: ...name 'RANK_3' is not defined
>>> set(PyxEnum) == set([TWO, THREE, FIVE])
True
>>> str(PyxEnum.TWO)
'PyxEnum.TWO'
>>> PyxEnum.TWO + PyxEnum.THREE == PyxEnum.FIVE
True
>>> PyxEnum(2) is PyxEnum["TWO"] is PyxEnum.TWO
True
""" """
...@@ -51,3 +60,25 @@ cpdef enum PyxEnum: ...@@ -51,3 +60,25 @@ cpdef enum PyxEnum:
cdef enum SecretPyxEnum: cdef enum SecretPyxEnum:
SEVEN = 7 SEVEN = 7
def test_as_variable_from_cython():
"""
>>> test_as_variable_from_cython()
"""
import sys
if sys.version_info >= (2, 7):
assert list(PyxEnum) == [TWO, THREE, FIVE], list(PyxEnum)
assert list(PxdEnum) == [RANK_0, RANK_1, RANK_2], list(PxdEnum)
else:
# No OrderedDict.
assert set(PyxEnum) == {TWO, THREE, FIVE}, list(PyxEnum)
assert set(PxdEnum) == {RANK_0, RANK_1, RANK_2}, list(PxdEnum)
cdef int verify_pure_c() nogil:
cdef int x = TWO
cdef int y = PyxEnum.THREE
cdef int z = SecretPyxEnum.SEVEN
return x + y + z
# Use it to suppress warning.
verify_pure_c()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment