Commit 1788299e authored by Dag Sverre Seljebotn's avatar Dag Sverre Seljebotn

Buffer parsing complete; small transform factorizations and renaming of PostParse

parent 52c259aa
...@@ -32,6 +32,7 @@ class CompileError(PyrexError): ...@@ -32,6 +32,7 @@ class CompileError(PyrexError):
def __init__(self, position = None, message = ""): def __init__(self, position = None, message = ""):
self.position = position self.position = position
self.message_only = message
# Deprecated and withdrawn in 2.6: # Deprecated and withdrawn in 2.6:
# self.message = message # self.message = message
if position: if position:
......
...@@ -334,17 +334,18 @@ def create_generate_code(context, options, result): ...@@ -334,17 +334,18 @@ def create_generate_code(context, options, result):
return generate_code return generate_code
def create_default_pipeline(context, options, result): def create_default_pipeline(context, options, result):
from ParseTreeTransforms import WithTransform, PostParse from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ModuleNode import check_c_classes from ModuleNode import check_c_classes
return [ return [
create_parse(context), create_parse(context),
PostParse(), NormalizeTree(context),
WithTransform(), PostParse(context),
AnalyseDeclarationsTransform(), WithTransform(context),
AnalyseDeclarationsTransform(context),
check_c_classes, check_c_classes,
AnalyseExpressionsTransform(), AnalyseExpressionsTransform(context),
create_generate_code(context, options, result) create_generate_code(context, options, result)
] ]
......
...@@ -565,7 +565,6 @@ class CBaseTypeNode(Node): ...@@ -565,7 +565,6 @@ class CBaseTypeNode(Node):
pass pass
class CSimpleBaseTypeNode(CBaseTypeNode): class CSimpleBaseTypeNode(CBaseTypeNode):
# name string # name string
# module_path [string] Qualifying name components # module_path [string] Qualifying name components
......
from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
from sets import Set as set
class PostParse(VisitorTransform): class NormalizeTree(CythonTransform):
""" """
This transform fixes up a few things after parsing This transform fixes up a few things after parsing
in order to make the parse tree more suitable for in order to make the parse tree more suitable for
...@@ -25,15 +28,11 @@ class PostParse(VisitorTransform): ...@@ -25,15 +28,11 @@ class PostParse(VisitorTransform):
StatListNode has no children to see if the block is empty). StatListNode has no children to see if the block is empty).
""" """
def __init__(self): def __init__(self, context):
super(PostParse, self).__init__() super(NormalizeTree, self).__init__(context)
self.is_in_statlist = False self.is_in_statlist = False
self.is_in_expr = False self.is_in_expr = False
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_ExprNode(self, node): def visit_ExprNode(self, node):
stacktmp = self.is_in_expr stacktmp = self.is_in_expr
self.is_in_expr = True self.is_in_expr = True
...@@ -73,7 +72,80 @@ class PostParse(VisitorTransform): ...@@ -73,7 +72,80 @@ class PostParse(VisitorTransform):
return self.visit_StatNode(node, True) return self.visit_StatNode(node, True)
class WithTransform(VisitorTransform): class PostParseError(CompileError): pass
# error strings checked by unit tests, so define them
ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
ERR_BUF_INT = '"%s" must be an integer'
ERR_BUF_NONNEG = '"%s" must be non-negative'
class PostParse(CythonTransform):
"""
Basic interpretation of the parse tree, as well as validity
checking that can be done on a very basic level on the parse
tree (while still not being a problem with the basic syntax,
as such).
Specifically:
- CBufferAccessTypeNode has its options interpreted:
Any first positional argument goes into the "dtype" attribute,
any "ndim" keyword argument goes into the "ndim" attribute and
so on. Also it is checked that the option combination is valid.
Note: Currently Parsing.py does a lot of interpretation and
reorganization that can be refactored into this transform
if a more pure Abstract Syntax Tree is wanted.
"""
buffer_options = ("dtype", "ndim") # ordered!
def visit_CBufferAccessTypeNode(self, node):
options = {}
# Fetch positional arguments
if len(node.positional_args) > len(self.buffer_options):
self.context.error(ERR_BUF_TOO_MANY)
for arg, unicode_name in zip(node.positional_args, self.buffer_options):
name = str(unicode_name)
options[name] = arg
# Fetch named arguments
for item in node.keyword_args.key_value_pairs:
name = str(item.key.value)
if not name in self.buffer_options:
raise PostParseError(item.key.pos,
ERR_BUF_UNKNOWN % name)
if name in options.keys():
raise PostParseError(item.key.pos,
ERR_BUF_DUP % key)
options[name] = item.value
provided = options.keys()
# get dtype
dtype = options.get("dtype")
if dtype is None: raise PostParseError(node.pos, ERR_BUF_MISSING % 'dtype')
node.dtype = dtype
# get ndim
if "ndim" in provided:
ndimnode = options["ndim"]
if not isinstance(ndimnode, IntNode):
# Compile-time values (DEF) are currently resolved by the parser,
# so nothing more to do here
raise PostParseError(ndimnode.pos, ERR_BUF_INT % 'ndim')
ndim_value = int(ndimnode.value)
if ndim_value < 0:
raise PostParseError(ndimnode.pos, ERR_BUF_NONNEG % 'ndim')
node.ndim = int(ndimnode.value)
# We're done with the parse tree args
node.positional_args = None
node.keyword_args = None
return node
class WithTransform(CythonTransform):
# EXCINFO is manually set to a variable that contains # EXCINFO is manually set to a variable that contains
# the exc_info() tuple that can be generated by the enclosing except # the exc_info() tuple that can be generated by the enclosing except
...@@ -94,7 +166,7 @@ class WithTransform(VisitorTransform): ...@@ -94,7 +166,7 @@ class WithTransform(VisitorTransform):
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"], """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
pipeline=[PostParse()]) pipeline=[NormalizeTree(None)])
template_with_target = TreeFragment(u""" template_with_target = TreeFragment(u"""
MGR = EXPR MGR = EXPR
...@@ -113,11 +185,7 @@ class WithTransform(VisitorTransform): ...@@ -113,11 +185,7 @@ class WithTransform(VisitorTransform):
if EXC: if EXC:
EXIT(None, None, None) EXIT(None, None, None)
""", temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"], """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
pipeline=[PostParse()]) pipeline=[NormalizeTree(None)])
def visit_Node(self, node):
self.visitchildren(node)
return node
def visit_WithStatNode(self, node): def visit_WithStatNode(self, node):
excinfo_name = temp_name_handle('EXCINFO') excinfo_name = temp_name_handle('EXCINFO')
...@@ -143,7 +211,7 @@ class WithTransform(VisitorTransform): ...@@ -143,7 +211,7 @@ class WithTransform(VisitorTransform):
return result.stats return result.stats
class AnalyseDeclarationsTransform(VisitorTransform): class AnalyseDeclarationsTransform(CythonTransform):
def __call__(self, root): def __call__(self, root):
self.env_stack = [root.scope] self.env_stack = [root.scope]
...@@ -164,12 +232,7 @@ class AnalyseDeclarationsTransform(VisitorTransform): ...@@ -164,12 +232,7 @@ class AnalyseDeclarationsTransform(VisitorTransform):
self.env_stack.pop() self.env_stack.pop()
return node return node
def visit_Node(self, node): class AnalyseExpressionsTransform(CythonTransform):
self.visitchildren(node)
return node
class AnalyseExpressionsTransform(VisitorTransform):
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
node.body.analyse_expressions(node.scope) node.body.analyse_expressions(node.scope)
...@@ -181,7 +244,4 @@ class AnalyseExpressionsTransform(VisitorTransform): ...@@ -181,7 +244,4 @@ class AnalyseExpressionsTransform(VisitorTransform):
self.visitchildren(node) self.visitchildren(node)
return node return node
def visit_Node(self, node):
self.visitchildren(node)
return node
...@@ -281,20 +281,18 @@ def p_trailer(s, node1): ...@@ -281,20 +281,18 @@ def p_trailer(s, node1):
return ExprNodes.AttributeNode(pos, return ExprNodes.AttributeNode(pos,
obj = node1, attribute = name) obj = node1, attribute = name)
def p_positional_and_keyword_callargs(s, end_sy_set): # arglist: argument (',' argument)* [',']
""" # argument: [test '='] test # Really [keyword '='] test
Parses positional and keyword call arguments. end_sy_set
should contain any s.sy that terminate the argument chain
(this is ('*', '**', ')') for a normal function call,
and (']',) for buffers declarators).
Returns: (positional_args, keyword_args) def p_call(s, function):
""" # s.sy == '('
pos = s.position()
s.next()
positional_args = [] positional_args = []
keyword_args = [] keyword_args = []
while s.sy not in end_sy_set: star_arg = None
if s.sy == '*' or s.sy == '**': starstar_arg = None
s.error('Argument expansion not allowed here.') while s.sy not in ('*', '**', ')'):
arg = p_simple_expr(s) arg = p_simple_expr(s)
if s.sy == '=': if s.sy == '=':
s.next() s.next()
...@@ -314,20 +312,6 @@ def p_positional_and_keyword_callargs(s, end_sy_set): ...@@ -314,20 +312,6 @@ def p_positional_and_keyword_callargs(s, end_sy_set):
if s.sy != ',': if s.sy != ',':
break break
s.next() s.next()
return positional_args, keyword_args
# arglist: argument (',' argument)* [',']
# argument: [test '='] test # Really [keyword '='] test
def p_call(s, function):
# s.sy == '('
pos = s.position()
s.next()
star_arg = None
starstar_arg = None
positional_args, keyword_args = (
p_positional_and_keyword_callargs(s,('*', '**', ')')))
if s.sy == '*': if s.sy == '*':
s.next() s.next()
...@@ -1473,6 +1457,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0): ...@@ -1473,6 +1457,71 @@ def p_suite(s, ctx = Ctx(), with_doc = 0, with_pseudo_doc = 0):
else: else:
return body return body
def p_positional_and_keyword_args(s, end_sy_set, type_positions=(), type_keywords=()):
"""
Parses positional and keyword arguments. end_sy_set
should contain any s.sy that terminate the argument list.
Argument expansion (* and **) are not allowed.
type_positions and type_keywords specifies which argument
positions and/or names which should be interpreted as
types. Other arguments will be treated as expressions.
Returns: (positional_args, keyword_args)
"""
positional_args = []
keyword_args = []
pos_idx = 0
while s.sy not in end_sy_set:
if s.sy == '*' or s.sy == '**':
s.error('Argument expansion not allowed here.')
was_keyword = False
parsed_type = False
if s.sy == 'IDENT':
# Since we can have either types or expressions as positional args,
# we use a strategy of looking an extra step forward for a '=' and
# if it is a positional arg we backtrack.
ident = s.systring
s.next()
if s.sy == '=':
s.next()
# Is keyword arg
if ident in type_keywords:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
keyword_node = ExprNodes.IdentifierStringNode(arg.pos,
value = Utils.EncodedString(ident))
keyword_args.append((keyword_node, arg))
was_keyword = True
else:
s.put_back('IDENT', ident)
if not was_keyword:
if pos_idx in type_positions:
arg = p_c_base_type(s)
parsed_type = True
else:
arg = p_simple_expr(s)
positional_args.append(arg)
pos_idx += 1
if len(keyword_args) > 0:
s.error("Non-keyword arg following keyword arg",
pos = arg.pos)
if s.sy != ',':
if s.sy not in end_sy_set:
if parsed_type:
s.error("Expected: type")
else:
s.error("Expected: expression")
break
s.next()
return positional_args, keyword_args
def p_c_base_type(s, self_flag = 0, nonempty = 0): def p_c_base_type(s, self_flag = 0, nonempty = 0):
# If self_flag is true, this is the base type for the # If self_flag is true, this is the base type for the
# self argument of a C method of an extension type. # self argument of a C method of an extension type.
...@@ -1556,24 +1605,32 @@ def p_c_simple_base_type(s, self_flag, nonempty): ...@@ -1556,24 +1605,32 @@ def p_c_simple_base_type(s, self_flag, nonempty):
if s.sy == '[': if s.sy == '[':
if is_basic: if is_basic:
p.error("Basic C types do not support buffer access") p.error("Basic C types do not support buffer access")
s.next() return p_buffer_access(s, type_node)
positional_args, keyword_args = (
p_positional_and_keyword_callargs(s, ('[]',)))
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
return Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = type_node)
else: else:
return type_node return type_node
def p_buffer_access(s, type_node):
# s.sy == '['
pos = s.position()
s.next()
positional_args, keyword_args = (
p_positional_and_keyword_args(s, (']',), (0,), ('dtype',))
)
s.expect(']')
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = [
ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args
])
result = Nodes.CBufferAccessTypeNode(pos,
positional_args = positional_args,
keyword_args = keyword_dict,
base_type_node = type_node)
return result
def looking_at_type(s): def looking_at_type(s):
return looking_at_base_type(s) or s.looking_at_type_name() return looking_at_base_type(s) or s.looking_at_type_name()
......
...@@ -2,7 +2,7 @@ from Cython.TestUtils import TransformTest ...@@ -2,7 +2,7 @@ from Cython.TestUtils import TransformTest
from Cython.Compiler.ParseTreeTransforms import * from Cython.Compiler.ParseTreeTransforms import *
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
class TestPostParse(TransformTest): class TestNormalizeTree(TransformTest):
def test_parserbehaviour_is_what_we_coded_for(self): def test_parserbehaviour_is_what_we_coded_for(self):
t = self.fragment(u"if x: y").root t = self.fragment(u"if x: y").root
self.assertLines(u""" self.assertLines(u"""
...@@ -15,7 +15,7 @@ class TestPostParse(TransformTest): ...@@ -15,7 +15,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_singlestat(self): def test_wrap_singlestat(self):
t = self.run_pipeline([PostParse()], u"if x: y") t = self.run_pipeline([NormalizeTree(None)], u"if x: y")
self.assertLines(u""" self.assertLines(u"""
(root): StatListNode (root): StatListNode
stats[0]: IfStatNode stats[0]: IfStatNode
...@@ -27,7 +27,7 @@ class TestPostParse(TransformTest): ...@@ -27,7 +27,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_multistat(self): def test_wrap_multistat(self):
t = self.run_pipeline([PostParse()], u""" t = self.run_pipeline([NormalizeTree(None)], u"""
if z: if z:
x x
y y
...@@ -45,7 +45,7 @@ class TestPostParse(TransformTest): ...@@ -45,7 +45,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_statinexpr(self): def test_statinexpr(self):
t = self.run_pipeline([PostParse()], u""" t = self.run_pipeline([NormalizeTree(None)], u"""
a, b = x, y a, b = x, y
""") """)
self.assertLines(u""" self.assertLines(u"""
...@@ -60,7 +60,7 @@ class TestPostParse(TransformTest): ...@@ -60,7 +60,7 @@ class TestPostParse(TransformTest):
""", self.treetypes(t)) """, self.treetypes(t))
def test_wrap_offagain(self): def test_wrap_offagain(self):
t = self.run_pipeline([PostParse()], u""" t = self.run_pipeline([NormalizeTree(None)], u"""
x x
y y
if z: if z:
...@@ -82,13 +82,13 @@ class TestPostParse(TransformTest): ...@@ -82,13 +82,13 @@ class TestPostParse(TransformTest):
def test_pass_eliminated(self): def test_pass_eliminated(self):
t = self.run_pipeline([PostParse()], u"pass") t = self.run_pipeline([NormalizeTree(None)], u"pass")
self.assert_(len(t.stats) == 0) self.assert_(len(t.stats) == 0)
class TestWithTransform(TransformTest): class TestWithTransform(TransformTest):
def test_simplified(self): def test_simplified(self):
t = self.run_pipeline([WithTransform()], u""" t = self.run_pipeline([WithTransform(None)], u"""
with x: with x:
y = z ** 3 y = z ** 3
""") """)
...@@ -113,7 +113,7 @@ class TestWithTransform(TransformTest): ...@@ -113,7 +113,7 @@ class TestWithTransform(TransformTest):
""", t) """, t)
def test_basic(self): def test_basic(self):
t = self.run_pipeline([WithTransform()], u""" t = self.run_pipeline([WithTransform(None)], u"""
with x as y: with x as y:
y = z ** 3 y = z ** 3
""") """)
......
...@@ -131,7 +131,6 @@ class VisitorTransform(TreeVisitor): ...@@ -131,7 +131,6 @@ class VisitorTransform(TreeVisitor):
was not, an exception will be raised. (Typically you want to ensure that you was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.) are within a StatListNode or similar before doing this.)
""" """
def visitchildren(self, parent, attrs=None): def visitchildren(self, parent, attrs=None):
result = super(VisitorTransform, self).visitchildren(parent, attrs) result = super(VisitorTransform, self).visitchildren(parent, attrs)
for attr, newnode in result.iteritems(): for attr, newnode in result.iteritems():
...@@ -152,6 +151,19 @@ class VisitorTransform(TreeVisitor): ...@@ -152,6 +151,19 @@ class VisitorTransform(TreeVisitor):
def __call__(self, root): def __call__(self, root):
return self.visit(root) return self.visit(root)
class CythonTransform(VisitorTransform):
"""
Certain common conventions and utilitues for Cython transforms.
"""
def __init__(self, context):
super(CythonTransform, self).__init__()
self.context = context
def visit_Node(self, node):
self.visitchildren(node)
return node
# Utils # Utils
def ensure_statlist(node): def ensure_statlist(node):
if not isinstance(node, Nodes.StatListNode): if not isinstance(node, Nodes.StatListNode):
......
...@@ -50,12 +50,12 @@ class CythonTest(unittest.TestCase): ...@@ -50,12 +50,12 @@ class CythonTest(unittest.TestCase):
self.assertEqual(len(result_lines), len(expected_lines), self.assertEqual(len(result_lines), len(expected_lines),
"Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected)) "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(result_lines), expected))
def fragment(self, code, pxds={}): def fragment(self, code, pxds={}, pipeline=[]):
"Simply create a tree fragment using the name of the test-case in parse errors." "Simply create a tree fragment using the name of the test-case in parse errors."
name = self.id() name = self.id()
if name.startswith("__main__."): name = name[len("__main__."):] if name.startswith("__main__."): name = name[len("__main__."):]
name = name.replace(".", "_") name = name.replace(".", "_")
return TreeFragment(code, name, pxds) return TreeFragment(code, name, pxds, pipeline=pipeline)
def treetypes(self, root): def treetypes(self, root):
"""Returns a string representing the tree by class names. """Returns a string representing the tree by class names.
...@@ -66,6 +66,26 @@ class CythonTest(unittest.TestCase): ...@@ -66,6 +66,26 @@ class CythonTest(unittest.TestCase):
w.visit(root) w.visit(root)
return u"\n".join([u""] + w.result + [u""]) return u"\n".join([u""] + w.result + [u""])
def should_fail(self, func, exc_type=Exception):
"""Calls "func" and fails if it doesn't raise the right exception
(any exception by default). Also returns the exception in question.
"""
try:
func()
self.fail("Expected an exception of type %r" % exc_type)
except exc_type, e:
self.assert_(isinstance(e, exc_type))
return e
def should_not_fail(self, func):
"""Calls func and succeeds if and only if no exception is raised
(i.e. converts exception raising into a failed testcase). Returns
the return value of func."""
try:
return func()
except:
self.fail()
class TransformTest(CythonTest): class TransformTest(CythonTest):
""" """
Utility base class for transform unit tests. It is based around constructing Utility base class for transform unit tests. It is based around constructing
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment