Skip to content
Snippets Groups Projects
bimap.hpp 7.01 KiB
Newer Older
#include <set>
#include <vector>
#include <utility>
#include <iterator>
#include <unordered_set>

template <typename T>
struct bimap_left {
    T val;
};

template <typename T>
struct bimap_right {
    T val;
};


template <typename L, typename R>
struct CompareLeft {
    bool
    operator()(const std::pair<L, R> &lhs, const std::pair<L, R> &rhs) const
    {
        return lhs.first < rhs.first || (lhs.first == rhs.first && lhs.second < rhs.second);
    }

    bool
    operator()(const L &lhs, const std::pair<L, R> &rhs) const
    {
        return lhs < rhs.first;
    }

    bool
    operator()(const std::pair<L, R> &lhs, const L &rhs) const 
    {
        return lhs.first < rhs;
    }

    using is_transparent = L;
};


template <typename L, typename R>
struct CompareRight {
    bool
    operator()(const std::pair<L, R> &lhs, const std::pair<L, R> &rhs) const
    {
        return lhs.second < rhs.second || (lhs.second == rhs.second && lhs.first < rhs.first);
    }

    bool
    operator()(const R &lhs, const std::pair<L, R> &rhs) const 
    {
        return lhs < rhs.second;
    }

    bool
    operator()(const std::pair<L, R> &lhs, const R &rhs) const 
    {
        return lhs.second < rhs;
    }

    using is_transparent = L;
};


//TODO: maybe add unkeyed versions of member functions for when L != R
// (replacing left with L, right with R)

template <typename L, typename R>
struct bimap {
    using left = bimap_left<L>;
    using right = bimap_right<R>;

    using relation = std::pair<L,R>;

    using l_relation_set  = std::set<relation, CompareLeft<L,R>>;
    using r_relation_set  = std::set<relation, CompareRight<L,R>>;

    using relation_iterator = typename l_relation_set::iterator;

    //TODO: see if the iterator types are the same

    std::set<L>  left_elements;
    std::set<R>  right_elements;
    l_relation_set l_relations;
    r_relation_set r_relations;

    //TODO: badly named? Maybe should make the sets private if I have accessors...?
    const l_relation_set&
    relations_ref()
    {
	return l_relations;
    }

    std::vector<relation>
    all_relations_vec()
    {
	std::vector<relation> rel;
	for (auto it = l_relations.begin(); it != l_relations.end(); it++){
	    rel.push_back(*it);
	}
	return rel;
    }


    // Insert elements
    void
    insert(left l_key)
    {
        left_elements.insert(l_key.val);
    }

    void
    insert(right r_key)
    {
        right_elements.insert(r_key.val);
    }

    void
    l_insert(L l)
    {
        left_elements.insert(l);
    }

    void
    r_insert(R r)
    {
        right_elements.insert(r);
    }

    // Insert relations
    void
    insert(L l, R r)
    {
        left_elements.insert(l);
        right_elements.insert(r);
        l_relations.insert({l,r});
        r_relations.insert({l,r});
    }

    void
    insert(relation rel)
    {
        left_elements.insert(rel.first);
        right_elements.insert(rel.second);
        l_relations.insert(rel);
        r_relations.insert(rel);
    }

    // Erase relations
    void
    erase(const relation &rel)
    {
        r_relations.erase(rel);
        l_relations.erase(rel);
    }

    // Erase elements
    void
    erase(left l_key)
    {
        left_elements.erase(l_key.val);
        //Find all relations in l_relations and remove them from r_relations
        auto &[lb, ub] = l_relations.equal_range(l_key.val);
        for (auto it = lb; it < ub; it++){
            r_relations.erase(*it);
        }
        l_relations.erase(l_key.val);
    }

    void
    erase(right r_key)
    {
        right_elements.erase(r_key.val);
        auto &[lb, ub] = r_relations.equal_range(r_key.val);
        for (auto it = lb; it < ub; it++){
            l_relations.erase(*it);
        }
        r_relations.erase(r_key.val);
    }

    void
    l_erase(L l)
    {
        left_elements.erase(l);
        auto &[lb, ub] = l_relations.equal_range(l);
        for (auto it = lb; it < ub; it++){
            r_relations.erase(*it);
        }
        l_relations.erase(l);
    }

    void
    r_erase(R r)
    {
        right_elements.erase(r);
        auto &[lb, ub] = l_relations.equal_range(r);
        for (auto it = lb; it < ub; it++){
            r_relations.erase(*it);
        }
        r_relations.erase(r);
    }

    // Contains
    bool contains(relation rel)
    {
        return l_relations.contains(rel);
    }

    bool contains(left l_key)
    {
        return l_relations.contains(l_key.val);
    }

    bool contains(right r_key)
    {
        return r_relations.contains(r_key.val);
    }

    bool l_contains(L l)
    {
        return l_relations.contains(l);
    }

    bool r_contains(R r)
    {
        return r_relations.contains(r);
    }

    
    // Count
    size_t
    count(left l_key)
    {
        return l_relations.count(l_key.val);
    }

    size_t
    count(right r_key)
    {
        return r_relations.count(r_key.val);
    }

    size_t
    l_count(L l)
    {
        return l_relations.count(l);
    }

    size_t
    r_count(R r)
    {
        return r_relations.count(r);
    }

    // Equal range
    std::pair<relation_iterator, relation_iterator>
    equal_range(left l_key)
    {
        return l_relations.equal_range(l_key.val);
    }

    std::pair<relation_iterator, relation_iterator>
    equal_range(right r_key)
    {
        return r_relations.equal_range(r_key.val);
    }

    std::pair<relation_iterator, relation_iterator>
    l_equal_range(L l)
    {
        return equal_range(left{l});
    }

    std::pair<relation_iterator, relation_iterator>
    r_equal_range(R r)
    {
        return equal_range(right{r});
    }

    // Convenience function for getting a vector out
    std::vector<R>
    mapped_vector(left l_key)
    {
        std::vector<R> v;
        auto [lb, ub] = equal_range(l_key);
        v.reserve(std::distance(lb,ub));
        for (auto it = lb; it != ub; it++){
            v.push_back(it->second);
        }
        return v;
    }

    std::vector<L>
    mapped_vector(right r_key)
    {
        std::vector<L> v;
        auto [lb, ub] = equal_range(r_key);
        v.reserve(std::distance(lb,ub));
        for (auto it = lb; it != ub; it++){
            v.push_back(it->first);
        }
        return v;
    }

    // Convenience function for getting a set
    std::vector<R>
    mapped_set(left l_key)
    {
        std::set<R> s;
        auto [lb, ub] = equal_range(l_key);
        for (auto it = lb; it != ub; it++){
            s.insert(it->second);
        }
        return s;
    }

    std::vector<L>
    mapped_set(right r_key)
    {
        std::set<L> s;
        auto [lb, ub] = equal_range(r_key);
        s.insert(ub - lb);
        for (auto it = lb; it != ub; it++){
            s.insert(it->first);
        }
        return s;
    }

    std::vector<R>
    l_mapped_vector(L l)
    {
        return mapped_vector(left{l});
    }

    std::vector<L>
    r_mapped_vector(R r)
    {
        return mapped_vector(right{r});
    }

    std::vector<R>
    l_mapped_set(L l)
    {
        return mapped_set(left{l});
    }

    std::vector<L>
    r_mapped_set(R r)
    {
        return mapped_set(right{r});
    }

};