#!/usr/bin/python3

#
# Copyright oVirt Authors
# SPDX-License-Identifier: Apache-2.0
#

from __future__ import print_function

import argparse
import getpass
import os
import sys

import psycopg2
import psycopg2.extras


class PwdArgsParser(argparse.ArgumentParser):

    MISSING_PWD_MSG = 'PASSWORD variable not found in environment'
    USAGE_PWD_MSG = ('usage: Database password must be set via an '
                     'environment variable ("PASSWORD")')

    def print_usage(self, *args, **kwargs):
        print(self.USAGE_PWD_MSG)
        super(PwdArgsParser, self).print_usage(*args, **kwargs)

    def format_help(self, *args, **kwargs):
        res = super(PwdArgsParser, self).format_help(*args, **kwargs)
        return '{}\n{}'.format(self.USAGE_PWD_MSG, res)

    def error(self, message):
        if 'PASSWORD' not in os.environ:
            message = '{},\n{}'.format(message, self.MISSING_PWD_MSG)
        super(PwdArgsParser, self).error(message)

    def parse_args(self):
        args = super(PwdArgsParser, self).parse_args()
        try:
            setattr(args, 'password', os.environ['PASSWORD'])
        except KeyError:
            setattr(args, 'password', getpass.getpass(self.MISSING_PWD_MSG +
                                                      ', please enter now: '))
        return args


def _create_network_name_map(cursor):
    dc_name_map = _create_dc_name_map(cursor)

    cursor.execute("SELECT vdsm_name, name, storage_pool_id FROM network;")
    return {(dc_name_map[r['storage_pool_id']], r['vdsm_name']): r['name']
            for r in cursor.fetchall()}


def _create_dc_name_map(cursor):
    cursor.execute("SELECT id, name FROM storage_pool;")
    return {row['id']: row['name'] for row in cursor.fetchall()}


def _connect(args):
    conn = psycopg2.connect(
        host=args.host,
        port=args.port,
        user=args.user,
        password=args.password,
        database=args.database,
        sslmode=_determine_ssl_mode(args.secure, args.host_validation),
        cursor_factory=psycopg2.extras.RealDictCursor)
    return conn


def _determine_ssl_mode(secure, host_validation):
    if not secure:
        return 'allow'
    if host_validation:
        return 'verify-full'
    return 'require'


def _parse_args():
    parser = PwdArgsParser(
        description='Outputs a mapping of all logical networks, '
                    'each paired with its matching vdsm (interface) name '
                    '(vdsm name | network name | data-center)')
    parser.add_argument('--host', default='localhost')
    parser.add_argument('--port', default='5432')
    parser.add_argument('--database', default='engine')
    parser.add_argument('--secure', action='store_true')
    parser.add_argument('--host-validation', action='store_true')
    parser.add_argument('--user', required=True)
    return parser.parse_args()


def main():
    args = _parse_args()
    conn = _connect(args)
    network_name_map = _create_network_name_map(conn.cursor())
    conn.close()

    for (dc_name, vdsm_name), network_name in network_name_map.items():
        print(vdsm_name, '\t', network_name, '\t', dc_name)


if __name__ == '__main__':
    sys.exit(main() or 0)
