#pragma once

#include "testlib.h"
#include <filesystem>
#include <optional>
#include <print>
#include <ranges>

#define myAssert(test, message)                                                                                                                                \
    if (!(test)) {                                                                                                                                             \
        std::println(stderr, "Assertion ({}) <=> {} failed at line {}", #test, message, __LINE__);                                                             \
        exit(0);                                                                                                                                               \
    }

struct timer {
    std::string name;

    timer(const std::string name) : name(name), beg(clock_::now()) { std::println(stderr, "{} started", name); }
    ~timer() { std::println(stderr, "{} took {}", name, elapsed()); }

    double elapsed() const { return std::chrono::duration_cast<std::chrono::milliseconds>(clock_::now() - beg).count(); }

private:
    typedef std::chrono::high_resolution_clock clock_;
    typedef std::chrono::duration<double, std::ratio<1>> second_;
    std::chrono::time_point<clock_> beg;
};

#define TIMEIT(...) timer __INTERNAL_TIMER_NAME(std::format(__VA_ARGS__));

struct BaseConstraints {
    void validate(std::istream& in) const;
};

struct Solution {
    std::string name;
    const std::filesystem::path path;

    Solution(const std::string name) : name(name), path("bin/" + name) {}
};

template <class Constraints>
struct Subtask {
    static int64_t globalSubtaskIndex;
    int64_t subtaskIndex;

    mutable int64_t start = 0;
    mutable int64_t end = 0;

    std::string name;
    int64_t points;

    std::vector<std::function<bool(int64_t, std::ostream&)>> testGens;
    std::vector<int64_t> dependencies;
    std::vector<int64_t> dependees;
    std::optional<Constraints> constraints;

    Subtask(const std::string name, const int64_t points) : subtaskIndex(globalSubtaskIndex++), name(name), points(points), testGens() {}

    void validate(std::istream& in) const {
        if (constraints) {
            constraints.value().validate(in);
        } else {
            std::println(stderr, "-- Skipping validation, as no constraints were provided");
        }
    }

    std::vector<int64_t> allDependees(const std::vector<Subtask>& subtasks) const {
        std::vector<int64_t> ret = {subtaskIndex};
        for (const auto& dId : dependees) {
            const auto dRet = subtasks[dId].allDependees(subtasks);
            for (const auto it : dRet) {
                ret.push_back(it);
            }
        }
        std::sort(ret.begin(), ret.end());
        ret.resize(std::unique(ret.begin(), ret.end()) - ret.begin());
        return ret;
    }

    template <class F, class... Ranges>
    Subtask& withTestGens(const std::string name, const F f, const Ranges... ranges) {
        for (const auto args : std::views::cartesian_product(ranges...)) {
            // Remove refs
            const auto argscpy = std::apply([](const auto&... elems) { return std::tuple<std::decay_t<decltype(elems)>...>(elems...); }, args);
            testGens.emplace_back([name, argscpy, constraints = this->constraints, f](const int64_t seed, std::ostream& out) {
                std::println(stderr, "We have added test gen ({})", name);
                rnd.setSeed(seed);
                return std::apply([&](const auto... _args) { return f(out, constraints, _args...); }, argscpy);
            });
        }
        return *this;
    }

    template <class F, class... Ranges>
    Subtask& withTestGens(const int64_t count, const std::string name, const F f, const Ranges... ranges) {
        for (int64_t i = 0; i < count; i++) {
            withTestGens(name, f, ranges...);
        }
        return *this;
    }

    Subtask& dependsOn(Subtask& other) {
        dependencies.push_back(other.subtaskIndex);
        other.dependees.push_back(subtaskIndex);
        return *this;
    }

    Subtask& withConstraints(const Constraints constraints_) {
        myAssert(!constraints, "cannot overwrite constraints");
        constraints = constraints_;
        return *this;
    }
};

#define nameFun(FUN_BY_NAME) std::string(#FUN_BY_NAME), FUN_BY_NAME

std::pair<std::string, std::string> formatFileNames(const std::string taskName, const int64_t test, const int64_t digits) {
    // TODO: fix
    std::string testIndex = vtos(test);
    while (testIndex.size() < digits) {
        testIndex = "0" + testIndex;
    }
    myAssert(testIndex.size() <= digits, "index needs to fit into digits");

    const std::string testFileName = std::format("tests/{}.{}", taskName, testIndex);
    return {testFileName + ".in", testFileName + ".out"};
}

std::string gradeProperties() {
    const std::string fileName = "grade.properties";
    std::println(stderr, "Generating grade.properties");
    return fileName;
}

template <class Constraints>
bool fileTestGen(std::ostream& out, const std::optional<Constraints>& _, const std::filesystem::path path) {
    std::ifstream fIn(path);
    const std::string fileContents{std::istreambuf_iterator<char>(fIn), std::istreambuf_iterator<char>()};
    out << fileContents;
    std::flush(out);
    return true;
}

template <class Constraints>
int64_t Subtask<Constraints>::globalSubtaskIndex = 0;

template <class Constraints>
void generateSubtasks(const std::string taskName, const std::vector<Subtask<Constraints>>& subtasks, const Solution& author) {
    // TODO: validation
    int64_t testInd = 1;
    int64_t totalPoints = 0;

    int64_t currentGen = 0;

    for (int64_t i = 0; i < subtasks.size(); i++) {
        myAssert(i == subtasks[i].subtaskIndex, "subtask indexing invariants violated");
        totalPoints += subtasks[i].points;
        for (const auto dId : subtasks[i].dependencies) {
            myAssert(dId < i, "dependencies should be left-to-right");
        }
    }
    myAssert(totalPoints == 100, "point over all subtasks should sum to 100");

    for (const auto& st : subtasks) {
        st.start = testInd;
        std::println(stderr, "Started generating subtask {}", st.name);
        for (const auto& tg : st.testGens) {
            auto files = formatFileNames(taskName, testInd++, 3);
            {
                TIMEIT("---- Generation {}", files.first);
                std::ofstream input(files.first);
                while (!tg(currentGen++, input)) {
                }
            }
            {
                TIMEIT("---- Validation {}", files.first);
                for (const auto sId : st.allDependees(subtasks)) {
                    std::ifstream input(files.first);
                    subtasks[sId].validate(input);
                }
            }
            {
                TIMEIT("---- Solving");
                const std::string command = std::format("./{} < {} > {}", author.path.string(), files.first, files.second);
                const auto _ = std::system(command.c_str());
            }
        }
        myAssert(testInd != st.start, "empty subtask");
        st.end = testInd - 1;
        totalPoints += st.points;
    }
}

template <class Constraints>
void stressTest(const std::vector<Subtask<Constraints>>& subtasks, const Solution& sol1, const Solution& sol2) {
    std::println(stderr, "------------------- Begin stress testing -------------------");

    int64_t currentGen = 0;

    while (true) {
        for (const auto& st : subtasks) {
            std::println(stderr, "Started generating subtask {}", st.name);
            for (const auto& tg : st.testGens) {
                const std::string fileIn = "stress.in";
                const std::string fileOut1 = "stress_1.out";
                const std::string fileOut2 = "stress_2.out";
                {
                    TIMEIT("---- Generation strees test {}", fileIn);
                    std::ofstream input(fileIn);
                    while (!tg(currentGen++, input)) {
                    }
                }
                {
                    TIMEIT("---- Validation {}", fileIn);
                    for (const auto sId : st.allDependees(subtasks)) {
                        std::ifstream input(fileIn);
                        subtasks[sId].validate(input);
                    }
                }
                {
                    TIMEIT("---- Solving with {}", sol1.path.string());
                    const std::string command = std::format("./{} < {} > {}", sol1.path.string(), fileIn, fileOut1);
                    const auto _ = std::system(command.c_str());
                }
                {
                    TIMEIT("---- Solving with {}", sol2.path.string());
                    const std::string command = std::format("./{} < {} > {}", sol2.path.string(), fileIn, fileOut2);
                    const auto _ = std::system(command.c_str());
                }
                {
                    const auto diffCommand = "diff " + fileOut1 + " " + fileOut2;
                    if (std::system(diffCommand.c_str())) {
                        std::println(stderr, "Found diff!");
                        return;
                    }
                }
            }
        }
    }
}

template <class Constraints>
void generateGradeProperties(const std::vector<Subtask<Constraints>>& subtasks, const std::string memory, const std::string time) {
    // TODO: checks
    const auto gprop = gradeProperties();
    std::ofstream out(gprop);

    out << "memory=" << memory << "\n";
    out << "time=" << time << "\n";

    out << "weights=";
    for (int64_t i = 0; i < subtasks.size(); i++) {
        out << subtasks[i].points;
        if (i + 1 != subtasks.size()) {
            out << ",";
        } else {
            out << "\n";
        }
    }

    out << "groups=";
    for (int64_t i = 0; i < subtasks.size(); i++) {
        out << subtasks[i].start << "-" << subtasks[i].end;
        if (i + 1 != subtasks.size()) {
            out << ",";
        } else {
            out << "\n";
        }
    }

    out << "dependencies=";
    for (int64_t i = 0; i < subtasks.size(); i++) {
        const auto& st = subtasks[i];
        std::vector<int64_t> filtered;
        for (int64_t j = 0; j < st.dependencies.size(); j++) {
            filtered.push_back(st.dependencies[j]);
        }
        std::sort(filtered.begin(), filtered.end());
        filtered.resize(std::unique(filtered.begin(), filtered.end()) - filtered.begin());
        for (int64_t j = 0; j < filtered.size(); j++) {
            out << filtered[j] << " ";
            if (j + 1 != filtered.size()) {
                out << ";";
            }
        }
        if (i + 1 != subtasks.size()) {
            out << ",";
        } else {
            out << "\n";
        }
    }
}
