ParseTreeTransforms.py 39.3 KB
Newer Older
1
from Cython.Compiler.Visitor import VisitorTransform, CythonTransform, TreeVisitor
2
from Cython.Compiler.ModuleNode import ModuleNode
3
from Cython.Compiler.Nodes import *
4
from Cython.Compiler.ExprNodes import *
5
from Cython.Compiler.UtilNodes import *
6
from Cython.Compiler.TreeFragment import TreeFragment, TemplateTransform
7
from Cython.Compiler.StringEncoding import EncodedString
8
from Cython.Compiler.Errors import error, CompileError
Stefan Behnel's avatar
Stefan Behnel committed
9 10 11 12
try:
    set
except NameError:
    from sets import Set as set
13
import copy
14

15 16 17 18 19 20 21 22 23

class NameNodeCollector(TreeVisitor):
    """Collect all NameNodes of a (sub-)tree in the ``name_nodes``
    attribute.
    """
    def __init__(self):
        super(NameNodeCollector, self).__init__()
        self.name_nodes = []

24
    visit_Node = TreeVisitor.visitchildren
25 26 27 28 29

    def visit_NameNode(self, node):
        self.name_nodes.append(node)


30
class SkipDeclarations(object):
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
    """
    Variable and function declarations can often have a deep tree structure, 
    and yet most transformations don't need to descend to this depth. 
    
    Declaration nodes are removed after AnalyseDeclarationsTransform, so there 
    is no need to use this for transformations after that point. 
    """
    def visit_CTypeDefNode(self, node):
        return node
    
    def visit_CVarDefNode(self, node):
        return node
    
    def visit_CDeclaratorNode(self, node):
        return node
    
    def visit_CBaseTypeNode(self, node):
        return node
    
    def visit_CEnumDefNode(self, node):
        return node

    def visit_CStructOrUnionDefNode(self, node):
        return node


57
class NormalizeTree(CythonTransform):
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
    """
    This transform fixes up a few things after parsing
    in order to make the parse tree more suitable for
    transforms.

    a) After parsing, blocks with only one statement will
    be represented by that statement, not by a StatListNode.
    When doing transforms this is annoying and inconsistent,
    as one cannot in general remove a statement in a consistent
    way and so on. This transform wraps any single statements
    in a StatListNode containing a single statement.

    b) The PassStatNode is a noop and serves no purpose beyond
    plugging such one-statement blocks; i.e., once parsed a
`    "pass" can just as well be represented using an empty
    StatListNode. This means less special cases to worry about
    in subsequent transforms (one always checks to see if a
    StatListNode has no children to see if the block is empty).
    """

78 79
    def __init__(self, context):
        super(NormalizeTree, self).__init__(context)
80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114
        self.is_in_statlist = False
        self.is_in_expr = False

    def visit_ExprNode(self, node):
        stacktmp = self.is_in_expr
        self.is_in_expr = True
        self.visitchildren(node)
        self.is_in_expr = stacktmp
        return node

    def visit_StatNode(self, node, is_listcontainer=False):
        stacktmp = self.is_in_statlist
        self.is_in_statlist = is_listcontainer
        self.visitchildren(node)
        self.is_in_statlist = stacktmp
        if not self.is_in_statlist and not self.is_in_expr:
            return StatListNode(pos=node.pos, stats=[node])
        else:
            return node

    def visit_StatListNode(self, node):
        self.is_in_statlist = True
        self.visitchildren(node)
        self.is_in_statlist = False
        return node

    def visit_ParallelAssignmentNode(self, node):
        return self.visit_StatNode(node, True)
    
    def visit_CEnumDefNode(self, node):
        return self.visit_StatNode(node, True)

    def visit_CStructOrUnionDefNode(self, node):
        return self.visit_StatNode(node, True)

115 116 117 118 119 120 121
    # Eliminate PassStatNode
    def visit_PassStatNode(self, node):
        if not self.is_in_statlist:
            return StatListNode(pos=node.pos, stats=[])
        else:
            return []

122 123 124
    def visit_CDeclaratorNode(self, node):
        return node    

125

126 127 128
class PostParseError(CompileError): pass

# error strings checked by unit tests, so define them
129
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
130
ERR_BUF_LOCALONLY = 'Buffer types only allowed as function local variables'
131 132
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
133 134 135 136 137 138 139 140
class PostParse(CythonTransform):
    """
    Basic interpretation of the parse tree, as well as validity
    checking that can be done on a very basic level on the parse
    tree (while still not being a problem with the basic syntax,
    as such).

    Specifically:
141
    - Default values to cdef assignments are turned into single
142 143
    assignments following the declaration (everywhere but in class
    bodies, where they raise a compile error)
144 145 146 147 148 149 150 151 152 153 154 155
    
    - Interpret some node structures into Python runtime values.
    Some nodes take compile-time arguments (currently:
    CBufferAccessTypeNode[args] and __cythonbufferdefaults__ = {args}),
    which should be interpreted. This happens in a general way
    and other steps should be taken to ensure validity.

    Type arguments cannot be interpreted in this way.

    - For __cythonbufferdefaults__ the arguments are checked for
    validity.

156
    CBufferAccessTypeNode has its directives interpreted:
157 158
    Any first positional argument goes into the "dtype" attribute,
    any "ndim" keyword argument goes into the "ndim" attribute and
159
    so on. Also it is checked that the directive combination is valid.
160 161
    - __cythonbufferdefaults__ attributes are parsed and put into the
    type information.
162 163 164 165 166 167

    Note: Currently Parsing.py does a lot of interpretation and
    reorganization that can be refactored into this transform
    if a more pure Abstract Syntax Tree is wanted.
    """

168
    # Track our context.
169 170
    scope_type = None # can be either of 'module', 'function', 'class'

171 172 173 174 175 176
    def __init__(self, context):
        super(PostParse, self).__init__(context)
        self.specialattribute_handlers = {
            '__cythonbufferdefaults__' : self.handle_bufferdefaults
        }

177 178
    def visit_ModuleNode(self, node):
        self.scope_type = 'module'
179
        self.scope_node = node
180 181
        self.visitchildren(node)
        return node
182 183 184 185 186

    def visit_scope(self, node, scope_type):
        prev = self.scope_type, self.scope_node
        self.scope_type = scope_type
        self.scope_node = node
187
        self.visitchildren(node)
188
        self.scope_type, self.scope_node = prev
189
        return node
190 191 192
    
    def visit_ClassDefNode(self, node):
        return self.visit_scope(node, 'class')
193 194

    def visit_FuncDefNode(self, node):
195 196 197 198
        return self.visit_scope(node, 'function')

    def visit_CStructOrUnionDefNode(self, node):
        return self.visit_scope(node, 'struct')
199 200

    # cdef variables
201 202 203
    def handle_bufferdefaults(self, decl):
        if not isinstance(decl.default, DictNode):
            raise PostParseError(decl.pos, ERR_BUF_DEFAULTS)
204 205
        self.scope_node.buffer_defaults_node = decl.default
        self.scope_node.buffer_defaults_pos = decl.pos
206

207 208
    def visit_CVarDefNode(self, node):
        # This assumes only plain names and pointers are assignable on
209 210 211
        # declaration. Also, it makes use of the fact that a cdef decl
        # must appear before the first use, so we don't have to deal with
        # "i = 3; cdef int i = i" and can simply move the nodes around.
212 213
        try:
            self.visitchildren(node)
214 215 216 217 218 219 220 221
            stats = [node]
            newdecls = []
            for decl in node.declarators:
                declbase = decl
                while isinstance(declbase, CPtrDeclaratorNode):
                    declbase = declbase.base
                if isinstance(declbase, CNameDeclaratorNode):
                    if declbase.default is not None:
222 223
                        if self.scope_type in ('class', 'struct'):
                            if isinstance(self.scope_node, CClassDefNode):
224 225 226 227 228 229 230
                                handler = self.specialattribute_handlers.get(decl.name)
                                if handler:
                                    if decl is not declbase:
                                        raise PostParseError(decl.pos, ERR_INVALID_SPECIALATTR_TYPE)
                                    handler(decl)
                                    continue # Remove declaration
                            raise PostParseError(decl.pos, ERR_CDEF_INCLASS)
231
                        first_assignment = self.scope_type != 'module'
232 233
                        stats.append(SingleAssignmentNode(node.pos,
                            lhs=NameNode(node.pos, name=declbase.name),
234
                            rhs=declbase.default, first=first_assignment))
235 236 237 238
                        declbase.default = None
                newdecls.append(decl)
            node.declarators = newdecls
            return stats
239 240 241 242 243
        except PostParseError, e:
            # An error in a cdef clause is ok, simply remove the declaration
            # and try to move on to report more errors
            self.context.nonfatal_error(e)
            return None
244

245
    def visit_CBufferAccessTypeNode(self, node):
246 247
        if not self.scope_type == 'function':
            raise PostParseError(node.pos, ERR_BUF_LOCALONLY)
248 249
        return node

250
class PxdPostParse(CythonTransform, SkipDeclarations):
251 252 253
    """
    Basic interpretation/validity checking that should only be
    done on pxd trees.
254 255 256 257 258 259

    A lot of this checking currently happens in the parser; but
    what is listed below happens here.

    - "def" functions are let through only if they fill the
    getbuffer/releasebuffer slots
260 261 262
    
    - cdef functions are let through only if they are on the
    top level and are declared "inline"
263
    """
264 265
    ERR_INLINE_ONLY = "function definition in pxd file must be declared 'cdef inline'"
    ERR_NOGO_WITH_INLINE = "inline function definition in pxd file cannot be '%s'"
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280

    def __call__(self, node):
        self.scope_type = 'pxd'
        return super(PxdPostParse, self).__call__(node)

    def visit_CClassDefNode(self, node):
        old = self.scope_type
        self.scope_type = 'cclass'
        self.visitchildren(node)
        self.scope_type = old
        return node

    def visit_FuncDefNode(self, node):
        # FuncDefNode always come with an implementation (without
        # an imp they are CVarDefNodes..)
281
        err = self.ERR_INLINE_ONLY
282 283 284

        if (isinstance(node, DefNode) and self.scope_type == 'cclass'
            and node.name in ('__getbuffer__', '__releasebuffer__')):
285
            err = None # allow these slots
286 287
            
        if isinstance(node, CFuncDefNode):
288 289 290 291 292 293 294 295 296
            if u'inline' in node.modifiers and self.scope_type == 'pxd':
                node.inline_in_pxd = True
                if node.visibility != 'private':
                    err = self.ERR_NOGO_WITH_INLINE % node.visibility
                elif node.api:
                    err = self.ERR_NOGO_WITH_INLINE % 'api'
                else:
                    err = None # allow inline function
            else:
297 298
                err = self.ERR_INLINE_ONLY

299 300
        if err:
            self.context.nonfatal_error(PostParseError(node.pos, err))
301 302 303
            return None
        else:
            return node
304 305
    
class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
306
    """
307
    After parsing, directives can be stored in a number of places:
308 309
    - #cython-comments at the top of the file (stored in ModuleNode)
    - Command-line arguments overriding these
310 311
    - @cython.directivename decorators
    - with cython.directivename: statements
312

313
    This transform is responsible for interpreting these various sources
314
    and store the directive in two ways:
315 316 317 318 319 320 321 322 323 324 325
    - Set the directives attribute of the ModuleNode for global directives.
    - Use a CompilerDirectivesNode to override directives for a subtree.

    (The first one is primarily to not have to modify with the tree
    structure, so that ModuleNode stay on top.)

    The directives are stored in dictionaries from name to value in effect.
    Each such dictionary is always filled in for all possible directives,
    using default values where no value is given by the user.

    The available directives are controlled in Options.py.
326 327 328

    Note that we have to run this prior to analysis, and so some minor
    duplication of functionality has to occur: We manually track cimports
329
    and which names the "cython" module may have been imported to.
330
    """
331
    special_methods = set(['declare', 'union', 'struct', 'typedef', 'sizeof', 'typeof', 'cast', 'address', 'pointer', 'compiled', 'NULL'])
332

333
    def __init__(self, context, compilation_directive_defaults):
334
        super(InterpretCompilerDirectives, self).__init__(context)
335 336 337
        self.compilation_directive_defaults = {}
        for key, value in compilation_directive_defaults.iteritems():
            self.compilation_directive_defaults[unicode(key)] = value
338
        self.cython_module_names = set()
339
        self.directive_names = {}
340

341
    def check_directive_scope(self, pos, directive, scope):
342
        legal_scopes = Options.directive_scopes.get(directive, None)
343 344 345 346 347 348 349
        if legal_scopes and scope not in legal_scopes:
            self.context.nonfatal_error(PostParseError(pos, 'The %s compiler directive '
                                        'is not allowed in %s scope' % (directive, scope)))
            return False
        else:
            return True
        
350
    # Set up processing and handle the cython: comments.
351
    def visit_ModuleNode(self, node):
352
        for key, value in node.directive_comments.iteritems():
353 354
            if not self.check_directive_scope(node.pos, key, 'module'):
                self.wrong_scope_error(node.pos, key, 'module')
355 356 357 358
                del node.directive_comments[key]

        directives = copy.copy(Options.directive_defaults)
        directives.update(self.compilation_directive_defaults)
359 360 361
        directives.update(node.directive_comments)
        self.directives = directives
        node.directives = directives
362
        self.visitchildren(node)
363
        node.cython_module_names = self.cython_module_names
364 365
        return node

366 367 368 369 370 371 372 373
    # Track cimports of the cython module.
    def visit_CImportStatNode(self, node):
        if node.module_name == u"cython":
            if node.as_name:
                modname = node.as_name
            else:
                modname = u"cython"
            self.cython_module_names.add(modname)
374 375 376 377 378 379
        return node
    
    def visit_FromCImportStatNode(self, node):
        if node.module_name == u"cython":
            newimp = []
            for pos, name, as_name, kind in node.imported_names:
380
                if (name in Options.directive_types or 
381 382
                        name in self.special_methods or
                        PyrexTypes.parse_basic_type(name)):
Robert Bradshaw's avatar
Robert Bradshaw committed
383 384
                    if as_name is None:
                        as_name = name
385
                    self.directive_names[as_name] = name
386 387
                    if kind is not None:
                        self.context.nonfatal_error(PostParseError(pos,
388
                            "Compiler directive imports must be plain imports"))
389 390
                else:
                    newimp.append((pos, name, as_name, kind))
Robert Bradshaw's avatar
Robert Bradshaw committed
391 392 393
            if not newimp:
                return None
            node.imported_names = newimp
394
        return node
395
        
Robert Bradshaw's avatar
Robert Bradshaw committed
396 397 398
    def visit_FromImportStatNode(self, node):
        if node.module.module_name.value == u"cython":
            newimp = []
399
            for name, name_node in node.items:
400
                if (name in Options.directive_types or 
401 402
                        name in self.special_methods or
                        PyrexTypes.parse_basic_type(name)):
403
                    self.directive_names[name_node.name] = name
Robert Bradshaw's avatar
Robert Bradshaw committed
404
                else:
405
                    newimp.append((name, name_node))
Robert Bradshaw's avatar
Robert Bradshaw committed
406 407 408 409 410
            if not newimp:
                return None
            node.items = newimp
        return node

411 412 413
    def visit_SingleAssignmentNode(self, node):
        if (isinstance(node.rhs, ImportNode) and
                node.rhs.module_name.value == u'cython'):
414 415 416 417
            node = CImportStatNode(node.pos, 
                                   module_name = u'cython',
                                   as_name = node.lhs.name)
            self.visit_CImportStatNode(node)
418 419
        else:
            self.visitchildren(node)
420 421 422 423 424
        return node
            
    def visit_NameNode(self, node):
        if node.name in self.cython_module_names:
            node.is_cython_module = True
Robert Bradshaw's avatar
Robert Bradshaw committed
425
        else:
426
            node.cython_attribute = self.directive_names.get(node.name)
427
        return node
428

429 430 431
    def try_to_parse_directive(self, node):
        # If node is the contents of an directive (in a with statement or
        # decorator), returns (directivename, value).
432
        # Otherwise, returns None
433
        optname = None
434
        if isinstance(node, CallNode):
Robert Bradshaw's avatar
Robert Bradshaw committed
435
            self.visit(node.function)
436
            optname = node.function.as_cython_attribute()
437 438

        if optname:
439 440
            directivetype = Options.directive_types.get(optname)
            if directivetype:
441
                args, kwds = node.explicit_args_kwds()
442
                if directivetype is bool:
443
                    if kwds is not None or len(args) != 1 or not isinstance(args[0], BoolNode):
444
                        raise PostParseError(node.function.pos,
445
                            'The %s directive takes one compile-time boolean argument' % optname)
446
                    return (optname, args[0].value)
447
                elif directivetype is str:
448 449
                    if kwds is not None or len(args) != 1 or not isinstance(args[0], (StringNode, UnicodeNode)):
                        raise PostParseError(node.function.pos,
450
                            'The %s directive takes one compile-time string argument' % optname)
451
                    return (optname, str(args[0].value))
452
                elif directivetype is dict:
453
                    if len(args) != 0:
454
                        raise PostParseError(node.function.pos,
455
                            'The %s directive takes no prepositional arguments' % optname)
456
                    return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
457
                elif directivetype is list:
458
                    if kwds and len(kwds) != 0:
459
                        raise PostParseError(node.function.pos,
460
                            'The %s directive takes no keyword arguments' % optname)
461
                    return optname, [ str(arg.value) for arg in args ]
462 463
                else:
                    assert False
464 465

        return None
466

467 468 469 470 471
    def visit_with_directives(self, body, directives):
        olddirectives = self.directives
        newdirectives = copy.copy(olddirectives)
        newdirectives.update(directives)
        self.directives = newdirectives
472 473 474
        assert isinstance(body, StatListNode), body
        retbody = self.visit_Node(body)
        directive = CompilerDirectivesNode(pos=retbody.pos, body=retbody,
475 476
                                           directives=newdirectives)
        self.directives = olddirectives
477
        return directive
478 479
 
    # Handle decorators
480
    def visit_FuncDefNode(self, node):
481
        directives = []
482
        if node.decorators:
483
            # Split the decorators into two lists -- real decorators and directives
484 485
            realdecs = []
            for dec in node.decorators:
486 487 488
                directive = self.try_to_parse_directive(dec.decorator)
                if directive is not None:
                    directives.append(directive)
489 490
                else:
                    realdecs.append(dec)
491 492 493 494
            if realdecs and isinstance(node, CFuncDefNode):
                raise PostParseError(realdecs[0].pos, "Cdef functions cannot take arbitrary decorators.")
            else:
                node.decorators = realdecs
495
        
496
        if directives:
497
            optdict = {}
498 499 500 501
            directives.reverse() # Decorators coming first take precedence
            for directive in directives:
                name, value = directive
                legal_scopes = Options.directive_scopes.get(name, None)
502 503
                if not self.check_directive_scope(node.pos, name, 'function'):
                    continue
504 505 506 507 508 509 510 511 512 513
                if name in optdict:
                    old_value = optdict[name]
                    # keywords and arg lists can be merged, everything
                    # else overrides completely
                    if isinstance(old_value, dict):
                        old_value.update(value)
                    elif isinstance(old_value, list):
                        old_value.extend(value)
                    else:
                        optdict[name] = value
514 515
                else:
                    optdict[name] = value
516
            body = StatListNode(node.pos, stats=[node])
517
            return self.visit_with_directives(body, optdict)
518 519
        else:
            return self.visit_Node(node)
520 521 522 523
    
    def visit_CVarDefNode(self, node):
        if node.decorators:
            for dec in node.decorators:
524 525 526
                directive = self.try_to_parse_directive(dec.decorator)
                if directive is not None and directive[0] == u'locals':
                    node.directive_locals = directive[1]
527
                else:
528 529 530
                    self.context.nonfatal_error(PostParseError(dec.pos,
                        "Cdef functions can only take cython.locals() decorator."))
                    continue
531
        return node
532
                                   
533 534
    # Handle with statements
    def visit_WithStatNode(self, node):
535 536
        directive = self.try_to_parse_directive(node.manager)
        if directive is not None:
537
            if node.target is not None:
538
                self.context.nonfatal_error(
539
                    PostParseError(node.pos, "Compiler directive with statements cannot contain 'as'"))
540
            else:
541
                name, value = directive
542
                if self.check_directive_scope(node.pos, name, 'with statement'):
543
                    return self.visit_with_directives(node.body, {name:value})
544
        return self.visit_Node(node)
545

546
class WithTransform(CythonTransform, SkipDeclarations):
547

548 549 550
    # EXCINFO is manually set to a variable that contains
    # the exc_info() tuple that can be generated by the enclosing except
    # statement.
551 552 553 554 555 556 557
    template_without_target = TreeFragment(u"""
        MGR = EXPR
        EXIT = MGR.__exit__
        MGR.__enter__()
        EXC = True
        try:
            try:
558
                EXCINFO = None
559 560 561
                BODY
            except:
                EXC = False
562
                if not EXIT(*EXCINFO):
563 564 565 566
                    raise
        finally:
            if EXC:
                EXIT(None, None, None)
567
    """, temps=[u'MGR', u'EXC', u"EXIT"],
568
    pipeline=[NormalizeTree(None)])
569 570 571 572 573 574 575 576

    template_with_target = TreeFragment(u"""
        MGR = EXPR
        EXIT = MGR.__exit__
        VALUE = MGR.__enter__()
        EXC = True
        try:
            try:
577
                EXCINFO = None
578 579 580 581
                TARGET = VALUE
                BODY
            except:
                EXC = False
582
                if not EXIT(*EXCINFO):
583 584 585 586
                    raise
        finally:
            if EXC:
                EXIT(None, None, None)
Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
587 588
            MGR = EXIT = VALUE = EXC = None
            
589
    """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
590
    pipeline=[NormalizeTree(None)])
591 592

    def visit_WithStatNode(self, node):
593 594 595 596
        # TODO: Cleanup badly needed
        TemplateTransform.temp_name_counter += 1
        handle = "__tmpvar_%d" % TemplateTransform.temp_name_counter
        
597
        self.visitchildren(node, ['body'])
598
        excinfo_temp = NameNode(node.pos, name=handle)#TempHandle(Builtin.tuple_type)
599 600 601 602
        if node.target is not None:
            result = self.template_with_target.substitute({
                u'EXPR' : node.manager,
                u'BODY' : node.body,
603
                u'TARGET' : node.target,
604
                u'EXCINFO' : excinfo_temp
605
                }, pos=node.pos)
606 607 608 609
        else:
            result = self.template_without_target.substitute({
                u'EXPR' : node.manager,
                u'BODY' : node.body,
610
                u'EXCINFO' : excinfo_temp
611
                }, pos=node.pos)
612 613

        # Set except excinfo target to EXCINFO
Dag Sverre Seljebotn's avatar
Dag Sverre Seljebotn committed
614
        try_except = result.stats[-1].body.stats[-1]
615 616
        try_except.except_clauses[0].excinfo_target = NameNode(node.pos, name=handle)
#            excinfo_temp.ref(node.pos))
617

618 619
#        result.stats[-1].body.stats[-1] = TempsBlockNode(
#            node.pos, temps=[excinfo_temp], body=try_except)
620 621

        return result
622 623 624 625 626
        
    def visit_ExprNode(self, node):
        # With statements are never inside expressions.
        return node
        
627

628
class DecoratorTransform(CythonTransform, SkipDeclarations):
629

630
    def visit_DefNode(self, func_node):
631
        self.visitchildren(func_node)
632 633
        if not func_node.decorators:
            return func_node
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661
        return self._handle_decorators(
            func_node, func_node.name)

    def _visit_CClassDefNode(self, class_node):
        # This doesn't currently work, so it's disabled (also in the
        # parser).
        #
        # Problem: assignments to cdef class names do not work.  They
        # would require an additional check anyway, as the extension
        # type must not change its C type, so decorators cannot
        # replace an extension type, just alter it and return it.

        self.visitchildren(class_node)
        if not class_node.decorators:
            return class_node
        return self._handle_decorators(
            class_node, class_node.class_name)

    def visit_ClassDefNode(self, class_node):
        self.visitchildren(class_node)
        if not class_node.decorators:
            return class_node
        return self._handle_decorators(
            class_node, class_node.name)

    def _handle_decorators(self, node, name):
        decorator_result = NameNode(node.pos, name = name)
        for decorator in node.decorators[::-1]:
662 663 664 665 666
            decorator_result = SimpleCallNode(
                decorator.pos,
                function = decorator.decorator,
                args = [decorator_result])

667
        name_node = NameNode(node.pos, name = name)
668
        reassignment = SingleAssignmentNode(
669 670
            node.pos,
            lhs = name_node,
671
            rhs = decorator_result)
672
        return [node, reassignment]
673

674

675
class AnalyseDeclarationsTransform(CythonTransform):
676

677 678 679 680 681 682 683 684
    basic_property = TreeFragment(u"""
property NAME:
    def __get__(self):
        return ATTR
    def __set__(self, value):
        ATTR = value
    """, level='c_class')

685 686
    def __call__(self, root):
        self.env_stack = [root.scope]
687
        # needed to determine if a cdef var is declared after it's used.
688
        self.seen_vars_stack = []
689
        return super(AnalyseDeclarationsTransform, self).__call__(root)        
690
    
691
    def visit_NameNode(self, node):
692
        self.seen_vars_stack[-1].add(node.name)
693 694
        return node

695
    def visit_ModuleNode(self, node):
696
        self.seen_vars_stack.append(set())
697 698
        node.analyse_declarations(self.env_stack[-1])
        self.visitchildren(node)
699
        self.seen_vars_stack.pop()
700
        return node
701
        
702
    def visit_FuncDefNode(self, node):
703
        self.seen_vars_stack.append(set())
704 705 706
        lenv = node.create_local_scope(self.env_stack[-1])
        node.body.analyse_control_flow(lenv) # this will be totally refactored
        node.declare_arguments(lenv)
707 708 709 710 711 712 713
        for var, type_node in node.directive_locals.items():
            if not lenv.lookup_here(var):   # don't redeclare args
                type = type_node.analyse_as_type(lenv)
                if type:
                    lenv.declare_var(var, type, type_node.pos)
                else:
                    error(type_node.pos, "Not a type")
714 715 716 717
        node.body.analyse_declarations(lenv)
        self.env_stack.append(lenv)
        self.visitchildren(node)
        self.env_stack.pop()
718
        self.seen_vars_stack.pop()
719
        return node
720
    
721 722 723 724
    # Some nodes are no longer needed after declaration
    # analysis and can be dropped. The analysis was performed
    # on these nodes in a seperate recursive process from the
    # enclosing function or module, so we can simply drop them.
725
    def visit_CDeclaratorNode(self, node):
726 727
        # necessary to ensure that all CNameDeclaratorNodes are visited.
        self.visitchildren(node)
728 729 730 731 732 733 734 735 736
        return node
    
    def visit_CTypeDefNode(self, node):
        return node

    def visit_CBaseTypeNode(self, node):
        return None
    
    def visit_CEnumDefNode(self, node):
737 738 739 740
        if node.visibility == 'public':
            return node
        else:
            return None
741 742 743 744

    def visit_CStructOrUnionDefNode(self, node):
        return None

745
    def visit_CNameDeclaratorNode(self, node):
746 747 748
        if node.name in self.seen_vars_stack[-1]:
            entry = self.env_stack[-1].lookup(node.name)
            if entry is None or entry.visibility != 'extern':
749
                warning(node.pos, "cdef variable '%s' declared after it is used" % node.name, 2)
750 751 752
        self.visitchildren(node)
        return node

753
    def visit_CVarDefNode(self, node):
754 755 756 757

        # to ensure all CNameDeclaratorNodes are visited.
        self.visitchildren(node)

758 759 760 761 762 763
        if node.need_properties:
            # cdef public attributes may need type testing on 
            # assignment, so we create a property accesss
            # mechanism for them. 
            stats = []
            for entry in node.need_properties:
764
                property = self.create_Property(entry)
765 766 767 768 769 770
                property.analyse_declarations(node.dest_scope)
                self.visit(property)
                stats.append(property)
            return StatListNode(pos=node.pos, stats=stats)
        else:
            return None
771
            
772 773
    def create_Property(self, entry):
        template = self.basic_property
774
        property = template.substitute({
775 776 777 778 779 780
                u"ATTR": AttributeNode(pos=entry.pos,
                                       obj=NameNode(pos=entry.pos, name="self"), 
                                       attribute=entry.name),
            }, pos=entry.pos).stats[0]
        property.name = entry.name
        return property
781

782
class AnalyseExpressionsTransform(CythonTransform):
Robert Bradshaw's avatar
Robert Bradshaw committed
783

784
    def visit_ModuleNode(self, node):
785
        node.scope.infer_types()
786 787 788 789 790
        node.body.analyse_expressions(node.scope)
        self.visitchildren(node)
        return node
        
    def visit_FuncDefNode(self, node):
791
        node.local_scope.infer_types()
792 793 794
        node.body.analyse_expressions(node.local_scope)
        self.visitchildren(node)
        return node
795 796 797 798 799 800 801 802 803
        
class AlignFunctionDefinitions(CythonTransform):
    """
    This class takes the signatures from a .pxd file and applies them to 
    the def methods in a .py file. 
    """
    
    def visit_ModuleNode(self, node):
        self.scope = node.scope
804
        self.directives = node.directives
805 806 807 808 809 810 811 812 813 814 815 816
        self.visitchildren(node)
        return node
    
    def visit_PyClassDefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
        if pxd_def:
            if pxd_def.is_cclass:
                return self.visit_CClassDefNode(node.as_cclass(), pxd_def)
            else:
                error(node.pos, "'%s' redeclared" % node.name)
                error(pxd_def.pos, "previous declaration here")
                return None
817 818
        else:
            return node
819 820 821 822 823 824 825 826 827 828 829 830 831 832 833
        
    def visit_CClassDefNode(self, node, pxd_def=None):
        if pxd_def is None:
            pxd_def = self.scope.lookup(node.class_name)
        if pxd_def:
            outer_scope = self.scope
            self.scope = pxd_def.type.scope
        self.visitchildren(node)
        if pxd_def:
            self.scope = outer_scope
        return node
        
    def visit_DefNode(self, node):
        pxd_def = self.scope.lookup(node.name)
        if pxd_def:
834 835 836
            if self.scope.is_c_class_scope and len(pxd_def.type.args) > 0:
                # The self parameter type needs adjusting.
                pxd_def.type.args[0].type = self.scope.parent_type
837 838 839 840 841 842
            if pxd_def.is_cfunction:
                node = node.as_cfunction(pxd_def)
            else:
                error(node.pos, "'%s' redeclared" % node.name)
                error(pxd_def.pos, "previous declaration here")
                return None
843 844
        elif self.scope.is_module_scope and self.directives['auto_cpdef']:
            node = node.as_cfunction(scope=self.scope)
845 846 847 848
        # Enable this when internal def functions are allowed. 
        # self.visitchildren(node)
        return node
        
849

850
class MarkClosureVisitor(CythonTransform):
Robert Bradshaw's avatar
Robert Bradshaw committed
851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868
    
    needs_closure = False
    
    def visit_FuncDefNode(self, node):
        self.needs_closure = False
        self.visitchildren(node)
        node.needs_closure = self.needs_closure
        self.needs_closure = True
        return node
        
    def visit_ClassDefNode(self, node):
        self.visitchildren(node)
        self.needs_closure = True
        return node
        
    def visit_YieldNode(self, node):
        self.needs_closure = True
        
869
class CreateClosureClasses(CythonTransform):
870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894
    # Output closure classes in module scope for all functions
    # that need it. 
    
    def visit_ModuleNode(self, node):
        self.module_scope = node.scope
        self.visitchildren(node)
        return node

    def create_class_from_scope(self, node, target_module_scope):
        as_name = temp_name_handle("closure")
        func_scope = node.local_scope

        entry = target_module_scope.declare_c_class(name = as_name,
            pos = node.pos, defining = True, implementing = True)
        class_scope = entry.type.scope
        for entry in func_scope.entries.values():
            class_scope.declare_var(pos=node.pos,
                                    name=entry.name,
                                    cname=entry.cname,
                                    type=entry.type,
                                    is_cdef=True)
            
    def visit_FuncDefNode(self, node):
        self.create_class_from_scope(node, self.module_scope)
        return node
895 896 897 898 899 900 901 902 903 904


class GilCheck(VisitorTransform):
    """
    Call `node.gil_check(env)` on each node to make sure we hold the
    GIL when we need it.  Raise an error when on Python operations
    inside a `nogil` environment.
    """
    def __call__(self, root):
        self.env_stack = [root.scope]
905
        self.nogil = False
906 907 908 909
        return super(GilCheck, self).__call__(root)

    def visit_FuncDefNode(self, node):
        self.env_stack.append(node.local_scope)
910 911 912 913
        was_nogil = self.nogil
        self.nogil = node.local_scope.nogil
        if self.nogil and node.nogil_check:
            node.nogil_check(node.local_scope)
914 915
        self.visitchildren(node)
        self.env_stack.pop()
916
        self.nogil = was_nogil
917 918 919 920
        return node

    def visit_GILStatNode(self, node):
        env = self.env_stack[-1]
921 922 923
        if self.nogil and node.nogil_check: node.nogil_check()
        was_nogil = self.nogil
        self.nogil = (node.state == 'nogil')
924
        self.visitchildren(node)
925
        self.nogil = was_nogil
926 927 928
        return node

    def visit_Node(self, node):
