from Cython.Compiler.Visitor import VisitorTransform, ScopeTrackingTransform, TreeVisitor from Nodes import StatListNode, SingleAssignmentNode, CFuncDefNode from ExprNodes import (DictNode, DictItemNode, NameNode, UnicodeNode, NoneNode, ExprNode, AttributeNode, ModuleRefNode, DocstringRefNode) from PyrexTypes import py_object_type from Builtin import dict_type from StringEncoding import EncodedString import Naming class DoctestHackTransform(ScopeTrackingTransform): # Handles doctesthack directive def visit_ModuleNode(self, node): self.scope_type = 'module' self.scope_node = node if self.current_directives['doctesthack']: assert isinstance(node.body, StatListNode) # First see if __test__ is already created if u'__test__' in node.scope.entries: # Do nothing return node pos = node.pos self.tests = [] self.testspos = node.pos test_dict_entry = node.scope.declare_var(EncodedString(u'__test__'), py_object_type, pos, visibility='public') create_test_dict_assignment = SingleAssignmentNode(pos, lhs=NameNode(pos, name=EncodedString(u'__test__'), entry=test_dict_entry), rhs=DictNode(pos, key_value_pairs=self.tests)) self.visitchildren(node) node.body.stats.append(create_test_dict_assignment) return node def add_test(self, testpos, name, func_ref_node): # func_ref_node must evaluate to the function object containing # the docstring, BUT it should not be the function itself (which # would lead to a new *definition* of the function) pos = self.testspos keystr = u'%s (line %d)' % (name, testpos[1]) key = UnicodeNode(pos, value=EncodedString(keystr)) value = DocstringRefNode(pos, func_ref_node) self.tests.append(DictItemNode(pos, key=key, value=value)) def visit_FuncDefNode(self, node): if node.doc: if isinstance(node, CFuncDefNode) and not node.py_func: # skip non-cpdef cdef functions return node pos = self.testspos if self.scope_type == 'module': parent = ModuleRefNode(pos) name = node.entry.name elif self.scope_type in ('pyclass', 'cclass'): mod = ModuleRefNode(pos) if self.scope_type == 'pyclass': clsname = self.scope_node.name else: clsname = self.scope_node.class_name parent = AttributeNode(pos, obj=mod, attribute=clsname, type=py_object_type, is_py_attr=True, is_temp=True) name = "%s.%s" % (clsname, node.entry.name) else: assert False getfunc = AttributeNode(pos, obj=parent, attribute=node.entry.name, type=py_object_type, is_py_attr=True, is_temp=True) self.add_test(node.pos, name, getfunc) return node