Commit e884efb9 authored by Stefan Behnel's avatar Stefan Behnel

prevent absolute cimports from trying relative imports

parent 6528682e
...@@ -7087,6 +7087,7 @@ class EnsureGILNode(GILExitNode): ...@@ -7087,6 +7087,7 @@ class EnsureGILNode(GILExitNode):
def generate_execution_code(self, code): def generate_execution_code(self, code):
code.put_ensure_gil(declare_gilstate=False) code.put_ensure_gil(declare_gilstate=False)
utility_code_for_cimports = { utility_code_for_cimports = {
# utility code (or inlining c) in a pxd (or pyx) file. # utility code (or inlining c) in a pxd (or pyx) file.
# TODO: Consider a generic user-level mechanism for importing # TODO: Consider a generic user-level mechanism for importing
...@@ -7094,19 +7095,23 @@ utility_code_for_cimports = { ...@@ -7094,19 +7095,23 @@ utility_code_for_cimports = {
'cpython.array.array' : ("ArrayAPI", "arrayarray.h"), 'cpython.array.array' : ("ArrayAPI", "arrayarray.h"),
} }
class CImportStatNode(StatNode): class CImportStatNode(StatNode):
# cimport statement # cimport statement
# #
# module_name string Qualified name of module being imported # module_name string Qualified name of module being imported
# as_name string or None Name specified in "as" clause, if any # as_name string or None Name specified in "as" clause, if any
# is_absolute bool True for absolute imports, False otherwise
child_attrs = [] child_attrs = []
is_absolute = False
def analyse_declarations(self, env): def analyse_declarations(self, env):
if not env.is_module_scope: if not env.is_module_scope:
error(self.pos, "cimport only allowed at module level") error(self.pos, "cimport only allowed at module level")
return return
module_scope = env.find_module(self.module_name, self.pos) module_scope = env.find_module(
self.module_name, self.pos, relative_level=0 if self.is_absolute else -1)
if "." in self.module_name: if "." in self.module_name:
names = [EncodedString(name) for name in self.module_name.split(".")] names = [EncodedString(name) for name in self.module_name.split(".")]
top_name = names[0] top_name = names[0]
......
...@@ -1279,38 +1279,43 @@ def p_raise_statement(s): ...@@ -1279,38 +1279,43 @@ def p_raise_statement(s):
else: else:
return Nodes.ReraiseStatNode(pos) return Nodes.ReraiseStatNode(pos)
def p_import_statement(s): def p_import_statement(s):
# s.sy in ('import', 'cimport') # s.sy in ('import', 'cimport')
pos = s.position() pos = s.position()
kind = s.sy kind = s.sy
s.next() s.next()
items = [p_dotted_name(s, as_allowed = 1)] items = [p_dotted_name(s, as_allowed=1)]
while s.sy == ',': while s.sy == ',':
s.next() s.next()
items.append(p_dotted_name(s, as_allowed = 1)) items.append(p_dotted_name(s, as_allowed=1))
stats = [] stats = []
is_absolute = Future.absolute_import in s.context.future_directives
for pos, target_name, dotted_name, as_name in items: for pos, target_name, dotted_name, as_name in items:
dotted_name = EncodedString(dotted_name) dotted_name = EncodedString(dotted_name)
if kind == 'cimport': if kind == 'cimport':
stat = Nodes.CImportStatNode(pos, stat = Nodes.CImportStatNode(
module_name = dotted_name, pos,
as_name = as_name) module_name=dotted_name,
as_name=as_name,
is_absolute=is_absolute)
else: else:
if as_name and "." in dotted_name: if as_name and "." in dotted_name:
name_list = ExprNodes.ListNode(pos, args = [ name_list = ExprNodes.ListNode(pos, args=[
ExprNodes.IdentifierStringNode(pos, value = EncodedString("*"))]) ExprNodes.IdentifierStringNode(pos, value=EncodedString("*"))])
else: else:
name_list = None name_list = None
stat = Nodes.SingleAssignmentNode(pos, stat = Nodes.SingleAssignmentNode(
lhs = ExprNodes.NameNode(pos, pos,
name = as_name or target_name), lhs=ExprNodes.NameNode(pos, name=as_name or target_name),
rhs = ExprNodes.ImportNode(pos, rhs=ExprNodes.ImportNode(
module_name = ExprNodes.IdentifierStringNode( pos,
pos, value = dotted_name), module_name=ExprNodes.IdentifierStringNode(pos, value=dotted_name),
level = None, level=0 if is_absolute else None,
name_list = name_list)) name_list=name_list))
stats.append(stat) stats.append(stat)
return Nodes.StatListNode(pos, stats = stats) return Nodes.StatListNode(pos, stats=stats)
def p_from_import_statement(s, first_statement = 0): def p_from_import_statement(s, first_statement = 0):
# s.sy == 'from' # s.sy == 'from'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment