#!/usr/bin/python3
#
# Copyright © 2022 Nicolas Dandrimont (nicolas@dandrimont.eu)
#
# This program is free software; you can redistribute it and/or modify it
# under the terms of the GNU General Public License as published by the
# Free Software Foundation; either version 2 of the License, or (at your
# option) any later version.
#
# This program is distributed in the hope that it will be useful, but
# WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General
# Public License for more details.
#
# You should have received a copy of the GNU General Public License along with
# this program. If not, see <https://www.gnu.org/licenses/>

from __future__ import annotations

import os
import sys

import requests
import yaml

"""Check that salsaci-overrides.yml overrides all of the common salsaci pipelines"""

KNOWN_TOPLEVEL_KEYS = {"include", "variables", "defaults"}


def merge(source, destination):
    """
    run me with nosetests --with-doctest file.py

    >>> a = { 'first' : { 'all_rows' : { 'pass' : 'dog', 'number' : '1' } } }
    >>> b = { 'first' : { 'all_rows' : { 'fail' : 'cat', 'number' : '5' } }, 'list': [1, 2] }
    >>> merge(b, a) == { 'first' : { 'all_rows' : { 'pass' : 'dog', 'fail' : 'cat', 'number' : '5' } }, 'list': [1, 2] }
    True
    """
    for key, value in source.items():
        if isinstance(value, dict):
            # get node or create one
            node = destination.setdefault(key, {})
            merge(value, node)
        elif isinstance(value, list):
            node = destination.setdefault(key, [])
            destination[key] = destination[key] + value
        else:
            destination[key] = value

    return destination


def get_gitlabci_config(git_root: str) -> Dict[str, Any]:
    """Retrieve the full gitlab-ci config out of the git root"""

    gitlabci_path = os.path.join(git_root, ".gitlab-ci.yml")
    return yaml.safe_load(open(gitlabci_path, "r"))


def fetch_config(url: str) -> Dict[str, Any]:
    """Fetch a gitlab-ci config, resolving any includes"""
    req = requests.get(url)
    req.raise_for_status()

    config = yaml.safe_load(req.content)

    merged_config = {}
    for include in config.pop("include", []):
        included_config = fetch_config(include)
        merge(included_config, merged_config)

    merge(config, merged_config)
    return merged_config


def get_salsaci_config(gitlabci_config: Dict[str, Any]) -> Dict[str, Any]:
    """Retrieve the config of the salsaci jobs, out of .gitlab-ci.yml includes"""
    includes = gitlabci_config.get("include", [])
    if not includes:
        raise ValueError("No includes found in .gitlab-ci.yml")

    for include in includes:
        if include.endswith("pipeline-jobs.yml"):
            break
    else:
        raise ValueError("No include for pipeline-jobs.yml found in .gitlab-ci.yml")

    return fetch_config(include)


def get_salsaci_overrides(
    git_root: str, gitlabci_config: Dict[str, Any]
) -> Dict[str, Any]:
    """Retrieve the salsaci overrides, out of .gitlab-ci.yml includes"""
    includes = gitlabci_config.get("include", [])
    if not includes:
        raise ValueError("No includes found in .gitlab-ci.yml")

    for include in includes:
        if include.endswith("salsaci-overrides.yml"):
            break
    else:
        raise ValueError("No include for salsaci-overrides.yml found in .gitlab-ci.yml")

    return yaml.safe_load(open(os.path.join(git_root, include), "r"))


if __name__ == "__main__":
    git_root = os.path.abspath(os.path.join(os.path.dirname(sys.argv[0]), ".."))
    gitlabci_config = get_gitlabci_config(git_root)
    salsaci_config = get_salsaci_config(gitlabci_config)
    salsaci_overrides = get_salsaci_overrides(git_root, gitlabci_config)

    keys_missing = {
        key
        for key in salsaci_config.keys()
        - salsaci_overrides.keys()
        - KNOWN_TOPLEVEL_KEYS
        if not key.startswith(".")
    }

    for key in keys_missing:
        print("missing", key, salsaci_config[key])

    keys_extra = {
        key
        for key in salsaci_overrides.keys()
        - salsaci_config.keys()
        - KNOWN_TOPLEVEL_KEYS
        if salsaci_overrides[key]["stage"].startswith("salsaci/")
    }

    for key in keys_extra:
        print("extra", key, salsaci_overrides[key])

    if keys_missing or keys_extra:
        sys.exit(1)
