#include <set>
#include <vector>
#include <utility>
#include <iterator>
#include <unordered_set>

namespace ojl {
  template <typename T>
  struct bimap_left {
      T val;
  };

  template <typename T>
  struct bimap_right {
      T val;
  };


  template <typename L, typename R>
  struct CompareLeft {
      bool
      operator()(const std::pair<L, R> &lhs, const std::pair<L, R> &rhs) const
      {
          return lhs.first < rhs.first || (lhs.first == rhs.first && lhs.second < rhs.second);
      }

      bool
      operator()(const L &lhs, const std::pair<L, R> &rhs) const
      {
          return lhs < rhs.first;
      }

      bool
      operator()(const std::pair<L, R> &lhs, const L &rhs) const 
      {
          return lhs.first < rhs;
      }

      using is_transparent = L;
  };


  template <typename L, typename R>
  struct CompareRight {
      bool
      operator()(const std::pair<L, R> &lhs, const std::pair<L, R> &rhs) const
      {
          return lhs.second < rhs.second || (lhs.second == rhs.second && lhs.first < rhs.first);
      }

      bool
      operator()(const R &lhs, const std::pair<L, R> &rhs) const 
      {
          return lhs < rhs.second;
      }

      bool
      operator()(const std::pair<L, R> &lhs, const R &rhs) const 
      {
          return lhs.second < rhs;
      }

      using is_transparent = L;
  };


  //TODO: maybe add unkeyed versions of member functions for when L != R
  // (replacing left with L, right with R)

  template <typename L, typename R>
  struct bimap {
      using left = bimap_left<L>;
      using right = bimap_right<R>;

      using relation = std::pair<L,R>;

      using l_relation_set  = std::set<relation, CompareLeft<L,R>>;
      using r_relation_set  = std::set<relation, CompareRight<L,R>>;

      using relation_iterator = typename l_relation_set::iterator;

      //TODO: see if the iterator types are the same

      std::set<L>  left_elements;
      std::set<R>  right_elements;
      l_relation_set l_relations;
      r_relation_set r_relations;

      //TODO: badly named? Maybe should make the sets private if I have accessors...?
      const l_relation_set&
      relations_ref()
      {
          return l_relations;
      }

      std::vector<relation>
      all_relations_vec()
      {
          std::vector<relation> rel;
          for (auto it = l_relations.begin(); it != l_relations.end(); it++){
              rel.push_back(*it);
          }
          return rel;
      }


      // Insert elements
      void
      insert(left l_key)
      {
          left_elements.insert(l_key.val);
      }

      void
      insert(right r_key)
      {
          right_elements.insert(r_key.val);
      }

      void
      l_insert(L l)
      {
          left_elements.insert(l);
      }

      void
      r_insert(R r)
      {
          right_elements.insert(r);
      }

      // Insert relations
      void
      insert(L l, R r)
      {
          left_elements.insert(l);
          right_elements.insert(r);
          l_relations.insert({l,r});
          r_relations.insert({l,r});
      }

      void
      insert(relation rel)
      {
          left_elements.insert(rel.first);
          right_elements.insert(rel.second);
          l_relations.insert(rel);
          r_relations.insert(rel);
      }

      // Erase relations
      void
      erase(const relation &rel)
      {
          r_relations.erase(rel);
          l_relations.erase(rel);
      }

      // Erase elements
      void
      erase(left l_key)
      {
          left_elements.erase(l_key.val);
          //Find all relations in l_relations and remove them from r_relations
          auto &[lb, ub] = l_relations.equal_range(l_key.val);
          for (auto it = lb; it < ub; it++){
              r_relations.erase(*it);
          }
          l_relations.erase(l_key.val);
      }

      void
      erase(right r_key)
      {
          right_elements.erase(r_key.val);
          auto &[lb, ub] = r_relations.equal_range(r_key.val);
          for (auto it = lb; it < ub; it++){
              l_relations.erase(*it);
          }
          r_relations.erase(r_key.val);
      }

      void
      l_erase(L l)
      {
          left_elements.erase(l);
          auto &[lb, ub] = l_relations.equal_range(l);
          for (auto it = lb; it < ub; it++){
              r_relations.erase(*it);
          }
          l_relations.erase(l);
      }

      void
      r_erase(R r)
      {
          right_elements.erase(r);
          auto &[lb, ub] = l_relations.equal_range(r);
          for (auto it = lb; it < ub; it++){
              r_relations.erase(*it);
          }
          r_relations.erase(r);
      }

      // Contains
      bool contains(relation rel)
      {
          return l_relations.contains(rel);
      }

      bool contains(left l_key)
      {
          return l_relations.contains(l_key.val);
      }

      bool contains(right r_key)
      {
          return r_relations.contains(r_key.val);
      }

      bool l_contains(L l)
      {
          return l_relations.contains(l);
      }

      bool r_contains(R r)
      {
          return r_relations.contains(r);
      }

      
      // Count
      size_t
      count(left l_key)
      {
          return l_relations.count(l_key.val);
      }

      size_t
      count(right r_key)
      {
          return r_relations.count(r_key.val);
      }

      size_t
      l_count(L l)
      {
          return l_relations.count(l);
      }

      size_t
      r_count(R r)
      {
          return r_relations.count(r);
      }

      // Equal range
      std::pair<relation_iterator, relation_iterator>
      equal_range(left l_key)
      {
          return l_relations.equal_range(l_key.val);
      }

      std::pair<relation_iterator, relation_iterator>
      equal_range(right r_key)
      {
          return r_relations.equal_range(r_key.val);
      }

      std::pair<relation_iterator, relation_iterator>
      l_equal_range(L l)
      {
          return equal_range(left{l});
      }

      std::pair<relation_iterator, relation_iterator>
      r_equal_range(R r)
      {
          return equal_range(right{r});
      }

      // Convenience function for getting a vector out
      std::vector<R>
      mapped_vector(left l_key)
      {
          std::vector<R> v;
          auto [lb, ub] = equal_range(l_key);
          v.reserve(std::distance(lb,ub));
          for (auto it = lb; it != ub; it++){
              v.push_back(it->second);
          }
          return v;
      }

      std::vector<L>
      mapped_vector(right r_key)
      {
          std::vector<L> v;
          auto [lb, ub] = equal_range(r_key);
          v.reserve(std::distance(lb,ub));
          for (auto it = lb; it != ub; it++){
              v.push_back(it->first);
          }
          return v;
      }

      // Convenience function for getting a set
      std::vector<R>
      mapped_set(left l_key)
      {
          std::set<R> s;
          auto [lb, ub] = equal_range(l_key);
          for (auto it = lb; it != ub; it++){
              s.insert(it->second);
          }
          return s;
      }

      std::vector<L>
      mapped_set(right r_key)
      {
          std::set<L> s;
          auto [lb, ub] = equal_range(r_key);
          s.insert(ub - lb);
          for (auto it = lb; it != ub; it++){
              s.insert(it->first);
          }
          return s;
      }

      std::vector<R>
      l_mapped_vector(L l)
      {
          return mapped_vector(left{l});
      }

      std::vector<L>
      r_mapped_vector(R r)
      {
          return mapped_vector(right{r});
      }

      std::vector<R>
      l_mapped_set(L l)
      {
          return mapped_set(left{l});
      }

      std::vector<L>
      r_mapped_set(R r)
      {
          return mapped_set(right{r});
      }

  };
}