#!/usr/local/bin/python3.11

#
# Copyright (c) 2021, Chad Miller  chad.org
# All rights reserved.
#

import curses
import subprocess
from contextlib import suppress
from base64 import encodebytes as b64encodebytes
import argparse
from itertools import groupby
from time import time
from random import choice as randomchoice
import sys


EMPTY_COLORS = (238, 8)
EMPTY_CHAR = "|"
PALETTES = list(range(start, end+(second-start), second-start) for second, start, end in ((63, 27, 207), (87, 51, 231), (116, 123, 88), (220, 226, 196), (224, 231, 196), (80, 51, 196), (77, 40, 225), (225, 231, 201), (243, 240, 255), (227, 226, 231), (27, 21, 51)))
DISPLAY_CHARS_LETTERS = ".abcdefghijklmnopqrstuvwxyz^"
DISPLAY_CHARS_DIGITS = ".0123456789#"
DISPLAY_CHARS_SYMBOLS = " .:;*#"
DISPLAY_CHARS = DISPLAY_CHARS_LETTERS
DIFFL_CLOCK_CHARS = "╷╴╵╶"

DIFFL_STAT_MEMORY = 5

raw = None

def short_time_and_color(ns):
    """Given a time in nanoseconds, return a pleasant, legible string
    approximating that time, and an integer representing an appropriate color for
    that number of nanoseconds of waiting. Bigger is worse."""
    if ns < 30000:
        color = TIME_ANSI_COLORS[0]
    elif ns < 600000:
        color = TIME_ANSI_COLORS[1]
    elif ns < 8500000:
        color = TIME_ANSI_COLORS[2]
    elif ns < 15000000:
        color = TIME_ANSI_COLORS[3]
    elif ns < 155000000:
        color = TIME_ANSI_COLORS[4]
    else:
        color = TIME_ANSI_COLORS[5]

    if ns < 800:
        return "{: 7.0f}ns".format(ns), color
    elif ns < 800000:
        return "{: 7.2f}µs".format(ns/1000), color
    elif ns < 1000000000:
        return "{: 7.2f}ms".format(ns/1000/1000), color
    else:
        return "{: 7.2f}s".format(ns/1000/1000/1000), color


def get_stats(pool_names, filename=None):
    """Populate a "raw" global variable of the last thing we read, and return a
    structure -- a list of pairs of vdev-name and vdev-timings, where a vdev-timing
    is a list of rows, each as a ColsIoStat named tuple."""
    global raw
    if filename:
        with open(filename, encoding="UTF-8") as f:
            raw = f.read()
    else:
        zpool_cmd = subprocess.run(["zpool", "iostat", "-wvHp", "--"] + (pool_names if pool_names else []), check=True, stdout=subprocess.PIPE, encoding="UTF-8")
        raw = zpool_cmd.stdout

    ## TODO: Get this header list from zpool somehow. Only it is authoratative.
    headers = "total wait read/total wait write/disk wait read/disk wait write/syncq wait read (through txg)/syncq wait write (through txg)/asyncq wait read (from zil)/asyncq wait write (to zil)/scrub/trim".split("/")

    stats = []
    for line in raw.split("\n"):
        if not line:
            continue
        elif "\t" in line:
            row_number += 1
            values = tuple(int(s) for s in line.split("\t"))
            stats[-1][1].append((values[0], tuple(zip(headers, values[1:]))))
        else:
            row_number = -1
            stats.append([line, []])

    return stats


def scaled_to_fraction(range_minimum, subject_value, range_maximum):
    """Take a number in a range and return a fraction of how far it is into
    that range."""
    if range_minimum == range_maximum:
        return 0
    assert range_minimum <= subject_value <= range_maximum, (range_minimum, subject_value, range_maximum)
    return (subject_value-range_minimum) / (range_maximum-range_minimum)


def stats_as_device_centric(stats):
    """
    >>> stats_as_device_centric([['dev1', [(1, (('mes1', 2), ('mes2', 3))), (4, (('mes1', 5), ('mes2', 6)))]], ['dev2', [(1, (('mes1', 7), ('mes2', 8))), (4, ('mes1', 9), ('mes2', 10))]]])
    [['dev1', [(1, (('mes1', 2), ('mes2', 3))), (4, (('mes1', 5), ('mes2', 6)))]], ['dev2', [(1, (('mes1', 7), ('mes2', 8))), (4, ('mes1', 9), ('mes2', 10))]]]
    """
    return stats


def stats_as_measurement_centric(stats):
    """
    >>> stats_as_measurement_centric([['dev1', [(1, (('mes1', 2), ('mes2', 3))), (4, (('mes1', 5), ('mes2', 6)))]], ['dev2', [(1, (('mes1', 7), ('mes2', 8))), (4, (('mes1', 9), ('mes2', 10)))]]])
    [['mes1', [(1, (('dev1', 2), ('dev2', 7))), (4, (('dev1', 5), ('dev2', 9)))]], ['mes2', [(1, (('dev1', 3), ('dev2', 8))), (4, (('dev1', 6), ('dev2', 10)))]]]
    >>> stats_as_measurement_centric([['devx', [(1, (('mes1', 2), ('mes2', 3))), (4, (('mes1', 5), ('mes2', 6 )))]], ['devx', [(1, (('mes1', 7), ('mes2', 8))), (4, (('mes1', 9), ('mes2', 10)))]]])
    [['mes1', [(1, (('devx', 2), ('devx', 7))), (4, (('devx', 5), ('devx', 9)))]], ['mes2', [(1, (('devx', 3), ('devx', 8))), (4, (('devx', 6), ('devx', 10)))]]]
    >>> stats_as_measurement_centric([['dev2', [(1, (('mes2', 2), ('mes1', 3))), (4, (('mes2', 5), ('mes1', 6)))]], ['dev1', [(1, (('mes2', 7), ('mes1', 8))), (4, (('mes2', 9), ('mes1', 10)))]]])
    [['mes2', [(1, (('dev2', 2), ('dev1', 7))), (4, (('dev2', 5), ('dev1', 9)))]], ['mes1', [(1, (('dev2', 3), ('dev1', 8))), (4, (('dev2', 6), ('dev1', 10)))]]]
    """

    return list([mes, list((timing, tuple(devicefortiming[4:] for devicefortiming in alldatafortiming)) for timing, alldatafortiming in groupby(sorted(v0, key=lambda six: six[3]), key=lambda six: six[3]))] for (_, mes), v0 in groupby(sorted(("m"+str(measurementnumber), "d"+str(devicenumber), measurementname, tns, device, bucketsize) for devicenumber, (device, timingsperdevice) in enumerate(stats) for tns, devicesandbuckets in timingsperdevice for measurementnumber, (measurementname, bucketsize) in enumerate(devicesandbuckets)), key=lambda six: (six[0], six[2])))


def render_stats(window, transform, should_show_differential, pool, filename=None):
    read_count = 0
    stats = None
    stats_history = []
    current = 0
    load_time = None
    diffl_stat_intervals = (2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 30, 60, 120, 180, 300, 600) #seconds
    diffl_stat_interval_index = 2
    diffl_title = "(&){:d}s↕ ".format(diffl_stat_intervals[diffl_stat_interval_index])
    while True:
        if not load_time or load_time + diffl_stat_intervals[diffl_stat_interval_index] < time():
            if stats:
                stats_history.append(stats)
            while len(stats_history) > DIFFL_STAT_MEMORY:
                stats_history.pop(0)

            stats = transform(get_stats(pool, filename))
            load_time = time()
            read_count += 1

        name, rows  = stats[current]

        if should_show_differential:
            rows_containing_data = list(row_number for row_number in range(len(rows)))
        else:
            rows_containing_data = list(row_number for row_number, (_, data) in enumerate(rows) if any(colval > 0 for colhead, colval in data))

        if should_show_differential and stats_history:
            assert stats is not None
            assert stats_history is not None
            assert stats_history[0] is not None
            rows = list((stats[current][1][rn][0], tuple((newk, newv-oldv) for (newk,newv),(oldk,oldv) in zip(stats[current][1][rn][1], stats_history[0][current][1][rn][1]))) for rn in range(len(rows)))

        max_per_column = tuple(max(column) for column in zip(*tuple([v for k, v in row] for _, row in rows)))


        window.clear()
        if should_show_differential:
            if stats_history:
                diffl_title = "({}){:d}s↕ ".format(DIFFL_CLOCK_CHARS[read_count%4], diffl_stat_intervals[diffl_stat_interval_index])
            with suppress(curses.error):
                window.addstr(0, 0, diffl_title)
        with suppress(curses.error):
            window.addstr(0, 10 if should_show_differential else 0, "{:>2d}/{}↔  Histogram for {}".format(current+1, len(stats), name))

        printed_row_number = 0
        for row_number, (time_ns, data) in enumerate(rows):
            if not rows_containing_data: continue
            if not min(rows_containing_data)-1 <= row_number <= max(rows_containing_data)+1: continue
            printed_row_number += 1

            with suppress(curses.error):
                t, col = short_time_and_color(time_ns)
                window.addstr(printed_row_number, 0, t, curses.color_pair(col))  # write legend

            for col_number, (_, value) in enumerate(data):

                scaled_0_to_1 = scaled_to_fraction(0, value, max_per_column[col_number])
                if scaled_0_to_1 > 0.001:
                    glyph = randomchoice(DISPLAY_CHARS[int(scaled_0_to_1 * (len(DISPLAY_CHARS)-1))])
                    color = HISTOGRAM_ANSI_COLORS[int(scaled_0_to_1 * (len(HISTOGRAM_ANSI_COLORS)-1))]
                else:
                    glyph = EMPTY_CHAR
                    color = EMPTY_COLORS[col_number % len(EMPTY_COLORS)]

                with suppress(curses.error):
                    window.addstr(printed_row_number, 12+(col_number*2), glyph, curses.color_pair(color))

        if rows_containing_data:
            device_names = list(reversed([k for k, _ in rows[0][1]]))
            glyph = EMPTY_CHAR
            while device_names:
                printed_row_number += 1
                with suppress(curses.error):
                    for col_number in range(len(device_names)):
                        color = EMPTY_COLORS[col_number % len(EMPTY_COLORS)]
                        window.addstr(printed_row_number, 12+(2*col_number), glyph, curses.color_pair(color))
                device_name = device_names.pop(0)
                with suppress(curses.error):
                    window.addstr(printed_row_number, 12+(2*col_number), "`{}".format(device_name))
        else:
            with suppress(curses.error):
                window.addstr(4, 16, "(no data)")

        height, width = window.getmaxyx()
        message = "  Population of histogram buckets shown with .a-z^ and colors"
        with suppress(curses.error):
            window.addstr(height-1, width-len(message)-1, message)
            for i, (ch, color) in enumerate(zip(reversed(message), reversed(HISTOGRAM_ANSI_COLORS))):
                window.addstr(height-1, width-1-i-1, ch, curses.color_pair(color))

        window.refresh()

        in_key = window.getch()
        if in_key == -1:
            pass
        elif in_key == curses.KEY_RIGHT:
            current += 1
        elif in_key == curses.KEY_LEFT:
            current -= 1
        elif in_key == curses.KEY_UP:
            if diffl_stat_interval_index < len(diffl_stat_intervals) - 1:
                diffl_stat_interval_index += 1
        elif in_key == curses.KEY_DOWN:
            if diffl_stat_interval_index > 0:
                diffl_stat_interval_index -= 1
        elif in_key == ord('d'):
            should_show_differential = not should_show_differential
        elif in_key == ord('m'):
            load_time = None
            current = 0
            stats = None
            stats_history = []
            if transform == stats_as_device_centric:
                transform = stats_as_measurement_centric
            else:
                transform = stats_as_device_centric
        elif in_key == ord('q') or in_key == ord('x') or in_key == 27:
            return
        if stats:
            current += len(stats)
            current %= len(stats)


