Commit ac9ebdda authored by Robert Bradshaw's avatar Robert Bradshaw

Fix cpdef enums cimported across modules.

Closes #531.
parent af2c3655
...@@ -123,6 +123,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode): ...@@ -123,6 +123,10 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
env = self.scope env = self.scope
if env.has_import_star: if env.has_import_star:
self.create_import_star_conversion_utility_code(env) self.create_import_star_conversion_utility_code(env)
for name, entry in sorted(env.entries.items()):
if (entry.create_wrapper and entry.scope is env
and entry.is_type and entry.type.is_enum):
entry.type.create_type_wrapper(env)
def process_implementation(self, options, result): def process_implementation(self, options, result):
env = self.scope env = self.scope
......
...@@ -1497,7 +1497,7 @@ class CEnumDefNode(StatNode): ...@@ -1497,7 +1497,7 @@ class CEnumDefNode(StatNode):
self.name, self.pos, self.name, self.pos,
cname=self.cname, typedef_flag=self.typedef_flag, cname=self.cname, typedef_flag=self.typedef_flag,
visibility=self.visibility, api=self.api, visibility=self.visibility, api=self.api,
create_wrapper=self.create_wrapper and self.name is None) create_wrapper=self.create_wrapper)
def analyse_declarations(self, env): def analyse_declarations(self, env):
if self.items is not None: if self.items is not None:
...@@ -1505,15 +1505,6 @@ class CEnumDefNode(StatNode): ...@@ -1505,15 +1505,6 @@ class CEnumDefNode(StatNode):
self.entry.defined_in_pxd = 1 self.entry.defined_in_pxd = 1
for item in self.items: for item in self.items:
item.analyse_declarations(env, self.entry) item.analyse_declarations(env, self.entry)
if self.name is not None:
self.entry.type.values = set(item.name for item in self.items)
if self.create_wrapper and self.name is not None:
from .UtilityCode import CythonUtilityCode
env.use_utility_code(CythonUtilityCode.load(
"EnumType", "CpdefEnums.pyx",
context={"name": self.name,
"items": tuple(item.name for item in self.items)},
outer_module_scope=env.global_scope()))
def analyse_expressions(self, env): def analyse_expressions(self, env):
return self return self
...@@ -1557,7 +1548,7 @@ class CEnumDefItemNode(StatNode): ...@@ -1557,7 +1548,7 @@ class CEnumDefItemNode(StatNode):
create_wrapper=enum_entry.create_wrapper and enum_entry.name is None) create_wrapper=enum_entry.create_wrapper and enum_entry.name is None)
enum_entry.enum_values.append(entry) enum_entry.enum_values.append(entry)
if enum_entry.name: if enum_entry.name:
enum_entry.type.values.append(entry.cname) enum_entry.type.values.append(entry.name)
class CTypeDefNode(StatNode): class CTypeDefNode(StatNode):
......
...@@ -3597,6 +3597,7 @@ class CEnumType(CType): ...@@ -3597,6 +3597,7 @@ class CEnumType(CType):
# name string # name string
# cname string or None # cname string or None
# typedef_flag boolean # typedef_flag boolean
# values [string], populated during declaration analysis
is_enum = 1 is_enum = 1
signed = 1 signed = 1
...@@ -3654,6 +3655,14 @@ class CEnumType(CType): ...@@ -3654,6 +3655,14 @@ class CEnumType(CType):
typecast(self, c_long_type, rhs), typecast(self, c_long_type, rhs),
' %s' % code.error_goto_if(error_condition or self.error_condition(result_code), error_pos)) ' %s' % code.error_goto_if(error_condition or self.error_condition(result_code), error_pos))
def create_type_wrapper(self, env):
from .UtilityCode import CythonUtilityCode
env.use_utility_code(CythonUtilityCode.load(
"EnumType", "CpdefEnums.pyx",
context={"name": self.name,
"items": tuple(self.values)},
outer_module_scope=env.global_scope()))
class CTupleType(CType): class CTupleType(CType):
# components [PyrexType] # components [PyrexType]
......
...@@ -23,12 +23,18 @@ cpdef foo(): pass ...@@ -23,12 +23,18 @@ cpdef foo(): pass
cpdef enum: cpdef enum:
FOO FOO
cpdef enum NamedEnumType:
NamedEnumValue = 389
cpdef foo() cpdef foo()
######## no_enums.pyx ######## ######## no_enums.pyx ########
from enums cimport * from enums cimport *
def get_named_enum_value():
return NamedEnumType.NamedEnumValue
######## import_enums_test.py ######## ######## import_enums_test.py ########
# We can import enums with a star import. # We can import enums with a star import.
...@@ -36,6 +42,7 @@ from enums import * ...@@ -36,6 +42,7 @@ from enums import *
print(dir()) print(dir())
assert 'BAR' in dir() and 'FOO' in dir() assert 'BAR' in dir() and 'FOO' in dir()
assert 'NamedEnumType' in dir()
# enums not generated in the wrong module # enums not generated in the wrong module
import no_enums import no_enums
...@@ -43,3 +50,4 @@ print(dir(no_enums)) ...@@ -43,3 +50,4 @@ print(dir(no_enums))
assert 'FOO' not in dir(no_enums) assert 'FOO' not in dir(no_enums)
assert 'foo' not in dir(no_enums) assert 'foo' not in dir(no_enums)
assert no_enums.get_named_enum_value() == NamedEnumType.NamedEnumValue
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment