#!/usr/bin/python3

import subprocess

import socket
import os
import tempfile
import shutil
import subprocess
import sys
import signal
import time
import psycopg2


from os.path import join, exists

lib_postgresql = "/usr/lib/postgresql/"

skip_tests = [
    "test_use_json_for_changes_delete",
]


class Postgres:
    def __init__(self, host: str, port: int):
        self.bin_dir = self.find_postgresql_bin()

        if not self.bin_dir:
            print("postgresql bin dir not found")
            sys.exit(1)

        self.host = host
        self.port = port
        self.base_dir = tempfile.mkdtemp()
        os.mkdir(join(self.base_dir, "tmp"))

    def run_initdb(self) -> None:
        print("running initdb")
        initdb = join(self.bin_dir, "bin", "initdb")

        args = [
            initdb,
            "-D",
            join(self.base_dir, "data"),
            "--lc-messages=C",
            "-U",
            "postgres",
            "-A",
            "trust",
        ]
        subprocess.run(args, check=True)

    def run_server(self) -> None:
        print("starting server")
        postgres = join(self.bin_dir, "bin", "postgres")
        assert os.path.exists(postgres)

        args = [
            postgres,
            "-p",
            str(self.port),
            "-D",
            os.path.join(self.base_dir, "data"),
            "-k",
            os.path.join(self.base_dir, "tmp"),
            "-h",
            self.host,
            "-F",
            "-c",
            "logging_collector=off",
        ]
        return subprocess.Popen(args, stderr=subprocess.PIPE)

    def stop_server(self, server):
        print("stopping postgres server")
        if server is None:
            return

        server.send_signal(signal.SIGTERM)
        t0 = time.time()

        while server.poll() is None:
            if time.time() - t0 > 5:
                server.kill()
            time.sleep(0.1)

    def configure(self):
        conn_params = {"user": "postgres", "host": self.host, "port": self.port}
        for _ in range(50):
            try:
                conn = psycopg2.connect(dbname="template1", **conn_params)
            except psycopg2.OperationalError as e:
                allowed = [
                    "the database system is starting up",
                    "could not connect to server: Connection refused",
                    "failed: Connection refused",
                ]
                if not any(msg in e.args[0] for msg in allowed):
                    raise
                time.sleep(0.1)
                continue
            break
        conn.set_session(autocommit=True)
        cur = conn.cursor()
        cur.execute(
            "CREATE ROLE test PASSWORD 'test' SUPERUSER CREATEDB CREATEROLE INHERIT LOGIN"
        )
        cur.execute("CREATE DATABASE test")
        cur.execute('GRANT CREATE ON DATABASE test TO "test"')
        conn.close()

    def find_postgresql_bin(self) -> str | None:
        versions = [int(d) for d in os.listdir(lib_postgresql) if d.isdigit()]

        for v in sorted(versions, reverse=True):
            bin_dir = join(lib_postgresql, str(v))
            if all(exists(join(bin_dir, "bin", f)) for f in ("initdb", "postgres")):
                return bin_dir

        return None


def get_unused_port(host: str) -> int:
    """Find and return an unused port."""
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.bind((host, 0))
    port = sock.getsockname()[1]
    sock.close()

    return port


def clean_up(base_dir):
    print("clean up")
    shutil.rmtree(base_dir, ignore_errors=True)


def main():
    """Start scratch database server and run tests."""

    host = "localhost"
    pg_port = get_unused_port(host)

    pg = Postgres(host, pg_port)
    pg.run_initdb()
    pg_server = pg.run_server()
    pg.configure()

    os.chdir("auditlog_tests")
    args = ["python3", "./manage.py", "test"]
    p = subprocess.run(
        args, check=True, env={"PYTHONPATH": "..", "TEST_DB_PORT": str(pg_port)}
    )
    ret = p.returncode

    pg.stop_server(pg_server)

    pg_server_errors = pg_server.stderr.read()
    if ret:
        print("postgres server error log")
        print()
        print(pg_server_errors.decode("utf-8"))

    clean_up(pg.base_dir)

    return ret


if __name__ == "__main__":
    if os.getuid() == 0:  # Postgres refuses to run as root
        sys.exit(0)

    ret = main()
    sys.exit(ret)