def main(window, should_show_differential, pool, filename, views):
    window.timeout(1000)

    curses.use_default_colors()
    curses.curs_set(0)
    for i in range(0, curses.COLORS):
        curses.init_pair(i, i, -1)

    transformation_function = stats_as_measurement_centric
    for view in views:
        if view == "d": transformation_function = stats_as_device_centric
        if view == "m": transformation_function = stats_as_measurement_centric

    render_stats(window, transformation_function, should_show_differential, pool, filename)


if __name__ == "__main__":
    import doctest
    if doctest.testmod().failed:
        sys.exit(1)

    try:
        arg_parser = argparse.ArgumentParser("zpool-iostat-viz", description="Display ZFS pool statistics, by device and by measurement")

        arg_parser.add_argument("--diff", "-d", dest="diff", action="store_true", help="show changes while running")
        arg_parser.add_argument("--by", dest="by", choices="dm", action="store", default="s", help="slice data by device or measurement")
        arg_parser.add_argument("--from-file", "-f", dest="file", action="store")
        arg_parser.add_argument("--pal-time", "--pt", action="store", metavar="P", default="3", help="palette for time buckets")
        arg_parser.add_argument("--pal-count", "--pc", action="store", metavar="P", default="0", help="palette for bucket populations")
        arg_parser.add_argument("parts", metavar="pool/vdev", nargs="*", help="Pools or vdevs to display")
        arg_parser.add_argument("--help-colors", action="store_true", help="see color palettes available")
        arg_parser.add_argument("--digits", action="store_true", help="use digits instead of letters")
        arg_parser.add_argument("--symbols", action="store_true", help="use digits instead of letters")

        parsed_args = vars(arg_parser.parse_args())

        help_see_colors = parsed_args["help_colors"]
        try:
            HISTOGRAM_ANSI_COLORS = PALETTES[int(parsed_args["pal_count"], 16)]
            TIME_ANSI_COLORS = PALETTES[int(parsed_args["pal_time"], 16)]
        except (IndexError, ValueError):
            help_see_colors = True

        if help_see_colors:
            print("color palettes (P) for use with --pal-time P or --pal-count P")
            for pi, palette in enumerate(PALETTES):
                rainbow = ["\033[38;5;{0}m {0:03d}\033[m".format(color) for color in palette]
                print("   {0:x}   {1}".format(pi, "".join(rainbow)), end="")
                if hex(pi)[2:] == parsed_args["pal_count"]: print("  (count)", end="")
                if hex(pi)[2:] == parsed_args["pal_time"]: print("  (time)", end="")
                print()
            sys.exit(0)

        if parsed_args["digits"]:
            DISPLAY_CHARS = DISPLAY_CHARS_DIGITS
        if parsed_args["symbols"]:
            DISPLAY_CHARS = DISPLAY_CHARS_SYMBOLS

        curses.wrapper(lambda window: main(window, parsed_args["diff"], parsed_args["parts"], parsed_args["file"], parsed_args["by"] or "m"))
    except subprocess.CalledProcessError as exc:
        print("I couldn't get your pool information. Make sure you have 'zpool' program and specify your pool correctly.")
        print(exc)
    except Exception as exc:
        print("CRASH! Sorry!")
        print("Please report this error at \nhttps://github.com/chadmiller/zpool-iostat-viz/issues/new")
        print()
        print("Paste the following:")
        if raw is not None:
            print(b64encodebytes(raw.encode("UTF-8")).decode("ASCII"), end="and also include ")
        raise exc
