Skip to content
Snippets Groups Projects
bijection.hpp 2.69 KiB
Newer Older
//WARNING: WIP, not tested, use at your own risk


#include <set>
#include <vector>
#include <utility>
#include <iterator>
#include <unordered_set>

template <typename T>
struct bimap_left {
    T val;
};

template <typename T>
struct bimap_right {
    T val;
};

template <typename L, typename R>
struct CompareBijection {
    bool
    operator()(const std::pair<L, R> &lhs, const std::pair<L, R> &rhs) const
    {
        return lhs.first < rhs.first;
    }

    bool
    operator()(const bimap_left<L> &lhs, const std::pair<L, R> &rhs) const
    {
        return lhs.val < rhs.first;
    }

    bool
    operator()(const bimap_right<R> &lhs, const std::pair<L, R> &rhs) const 
    {
        return lhs.val < rhs.second;
    }

    bool
    operator()(const std::pair<L, R> &lhs, const bimap_left<L> &rhs) const 
    {
        return lhs.first < rhs.val;
    }

    bool
    operator()(const std::pair<L, R> &lhs, const bimap_right<R> &rhs) const 
    {
        return lhs.second < rhs.val;
    }

    using is_transparent = L;
};

template <typename L, typename R>
struct bijection {
    using left = bimap_left<L>;
    using right = bimap_right<R>;
    using relation = std::pair<L,R>;
    using relation_set  = std::set<relation, CompareBijection<L,R>>;
    using relation_iterator = typename relation_set::iterator;

    std::set<L>  left_elements;
    std::set<R>  right_elements;
    relation_set relations;

    // Insert relations
    void
    insert(L l, R r)
    {
        relations.insert({l,r});
    }

    void
    insert(relation rel)
    {
        relations.insert(rel);
    }

    // Erase elements
    void
    erase(left l_key)
    {
        relations.erase(l_key);
    }

    void
    erase(right r_key)
    {
        relations.erase(r_key);
    }

    void
    l_erase(L l)
    {
        relations.erase(left{l});
    }

    void
    r_erase(R r)
    {
        relations.erase(right{r});
    }

    // Contains
    template <typename Key>
    bool contains(Key key)
    {
        return relations.contains(key);
    }
    
    // Count
    template <typename Key>
    size_t
    count(Key key)
    {
        return relations.count(key);
    }

    size_t
    l_count(L l)
    {
        return count(left{l});
    }

    size_t
    r_count(R r)
    {
        return count(right{r});
    }

    relation_iterator
    find(left l_key){
        return relations.find(l_key);
    }

    relation_iterator
    find(right r_key){
        return relations.find(l_key);
    }


    //TODO: add operator [] for both left and right access

    R
    at(left l_key){
        auto r = relations.find(l_key);
        return r->second;
    }

    L
    at(right r_key){
        auto r = relations.find(l_key);
        return r->first;
    }


};