929 930
        if self.env_stack and self.nogil and node.nogil_check:
            node.nogil_check(self.env_stack[-1])
931 932 933
        self.visitchildren(node)
        return node

934

Robert Bradshaw's avatar
Robert Bradshaw committed
935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951
class EnvTransform(CythonTransform):
    """
    This transformation keeps a stack of the environments. 
    """
    def __call__(self, root):
        self.env_stack = [root.scope]
        return super(EnvTransform, self).__call__(root)        
    
    def visit_FuncDefNode(self, node):
        self.env_stack.append(node.local_scope)
        self.visitchildren(node)
        self.env_stack.pop()
        return node


class TransformBuiltinMethods(EnvTransform):

952 953 954 955 956 957
    def visit_SingleAssignmentNode(self, node):
        if node.declaration_only:
            return None
        else:
            self.visitchildren(node)
            return node
958 959
    
    def visit_AttributeNode(self, node):
960
        self.visitchildren(node)
961 962 963 964 965 966 967
        return self.visit_cython_attribute(node)

    def visit_NameNode(self, node):
        return self.visit_cython_attribute(node)
        
    def visit_cython_attribute(self, node):
        attribute = node.as_cython_attribute()
968 969 970
        if attribute:
            if attribute == u'compiled':
                node = BoolNode(node.pos, value=True)
971 972 973
            elif attribute == u'NULL':
                node = NullNode(node.pos)
            elif not PyrexTypes.parse_basic_type(attribute):
Robert Bradshaw's avatar
Robert Bradshaw committed
974
                error(node.pos, u"'%s' not a valid cython attribute or is being used incorrectly" % attribute)
975 976 977 978
        return node

    def visit_SimpleCallNode(self, node):

Robert Bradshaw's avatar
Robert Bradshaw committed
979
        # locals builtin
Robert Bradshaw's avatar
Robert Bradshaw committed
980 981 982
        if isinstance(node.function, ExprNodes.NameNode):
            if node.function.name == 'locals':
                lenv = self.env_stack[-1]
983 984 985 986
                entry = lenv.lookup_here('locals')
                if entry:
                    # not the builtin 'locals'
                    return node
987 988 989
                if len(node.args) > 0:
                    error(self.pos, "Builtin 'locals()' called with wrong number of args, expected 0, got %d" % len(node.args))
                    return node
990
                pos = node.pos
Robert Bradshaw's avatar
Robert Bradshaw committed
991
                items = [ExprNodes.DictItemNode(pos, 
992
                                                key=ExprNodes.StringNode(pos, value=var),
Robert Bradshaw's avatar
Robert Bradshaw committed
993 994
                                                value=ExprNodes.NameNode(pos, name=var)) for var in lenv.entries]
                return ExprNodes.DictNode(pos, key_value_pairs=items)
995 996

        # cython.foo
997
        function = node.function.as_cython_attribute()
998 999 1000
        if function:
            if function == u'cast':
                if len(node.args) != 2:
1001
                    error(node.function.pos, u"cast takes exactly two arguments")
1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
                else:
                    type = node.args[0].analyse_as_type(self.env_stack[-1])
                    if type:
                        node = TypecastNode(node.function.pos, type=type, operand=node.args[1])
                    else:
                        error(node.args[0].pos, "Not a type")
            elif function == u'sizeof':
                if len(node.args) != 1:
                    error(node.function.pos, u"sizeof takes exactly one argument" % function)
                else:
                    type = node.args[0].analyse_as_type(self.env_stack[-1])
                    if type:
                        node = SizeofTypeNode(node.function.pos, arg_type=type)
                    else:
                        node = SizeofVarNode(node.function.pos, operand=node.args[0])
1017 1018 1019 1020 1021
            elif function == 'typeof':
                if len(node.args) != 1:
                    error(node.function.pos, u"sizeof takes exactly one argument" % function)
                else:
                    node = TypeofNode(node.function.pos, operand=node.args[0])
1022 1023 1024 1025 1026
            elif function == 'address':
                if len(node.args) != 1:
                    error(node.function.pos, u"sizeof takes exactly one argument" % function)
                else:
                    node = AmpersandNode(node.function.pos, operand=node.args[0])
1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038
            elif function == 'cmod':
                if len(node.args) != 2:
                    error(node.function.pos, u"cmod takes exactly one argument" % function)
                else:
                    node = binop_node(node.function.pos, '%', node.args[0], node.args[1])
                    node.cdivision = True
            elif function == 'cdiv':
                if len(node.args) != 2:
                    error(node.function.pos, u"cmod takes exactly one argument" % function)
                else:
                    node = binop_node(node.function.pos, '/', node.args[0], node.args[1])
                    node.cdivision = True
1039 1040 1041 1042
            else:
                error(node.function.pos, u"'%s' not a valid cython language construct" % function)
        
        self.visitchildren(node)
Robert Bradshaw's avatar
Robert Bradshaw committed
1043
        return node