Commit 8e34cc36 authored by Robert Bradshaw's avatar Robert Bradshaw

Allow enum type access patterns from Cython as well as Python.

parent 60e419e6
...@@ -1859,7 +1859,12 @@ class NameNode(AtomicExprNode): ...@@ -1859,7 +1859,12 @@ class NameNode(AtomicExprNode):
entry = self.entry entry = self.entry
if entry.is_type and entry.type.is_extension_type: if entry.is_type and entry.type.is_extension_type:
self.type_entry = entry self.type_entry = entry
if not (entry.is_const or entry.is_variable if entry.is_type and entry.type.is_enum:
py_entry = Symtab.Entry(self.name, None, py_object_type)
py_entry.is_pyglobal = True
py_entry.scope = self.entry.scope
self.entry = py_entry
elif not (entry.is_const or entry.is_variable
or entry.is_builtin or entry.is_cfunction or entry.is_builtin or entry.is_cfunction
or entry.is_cpp_class): or entry.is_cpp_class):
if self.entry.as_variable: if self.entry.as_variable:
...@@ -5832,7 +5837,7 @@ class AttributeNode(ExprNode): ...@@ -5832,7 +5837,7 @@ class AttributeNode(ExprNode):
node = self.analyse_as_cimported_attribute_node(env, target=False) node = self.analyse_as_cimported_attribute_node(env, target=False)
if node is not None: if node is not None:
return node.entry.type return node.entry.type
node = self.analyse_as_unbound_cmethod_node(env) node = self.analyse_as_type_attribute(env)
if node is not None: if node is not None:
return node.entry.type return node.entry.type
obj_type = self.obj.infer_type(env) obj_type = self.obj.infer_type(env)
...@@ -5859,7 +5864,7 @@ class AttributeNode(ExprNode): ...@@ -5859,7 +5864,7 @@ class AttributeNode(ExprNode):
self.initialized_check = env.directives['initializedcheck'] self.initialized_check = env.directives['initializedcheck']
node = self.analyse_as_cimported_attribute_node(env, target) node = self.analyse_as_cimported_attribute_node(env, target)
if node is None and not target: if node is None and not target:
node = self.analyse_as_unbound_cmethod_node(env) node = self.analyse_as_type_attribute(env)
if node is None: if node is None:
node = self.analyse_as_ordinary_attribute_node(env, target) node = self.analyse_as_ordinary_attribute_node(env, target)
assert node is not None assert node is not None
...@@ -5883,7 +5888,7 @@ class AttributeNode(ExprNode): ...@@ -5883,7 +5888,7 @@ class AttributeNode(ExprNode):
return self.as_name_node(env, entry, target) return self.as_name_node(env, entry, target)
return None return None
def analyse_as_unbound_cmethod_node(self, env): def analyse_as_type_attribute(self, env):
# Try to interpret this as a reference to an unbound # Try to interpret this as a reference to an unbound
# C method of an extension type or builtin type. If successful, # C method of an extension type or builtin type. If successful,
# creates a corresponding NameNode and returns it, otherwise # creates a corresponding NameNode and returns it, otherwise
...@@ -5891,7 +5896,8 @@ class AttributeNode(ExprNode): ...@@ -5891,7 +5896,8 @@ class AttributeNode(ExprNode):
if self.obj.is_string_literal: if self.obj.is_string_literal:
return return
type = self.obj.analyse_as_type(env) type = self.obj.analyse_as_type(env)
if type and (type.is_extension_type or type.is_builtin_type or type.is_cpp_class): if type:
if type.is_extension_type or type.is_builtin_type or type.is_cpp_class:
entry = type.scope.lookup_here(self.attribute) entry = type.scope.lookup_here(self.attribute)
if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction): if entry and (entry.is_cmethod or type.is_cpp_class and entry.type.is_cfunction):
if type.is_builtin_type: if type.is_builtin_type:
...@@ -5922,6 +5928,11 @@ class AttributeNode(ExprNode): ...@@ -5922,6 +5928,11 @@ class AttributeNode(ExprNode):
ubcm_entry.func_cname = entry.func_cname ubcm_entry.func_cname = entry.func_cname
ubcm_entry.is_unbound_cmethod = 1 ubcm_entry.is_unbound_cmethod = 1
return self.as_name_node(env, ubcm_entry, target=False) return self.as_name_node(env, ubcm_entry, target=False)
elif type.is_enum:
if self.attribute in type.values:
return self.as_name_node(env, env.lookup(self.attribute), target=False)
else:
error(self.pos, "%s not a known value of %s" % (self.attribute, type))
return None return None
def analyse_as_type(self, env): def analyse_as_type(self, env):
......
...@@ -1496,11 +1496,15 @@ class CEnumDefNode(StatNode): ...@@ -1496,11 +1496,15 @@ class CEnumDefNode(StatNode):
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
for item in self.items: for item in self.items:
item.analyse_declarations(env, self.entry) item.analyse_declarations(env, self.entry)
if self.name is not None:
self.entry.type.values = set(
(item.name) for item in self.items)
if self.create_wrapper and self.name is not None: if self.create_wrapper and self.name is not None:
from .UtilityCode import CythonUtilityCode from .UtilityCode import CythonUtilityCode
env.use_utility_code(CythonUtilityCode.load( env.use_utility_code(CythonUtilityCode.load(
"EnumType", "CpdefEnums.pyx", "EnumType", "CpdefEnums.pyx",
context={"name": self.name, "items": tuple(item.name for item in self.items)}, context={"name": self.name,
"items": tuple(item.name for item in self.items)},
outer_module_scope=env.global_scope())) outer_module_scope=env.global_scope()))
def analyse_expressions(self, env): def analyse_expressions(self, env):
......
...@@ -2098,7 +2098,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin, ...@@ -2098,7 +2098,7 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
entry=type_entry, entry=type_entry,
type=type_entry.type), type=type_entry.type),
attribute=attr_name, attribute=attr_name,
is_called=True).analyse_as_unbound_cmethod_node(self.current_env()) is_called=True).analyse_as_type_attribute(self.current_env())
if method is None: if method is None:
return node return node
args = node.args args = node.args
......
...@@ -58,3 +58,19 @@ cpdef enum PyxEnum: ...@@ -58,3 +58,19 @@ cpdef enum PyxEnum:
cdef enum SecretPyxEnum: cdef enum SecretPyxEnum:
SEVEN = 7 SEVEN = 7
def test_as_variable_from_cython():
"""
>>> test_as_variable_from_cython()
"""
assert list(PyxEnum) == [TWO, THREE, FIVE]
assert list(PxdEnum) == [RANK_0, RANK_1, RANK_2]
cdef int verify_pure_c() nogil:
cdef int x = TWO
cdef int y = PyxEnum.THREE
cdef int z = SecretPyxEnum.SEVEN
return x + y + z
# Use it to suppress warning.
verify_pure_c()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment