/**
 * @file
 * @brief A generic [binary search tree](https://en.wikipedia.org/wiki/Binary_search_tree) implementation.
 * @see binary_search_tree.cpp
 */

#include <cassert>
#include <functional>
#include <iostream>
#include <memory>
#include <vector>

/**
 * @brief The Binary Search Tree class.
 *
 * @tparam T The type of the binary search tree key.
 */
template <class T>
class binary_search_tree {
 private:
    /**
    * @brief A struct to represent a node in the Binary Search Tree.
    */
    struct bst_node {
        T value; /**< The value/key of the node. */
        std::unique_ptr<bst_node> left; /**< Pointer to left subtree. */
        std::unique_ptr<bst_node> right; /**< Pointer to right subtree. */

        /**
        * Constructor for bst_node, used to simplify node construction and
        * smart pointer construction.
        * @param _value The value of the constructed node.
        */
        explicit bst_node(T _value) {
            value = _value;
            left = nullptr;
            right = nullptr;
        }
    };

    std::unique_ptr<bst_node> root_; /**< Pointer to the root of the BST. */
    std::size_t size_ = 0; /**< Number of elements/nodes in the BST. */

    /**
     * @brief Recursive function to find the maximum value in the BST.
     *
     * @param node The node to search from.
     * @param ret_value Variable to hold the maximum value.
     * @return true If the maximum value was successfully found.
     * @return false Otherwise.
     */
    bool find_max(std::unique_ptr<bst_node>& node, T& ret_value) {
        if (!node) {
            return false;
        } else if (!node->right) {
            ret_value = node->value;
            return true;
        }
        return find_max(node->right, ret_value);
    }

    /**
     * @brief Recursive function to find the minimum value in the BST.
     *
     * @param node The node to search from.
     * @param ret_value Variable to hold the minimum value.
     * @return true If the minimum value was successfully found.
     * @return false Otherwise.
     */
    bool find_min(std::unique_ptr<bst_node>& node, T& ret_value) {
        if (!node) {
            return false;
        } else if (!node->left) {
            ret_value = node->value;
            return true;
        }

        return find_min(node->left, ret_value);
    }

    /**
     * @brief Recursive function to insert a value into the BST.
     *
     * @param node The node to search from.
     * @param new_value The value to insert.
     * @return true If the insert operation was successful.
     * @return false Otherwise.
     */
    bool insert(std::unique_ptr<bst_node>& node, T new_value) {
        if (root_ == node && !root_) {
            root_ = std::unique_ptr<bst_node>(new bst_node(new_value));
            return true;
        }

        if (new_value < node->value) {
            if (!node->left) {
                node->left = std::unique_ptr<bst_node>(new bst_node(new_value));
                return true;
            } else {
                return insert(node->left, new_value);
            }
        } else if (new_value > node->value) {
            if (!node->right) {
                node->right =
                    std::unique_ptr<bst_node>(new bst_node(new_value));
                return true;
            } else {
                return insert(node->right, new_value);
            }
        } else {
            return false;
        }
    }

    /**
     * @brief Recursive function to remove a value from the BST.
     *
     * @param parent The parent node of node.
     * @param node The node to search from.
     * @param rm_value The value to remove.
     * @return true If the removal operation was successful.
     * @return false Otherwise.
     */
    bool remove(std::unique_ptr<bst_node>& parent,
                std::unique_ptr<bst_node>& node, T rm_value) {
        if (!node) {
            return false;
        }

        if (node->value == rm_value) {
            if (node->left && node->right) {
                T successor_node_value{};
                find_max(node->left, successor_node_value);
                remove(root_, root_, successor_node_value);
                node->value = successor_node_value;
                return true;
            } else if (node->left || node->right) {
                std::unique_ptr<bst_node>& non_null =
                    (node->left ? node->left : node->right);

                if (node == root_) {
                    root_ = std::move(non_null);
                } else if (rm_value < parent->value) {
                    parent->left = std::move(non_null);
                } else {
                    parent->right = std::move(non_null);
                }

                return true;
            } else {
                if (node == root_) {
                    root_.reset(nullptr);
                } else if (rm_value < parent->value) {
                    parent->left.reset(nullptr);
                } else {
                    parent->right.reset(nullptr);
                }

                return true;
            }
        } else if (rm_value < node->value) {
            return remove(node, node->left, rm_value);
        } else {
            return remove(node, node->right, rm_value);
        }
    }

    /**
     * @brief Recursive function to check if a value is in the BST.
     *
     * @param node The node to search from.
     * @param value The value to find.
     * @return true If the value was found in the BST.
     * @return false Otherwise.
     */
    bool contains(std::unique_ptr<bst_node>& node, T value) {
        if (!node) {
            return false;
        }

        if (value < node->value) {
            return contains(node->left, value);
        } else if (value > node->value) {
            return contains(node->right, value);
        } else {
            return true;
        }
    }

    /**
     * @brief Recursive function to traverse the tree in in-order order.
     *
     * @param callback Function that is called when a value needs to processed.
     * @param node The node to traverse from.
     */
    void traverse_inorder(std::function<void(T)> callback,
                          std::unique_ptr<bst_node>& node) {
        if (!node) {
            return;
        }

        traverse_inorder(callback, node->left);
        callback(node->value);
        traverse_inorder(callback, node->right);
    }

    /**
     * @brief Recursive function to traverse the tree in pre-order order.
     *
     * @param callback Function that is called when a value needs to processed.
     * @param node The node to traverse from.
     */
    void traverse_preorder(std::function<void(T)> callback,
                           std::unique_ptr<bst_node>& node) {
        if (!node) {
            return;
        }

        callback(node->value);
        traverse_preorder(callback, node->left);
        traverse_preorder(callback, node->right);
    }

    /**
     * @brief Recursive function to traverse the tree in post-order order.
     *
     * @param callback Function that is called when a value needs to processed.
     * @param node The node to traverse from.
     */
    void traverse_postorder(std::function<void(T)> callback,
                            std::unique_ptr<bst_node>& node) {
        if (!node) {
            return;
        }

        traverse_postorder(callback, node->left);
        traverse_postorder(callback, node->right);
        callback(node->value);
    }

 public:
    /**
     * @brief Construct a new Binary Search Tree object.
     *
     */
    binary_search_tree() {
        root_ = nullptr;
        size_ = 0;
    }

    /**
     * @brief Insert a new value into the BST.
     *
     * @param new_value The value to insert into the BST.
     * @return true If the insertion was successful.
     * @return false Otherwise.
     */
    bool insert(T new_value) {
        bool result = insert(root_, new_value);
        if (result) {
            size_++;
        }
        return result;
    }

    /**
     * @brief Remove a specified value from the BST.
     *
     * @param rm_value The value to remove.
     * @return true If the removal was successful.
     * @return false Otherwise.
     */
    bool remove(T rm_value) {
        bool result = remove(root_, root_, rm_value);
        if (result) {
            size_--;
        }
        return result;
    }

    /**
     * @brief Check if a value is in the BST.
     *
     * @param value The value to find.
     * @return true If value is in the BST.
     * @return false Otherwise.
     */
    bool contains(T value) { return contains(root_, value); }

    /**
     * @brief Find the smallest value in the BST.
     *
     * @param ret_value Variable to hold the minimum value.
     * @return true If minimum value was successfully found.
     * @return false Otherwise.
     */
    bool find_min(T& ret_value) { return find_min(root_, ret_value); }

    /**
     * @brief Find the largest value in the BST.
     *
     * @param ret_value Variable to hold the maximum value.
     * @return true If maximum value was successfully found.
     * @return false Otherwise.
     */
    bool find_max(T& ret_value) { return find_max(root_, ret_value); }

    /**
     * @brief Get the number of values in the BST.
     *
     * @return std::size_t Number of values in the BST.
     */
    std::size_t size() { return size_; }

    /**
     * @brief Get all values of the BST in in-order order.
     *
     * @return std::vector<T> List of values, sorted in in-order order.
     */
    std::vector<T> get_elements_inorder() {
        std::vector<T> result;
        traverse_inorder([&](T node_value) { result.push_back(node_value); },
                         root_);
        return result;
    }

    /**
     * @brief Get all values of the BST in pre-order order.
     *
     * @return std::vector<T> List of values, sorted in pre-order order.
     */
    std::vector<T> get_elements_preorder() {
        std::vector<T> result;
        traverse_preorder([&](T node_value) { result.push_back(node_value); },
                          root_);
        return result;
    }

    /**
     * @brief Get all values of the BST in post-order order.
     *
     * @return std::vector<T> List of values, sorted in post-order order.
     */
    std::vector<T> get_elements_postorder() {
        std::vector<T> result;
        traverse_postorder([&](T node_value) { result.push_back(node_value); },
                           root_);
        return result;
    }
};

/**
 * @brief Function for testing insert().
 * 
 * @returns `void`
 */
static void test_insert() {
    std::cout << "Testing BST insert...";

    binary_search_tree<int> tree;
    bool res = tree.insert(5);
    int min = -1, max = -1;
    assert(res);
    assert(tree.find_max(max));
    assert(tree.find_min(min));
    assert(max == 5);
    assert(min == 5);
    assert(tree.size() == 1);

    tree.insert(4);
    tree.insert(3);
    tree.insert(6);
    assert(tree.find_max(max));
    assert(tree.find_min(min));
    assert(max == 6);
    assert(min == 3);
    assert(tree.size() == 4);

    bool fail_res = tree.insert(4);
    assert(!fail_res);
    assert(tree.size() == 4);

    std::cout << "ok" << std::endl;
}

/**
 * @brief Function for testing remove().
 * 
 * @returns `void`
 */
static void test_remove() {
    std::cout << "Testing BST remove...";

    binary_search_tree<int> tree;
    tree.insert(5);
    tree.insert(4);
    tree.insert(3);
    tree.insert(6);

    bool res = tree.remove(5);
    int min = -1, max = -1;
    assert(res);
    assert(tree.find_max(max));
    assert(tree.find_min(min));
    assert(max == 6);
    assert(min == 3);
    assert(tree.size() == 3);
    assert(tree.contains(5) == false);

    tree.remove(4);
    tree.remove(3);
    tree.remove(6);
    assert(tree.size() == 0);
    assert(tree.contains(6) == false);

    bool fail_res = tree.remove(5);
    assert(!fail_res);
    assert(tree.size() == 0);

    std::cout << "ok" << std::endl;
}

/**
 * @brief Function for testing contains().
 * 
 * @returns `void`
 */
static void test_contains() {
    std::cout << "Testing BST contains...";

    binary_search_tree<int> tree;
    tree.insert(5);
    tree.insert(4);
    tree.insert(3);
    tree.insert(6);

    assert(tree.contains(5));
    assert(tree.contains(4));
    assert(tree.contains(3));
    assert(tree.contains(6));
    assert(!tree.contains(999));

    std::cout << "ok" << std::endl;
}

/**
 * @brief Function for testing find_min().
 * 
 * @returns `void`
 */
static void test_find_min() {
    std::cout << "Testing BST find_min...";

    int min = 0;
    binary_search_tree<int> tree;
    assert(!tree.find_min(min));

    tree.insert(5);
    tree.insert(4);
    tree.insert(3);
    tree.insert(6);

    assert(tree.find_min(min));
    assert(min == 3);

    std::cout << "ok" << std::endl;
}

/**
 * @brief Function for testing find_max().
 * 
 * @returns `void`
 */
static void test_find_max() {
    std::cout << "Testing BST find_max...";

    int max = 0;
    binary_search_tree<int> tree;
    assert(!tree.find_max(max));

    tree.insert(5);
    tree.insert(4);
    tree.insert(3);
    tree.insert(6);

    assert(tree.find_max(max));
    assert(max == 6);

    std::cout << "ok" << std::endl;
}

/**
 * @brief Function for testing get_elements_inorder().
 * 
 * @returns `void`
 */
static void test_get_elements_inorder() {
    std::cout << "Testing BST get_elements_inorder...";

    binary_search_tree<int> tree;
    tree.insert(5);
    tree.insert(4);
    tree.insert(3);
    tree.insert(6);

    std::vector<int> expected = {3, 4, 5, 6};
    std::vector<int> actual = tree.get_elements_inorder();
    assert(actual == expected);

    std::cout << "ok" << std::endl;
}

/**
 * @brief Function for testing get_elements_preorder().
 * 
 * @returns `void`
 */
static void test_get_elements_preorder() {
    std::cout << "Testing BST get_elements_preorder...";

    binary_search_tree<int> tree;
    tree.insert(5);
    tree.insert(4);
    tree.insert(3);
    tree.insert(6);

    std::vector<int> expected = {5, 4, 3, 6};
    std::vector<int> actual = tree.get_elements_preorder();
    assert(actual == expected);

    std::cout << "ok" << std::endl;
}

/**
 * @brief Function for testing get_elements_postorder().
 * 
 * @returns `void`
 */
static void test_get_elements_postorder() {
    std::cout << "Testing BST get_elements_postorder...";

    binary_search_tree<int> tree;
    tree.insert(5);
    tree.insert(4);
    tree.insert(3);
    tree.insert(6);

    std::vector<int> expected = {3, 4, 6, 5};
    std::vector<int> actual = tree.get_elements_postorder();
    assert(actual == expected);

    std::cout << "ok" << std::endl;
}

int main() {
    test_insert();
    test_remove();
    test_contains();
    test_find_max();
    test_find_min();
    test_get_elements_inorder();
    test_get_elements_preorder();
    test_get_elements_postorder();
}

Binary Search Tree2