jedityper.py 4.16 KB
Newer Older
1 2 3
"""
Inject Cython type declarations into a .py file using the Jedi static analysis tool.
"""
4 5 6 7 8

from __future__ import absolute_import

from io import open
from collections import defaultdict
9 10 11
from itertools import chain

import jedi
12 13
from jedi.parser.tree import Module, ImportName
from jedi.evaluate.representation import Function, Instance, Class
14
from jedi.evaluate.iterable import ArrayMixin, GeneratorComprehension
15 16 17 18 19 20 21 22 23 24

from Cython.Utils import open_source_file


default_type_map = {
    'float': 'double',
    'int': 'long',
}


25
def analyse(source_path=None, code=None):
26 27 28 29
    """
    Analyse a Python source code file with Jedi.
    Returns a mapping from (scope-name, (line, column)) pairs to a name-types mapping.
    """
30 31
    if not source_path and code is None:
        raise ValueError("Either 'source_path' or 'code' is required.")
32
    scoped_names = {}
33
    statement_iter = jedi.names(source=code, path=source_path, all_scopes=True)
mathbunnyru's avatar
mathbunnyru committed
34

35
    for statement in statement_iter:
36 37 38 39 40 41 42
        parent = statement.parent()
        scope = parent._definition
        evaluator = statement._evaluator

        # skip function/generator definitions, class definitions, and module imports
        if any(isinstance(statement._definition, t) for t in [Function, Class, ImportName]):
            continue
43 44 45 46
        key = (None if isinstance(scope, Module) else str(parent.name), scope.start_pos)
        try:
            names = scoped_names[key]
        except KeyError:
47
            names = scoped_names[key] = defaultdict(set)
mathbunnyru's avatar
mathbunnyru committed
48

49
        position = statement.start_pos if statement.name in names else None
50

51
        for name_type in evaluator.find_types(scope, statement.name, position=position ,search_global=True):
52 53
            if isinstance(name_type, Instance):
                if isinstance(name_type.base, Class):
54 55
                    type_name = 'object'
                else:
56
                    type_name = name_type.base.obj.__name__
57
            elif isinstance(name_type, ArrayMixin):
58 59 60 61 62
                type_name = name_type.type
            elif isinstance(name_type, GeneratorComprehension):
                type_name = None
            else:
                try:
mathbunnyru's avatar
mathbunnyru committed
63 64
                    type_name = type(name_type.obj).__name__
                except AttributeError as error:
65 66 67
                    type_name = None
            if type_name is not None:
                names[str(statement.name)].add(type_name)
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86
    return scoped_names


def inject_types(source_path, types, type_map=default_type_map, mode='python'):
    """
    Hack type declarations into source code file.

    @param mode is currently 'python', which means that the generated type declarations use pure Python syntax.
    """
    col_and_types_by_line = dict(
        # {line: (column, scope_name or None, [(name, type)])}
        (k[-1][0], (k[-1][1], k[0], [(n, next(iter(t))) for (n, t) in v.items() if len(t) == 1]))
        for (k, v) in types.items())

    lines = [u'import cython\n']
    with open_source_file(source_path) as f:
        for line_no, line in enumerate(f, 1):
            if line_no in col_and_types_by_line:
                col, scope, types = col_and_types_by_line[line_no]
87 88 89 90 91 92 93 94
                if types:
                    types = ', '.join("%s='%s'" % (name, type_map.get(type_name, type_name))
                                    for name, type_name in types)
                    if scope is None:
                        type_decl = u'{indent}cython.declare({types})\n'
                    else:
                        type_decl = u'{indent}@cython.locals({types})\n'
                    lines.append(type_decl.format(indent=' '*col, types=types))
95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
            lines.append(line)

    return lines


def main(file_paths=None, overwrite=False):
    """
    Main entry point to process a list of .py files and inject type inferred declarations.
    """
    if file_paths is None:
        import sys
        file_paths = sys.argv[1:]

    for source_path in file_paths:
        types = analyse(source_path)
        lines = inject_types(source_path, types)
        target_path = source_path + ('' if overwrite else '_typed.py')
        with open(target_path, 'w', encoding='utf8') as f:
            for line in lines:
                f.write(line)


if __name__ == '__main__':
    main()