#!/usr/local/bin/python3.11

"""Munin plugin to monitor nftables named counters.

=head1 NAME

nft_counter - monitor nftables named counters

=head1 APPLICABLE SYSTEMS

Linux systems with nftables running and counters defined.

=head1 CONFIGURATION

Include/exclude regexp filters can be configured for each graph. Default
is to include everything.

Example:

  [nft_counter]
  user root

  # Graph only "inet_foomuuri_ssh" and "inet_foomuuri_http" counters
  env.bytes_include_re inet_foomuuri_(ssh|http)
  #env.bytes_exclude_re nothing

  # Graph everything but "inet_foomuuri_ssh" counter
  #env.packets_include_re .
  env.packets_exclude_re inet_foomuuri_ssh

  # Automatically find "_in" and "_out" counters and graph them in
  # traffic style, similar to "if_" plugin.
  #env.traffic_include_re .
  #env.traffic_exclude_re nothing

=head1 AUTHOR

Kim B. Heino <b@bbbs.net>

=head1 LICENSE

GPLv2

=head1 MAGIC MARKERS

 #%# family=auto
 #%# capabilities=autoconf

=cut
"""

import os
import re
import subprocess
import json
import sys
import unicodedata


def safename(name):
    """Return safe variable name."""
    # Convert ä->a as isalpha('ä') is true
    value = unicodedata.normalize('NFKD', name)
    value = value.encode('ASCII', 'ignore').decode('utf-8')

    # Remove non-alphanumeric chars
    return ''.join(char.lower() if char.isalnum() else '_' for char in value)


def collect_data():
    """Run nft and parse its output."""
    # List counters in nft json
    try:
        text = subprocess.run(['nft', '--json', 'list', 'counters'],
                              stdout=subprocess.PIPE, check=False,
                              encoding='utf-8', errors='ignore').stdout
    except FileNotFoundError:
        return {}
    try:
        data = json.loads(text)
    except ValueError:
        return {}

    # Parse counters from nft json
    counters = {}
    for nft_type in data.get('nftables', []):
        if 'counter' not in nft_type:
            continue
        counter = nft_type['counter']
        name = f'{counter["family"]}_{counter["table"]}_{counter["name"]}'
        counters[safename(name)] = counter
    return counters


def filter_data(data, prefix):
    """Filter collected data by config regexps."""
    include = os.getenv(f'{prefix}_include_re', '')
    exclude = os.getenv(f'{prefix}_exclude_re', '$^')
    ret = {}
    for counter, values in data.items():
        if re.search(include, counter) and not re.search(exclude, counter):
            ret[counter] = values
    return ret


def group_traffic(data):
    """Find "_in" and "_out" pairs from counter data."""
    counters = filter_data(data, 'traffic')
    pairs = []
    for name in counters:
        if name.endswith('_in') and f'{name[:-3]}_out' in counters:
            pairs.append(name[:-3])
    return pairs


def config():
    """Print plugin config."""
    data = collect_data()

    counters = filter_data(data, 'bytes')
    if counters:
        print('multigraph nft_counter_bytes')
        print('graph_title nftables counter bytes')
        print('graph_category network')
        print('graph_vlabel bits per ${graph_period}')
        print('graph_args --base 1024')
        for counter, value in counters.items():
            if value['family'] == 'inet':
                print(f'{counter}.label {value["name"]}')
            else:
                print(f'{counter}.label {value["family"]} {value["name"]}')
            print(f'{counter}.type DERIVE')
            print(f'{counter}.cdef {counter},8,*')
            print(f'{counter}.min 0')

    counters = filter_data(data, 'packets')
    if counters:
        print('multigraph nft_counter_packets')
        print('graph_title nftables counter packets')
        print('graph_category network')
        print('graph_vlabel packets per ${graph_period}')
        print('graph_args --base 1000')
        for counter, value in counters.items():
            if value['family'] == 'inet':
                print(f'{counter}.label {value["name"]}')
            else:
                print(f'{counter}.label {value["family"]} {value["name"]}')
            print(f'{counter}.type DERIVE')
            print(f'{counter}.min 0')

    pairs = group_traffic(data)
    if pairs:
        print('multigraph nft_counter_traffic')
        print('graph_title nftables counter traffic')
        print('graph_category network')
        print('graph_vlabel bits in (-) / out (+) per ${graph_period}')
        for pair in pairs:
            print(f'{pair}_in.label received')
            print(f'{pair}_in.type DERIVE')
            print(f'{pair}_in.graph no')
            print(f'{pair}_in.cdef {pair}_in,8,*')
            print(f'{pair}_in.min 0')
            print(f'{pair}_out.label {data[f"{pair}_in"]["name"][:-3][:15]}')
            print(f'{pair}_out.type DERIVE')
            print(f'{pair}_out.negative {pair}_in')
            print(f'{pair}_out.cdef {pair}_out,8,*')
            print(f'{pair}_out.min 0')

    if os.environ.get('MUNIN_CAP_DIRTYCONFIG') == '1':
        fetch(data)


def fetch(data):
    """Print values."""
    counters = filter_data(data, 'bytes')
    if counters:
        print('multigraph nft_counter_bytes')
        for counter, value in counters.items():
            print(f'{counter}.value {value["bytes"]}')

    counters = filter_data(data, 'packets')
    if counters:
        print('multigraph nft_counter_packets')
        for counter, value in counters.items():
            print(f'{counter}.value {value["packets"]}')

    pairs = group_traffic(data)
    if pairs:
        print('multigraph nft_counter_traffic')
        for pair in pairs:
            print(f'{pair}_in.value {data[f"{pair}_in"]["bytes"]}')
            print(f'{pair}_out.value {data[f"{pair}_out"]["bytes"]}')


if __name__ == '__main__':
    if len(sys.argv) > 1 and sys.argv[1] == 'autoconf':
        print('yes' if collect_data() else 'no (no nft counters found)')
    elif len(sys.argv) > 1 and sys.argv[1] == 'config':
        config()
    else:
        fetch(collect_data())
