Commit 1788299e authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffer parsing complete; small transform factorizations and renaming of PostParse

parent 52c259aa
......@@ -32,6 +32,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""):
self.position = position
self.message_only = message
# Deprecated and withdrawn in 2.6:
# self.message = message
if position:
......
......@@ -334,17 +334,18 @@ def create_generate_code(context, options, result):
return generate_code
def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, PostParse
from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ModuleNode import check_c_classes
return [
create_parse(context),
PostParse(),
WithTransform(),
AnalyseDeclarationsTransform(),
NormalizeTree(context),
PostParse(context),
WithTransform(context),
AnalyseDeclarationsTransform(context),
check_c_classes,
AnalyseExpressionsTransform(),
AnalyseExpressionsTransform(context),
create_generate_code(context, options, result)
]
......
......@@ -565,7 +565,6 @@ class CBaseTypeNode(Node):
pass
class CSimpleBaseTypeNode(CBaseTypeNode):
# name string
# module_path [string] Qualifying name components
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
from sets import Set as set
class PostParse(VisitorTransform):
class NormalizeTree(CythonTransform):
"""
This transform fixes up a few things after parsing
in order to make the parse tree more suitable for
......@@ -25,15 +28,11 @@ class PostParse(VisitorTransform):
StatListNode has no children to see if the block is empty).
"""
def __init__(self):
super(PostParse, self).__init__()
def __init__(self, context):
super(NormalizeTree, self).__init__(context)
self.is_in_statlist = False
self.is_in_expr = False
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_ExprNode(self, node):
stacktmp = self.is_in_expr
self.is_in_expr = True
......@@ -73,7 +72,80 @@ class PostParse(VisitorTransform):
return self.visit_StatNode(node, True)
class WithTransform(VisitorTransform):
class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
class PostParse(CythonTransform):
"""
Basic interpretation of the parse tree, as well as validity
checking that can be done on a very basic level on the parse
tree (while still not being a problem with the basic syntax,
as such).
Specifically:
- CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the option combination is valid.
Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform
if a more pure Abstract Syntax Tree is wanted.
"""
buffer_options = ("dtype", "ndim") # ordered!
def visit_CBufferAccessTypeNode(self, node):
options = {}
# Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name)
options[name] = arg
# Fetch named arguments
for item in node.keyword_args.key_value_pairs:
name = str(item.key.value)
if not name in self.buffer_options:
raise PostParseError(item.key.pos,
ERR_BUF_UNKNOWN % name)
if name in options.keys():
raise PostParseError(item.key.pos,
ERR_BUF_DUP % key)
options[name] = item.value
provided = options.keys()
# get dtype
dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype = dtype
# get ndim
if "ndim" in provided:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
# so nothing more to do here
raise PostParseError(ndimnode.pos, ERR_BUF_INT % 'ndim')
ndim_value = int(ndimnode.value)
if ndim_value < 0:
raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
node.ndim = int(ndimnode.value)
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
return node
class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains
# the exc_info() tuple that can be generated by the enclosing except
......@@ -94,7 +166,7 @@ class WithTransform(VisitorTransform):
if EXC:
EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
pipeline=[PostParse()])
pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u"""
MGR = EXPR
......@@ -113,11 +185,7 @@ class WithTransform(VisitorTransform):
if EXC:
EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
pipeline=[PostParse()])
def visit_Node(self, node):
self.visitchildren(node)
return node
pipeline=[NormalizeTree(None)])
def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO')
......@@ -143,7 +211,7 @@ class WithTransform(VisitorTransform):
return result.stats
class AnalyseDeclarationsTransform(VisitorTransform):
class AnalyseDeclarationsTransform(CythonTransform):
def __call__(self, root):
self.env_stack = [root.scope]
......@@ -164,12 +232,7 @@ class AnalyseDeclarationsTransform(VisitorTransform):
self.env_stack.pop()
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
class AnalyseExpressionsTransform(VisitorTransform):
class AnalyseExpressionsTransform(CythonTransform):
def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope)
......@@ -181,7 +244,4 @@ class AnalyseExpressionsTransform(VisitorTransform):
self.visitchildren(node)
return node
def visit_Node(self, node):
self.visitchildren(node)
return node
......@@ -281,20 +281,18 @@ def p_trailer(s, node1):
return ExprNodes.AttributeNode(pos,
obj = node1, attribute = name)
def p_positional_and_keyword_callargs(s, end_sy_set):
"""
Parses positional and keyword call arguments. end_sy_set
should contain any s.sy that terminate the argument chain
(this is ('*', '**', ')') for a normal function call,
and (']',) for buffers declarators).
# arglist: argument (',' argument)* [',']
# argument: [test '='] test # Really [keyword '='] test
Returns: (positional_args, keyword_args)
"""
def p_call(s, function):
# s.sy == '('
pos = s.position()
s.next()
positional_args = []
keyword_args = []
while s.sy not in end_sy_set:
if s.sy == '*' or s.sy == '**':
s.error('Argument expansion not allowed here.')
star_arg = None
starstar_arg = None
while s.sy not in ('*', '**', ')'):
arg = p_simple_expr(s)
if s.sy == '=':
s.next()
......@@ -314,20 +312,6 @@ def p_positional_and_keyword_callargs(s, end_sy_set):
if s.sy != ',':
break
s.next()
return positional_args, keyword_args
# arglist: argument (',' argument)* [',']
# argument: [test '='] test # Really [keyword '='] test
def p_call(s, function):
# s.sy == '('
pos = s.position()
s.next()
star_arg = None
starstar_arg = None
positional_args, keyword_args = (
p_positional_and_keyword_callargs(s,('*', '**', ')')))
if s.sy == '*':
s.next()
......@@ -1473,6 +1457,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0):
else:
return body
def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keywords=()):
"""
Parses positional and keyword arguments. end_sy_set
should contain any s.sy that terminate the argument list.
Argument expansion (* and **) are not allowed.
type_positions and type_keywords specifies which argument
positions and/or names which should be interpreted as
types. Other arguments will be treated as expressions.
Returns: (positional_args, keyword_args)
"""
positional_args = []
keyword_args = []
pos_idx = 0
while s.sy not in end_sy_set:
if s.sy == '*' or s.sy == '**':
s.error('Argument expansion not allowed here.')
was_keyword = False
parsed_type = False
if s.sy == 'IDENT':
# Since we can have either types or expressions as positional args,
# we use a strategy of looking an extra step forward for a '=' and
# if it is a positional arg we backtrack.
ident = s.systring
s.next()
if s.sy == '=':
s.next()
# Is keyword arg
if ident in type_keywords:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
value = Utils.EncodedString(ident))
keyword_args.append((keyword_node, arg))
was_keyword = True
else:
s.put_back('IDENT', ident)
if not was_keyword:
if pos_idx in type_positions:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
positional_args.append(arg)
pos_idx += 1
if len(keyword_args) > 0:
s.error("Non-keyword arg following keyword arg",
pos = arg.pos)
if s.sy != ',':
if s.sy not in end_sy_set:
if parsed_type:
s.error("Expected: type")
else:
s.error("Expected: expression")
break
s.next()
return positional_args, keyword_args
def p_c_base_type(s, self_flag = 0, nonempty = 0):
# If self_flag is true, this is the base type for the
# self argument of a C method of an extension type.
......@@ -1556,24 +1605,32 @@ def p_c_simple_base_type(s, self_flag, nonempty):
if s.sy == '[':
if is_basic:
p.error("Basic C types do not support buffer access")
s.next()
positional_args, keyword_args = (
p_positional_and_keyword_callargs(s, ('[]',)))
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
return Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = type_node)
return p_buffer_access(s, type_node)
else:
return type_node
def p_buffer_access(s, type_node):
# s.sy == '['
pos = s.position()
s.next()
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = type_node)
return result
def looking_at_type(s):
return looking_at_base_type(s) or s.looking_at_type_name()
......
......@@ -2,7 +2,7 @@ from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import *
class TestPostParse(TransformTest):
class TestNormalizeTree(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root
self.assertLines(u"""
......@@ -15,7 +15,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t))
def test_wrap_singlestat(self):
t = self.run_pipeline([PostParse()], u"if x: y")
t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
self.assertLines(u"""
(root): StatListNode
stats[0]: IfStatNode
......@@ -27,7 +27,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t))
def test_wrap_multistat(self):
t = self.run_pipeline([PostParse()], u"""
t = self.run_pipeline([NormalizeTree(None)], u"""
if z:
x
y
......@@ -45,7 +45,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t))
def test_statinexpr(self):
t = self.run_pipeline([PostParse()], u"""
t = self.run_pipeline([NormalizeTree(None)], u"""
a, b = x, y
""")
self.assertLines(u"""
......@@ -60,7 +60,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t))
def test_wrap_offagain(self):
t = self.run_pipeline([PostParse()], u"""
t = self.run_pipeline([NormalizeTree(None)], u"""
x
y
if z:
......@@ -82,13 +82,13 @@ class TestPostParse(TransformTest):
def test_pass_eliminated(self):
t = self.run_pipeline([PostParse()], u"pass")
t = self.run_pipeline([NormalizeTree(None)], u"pass")
self.assert_(len(t.stats) == 0)
class TestWithTransform(TransformTest):
def test_simplified(self):
t = self.run_pipeline([WithTransform()], u"""
t = self.run_pipeline([WithTransform(None)], u"""
with x:
y = z ** 3
""")
......@@ -113,7 +113,7 @@ class TestWithTransform(TransformTest):
""", t)
def test_basic(self):
t = self.run_pipeline([WithTransform()], u"""
t = self.run_pipeline([WithTransform(None)], u"""
with x as y:
y = z ** 3
""")
......
......@@ -131,7 +131,6 @@ class VisitorTransform(TreeVisitor):
was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.)
"""
def visitchildren(self, parent, attrs=None):
result = super(VisitorTransform, self).visitchildren(parent, attrs)
for attr, newnode in result.iteritems():
......@@ -152,6 +151,19 @@ class VisitorTransform(TreeVisitor):
def __call__(self, root):
return self.visit(root)
class CythonTransform(VisitorTransform):
"""
Certain common conventions and utilitues for Cython transforms.
"""
def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
def visit_Node(self, node):
self.visitchildren(node)
return node
# Utils
def ensure_statlist(node):
if not isinstance(node, Nodes.StatListNode):
......
......@@ -50,12 +50,12 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
def fragment(self, code, pxds={}):
def fragment(self, code, pxds={}, pipeline=[]):
"Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_")
return TreeFragment(code, name, pxds)
return TreeFragment(code, name, pxds, pipeline=pipeline)
def treetypes(self, root):
"""Returns a string representing the tree by class names.
......@@ -66,6 +66,26 @@ class CythonTest(unittest.TestCase):
w.visit(root)
return u"\n".join([u""] + w.result + [u""])
def should_fail(self, func, exc_type=Exception):
"""Calls "func" and fails if it doesn't raise the right exception
(any exception by default). Also returns the exception in question.
"""
try:
func()
self.fail("Expected an exception of type %r" % exc_type)
except exc_type, e:
self.assert_(isinstance(e, exc_type))
return e
def should_not_fail(self, func):
"""Calls func and succeeds if and only if no exception is raised
(i.e. converts exception raising into a failed testcase). Returns
the return value of func."""
try:
return func()
except:
self.fail()
class TransformTest(CythonTest):
"""
Utility base class for transform unit tests. It is based around constructing
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment