#!/usr/local/bin/python3.8
# -*- coding: utf8 -*-
from __future__ import print_function


"""
Show a ports dependence tree.

Styles:
0 - text only
1 - single line pseudographics
2 - double line pseudographics
3 - single line pseudographics, no descendants indicator (default)
4 - double line pseudographics, no descendants indicator
5 - vt100 pseudographics
-1 - use ACS pseudographics if detected, else text only
"""

from subprocess import Popen, PIPE
import multiprocessing
import configargparse
import networkx as nx
import fnmatch
import os
import curses
try:
    # noinspection PyPep8Naming
    import cPickle as pickle
except ImportError:
    import pickle

# one of the next two lines is needed for graphical output
# import pylab as P
# import matplotlib.pyplot as P

DEPENDS_PREFIXES = ('FETCH', 'EXTRACT', 'PATCH', 'BUILD', 'LIB', 'RUN', 'TEST')  # no 'PKG' yet


def colorize(arg):
    """Add ANSI escape codes to visualize dependence types brighter"""
    color_map = {
        'F': "\033[0m\033[4mF\033[0m",  # underlined F
        'E': "\033[33mE",  # yEllow E
        'P': "\033[35mP",  # magenta P (Purple)
        'B': "\033[34mB",  # Blue B
        'L': "\033[36mL",  # green-bLue (cyan) L
        'R': "\033[31mR",  # Red R
        'T': "\033[32mT",  # green T
        ')': "\033[0m)",  # default color )
    }
    return "".join([color_map[symbol] if symbol in color_map.keys() else symbol for symbol in arg])


def acs_assign(pairs):
    """Convert pairs of chars into a dictionary with first char as key and second char as value."""
    mapping = {}
    for key, val in zip(pairs[::2], pairs[1::2]):
        mapping[key] = val
    return mapping


class Net(object):
    """Show a network as a tree with back references using pseudo-graphics"""

    acs_ok = None
    none = [u']', u'◆', u'◆', u'◆', u'◆', u'\x1b(0`\x1b(B']
    # parent =  [u'-', u'─', u'═', u'─', u'═', u'\x1b(0q\x1b(B']
    child = [u'+', u'┌', u'╔', u'┌', u'╔', u'\x1b(0l\x1b(B']
    both = [u'+', u'┬', u'╦', u'', u'', u'\x1b(0w\x1b(B']
    ref = [u' -> ', u' ─> ', u' ═> ', u' → ', u' ⇒ ', u' \x1b(0q\x1b(B> ']
    now_then = [u'+', u'├', u'╠', u'├', u'╠', u'\x1b(0t\x1b(B']
    then = [u'|', u'│', u'║', u'│', u'║', u'\x1b(0x\x1b(B']
    now = [u'`', u'└', u'╚', u'└', u'╚', u'\x1b(0m\x1b(B']
    never = [u' ', u' ', u' ', u' ', u' ', u' ']

    def init_acs(self):
        """Initialize table of char sequences for pseudo-graphic glyphs in
        alternative character set of termcap/terminfo.
        """
        curses.setupterm()
        en_acs = curses.tigetstr("enacs")
        sm_acs = curses.tigetstr("smacs")
        rm_acs = curses.tigetstr("rmacs")
        acs_c = curses.tigetstr("acsc")
        if not (en_acs and sm_acs and rm_acs):
            self.acs_ok = False
            self.style = 0
            return
        acs_dict = acs_assign(acs_c)
        for ind in "`qlwtx":
            if ind not in acs_dict.keys():
                self.acs_ok = False
                self.style = 0
                return
        if sm_acs is None:
            sm_acs = ''
        if rm_acs is None:
            rm_acs = ''
        self.none.append(sm_acs + acs_dict["`"] + rm_acs)  # '◆'
        # self.parent.append(sm_acs+acs_dict['q']+rm_acs)  # '─'
        self.child.append(sm_acs + acs_dict['l'] + rm_acs)  # '┌'
        self.both.append(sm_acs + acs_dict['w'] + rm_acs)  # '┬'
        self.ref.append(u" " + sm_acs + acs_dict['q'] + rm_acs + u"> ")  # ' ─> '
        self.now_then.append(sm_acs + acs_dict['t'] + rm_acs)  # '├'
        self.then.append(sm_acs + acs_dict['x'] + rm_acs)  # '│'
        self.now.append(sm_acs + acs_dict['m'] + rm_acs)  # '└'
        self.never.append(u' ')
        if en_acs:
            print(en_acs)
        self.acs_ok = True

    def __init__(self, net, options):
        self.net = net
        self.line = 0
        self.style = options.style
        self.number_all = options.number_all
        self.shown = {}
        self.portsdir = options.portsdir
        self.ignore = len(self.portsdir) + 1
        # self.ecc = {}
        self.color = options.color
        if self.acs_ok is None and self.style == -1:
            self.init_acs()

    def name(self, folder):
        """Return names on a node determined by folder"""
        if folder.startswith(self.portsdir):
            return folder[self.ignore:]
        return folder

    def deeper1(self, parent_is_last, has_child):
        """Prepare addition to prefix for deeper tree traversing"""
        if parent_is_last is None and has_child:
            return self.child[self.style]  # '┌'
        if parent_is_last is None and not has_child:
            return self.none[self.style]  # '⊢'
        if parent_is_last and has_child:
            return self.now[self.style] + self.both[self.style]  # '└┬'
        if parent_is_last and not has_child:
            return self.now[self.style]  # '└'
        if not parent_is_last and has_child:
            return self.now_then[self.style] + self.both[self.style]  # '├┬'
        if not parent_is_last and not has_child:
            return self.now_then[self.style]  # '├'

    def deeper2(self, parent_is_last):
        """Prepare addition to prefix for deeper tree traversing"""
        if parent_is_last is None:
            return u''  # ''
        if parent_is_last:
            return self.never[self.style]  # ' '
        if not parent_is_last:
            return self.then[self.style]  # '│'

    def dep_type(self, parent, child):
        """A string showing the type of dependence by the first char of its/their name(s)"""
        if parent == "" or child == "":
            return ""
        deptype = " (%s)" % "".join(sorted(self.net.edges[parent, child]["deptype"]))
        if self.color:
            return colorize(deptype)
        return deptype

    def show_descendants(self, port, prefix=u"", last=None, parent=""):
        """Print a directed net- or tree-like object as a pseudo-graphical tree."""

        # children = sorted(self.net.neighbors(port), key=lambda x: len(list(self.net.neighbors(x))), reverse=True)
        # children = sorted(self.net.neighbors(port), key=lambda x: self.ecc[x], reverse=True)
        children = sorted(self.net.neighbors(port))

        if port in list(self.shown.keys()):
            result = u''
            if self.number_all:
                result = u'%d' % self.line
                self.line += 1
            result += u'\t%s%s%s%s%s%d' % \
                      (prefix, self.deeper1(last, children), self.name(port),
                       self.dep_type(parent, port),
                       self.ref[self.style], self.shown[port])  # ' → '
            return result
        result = u'%d\t%s%s%s%s' % \
                 (self.line, prefix, self.deeper1(last, children), self.name(port),
                  self.dep_type(parent, port))
        self.shown[port] = self.line
        self.line += 1
        new_prefix = prefix + self.deeper2(last)
        if children:
            for child in children[:-1]:
                result = u'\n'.join([result,
                                     self.show_descendants(child, new_prefix, False, port)])
            result = u'\n'.join([result,
                                 self.show_descendants(children[-1], new_prefix, True, port)])
        return result

    def show(self, port=None):
        """Show whole graph"""
        # for component in nx.strongly_connected_component_subgraphs(self.net):
        #     tmp = nx.eccentricity(component)
        #     if tmp:
        #         self.ecc.update(tmp)
        if port is not None:
            return self.show_descendants(port)
        else:
            return u'\n'.join([self.show_descendants(ind)
                               for ind in sorted(self.net.nodes())
                               if ind not in list(self.shown.keys())])


def parse_args():
    """Set a parser up and parse the args."""
    parser = configargparse.ArgumentParser(
                default_config_files=['/usr/local/etc/porttree.conf', '~/.porttree.conf'],
                args_for_setting_config_path=["-c", "--config"])
    parser.add_argument("-C", "--color", action="store_true", help="Use colors")
    parser.add_argument("-D", "--no-depends", dest="nodepends",
                        action="store_true", help="Do not show direct depends")
    parser.add_argument("-i", "--cache", action="store_true",
                        help="Ignore updates, use cache unconditionally")
    parser.add_argument("-n", "--number-all", dest="number_all", action="store_true",
                        help="Number all lines, even with repeated port names")
    parser.add_argument("-O", "--others", action="store_true", help="Show all ports")
    parser.add_argument("-P", "--portsdir", metavar="DIR", default='/usr/ports',
                        help="Search for ports in directory DIR")
    parser.add_argument("-q", "--quiet", action="store_true",
                        help="Do not show scan progress")
    parser.add_argument("--no-quiet", dest="noquiet", action="store_true",
                        help="Show scan progress")
    parser.add_argument("-r", "-R", "--reverse", dest="reverse",
                        action="store_true", help="Also show reversed tree")
    parser.add_argument("-s", "--style", default=3, metavar="NUM", type=int,
                        dest="style", help="set style to NUM (-1...5)")
    parser.add_argument("-S", "--save", dest="save", metavar="FILE",
                        default="/var/tmp/porttree.cache",
                        help="Save the traversed network into FILE")
    parser.add_argument("-U", "--use-saved", dest="use_saved", metavar="FILE",
                        default="/var/tmp/porttree.cache",
                        help="Use the traversed network saved in FILE")
    parser.add_argument(dest="port", nargs="*", help="Port in form category/portname"
                        "/portsdir/category/port, or package glob")
    return check_args(parser)


def check_args(parser):
    """Check and normalize options and args"""
    options = parser.parse_args()
    args = options.port
    if options.nodepends:
        options.reverse = True
    if not (-1 <= options.style <= 5):
        parser.error("Style should be an integer in the range -1...5.")
    if not args or args == []:
        args = [os.getcwd()]
    options.portsdir = os.path.normpath(options.portsdir)
    if not options.portsdir.startswith('/'):
        parser.error("Ports directory should start with a slash '/'.")
    ignore = len(options.portsdir)+1
    new_args = []
    for port in args:
        if port.startswith(options.portsdir):
            new_args.append(port[ignore:])
            continue
        if port.find('/') == -1:
            var, err = Popen(('pkg', 'info', '-qo', port),
                             stdout=PIPE, stderr=PIPE, universal_newlines=True).communicate()
            if not err and var != "":
                new_args.extend([arg for arg in var.split('\n') if arg != ""])
                continue
            else:
                parser.error("Unrecognized category/port or package %s, error: %s." % (port, err))
        new_args.append(port)
    if options.cache:
        options.save = None
    if options.noquiet:
        options.quiet = None
    # parser.print_values()
    # print(options)
    return options, new_args


def un_flavour(port):
    """Remove flavour from dependence"""
    flavour = port.find("@")
    if flavour != -1:
        return port[:flavour]
    return port


def add_one_port(port, port_trees, port_time):
    """Extend each port_tree by a port and its corresponding children"""
    port = un_flavour(port)
    port_trees.add_node(port, time=port_time)
    var, err = Popen(('make', '-C', port,
                      '-V', 'FETCH_DEPENDS',
                      '-V', 'EXTRACT_DEPENDS',
                      '-V', 'PATCH_DEPENDS',
                      '-V', 'BUILD_DEPENDS',
                      '-V', 'LIB_DEPENDS',
                      '-V', 'RUN_DEPENDS',
                      '-V', 'TEST_DEPENDS'),
                     stdout=PIPE, stderr=PIPE, universal_newlines=True).communicate()
    if err:
        print("porttree: Error from 'make -C %s' subprocess: %s." % (port, err))
        return
    var = var.split('\n')
    for deptype, varind in zip(DEPENDS_PREFIXES, var[:7]):
        if varind in (None, ''):
            continue
        for depend in varind.split():
            dep = un_flavour(depend.split(':')[1])
            if dep not in port_trees.nodes():
                port_trees.add_node(dep)
            port_trees.add_edge(port, dep)  # repeating an edge is OK
            if "deptype" not in port_trees.edges[port, dep].keys():
                port_trees.edges[port, dep]["deptype"] = set()
            port_trees.edges[port, dep]["deptype"].add(deptype[0])


def print_results(port_tree, options, args):
    """Output the results according to options set by user"""
    if not options.nodepends:
        net = Net(port_tree, options)
        print("Depending on")
        for port in args:
            print(net.show(port))
        if options.others:
            print(net.show())
    if options.reverse:
        net_r = Net(port_tree.reverse(), options)
        print("\nRequired by:")
        for port in args:
            print(net_r.show(port))
        if options.others:
            print(net_r.show())

#    # get graphical output with the following lines (plus an import above)
#    nx.draw(net.net, pos=nx.spring_layout(net.net))
#    P.draw()
#    P.show()


def get_mtime(filename, timestamp):
    """Compare the named file modification time to given timestamp,
    report errors, return the newer of the times"""
    try:
        res = os.path.getmtime(filename)
        # if timestamp != 0.0 and res > timestamp:
        #     print("get_mtime() file=%s, %s > %s." % (filename, res, timestamp))
        return max(res, timestamp)
    except BaseException as err:
        print("get_mtime() error=%s; file=%s." % (err, filename))
    return timestamp


def latest_mtime(args, prune=None, skip=None):
    """Find the latest modification time for a list of filesystem subtrees"""
    res = 0.0  # Start with 0.0 from epoch
    for arg in args:
        res = get_mtime(arg, res)
        for root, dirs, files in os.walk(arg):
            for cat in dirs:
                if prune is not None and cat in prune:
                    dirs.remove(cat)
                    continue
                filename = os.path.join(root, cat)
                res = get_mtime(filename, res)
            for fil in files:
                if skip is not None and fil in skip:
                    continue
                filename = os.path.join(root, fil)
                res = get_mtime(filename, res)
    return res


def latest_mk():
    """Find the latest modification time for the ports system files"""
    return latest_mtime(("Keywords", "Mk", "Templates", "Tools", "Makefile"))


def latest_cat(category):
    """Find the latest modification time for a category in the ports system"""
    return latest_mtime((os.path.join(category, "Makefile"),))


def copy_from_old(portdir, port_trees, old_trees):
    """Copy a port and directly connected ports from old_trees to port_trees"""
    port_trees.add_node(portdir, time=old_trees.nodes[portdir]["time"])
    port_trees.add_edges_from([(portdir, child, attributes)
                               for child, attributes in old_trees.adj[portdir].items()])


def renew_tree(old_trees, options):
    """Renew ports trees from the newer of old cache or disk"""
    cwd = os.getcwd()
    os.chdir(options.portsdir)
    pool = multiprocessing.Pool(multiprocessing.cpu_count() + 1)
    mk_mtime = latest_mk()
    # Every port category name starts with a small Latin character
    categories = sorted(fnmatch.filter(os.listdir(options.portsdir), "[a-z]*"))
    # Skip non-dirs, 'distfiles' and 'packages', these are not port categories; 'base' is special
    categories = [category for category in categories
                  if category not in ('distfiles', 'packages', 'base') and os.path.isdir(category)]
    res_list = [pool.apply_async(renew_category, args=(category, mk_mtime, old_trees, options))
                for category in categories]
    category_list = [res.get() for res in res_list]
    pool.close()
    pool.join()
    os.chdir(cwd)
    return merge_categories(category_list)


def renew_category(category, mk_mtime, old_trees, options):
    """Scan for updates one directory of the ports tree"""
    os.chdir(options.portsdir)
    category_trees = nx.DiGraph()
    if not options.quiet:
        print(category)
    cat_mtime = latest_cat(category)
    for port in sorted(os.listdir(category)):
        # Skip 'Makefile' and its includes
        if port in ("Makefile", "Makefile.inc"):
            continue
        portdir = os.path.join(category, port)
        if not os.path.isdir(portdir):
            # Report remaining non-ports and go on
            print("main() portdir error: %s is not dir." % portdir)
            continue
        port_mtime = latest_mtime((portdir,), prune=('work',),
                                  skip=('distinfo', 'pkg-descr', 'pkg-plist'))
        port_mtime = max(port_mtime, mk_mtime, cat_mtime)
        if (old_trees is None or portdir not in old_trees.nodes.keys() or
                old_trees.nodes[portdir]['time'] < port_mtime):
            add_one_port(portdir, category_trees, port_mtime)
        else:
            copy_from_old(portdir, category_trees, old_trees)
    return category_trees


def merge_categories(category_list):
    """Merge the list of DiGraphs into single DiGraph"""
    res = nx.DiGraph()
    for category in category_list:
        res.add_nodes_from(category.nodes(data=True))
        res.add_edges_from([edge for edge in category.edges(data=True)])
    return res


def main():
    """Do the job"""
    options, args = parse_args()
    port_trees, old_trees = None, None
    try:
        if options.use_saved:
            saved_file = open(options.use_saved, "rb")
            old_trees = pickle.load(saved_file)
            saved_file.close()
    except (IOError, ValueError) as err:
        print("Cannot read cache, err=%s; will scan all ports." % err)
        options.cache = old_trees = None
    if options.cache:
        port_trees = old_trees
    else:
        port_trees = renew_tree(old_trees, options)
    if options.save:
        saved_file = open(options.save, "wb")
        pickle.dump(port_trees, saved_file, 2)  # save DiGraph, protocol compatible with py2
        saved_file.close()
    print_results(port_trees, options, args)


if __name__ == "__main__":
    main()
