#!/usr/bin/env python3

import os, sys, datetime

from hashlib import md5
import xml.etree.ElementTree as ET

from typing import Tuple


class SystemhealthXMLLoader:
    """
    Contains logic to load a systemhealth XML file descriptor
    """

    SUPPORTED_THREAD_AGG_FUNCS = {
        "sum": "SYSHEALTH_KPI_AGGREGATION_SUM",
        "avg": "SYSHEALTH_KPI_AGGREGATION_AVG",
        "median": "SYSHEALTH_KPI_AGGREGATION_MEDIAN",
        "min": "SYSHEALTH_KPI_AGGREGATION_MIN",
        "max": "SYSHEALTH_KPI_AGGREGATION_MAX",
    }

    def __init__(self, verbose: bool):
        """constructor"""
        self.verbose = verbose

    def __get_all_import_files(self, xml_filepath, xml_filename) -> list:
        try:
            import_files_list = []
            tree = ET.parse(os.path.join(xml_filepath, xml_filename))
            root = tree.getroot()
            for field in root.iter("kpi"):
                # Find all child nodes with tag <field>, grab the attributes of such node
                if "import" in field.attrib and "hidden" not in field.attrib:
                    # parse the file in the import
                    xml_file = field.attrib["import"]
                    import_files_list.append(os.path.join(xml_filepath, xml_file))
                    import_files_list += self.__get_all_import_files(
                        xml_filepath, xml_file
                    )
            return import_files_list
        except Exception as e:
            print(
                f"Exception '{str(e)}', while parsing XML file '{xml_filename}' stored at '{xml_filepath}'"
            )
            # FIXME: Implement a better way to report error and stop execution
            sys.exit(2)

    def expand_xml_list(self, system_health_xmls: list) -> list:
        """
        Enrich the provided 'system_health_xmls' list with additional XML files mentioned in the 'import' attribute
        of XML nodes contained in the original list.
        This function is able to handle also recursive imports
        """
        import_files_list = system_health_xmls
        for system_health_xml in system_health_xmls:
            xml_filepath, xml_filename = os.path.split(system_health_xml)
            _import_files = self.__get_all_import_files(xml_filepath, xml_filename)
            import_files_list += [
                import_file
                for import_file in _import_files
                if import_file not in import_files_list
            ]
        return import_files_list

    def parse_systemhealth_xml(self, xml_filepath: str) -> list:
        """
        Parse a systemhealth XML file into a list of KPIs.
        Each KPI is a dictionary with keys being XML attributes.
        Example return value: [
            {'name': "some_nice_kpi", 'type': 'uint64'},
            ...
        ]

        If a KPI has an XML "import" attribute, the appropriate XML file is opened and the KPIs
        in the imported file will be added as well.
        If a KPI has an XML "hidden" attribute, it is ignored.

        Args:
            xml_filepath: syshealth master XML to parse

        Returns:
            A list of dictionaries, each dictionary represents a KPI.
        """
        try:
            tree = ET.parse(xml_filepath)
            root = tree.getroot()
            all_fields = []
            for fields in root.iter("kpis"):
                for field in fields.iter("kpi"):
                    # Find all child nodes with tag <field>, grab the attributes of such node
                    if "import" not in field.attrib:
                        _attributes = field.attrib

                        # default values for optional attributes:
                        if "requires_labels" not in _attributes:
                            # requires_labels is not mandatory:
                            _attributes["requires_labels"] = "false"
                        if "service_impact" not in _attributes:
                            # service_impact is not mandatory:
                            _attributes["service_impact"] = "false"
                        _thread_agg_func = (
                            _attributes["thread_aggregation_function"]
                            if "thread_aggregation_function" in _attributes
                            # thread_aggregation_function is not mandatory:
                            else "sum"
                        )

                        # validate XML attributes -- check values of optional attributes
                        if _attributes["requires_labels"] not in ["false", "true"]:
                            raise RuntimeError(
                                f"invalid XML attribute requires_labels [{_attributes['requires_labels']}]"
                            )
                        if _attributes["service_impact"] not in ["false", "true"]:
                            raise RuntimeError(
                                f"invalid XML attribute service_impact [{_attributes['service_impact']}]"
                            )
                        if _thread_agg_func not in self.SUPPORTED_THREAD_AGG_FUNCS:
                            print(
                                f"ERROR : Invalid value '{_thread_agg_func}' found in attribute 'thread_aggregation_function'"
                            )
                            # FIXME: Implement a better way to report error and stop execution
                            sys.exit(2)
                        else:
                            _attributes[
                                "thread_aggregation_function"
                            ] = self.SUPPORTED_THREAD_AGG_FUNCS[_thread_agg_func]

                        # validate XML attributes -- check all required ones are present
                        for required_attr in [
                            "type",
                            "name",
                            "kpi_type",
                            "description",
                        ]:
                            if required_attr not in _attributes:
                                raise RuntimeError(
                                    f"missing required XML attribute '{required_attr}'"
                                )
                            if len(_attributes[required_attr]) == 0:
                                raise RuntimeError(
                                    f"empty required XML attribute '{required_attr}'"
                                )

                        # validate XML attributes -- check values of required attributes
                        if _attributes["type"] not in ["uint64", "float"]:
                            raise RuntimeError(
                                f"invalid XML attribute type [{_attributes['type']}]"
                            )
                        if _attributes["kpi_type"] not in ["gauge", "counter"]:
                            raise RuntimeError(
                                f"invalid XML attribute kpi_type [{_attributes['kpi_type']}]"
                            )

                        # store also the name of the .XML file from which this KPI comes from:
                        _attributes["source_xml"] = xml_filepath

                        all_fields.append(_attributes)

            if self.verbose:
                print(f"Successfully loaded systemhealth XML file '{xml_filepath}'")

            return all_fields
        except Exception as e:
            print(
                f"Exception '{str(e)}', while parsing XML file '{xml_filepath}'"
            )
            # FIXME: Implement a better way to report error and stop execution
            sys.exit(2)

    def get_all_kpis_recursively(self, system_health_xmls: list) -> dict:
        """
        Returns a dictionary of KPIs having KEY=source_xml, VALUE= list of dictionaries
        The returned dictionary is obtained merging together the
        KPI definitions from all the XML files provided as argument:

            full_kpi_list = get_all_kpis_recursively(...)
            full_kpi_list["source1.xml"] = [ dict(name="foobar", type="uint64", ...) ]

        See parse_systemhealth_xml() docs for more info about each nested dictionary.
        """

        full_xml_list = self.expand_xml_list(system_health_xmls)
        retdict = {}
        for fn in full_xml_list:
            retdict[fn] = self.parse_systemhealth_xml(fn)

        # now sort the produced dictionary by KEY;
        # see https://stackoverflow.com/questions/9001509/how-do-i-sort-a-dictionary-by-key/47017849#47017849
        return dict(sorted(retdict.items()))
