neolog.py 13.3 KB
Newer Older
1 2
#!/usr/bin/env python
#
3
# neolog - read a NEO log
4
#
5
# Copyright (C) 2012-2019  Nexedi SA
6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

20
import argparse, bz2, gzip, errno, os, signal, sqlite3, sys, time
21
from bisect import insort
22
from itertools import chain
23
from logging import getLevelName
24
from zlib import decompress
25

26 27 28 29 30 31 32 33 34 35 36
try:
    import zstd
except ImportError:
    zstdcat = 'zstdcat'
else:
    from cStringIO import StringIO
    def zstdcat(path):
        with open(path, 'rb') as f:
            return StringIO(zstd.decompress(f.read()))

comp_dict = dict(bz2=bz2.BZ2File, gz=gzip.GzipFile, xz='xzcat', zst=zstdcat)
37 38

class Log(object):
39

40
    _log_date = _packet_date = 0
41 42
    _protocol_date = None

43
    def __init__(self, db_path, decode=0, date_format=None,
44 45
                       filter_from=None, show_cluster=False, no_nid=False,
                       node_column=True, node_list=None):
46
        self._date_format = '%F %T' if date_format is None else date_format
47
        self._decode = decode
48
        self._filter_from = filter_from
49
        self._no_nid = no_nid
50
        self._node_column = node_column
51
        self._node_list = node_list
52 53
        self._node_dict = {}
        self._show_cluster = show_cluster
54 55 56 57 58
        name = os.path.basename(db_path)
        try:
            name, ext = name.rsplit(os.extsep, 1)
            ZipFile = comp_dict[ext]
        except (KeyError, ValueError):
59 60
            # BBB: Python 2 does not support URI so we can't open in read-only
            #      mode. See https://bugs.python.org/issue13773
61 62 63
            os.stat(db_path) # do not create empty DB if file is missing
            self._db = sqlite3.connect(db_path)
        else:
64
            import shutil, subprocess, tempfile
65
            with tempfile.NamedTemporaryFile() as f:
66 67 68 69
                if type(ZipFile) is str:
                    subprocess.check_call((ZipFile, db_path), stdout=f)
                else:
                    shutil.copyfileobj(ZipFile(db_path), f)
70 71 72
                self._db = sqlite3.connect(f.name)
            name = name.rsplit(os.extsep, 1)[0]
        self._default_name = name
73

74
    def __iter__(self):
75
        q = self._db.execute
76
        try:
77
            q("BEGIN")
78
            yield
79
            date = self._filter_from
80 81 82 83 84 85 86 87 88 89
            if date and max(self._log_date, self._packet_date) < date:
                log_args = packet_args = date,
                date = " WHERE date>=?"
            else:
                self._filter_from = None
                log_args = self._log_date,
                packet_args = self._packet_date,
                date = " WHERE date>?"
            old = "SELECT date, name, NULL, NULL, %s FROM %s" + date
            new = ("SELECT date, name, cluster, nid, %s"
90
                   " FROM %s LEFT JOIN node ON node=id" + date)
91 92 93 94 95 96 97 98 99 100 101 102 103 104
            log = 'level, pathname, lineno, msg'
            pkt = 'msg_id, code, peer, body'
            try:
                nl = q(new % (log, 'log'), log_args)
            except sqlite3.OperationalError:
                nl = q(old % (log, 'log'), log_args)
                np = q(old % (pkt, 'packet'), packet_args)
            else:
                np = q(new % (pkt, 'packet'), packet_args)
                try:
                    nl = chain(q(old % (log, 'log1'), log_args), nl)
                    np = chain(q(old % (pkt, 'packet1'), packet_args), np)
                except sqlite3.OperationalError:
                    pass
105 106
            try:
                p = np.next()
107
                self._reload(p[0])
108 109
            except StopIteration:
                p = None
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
            except sqlite3.DatabaseError, e:
                yield time.time(), None, 'PACKET', self._exc(e)
                p = None
            try:
                for date, name, cluster, nid, level, pathname, lineno, msg in nl:
                    while p and p[0] < date:
                        yield self._packet(*p)
                        try:
                            p = next(np, None)
                        except sqlite3.DatabaseError, e:
                            yield time.time(), None, 'PACKET', self._exc(e)
                            p = None
                    self._log_date = date
                    yield (date, self._node(name, cluster, nid),
                           getLevelName(level), msg.splitlines())
            except sqlite3.DatabaseError, e:
                yield time.time(), None, 'LOG', self._exc(e)
127
            if p:
128
                yield self._packet(*p)
129 130 131 132 133
                try:
                    for p in np:
                        yield self._packet(*p)
                except sqlite3.DatabaseError, e:
                    yield time.time(), None, 'PACKET', self._exc(e)
134
        finally:
135 136
            self._db.rollback()

137 138 139 140
    @staticmethod
    def _exc(e):
        return ('%s: %s' % (type(e).__name__, e)).splitlines()

141 142 143 144 145 146
    def _node(self, name, cluster, nid):
        if nid and not self._no_nid:
            name = self.uuid_str(nid)
            if self._show_cluster:
                name = cluster + '/' + name
        return name
147

148 149
    def _reload(self, date):
        q = self._db.execute
150 151 152 153 154
        date, text = q("SELECT * FROM protocol WHERE date<=?"
                       " ORDER BY date DESC", (date,)).next()
        if self._protocol_date == date:
            return
        self._protocol_date = date
155
        g = {}
156
        exec bz2.decompress(text) in g
157 158
        for x in 'uuid_str', 'Packets', 'PacketMalformedError':
            setattr(self, x, g[x])
159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182
        x = {}
        if self._decode > 1:
            PStruct = g['PStruct']
            PBoolean = g['PBoolean']
            def hasData(item):
                items = item._items
                for i, item in enumerate(items):
                    if isinstance(item, PStruct):
                        j = hasData(item)
                        if j:
                            return (i,) + j
                    elif (isinstance(item, PBoolean)
                          and item._name == 'compression'
                          and i + 2 < len(items)
                          and items[i+2]._name == 'data'):
                        return i,
            for p in self.Packets.itervalues():
                if p._fmt is not None:
                    path = hasData(p._fmt)
                    if path:
                        assert not hasattr(p, '_neolog'), p
                        x[p._code] = path
        self._getDataPath = x.get

183
        try:
184
            self._next_protocol, = q("SELECT date FROM protocol WHERE date>?",
185 186 187 188 189
                                     (date,)).next()
        except StopIteration:
            self._next_protocol = float('inf')

    def _emit(self, date, name, levelname, msg_list):
190 191 192 193
        if not name:
            name = self._default_name
        if self._node_list and name not in self._node_list:
            return
194 195 196 197 198
        prefix = self._date_format
        if prefix:
            d = int(date)
            prefix = '%s.%04u ' % (time.strftime(prefix, time.localtime(d)),
                                   int((date - d) * 10000))
199 200
        prefix += ('%-9s %-10s ' % (levelname, name) if self._node_column else
                   '%-9s ' % levelname)
201 202 203
        for msg in msg_list:
            print prefix + msg

204 205
    def _packet(self, date, name, cluster, nid, msg_id, code, peer, body):
        self._packet_date = date
206 207 208 209
        if self._next_protocol <= date:
            self._reload(date)
        try:
            p = self.Packets[code]
210
            msg = p.__name__
211
        except KeyError:
212 213 214
            msg = 'UnknownPacket[%u]' % code
            body = None
        msg = ['#0x%04x %-30s %s' % (msg_id, msg, peer)]
215
        if body is not None:
216
            log = getattr(p, '_neolog', None)
217
            if log or self._decode:
218 219 220 221 222
                p = p()
                p._id = msg_id
                p._body = body
                try:
                    args = p.decode()
223
                except self.PacketMalformedError:
224 225
                    msg.append("Can't decode packet")
                else:
226 227 228
                    if log:
                        args, extra = log(*args)
                        msg += extra
229 230 231 232 233
                    else:
                        path = self._getDataPath(code)
                        if path:
                            args = self._decompress(args, path)
                    if args and self._decode:
234
                        msg[0] += ' \t| ' + repr(args)
235
        return date, self._node(name, cluster, nid), 'PACKET', msg
236

237 238 239 240 241 242 243 244 245 246 247 248 249 250
    def _decompress(self, args, path):
        if args:
            args = list(args)
            i = path[0]
            path = path[1:]
            if path:
                args[i] = self._decompress(args[i], path)
            else:
                data = args[i+2]
                if args[i]:
                    data = decompress(data)
                args[i:i+3] = (len(data), data),
            return tuple(args)

251

252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274
def emit_many(log_list):
    log_list = [(log, iter(log).next) for log in log_list]
    for x in log_list: # try to start all transactions at the same time
        x[1]()
    event_list = []
    for log, next in log_list:
        try:
            event = next()
        except StopIteration:
            continue
        event_list.append((-event[0], next, log._emit, event))
    if event_list:
        event_list.sort()
        while True:
            key, next, emit, event = event_list.pop()
            try:
                next_date = - event_list[-1][0]
            except IndexError:
                next_date = float('inf')
            try:
                while event[0] <= next_date:
                    emit(*event)
                    event = next()
275 276 277 278
            except IOError, e:
                if e.errno == errno.EPIPE:
                    sys.exit(1)
                raise
279 280 281 282 283 284 285
            except StopIteration:
                if not event_list:
                    break
            else:
                insort(event_list, (-event[0], next, emit, event))

def main():
286 287 288
    parser = argparse.ArgumentParser(description='NEO Log Reader')
    _ = parser.add_argument
    _('-a', '--all', action="store_true",
289
        help='decode body of packets')
290
    _('-A', '--decompress', action="store_true",
291
        help='decompress data when decode body of packets (implies --all)')
292
    _('-d', '--date', metavar='FORMAT',
293
        help='custom date format, according to strftime(3)')
294
    _('-f', '--follow', action="store_true",
295
        help='output appended data as the file grows')
296
    _('-F', '--flush', action="append", type=int, metavar='PID',
297
        help='with -f, tell process PID to flush logs approximately N'
298 299
              ' seconds (see -s)')
    _('-n', '--node', action="append",
300
        help='only show log entries from the given node'
301 302
             ' (only useful for logs produced by threaded tests),'
             " special value '-' hides the column")
303 304
    _('-s', '--sleep-interval', type=float, default=1., metavar='N',
        help='with -f, sleep for approximately N seconds (default %(default)s)'
305 306 307 308 309 310
             ' between iterations')
    _('--from', dest='filter_from', metavar='N',
        help='show records more recent that timestamp N if N > 0, or now+N'
             ' if N < 0; N can also be a string that is parseable by dateutil')
    _('file', nargs='+',
        help='log file, compressed (gz, bz2 or xz) or not')
311 312 313 314 315
    _ = parser.add_mutually_exclusive_group().add_argument
    _('-C', '--cluster', action="store_true",
        help='show cluster name in node column')
    _('-N', '--no-nid', action="store_true",
        help='always show node name (instead of NID) in node column')
316 317
    args = parser.parse_args()
    if args.sleep_interval <= 0:
318
        parser.error("sleep_interval must be positive")
319
    filter_from = args.filter_from
320 321
    if filter_from:
        try:
322
            filter_from = float(args.filter_from)
323 324 325 326 327 328 329 330 331 332
        except ValueError:
            from dateutil.parser import parse
            x = parse(filter_from)
            if x.tzinfo:
                filter_from = (x - x.fromtimestamp(0, x.tzinfo)).total_seconds()
            else:
                filter_from = time.mktime(x.timetuple()) + x.microsecond * 1e-6
        else:
            if filter_from < 0:
                filter_from += time.time()
333
    node_list = args.node or []
334 335 336 337 338
    try:
        node_list.remove('-')
        node_column = False
    except ValueError:
        node_column = True
339
    log_list = [Log(db_path,
340
                    2 if args.decompress else 1 if args.all else 0,
341 342
                    args.date, filter_from, args.cluster, args.no_nid,
                    node_column, node_list)
343 344
                for db_path in args.file]
    if args.follow:
345
        try:
346
            pid_list = args.flush or ()
347 348 349 350
            while True:
                emit_many(log_list)
                for pid in pid_list:
                    os.kill(pid, signal.SIGRTMIN)
351
                time.sleep(args.sleep_interval)
352 353 354 355 356
        except KeyboardInterrupt:
            pass
    else:
        emit_many(log_list)

357
if __name__ == "__main__":
358
    main()