#! /usr/local/bin/python3.9
import sys; import os.path; __dir__ = os.path.dirname(os.path.realpath(__file__)); sys.path.insert(0, os.path.realpath(os.path.join(__dir__, '../python'))); sys.path.insert(0, os.path.realpath(os.path.join(__dir__, '../'))) # <-- added by BASIS
"""Read segmentation overlap measures of individual pairwise registrations
from CSV files generated by mirtk evaluate-overlap and compute the average
overlap for each class.
"""

import re
import math
import sys
import csv
import numpy as np
import argparse


def abbreviate(measure):
    """Get abbreviation of measure."""
    lstr = measure.lower().replace(' ', '').replace('-', '')
    if lstr in ('dice', 'dicecoefficient', 'dicesimilaritycoefficient',
                'sorensendicecoefficient', 'sorensendiceindex', 'sorensenindex',
                'sorensondicecoefficient', 'sorensondiceindex', 'sorensonindex'):
        return 'DSC'
    if lstr in ('jaccard', 'jaccardindex', 'jaccardsimilaritycoefficient'):
        return 'JSC'
    if lstr in ('sensitivity', 'truepositiverate'):
        return 'TPR'
    if lstr in ('specificity', 'truenegativerate'):
        return 'TNR'
    if lstr in ('positivepredictivevalue', 'precision'):
        return 'PPV'
    if lstr == 'negativepredictivevalue':
        return 'NPV'
    if lstr == 'falsepositiverate':
        return 'FPR'
    if lstr == 'falsediscoveryrate':
        return 'FDR'
    if lstr == 'accuracy':
        return 'ACC'
    if lstr in ('informedness', 'bookmakerinformedness'):
        return 'BM'
    if lstr in ('matthewscorrelation', 'matthewscorrelationcoefficient'):
        return 'MCC'
    if lstr in ('f1score', 'fscore', 'fmeasure'):
        return 'F1'
    return measure


def evaluate_overlap(tp, fp, fn, tn, measure='DSC'):
    """Evaluate overlap measure given entries of confusion matrix."""
    abbr = abbreviate(measure).lower()
    if abbr == 'tpr':
        return float(tp) / float(tp + fn)
    if abbr == 'tnr':
        return float(tn) / float(fp + tn)
    if abbr == 'ppv':
        return float(tp) / float(tp + fp)
    if abbr == 'npv':
        return float(tn) / float(tn + fn)
    if abbr == 'fpr':
        return float(fp) / float(fp + tn)
    if abbr == 'fdr':
        return float(fp) / float(fp + tp)
    if abbr == 'fnr':
        return float(fn) / float(tp + fn)
    if abbr == 'acc':
        return float(tp + tn) / float(tp + fn + fp + tn)
    if abbr == 'bm':
        return float(tp) / float(tp + fn) + float(tn) / float(fp + tn) - 1.
    if abbr == 'mcc':
        denom = math.sqrt(float(tp + fp) * float(tp + fn) * float(tn + fp) * float(tn + fn))
        return (float(tp) * float(tn) - float(fp) * float(fn)) / denom
    if abbr == 'markedness':
        return float(tp) / float(tp + fp) + float(tn) / float(tn + fn) - 1.
    if abbr in ('f1', 'dsc'):
        return float(2 * tp) / float((fp + tp) + (tp + fn))
    if abbr == 'jsc':
        return float(tp) / float(fp + tp + fn)
    else:
        raise Exception("Unknown overlap measure: {} (abbr: {})".format(measure, abbr))


class ParseSegmentOptionArguments(argparse.Action):
    """Custom argparse action for --segment option."""

    def __init__(self, option_strings, dest, nargs=None, **kwargs):
        if nargs is not '+':
            raise ValueError("nargs must be +")
        self.re_range = re.compile('^(\d+)\.\.(\d+)$')
        super(ParseSegmentOptionArguments, self).__init__(option_strings, dest, nargs, **kwargs)

    def __call__(self, parser, namespace, values, option_string=None):
        if len(values) <= 1:
            raise argparse.ArgumentError(self, 'At least two arguments required')
        labels = []
        self.re_range = re.compile('^(\d+)\.\.(\d+)$')
        for value in values[1:]:
            try:
                m_range = self.re_range.match(value)
                if m_range:
                    a = int(m_range.group(1))
                    b = int(m_range.group(2))
                    labels.extend(range(a, b + 1))
                else:
                    labels.append(int(value))
            except ValueError:
                argparse.ArgumentTypeError("Label argument must be integer or range specification, i.e., 1..5")
        segments = getattr(namespace, self.dest)
        if not segments:
            segments = []
        segments.append((values[0], set(labels)))
        setattr(namespace, self.dest, segments)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument('tables', nargs='+',
                        help="List of CSV files written by evaluate-overlap -table option.")
    parser.add_argument('--micro', '-micro', action='store_true',
                        help="Compute micro-average of each measure instead of a macro-average.")
    parser.add_argument('--measure', '-measure', '--metric', '-metric', dest='measure', nargs='+', default=[],
                        help="Overlap measure(s) to include in output, all by default.")
    parser.add_argument('--segment', '-segment', dest='segments', nargs='+', default=[], metavar='NAME LABEL...', action=ParseSegmentOptionArguments,
                        help="Average the overlap measures for the specified labels, LABEL must be integer or range '5..10'.")
    parser.add_argument('--noheader', dest="header", action="store_false",
                        help="Input tables have no header row, cannot use --measure and --micro options then.")
    parser.add_argument('--noid', dest="idcol", action="store_false",
                        help="Input tables have no label/subject ID column at index 0.")
    parser.add_argument('--output', '-output',
                        help="Name of output text file.")
    args = parser.parse_args()
    if not args.header and (args.measures or args.micro):
        raise Exception("Option --measure/--micro cannot be used when tables have --noheader")
    if not args.idcol and args.segments:
        raise Exception("Option --segment cannot be used when tables have --noid column with label IDs at index 0")
    header = []
    labels = []
    ncols = -1
    nrows = 0
    if args.idcol:
        with open(args.tables[0], 'rb') as f:
            reader = csv.reader(f)
            if args.header:
                header = reader.next()
                ncols = len(header)
            for row in reader:
                nrows += 1
                if ncols == -1:
                    ncols = len(row)
                elif len(row) != ncols:
                    raise ValueError("Rows of CSV tables must have equal number of columns")
                labels.append(int(row[0]))
    else:
        with open(args.tables[0], 'rb') as f:
            reader = csv.reader(f)
            if args.header:
                header = reader.next()
                ncols = len(header)
            for row in reader:
                nrows += 1
                if ncols == -1:
                    ncols = len(row)
                elif len(row) != ncols:
                    raise ValueError("Rows of CSV tables must have equal number of columns")
    if ncols <= 0 or nrows <= 0:
        raise ValueError("CSV tables must have at least one row/column")
    first_col = 1 if args.idcol else 0
    if header:
        header = [abbreviate(name) for name in header]
    else:
        header = [None] * ncols
    measures = [abbreviate(measure) for measure in args.measure]
    tp_col = -1
    fp_col = -1
    fn_col = -1
    tn_col = -1
    if header:
        for c in range(first_col, ncols):
            lstr = header[c].lower()
            if lstr == 'tp':
                tp_col = c
            elif lstr == 'fp':
                fp_col = c
            elif lstr == 'fn':
                fn_col = c
            elif lstr == 'tn':
                tn_col = c
    if args.micro:
        if tp_col == -1 or fp_col == -1 or fn_col == -1 or tn_col == -1:
            raise Exception("Missing one or more of TP,FP,FN,TN columns needed for micro-averaging")
        usecols = (tp_col, fp_col, fn_col, tn_col)
    else:
        usecols = range(first_col, ncols)
    if args.micro and measures:
        if args.idcol:
            header = [header[0]]
            header.extend(measures)
        else:
            header = measures
        cols = range(first_col, ncols)
    else:
        cols = []
        if args.micro or measures:
            for measure in measures:
                if measure not in header:
                    raise Exception("Requested measure {} not found in input tables, consider --micro average if TP,FP,FN,TN columns available".format(measure))
            for col in range(first_col, ncols):
                if measures and header[col] not in measures:
                    continue
                if args.micro and col in usecols:
                    continue
                cols.append(col)
        else:
            cols = usecols
    sums = np.zeros((nrows, len(usecols)), dtype=np.float)
    for csv_name in args.tables:
        tmp = np.genfromtxt(csv_name, delimiter=',', skip_header=1 if args.header else 0, usecols=usecols, dtype=np.float)
        if len(usecols) > 1:
            sums += tmp
        else:
            sums[:,0] += tmp
    num = len(args.tables)
    out = sys.stdout
    if args.output:
        out = open(args.output, 'w')
    try:
        if header:
            if args.segments:
                out.write('Segment')
                out.write(',')
            elif args.idcol:
                out.write(header[0])
                out.write(',')
            out.write(','.join([header[col] for col in cols]))
            out.write('\n')
        if args.segments:
            for segment in args.segments:
                out.write(segment[0])
                for col in cols:
                    avg = 0
                    for label in segment[1]:
                        row = labels.index(label)
                        if args.micro:
                            avg += evaluate_overlap(sums[row, tp_col - first_col],
                                                    sums[row, fp_col - first_col],
                                                    sums[row, fn_col - first_col],
                                                    sums[row, tn_col - first_col],
                                                    measure=header[col])
                        else:
                            avg += sums[row, col - first_col] / num
                    avg /= len(segment[1])
                    out.write(',')
                    out.write('{:.5f}'.format(avg))
                out.write('\n')
        else:
            for row in range(nrows):
                if args.idcol:
                    out.write(str(labels[row]))
                    out.write(',')
                for c in range(len(cols)):
                    if c > 0:
                        out.write(',')
                    col = cols[c]
                    if args.micro:
                        avg = evaluate_overlap(sums[row, tp_col - first_col],
                                               sums[row, fp_col - first_col],
                                               sums[row, fn_col - first_col],
                                               sums[row, tn_col - first_col],
                                               measure=header[col])
                    else:
                        avg = sums[row, col - first_col] / num
                    out.write('{:.5f}'.format(avg))
                out.write('\n')
    finally:
        if out != sys.stdout:
            out.close()
