# ###################################################################################
# ### Utility Functions for Maximum Likelihood Estimation (MLE) and Segmentation  ###
# ###                  ( based on stochastic blockmodels )                        ###
# ###################################################################################


import math
import numpy as np
import utils
import networkx as nx
import copy
import itertools
import collections
import random

# returns the index of the segment  in which particular timestamp lies using binary search 
def  _findSegment(a, n, K): 
  
    start = 0
    end = n - 1
  
    # Binary search 
    while (end <= end): 
        
        if end >= start:
            # mid point 
            mid = (start + end) //2
            #  element  found 
            if (K >= a[mid][0] and K <= a[mid][1]):                 
                return mid 
      
            # first half 
            elif (K > a[mid][1]): 
                start = mid +  1
      
            # second half 
            elif (K < a[mid][0]): 
                end = mid  - 1
        else:
            # print('K : {} MID: {} START:{} END: {} ARRA : {}, n: {}'.format(K,mid, start,end, a,n))
            # Not found 
            print('Not Found')
            return -1
        
# split an empty group if it does exist
def _split_the_empty_group(num_roles,list_of_groups,group_assignment):
    
    _is_grp_emty = False
    _empty_grp_index = 0
    
    # check whether a group is empty
    for k in range(0, num_roles):
        
        if len( list_of_groups[k]) == 0:
            
            _is_grp_emty = True
            _empty_grp_index = k
            
            print('group is empty..')
           
    # if group is empty, split the largest group        
    if  _is_grp_emty:

        list_len = [len(i) for i in list_of_groups] 
        max_len_index =  np.argmax(list_len)

        two_split = np.array_split(list_of_groups[max_len_index], 2)
        list_of_groups[max_len_index] =  two_split[0].tolist()
        list_of_groups[_empty_grp_index] =  two_split[1].tolist()
        
        # update new group assignments        
        for node in list_of_groups[_empty_grp_index]:
            group_assignment[node] = _empty_grp_index
        
        print(group_assignment)
    return group_assignment 
        
#  Assign groups ( ver 1 )
def assign_groups(group_assignment,lambda_estimates,change_points_arr,nodes,num_roles,num_segments,dic):

    list_of_groups=  [[] for _ in range(num_roles)]
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)     
    
    for v in nodes:       
        # remove v 
        for g in range(0, num_roles):        
            
            # Extract Group g
            _list_g = list_of_groups[g]
            
            if v in _list_g :
                _list_g.remove(v)
        
        # store current group assignment 
        max_i = group_assignment[v]
        
        # likelihoods when node v belongs to each group i
        likelihood_sum = np.zeros((num_roles) , dtype=float) 
        
        # if node v belongs to group i
        for i in range(0, num_roles): 

            likelihood_sum[i] = 0
            
            # interactions with other groups and own group
            for j in range(0, num_roles):
                
                W = list_of_groups[j]
                
                # compute edge counts; with  group  j, within each particular segment
                edge_counts = np.zeros(( num_segments) , dtype=int) 
                
                for w in W:
                    
                    v1 =  v
                    v2 =  w
                    
                    # undirected ; order of tuple
                    if v1 > v2:
                        v1, v2 = v2, v1
                    
                    # get the list of timestamps corresponding to (v1,v2) pair
                    if dic.get((v1,v2)):
                        # change-points of (i,j) group pair
                        chg_points = change_points_arr[i,j,:]
                        
                        # build ranges from change points
                        # ex: [0 20340 29460 47640 67800 82560] implies
                        # [[0, 20340], [20341, 29460], [29461, 47640], [47641, 67800], [67801, 82560]]
                        ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
                        ranges_arr[0][0] = 0 
                        
                        list_time_stamps = dic.get((v1,v2))
                        
                        # iterate over timestamps list
                        for item in list_time_stamps:
                            
                            # find the segment which the timestamp belongs
                            # (is dependent on which groups the two nodes belong)
                            d =  _findSegment(ranges_arr, len(ranges_arr) , int(item)) 
                            edge_counts[d] += 1
                    
                for d in range(0, num_segments):
                    delta_t =  change_points_arr[i,j,d+1] -  change_points_arr[i,j,d]
                    
                    # include the initial time-stamp
                    if d == 0:
                        delta_t += 1
                        
                    if lambda_estimates[i,j,d]!= 0:
                        likelihood_sum[i] += (edge_counts[d]*math.log(lambda_estimates[i,j,d]) - len(W)*lambda_estimates[i,j,d]*delta_t)
                    else:
                        print('lambda zero..')
        
        if  max(likelihood_sum) != 0:  
            max_i = np.argmax(likelihood_sum)
            group_assignment[v] = max_i
            
        list_of_groups[max_i].append(v)
    
    # To cater the emptiness of a group   
    group_assignment = _split_the_empty_group(num_roles,list_of_groups,group_assignment)    
    
    # print(list_of_groups)  
    return group_assignment 


#  Assign groups ( ver 2.2 )
def group_assignment_ver2_2(nodes,num_roles,num_segments,lambda_estimates,group_assignment,change_points_arr,dic):  
    
    list_of_groups=  [[] for _ in range(num_roles)]
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)
    # print(list_of_groups)
    
    # no of pairs of groups or roles
    size_all_pairs = int(num_roles*(num_roles + 1)/2)  
    
    # convert tensor to 2-d array
    l_array = np.zeros((size_all_pairs,num_segments) , dtype=float )
   
    cnt = 0    
    for z1 in range(0, num_roles):
        for z2 in range(z1, num_roles):           
            for d in range(0, num_segments):
                l_array[cnt,d]= lambda_estimates[z1,z2,d]
            cnt += 1
    
    # find h-levels
    unique_cols = np.unique(l_array, axis=1)        
    h = len(unique_cols[0])
    # print('h levels %d'%h)
    
    # h-level lamda estimates    
    lambda_estimates_h = np.zeros((num_roles, num_roles, h) , dtype=float )
    
    cnt = 0    
    for z1 in range(0, num_roles):
        for z2 in range(z1, num_roles):
            for d in range(0, h):
                lambda_estimates_h[z1,z2,d]= unique_cols[cnt,d]
                lambda_estimates_h[z2,z1,d]= unique_cols[cnt,d]
            cnt += 1 
    
    # segment to level dictionary mapping
    seg_level_dic = {}
    lst = lambda_estimates_h[0,0,:]
       
    for d in range(num_segments):   
        _h  = [ind for ind, x in enumerate( lst) if x == lambda_estimates[0,0,d]]
        seg_level_dic[d] = _h[0]   
    # print('seg_level_dic : {}'.format(seg_level_dic))
    
    # Multidigraph object 
    G = nx.MultiDiGraph()
    G.add_nodes_from(nodes)
    G.add_edges_from(list(dic))
    G = G.to_undirected()  
    
    # iterate over all nodes
    for v in nodes:
        
        # remove v 
        for g in range(0, num_roles):        
            #Extract Group g
            _list_g = list_of_groups[g]
            
            if v in _list_g :
                _list_g.remove(v)
        
        # store current group of v
        max_i = group_assignment[v]
        
         # likelihoods when node v belongs to each group i
        likelihood_sum = np.zeros((num_roles) , dtype=float)
        
        # edge counts in particular group and level
        edge_counts = np.zeros((num_roles, h) , dtype=int) 
        
        # find neighbours of node v
        neighbour_nodes  = [n for n in G.neighbors(v)]
        # print('neignbours {}'.format(neighbour_nodes))
        
        # iterate over all neigbour nodes of v
        for neigh in neighbour_nodes:
            
            v1 = v
            v2 = neigh
            
            # order of tuple : undirected
            if v1 > v2:
                v1, v2 = v2, v1
            
            i=group_assignment.get(v1)
            j=group_assignment.get(v2)
            
            # change-points of (i,j) group pair; in this case, equally partitioned
            chg_points = change_points_arr[i,j,:]
        
            ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
            ranges_arr[0][0] = 0
            
            # list of timestamps corresponding to (v1,v2) pair
            list_time_stamps = dic.get((v1,v2))
            
            # iterate over timestamps list
            for item in list_time_stamps:
                # find the segment which the timestamp belongs
                # (is primarily dependent on which groups the two nodes belong)
                # (herein, equally segmented)
                d =  _findSegment(ranges_arr,  len(ranges_arr) , int(item))                 
                _h = seg_level_dic[d]
                                
                if j == group_assignment.get(v):                    
                    edge_counts[i,_h] += 1
                else:                       
                    edge_counts[j,_h] += 1
        
        #  merged segment(time duration) according to level
        t = np.zeros((h) , dtype=float) 
        
        for d in range (0,num_segments):
        
            _h= seg_level_dic[d]
            delta_t =  change_points_arr[0,0,d+1] -  change_points_arr[0,0,d]
            
            if d == 0:
              delta_t +=  1 
        
            t[_h] = t[_h] + delta_t 
 
        # if node a belongs to group a  
        for a in range(0, num_roles): 
            
            likelihood_sum[a] = 0
            
            for j in range(0, num_roles):                   
                W = list_of_groups[j]  
                
                # maximum possible number of  ways of interacting with node v
                factor = len(W)
                
                for k in range(0, h):
  
                    if lambda_estimates_h[a,j,k] != 0 :
                        likelihood_sum[a] += (edge_counts[j,k]*math.log(lambda_estimates_h[a,j,k])  - factor*lambda_estimates_h[a,j,k]*t[k])
                    else:
                        print('lambda zero')
                         
        if  max(likelihood_sum) != 0:  
            max_i = np.argmax(likelihood_sum)
            group_assignment[v] = max_i
              
        list_of_groups[max_i].append(v)
    
    # To cater the emptiness of a group 
    group_assignment = _split_the_empty_group(num_roles,list_of_groups,group_assignment)  
    
    # print(list_of_groups)     
    # print(group_assignment)
    return group_assignment

# Estimate_lamda
def estimate_lamda(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic,tuning_params=None):
    lambda_estimates = np.zeros((num_roles, num_roles,num_segments) , dtype=float)
    list_of_groups=  [[] for _ in range(num_roles)]

    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)
   
    # create dictionary to store interaction counts
    # ( no of interactions between ith and jth group within dth segment)              
    key_data = [
        range(0,num_roles),
        range(0,num_roles),
        range(0,num_segments),
    ]
    keys = list(itertools.product(*key_data))  
    i_j_d = {key: 0 for key in keys}           
              
    for key, val in dic.items():
        
        i=group_assignment.get(key[0])
        j=group_assignment.get(key[1])

        # change-points of (i,j) group pair
        chg_points = change_points_arr[i,j,:]
        
        ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
        ranges_arr[0][0]=0
        
        n = len(ranges_arr) 
        
        for item in val:
            
            d =  _findSegment(ranges_arr, n, int(item)) 
            
            # undirected
            if i>j:
                i,j=j,i  
                
            i_j_d[(i,j,d)] += 1    
        
    for k in range(0, num_roles):
        for g in range(k,num_roles):
                U=list_of_groups[k]
                W=list_of_groups[g]  
                
                size_all_pairs = 0
                if k == g:
                    size_all_pairs = math.comb(len(U), 2)
                if k != g:
                    size_all_pairs = len(U)*len(W)
                
                for d in range(0, num_segments):                    
                    inter_count = i_j_d[(k,g,d)] 
                    delta_t = (change_points_arr[k,g,d+1] - change_points_arr[k,g,d])
                    if d == 0:
                        delta_t += 1 
                    lambda_estimates[k,g,d] = (inter_count+tuning_params['theta'])/((delta_t*size_all_pairs)+tuning_params['eta'])
                    lambda_estimates[g,k,d] = lambda_estimates[k,g,d]
                            
                # print('{} {} {}'.format(k,g,lambda_estimates[k,g,:]) )
    return lambda_estimates

# Estimate (k-h) lamda
def estimate_lamda_kh(num_roles,num_segments,lambda_estimates,group_assignment,change_points_arr,dic,tuning_params=None):
    list_of_groups=  [[] for _ in range(num_roles)]
       
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)

    # create dictionary to store interaction counts
    # ( no of interactions between ith and jth group within dth segment)              
    key_data = [
        range(0,num_roles),
        range(0,num_roles),
        range(0,num_segments),
    ]
    keys = list(itertools.product(*key_data))  
    i_j_d = {key: 0 for key in keys}  
    
    for key, val in dic.items():
        
        i=group_assignment.get(key[0])
        j=group_assignment.get(key[1])

        # change-points of (i,j) group pair
        chg_points = change_points_arr[i,j,:]
        
        ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
        ranges_arr[0][0]=0
        
        n = len(ranges_arr) 
        
        for item in val:
            d =  _findSegment(ranges_arr, n, int(item)) 
            # undirected graph        
            if i>j:
                i,j=j,i   
                     
            i_j_d[(i,j,d)] += 1    
                  
    for k in range(0, num_roles):
        for g in range(k,num_roles):
  
            U=list_of_groups[k]
            W=list_of_groups[g]  
            
            size_all_pairs = 0
            if k == g:
                size_all_pairs = math.comb(len(U), 2)
            if k != g:
                size_all_pairs = len(U)*len(W)

            _current_lamda_val = lambda_estimates[k,g,:].tolist()                
            _unique_val = list(set(_current_lamda_val))
                        
            h = len(_unique_val)
            
            for i in range(0,h):
                
                grp =[]
                
                for  j in range(0, num_segments):
    
                    if math.isclose(_current_lamda_val[j], _unique_val[i]):
                        grp.append(j)
                        
                inter_count = 0
                delta_t = 0
                
                for d in grp:

                    delta_t += (change_points_arr[k,g,d+1] - change_points_arr[k,g,d])
                    
                    if d == 0:
                        delta_t += 1 
                    
                    inter_count += i_j_d[(k,g,d)]   
                                                       
                alpha = inter_count
                beeta = size_all_pairs*delta_t
                lamda = (alpha+ random.random()*tuning_params['theta'])/(beeta+tuning_params['eta'])
                
                for jj in grp:
                    lambda_estimates[k,g,jj] = lamda
                    lambda_estimates[g,k,jj] = lamda    
            # print('{} {} {}'.format(k,g,lambda_estimates[k,g,:]) )          
        
    return lambda_estimates

# Compute cost (edge by edge)
def com_cost(num_roles,num_segments,lamda_estimates,change_points_arr,group_assignment,dic):    

    
    list_of_groups=  [[] for _ in range(num_roles)]
       
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)
    
    i_j_d = {}
    
    for i in range(0, num_roles):
        for j in range(0, num_roles):
           for d in range(0, num_segments): 
               i_j_d[(i,j,d)] = 0
       
    for key, val in dic.items():
        
        i=group_assignment.get(key[0])
        j=group_assignment.get(key[1])
        
        if i>j:
            i,j=j,i
            
        a = change_points_arr[i,j,:]
        
        ranges_arr = [ [a[s]+1,a[s+1]] for s in range(0,len(a)-1)]
        ranges_arr[0][0]=0
        
        n = len(ranges_arr) 
        
        for item in val:

            d =  _findSegment(ranges_arr, n, int(item))                         
            i_j_d[(i,j,d)] += 1

 
    liklihood_sum = 0

    for d in range(0, num_segments):
        for k in range(0, num_roles):
            for g in range(k, num_roles):
    
                U=list_of_groups[k]
                W=list_of_groups[g]
    
                size_all_pairs = 0
                if k == g:
                    size_all_pairs = math.comb(len(U), 2)
                if k != g:
                    size_all_pairs = len(U)*len(W)
        
                alpha = (size_all_pairs * lamda_estimates[k,g,d])
                
                delta= change_points_arr[k,g,d+1]-change_points_arr[k,g,d]
                
                if d == 0:
                    delta += 1 
                
                if lamda_estimates[k,g,d] != 0:
                    liklihood_sum += (i_j_d[(k,g,d)]* math.log(lamda_estimates[k,g,d])- (alpha*delta)) 
    print('Likelihood sum: %f'%(liklihood_sum))  
        
         
    return   liklihood_sum  


    
# Compute cost 
def compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic):
    # print('compute cost...')
    
    list_of_groups=  [[] for _ in range(num_roles)]
 
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)

    # create dictionary to store interaction counts
    # ( no of interactions between ith and jth group within dth segment)              
    key_data = [
        range(0,num_roles),
        range(0,num_roles),
        range(0,num_segments),
    ]
    keys = list(itertools.product(*key_data))  
    i_j_d = {key: 0 for key in keys}  
               
    for key, val in dic.items():
        
        i=group_assignment.get(key[0])
        j=group_assignment.get(key[1])
        
        chg_points = change_points_arr[i,j,:]
        
        ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
        ranges_arr[0][0]=0
        
        n = len(ranges_arr) 
        # print('{} {}'.format(chg_points,ranges_arr))
        for item in val:
            d =  _findSegment(ranges_arr, n, int(item)) 
            # undirected
            if i>j:
                i,j=j,i  
                
            i_j_d[(i,j,d)] += 1 
            
    liklihood_sum = 0

    for k in range(0, num_roles):
        for g in range(k,num_roles):

            U=list_of_groups[k]
            W=list_of_groups[g]  
            
            size_all_pairs = 0
            if k == g:
                size_all_pairs = math.comb(len(U), 2)
            if k != g:
                size_all_pairs = len(U)*len(W)                    
            
            for d in range(0, num_segments):                    
                inter_count = i_j_d[(k,g,d)]
               
                delta_t = (change_points_arr[k,g,d+1] - change_points_arr[k,g,d])
                
                if d == 0:
                    delta_t += 1 
                
                if lambda_estimates[k,g,d] != 0:
                    liklihood_sum += (inter_count*math.log(lambda_estimates[k,g,d]) - size_all_pairs*lambda_estimates[k,g,d]*delta_t)
                # print('Likelihood sum: %d %d %f'%(k,g,liklihood_sum))  
            # print('Likelihood sum: %d %d %f'%(k,g,temp))
    # print('Likelihood sum: %f'%(liklihood_sum))     
    return liklihood_sum

#  Estimate change points ( Naive Dynamic programming)
def dyn_prog_seg(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic):      
        
    #create a dictionary : key:timestamp, value:list of interaction-group pairs
    time_group_pairs = {}
    
    # store time stamps
    time_stamps=[]
    
    # 'dic': key: node-pair, value: list of time stamps
    for key, val in dic.items():
         
        i=group_assignment.get(key[0])
        j=group_assignment.get(key[1])
        
        # undirected
        if i > j:
            i,j = j,i
    
        for item in val:  
    
            if  item not in time_group_pairs:
                time_group_pairs[item] = []                      
            time_group_pairs[(item)] += {(i,j)}  
            
            if item not in time_stamps:
                time_stamps.append(item)                                 
    # print(time_group_pairs)
    # print(time_stamps)
      
    # prepare timestamp grid
    time_stamps  = sorted(time_stamps) 
    # print(time_stamps_unique)
    n =  len(time_stamps)    
    grid = time_stamps 
            
    #cumulative edge count  dictionary for all timestamps
    cum_cnt = {}
    #cumulative edge count for instant
    total_cum_cnt = np.zeros(( num_roles, num_roles) , dtype=int) 
        
    #cumulative edge count
    for  ind, tmsp in enumerate(time_stamps):
        lst = time_group_pairs.get(tmsp)
    
        for key in lst:
            total_cum_cnt[key[0],key[1]] += 1
        
        for k in range(0, num_roles):
            for g in range(k, num_roles):
                cum_cnt[ind,k,g] = total_cum_cnt[k,g]                       
    # print(total_cum_cnt) 

    # group assignment related params      
    list_of_groups=  [[] for _ in range(num_roles)]       
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)
 
    size_all_pairs = {}
    for k in range(0, num_roles):
        for g in range(k, num_roles):
            U=list_of_groups[k]
            W=list_of_groups[g]

            if k == g:
                size_all_pairs[k,g] = math.comb(len(U), 2)
            if k != g:
                size_all_pairs[k,g] = len(U)*len(W)
       
    cost = {}
    startsfrom = np.zeros((n,num_segments) , dtype=int) 
    
    for ind,tmp in enumerate(grid):
        tot_likelihood = 0
        
        for k in range(0, num_roles):
            for g in range(k, num_roles):
                
                intercount = cum_cnt[ind,k,g]
                delta_t = tmp+1
                
                if tmp == 0:
                      delta_t = tmp
                list_likelihood = [-(intercount*math.log(lam) - (size_all_pairs[k,g] *lam*(delta_t)))  for lam in lambda_estimates[k,g,:]] 
                tot_likelihood += min(list_likelihood)
                
        cost[0,ind,0] = tot_likelihood 
               
        
    for h in range(1, num_segments):
        for i in range((h+1), n):
            cost[0,i,h] = math.inf
            startsfrom[i,h] = -1;
    
            for j in range(h,i):
                                            
                tot_likelihood = 0 
                
                for k in range(0, num_roles):
                    for g in range(k, num_roles):
                
                        inter_count =  cum_cnt[i,k,g] - cum_cnt[j,k,g]
        
                        # compute cost[j,i,0]
                        list_likelihood = [-(inter_count*math.log(lam) - (size_all_pairs[k,g] *lam*(grid[i]-grid[j])))  for lam in lambda_estimates[k,g,:]] 
                        tot_likelihood += min(list_likelihood)
                        
                cost[j,i,0]  =  tot_likelihood
                    
                if cost[0,j,h-1] + cost[j,i,0] <= cost[0,i,h]:
                    cost[0,i,h] = cost[0,j,h-1]+cost[j,i,0]                            
                    startsfrom[i,h] = j

    cost_tot = (cost[0,(n-1),(num_segments-1)])
    print(cost_tot)
        
    boundary_point_array=[grid[n-1]]    
    counter = num_segments-1
    boundary_point = n-1
    # b=[]    
                  
    while counter > -1:

        boundary_point = startsfrom[boundary_point, counter]
        boundary_point_array.append(grid[boundary_point])
        # b.append(boundary_point)
        counter -= 1
    # print(b)
    change_points_arr[:,:,:] = list(reversed(boundary_point_array))
    change_points_arr[:,:,:] = list(reversed(boundary_point_array))

    return change_points_arr

# Estimate change points : LINEAR algorithm ( Fast segmentation with SMAWK ) - Ver 2
def linear_seg_ver_2(num_roles,num_segments,group_assignment,lambda_estimates,change_points_arr,dic):
    # no of pairs of groups or roles
    size_all_pairs = int(num_roles*(num_roles + 1)/2)  
    # convert tensor to 2-d array
    l_array = np.zeros((size_all_pairs,num_segments) , dtype=float )
   
    cnt = 0    
    for z1 in range(0, num_roles):
        for z2 in range(z1, num_roles):           
            for d in range(0, num_segments):
                l_array[cnt,d]= lambda_estimates[z1,z2,d]
            # print(l_array[cnt,:])
            cnt += 1
    
    # find h-levels
    unique_cols = np.unique(l_array, axis=1) 
    # print(unique_cols)       
    h = len(unique_cols[0]) 
    # print(unique_cols[0])  
    # print('h levels initials inside seg: %d'%h)
    
    # h-level lamda estimates    
    lambda_estimates_h = np.zeros((num_roles, num_roles, h) , dtype=float )
    
    cnt = 0   
    
    for z1 in range(0, num_roles):
        for z2 in range(z1, num_roles):
            for d in range(0, h):
                lambda_estimates_h[z1,z2,d]= unique_cols[cnt,d]
                lambda_estimates_h[z2,z1,d]= unique_cols[cnt,d]
            cnt += 1

    time_edges = {}
    time_stamps=[]

    for key, val in dic.items():
    # print('{} {}'.format(key[0],key[1]))
        
        for item in val:  
            if  item not in time_edges:
                time_edges[item] = []                      
            time_edges[(item)] += {key}
            time_stamps.append(item)
 
    # find unique time-stamps
    # (there are cases where multiple edges occur at the same timestamp )
    time_stamps = sorted(time_stamps)
    time_stamps_unique =  sorted(list(set(time_stamps)))    

    n =  len(time_stamps_unique)
      
    startsfrom = np.zeros((n,num_segments) , dtype=int) 
    lamda_value = np.zeros((n,num_segments) , dtype=int )
    
    c = np.zeros((n, h) , dtype=float)
    o_ek = np.zeros((n, num_segments) , dtype=float)
    
    list_of_groups=  [[] for _ in range(num_roles)]       
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)        
    
    for d in range(0, h):
  
        alpha = 0
        cnt = 0

        for k in range(0, num_roles):
            for g in range(k, num_roles):
    
                U=list_of_groups[k]
                W=list_of_groups[g]
    
                size_all_pairs = 0
                if k == g:
                    size_all_pairs = math.comb(len(U), 2)
                if k != g:
                    size_all_pairs = len(U)*len(W)
                
                if  lambda_estimates_h[k,g,d] != 0:
                    alpha +=   (size_all_pairs * lambda_estimates_h[k,g,d])   
                elif not U or not W:
                    print('list is empty...')  
                             
        for  ind, tmsp in enumerate(time_stamps_unique):
            lst = time_edges.get(tmsp)

            for key in lst:
                
                i=group_assignment.get(key[0])
                j=group_assignment.get(key[1])
   
                cnt += math.log(lambda_estimates_h[i,j,d])
                
            c[ind,d] = cnt - alpha*tmsp
                                  
    for e in range(0, n):
        # print('{} : {} {}'.format(c[e,:],max(c[e,:]),np.argmax(c[e,:])))
        o_ek[e,0] = max(c[e,:])
        lamda_value[e,0] = np.argmax(c[e,:])
    
    
    LARGE_VAL = pow(10, 10)
    
    for kk in range(1, num_segments):
        max_m = np.zeros((n, h) , dtype=float)
        indices = np.zeros((n, num_segments) , dtype=int)
        
        for a in range(0, h):                       
            def lookup(i,j):
                x = -LARGE_VAL
                
                if not ((j < kk) | (j >= i)):                
                    x = o_ek[j,kk-1] + c[i,a] - c[j,a]
                              
                return x

            rows = list(range(0, n))
            cols =  list(range(0, n))
            
            col_set = utils.smawk(rows,cols,lookup)
            # col_set = utils.Maxcompute(rows,cols,lookup)
                      
            for edg in range(0, n): 
                indices[edg,a] = col_set[edg]
                max_m[edg,a] = lookup(edg,indices[edg,a])
    
        for e in range(0, n):    
            o_ek[e,kk] = max(max_m[e,:]) 
            lamda_value[e,kk] = np.argmax(max_m[e,:])
            startsfrom[e,kk] = indices[e,lamda_value[e,kk]]
        
    boundary_point_array=[n-1]           
    counter = num_segments-1
    boundary_point = n-1
                                      
    new_lambda_estimates = np.zeros((num_roles, num_roles,num_segments) , dtype=float)  
    # print(o_ek[boundary_point,counter])   
               
    while counter > -1:       
        # print(o_ek[boundary_point,counter]) 

        d= int(lamda_value[boundary_point,counter])
        # update new lamdas
        for i1 in range(0, num_roles):
            for i2 in range(i1, num_roles): 
                new_lambda_estimates[i1,i2,counter] = lambda_estimates_h[i1,i2,d] 
                new_lambda_estimates[i2,i1,counter] = lambda_estimates_h[i1,i2,d] 
        
        boundary_point = startsfrom[boundary_point, counter]        
        boundary_point_array.append(boundary_point)    
        counter -= 1
    
    # print(list(reversed(boundary_point_array)))
    b_list = []
    for b in boundary_point_array:
        b_list.append(time_stamps_unique[b])
        
    
    change_points_arr_ = list(reversed(b_list))
    change_points_arr[:,:,:] = change_points_arr_
    
    print(change_points_arr_)    
    lambda_estimates = new_lambda_estimates
    
    # Make sure h-levels do exist
    
    # convert tensor to 2-d array
    size_all_pairs = int(num_roles*(num_roles + 1)/2) 
    l_array_new = np.zeros((size_all_pairs,num_segments) , dtype=float )
   
    cnt = 0    
    for z1 in range(0, num_roles):
        for z2 in range(z1, num_roles):           
            for d in range(0, num_segments):
                l_array_new[cnt,d]= lambda_estimates[z1,z2,d]
            # print(l_array_new[cnt,:])
            cnt += 1
    
    # find h-levels
    unique_cols_new = np.unique(l_array_new, axis=1)        
    h_current = len(unique_cols_new[0]) 
    # print('new h levels: %d'%h_current)
    
    if h > h_current:
        print('No of h-levels does not satisfied... (h:{} , h-new:{})'.format(h, h_current))
        # print('compute cost before modifying lambdas...')
        # compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic)
        
        h_levels = unique_cols[0]
        current_g_mapping = []
        
        lst = l_array_new[0,:]
                
        for i in lst:            
            for idx in range(0,h):
                
                if i == h_levels[idx]:
                    current_g_mapping.append(idx)
        # print([ele for ele in range(0,h)]   ) 
        # print(current_g_mapping)   
        
        list_all =  [ele for ele in range(0,h)] 

        a = set(current_g_mapping)
        b = set(list_all)
        
        diff = list(b - a)        
        n = h - len(a)
        missing_list = random.sample(diff, n)
        # print(missing_list)
       
        duplicate_entries = [item for item, count in collections.Counter(current_g_mapping).items() if count > 1]
        # print(duplicate_entries)
        
        new_g_mapping = current_g_mapping.copy()
        # we have duplicate and missing list
        
        changes= []

        for idx,m_e in  enumerate(missing_list):
            # find index of duplicate entries
            if  idx < len(duplicate_entries) :
                ind = current_g_mapping.index(duplicate_entries[idx])
            else:
                new_duplicate_entries = [item for item, count in collections.Counter(new_g_mapping).items() if count > 1]
                # print(len(new_duplicate_entries))
                duplicate_entries.append(new_duplicate_entries[0])
                ind = new_g_mapping.index(duplicate_entries[idx])
            # update duplicate by a missing entry
            new_g_mapping[ind] = m_e
            # add (index,val) pair
            changes.append([ind,m_e])                
            
        # print(new_g_mapping )        
        # print(changes)
        
        # update lamda esimates
        # we have current h-level indices (segment, index) and missing h-values to be updated

        m_lambda_estimates =  copy.copy(lambda_estimates)
        
        for ele in changes:
            d = ele[0]
            val = ele[1]
            # print('{} ,{}'.format(d,val))
            # print(m_lambda_estimates)
            # print(lambda_estimates_h)
            for i1 in range(0, num_roles):
                for i2 in range(i1, num_roles): 
                    m_lambda_estimates[i1,i2,d] = lambda_estimates_h[i1,i2,val] 
                    m_lambda_estimates[i2,i1,d] = lambda_estimates_h[i1,i2,val]
                        
        # print(m_lambda_estimates)
        # print('after changing lambdas...')
        # com_cost(num_roles,num_segments,m_lambda_estimates,change_points_arr,group_assignment,dic)
        lambda_estimates = m_lambda_estimates       
    return [lambda_estimates,change_points_arr]

##################################

#level-dependent group membership
def mm_group_assignment_ver2_2(nodes,num_roles,num_segments,lambda_estimates,group_assignment,change_points_arr,dic,g_mapping):

    current_g_mapping = g_mapping
    level_seg_mapping  = {}

    for d in range(num_segments):
        level = current_g_mapping[d]

        if level in level_seg_mapping:
            level_seg_mapping[level].append(d)
        else:
            level_seg_mapping[level] = []
            level_seg_mapping[level].append(d)


    h =  len(set(g_mapping))
    print('h levels %d'%h)


    # Multidigraph object
    G = nx.MultiDiGraph()
    G.add_nodes_from(nodes)
    G.add_edges_from(list(dic))
    G = G.to_undirected()

    for i_h in range(0,h):

        l_seg =  level_seg_mapping[i_h]
        # print('i: {} seg_level_dic : {}'.format(i_h,l_seg))

        grp = group_assignment[i_h]

        list_of_groups=  [[] for _ in range(num_roles)]

        for idx, val in grp.items():
            list_of_groups[val].append(idx)

        t = 0

        for d in l_seg:
            delta_t =  change_points_arr[0,0,d+1] -  change_points_arr[0,0,d]

            if d == 0:
              delta_t +=  1

            t = t + delta_t

        ss= level_seg_mapping[i_h]

        # iterate over all nodes
        for v in nodes:
                    # remove v
            for g in range(0, num_roles):
                #Extract Group g
                _list_g = list_of_groups[g]

                if v in _list_g :
                    _list_g.remove(v)


            # store current group of v
            max_i = grp[v]

            edge_counts =  np.zeros(num_roles)
            # find neighbours of node v
            neighbour_nodes  = [n for n in G.neighbors(v)]
            # print('neignbours {}'.format(neighbour_nodes))

            # iterate over all neigbour nodes of v
            for neigh in neighbour_nodes:

                v1 = v
                v2 = neigh

                # order of tuple : undirected
                if v1 > v2:
                    v1, v2 = v2, v1

                # i=group_assignment.get(v1)
                jk=grp[neigh]

                # change-points of (i,j) group pair; in this case, equally partitioned
                chg_points = change_points_arr[0,0,:]

                ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
                ranges_arr[0][0] = 0

                # list of timestamps corresponding to (v1,v2) pair
                list_time_stamps = dic.get((v1,v2))

                # iterate over timestamps list
                for item in list_time_stamps:
                    d =  _findSegment(ranges_arr,  len(ranges_arr) , int(item))
                    # print(' {} {} {} '.format(d, chg_points, item))
                    if d in ss:
                        # print('actual d {} ss {}'.format(d, ss))
                        edge_counts[jk] += 1
            # print('h: {} , cnt : {}'.format(i_h, edge_counts))
            likelihood_sum = np.zeros(num_roles)


                # if node a belongs to group a
            for l in range(0, num_roles):

                likelihood_sum[l] = 0

                for b in range(0, num_roles):
                    W = list_of_groups[b]

                    # maximum possible number of  ways of interacting with node v
                    factor = len(W)

                    likelihood_sum[l] += (edge_counts[b]*math.log(lambda_estimates[l,b,i_h])  - factor*lambda_estimates[l,b,i_h]*t)
            #SANITY CHECK
            # if max_i  !=    np.argmax(likelihood_sum):
            #     # print('{} : {} {} : {} : '.format(max_i, np.argmax(likelihood_sum),max(likelihood_sum),likelihood_sum))
            #     l1 = mm_compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic,g_mapping)
            #     group_assignment[i_h][v]  = np.argmax(likelihood_sum)
            #     l2 = mm_compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic,g_mapping)
            #     p1= round(abs(l2-l1),2)
            #     p2= round(abs(likelihood_sum[0]-likelihood_sum[1]),2)
            #     if p1 != p2:
            #         print('p: {} : {} {} : {} : {} {}'.format(p1,p2, likelihood_sum,edge_counts,lambda_estimates[0,:,i_h],lambda_estimates[1,:,i_h] ))
            #         if l2 < l1:
            #             print('l : {} : {} {} : {} : '.format(max_i, np.argmax(likelihood_sum),max(likelihood_sum),likelihood_sum))

            max_i = np.argmax(likelihood_sum)
            group_assignment[i_h][v] = max_i
            list_of_groups[max_i].append(v)

            # mm_compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic,g_mapping)
    return group_assignment

# #level-dependent estimate (k-h) lamda
def mm_estimate_lamda_kh(num_roles,num_segments,lambda_estimates,group_assignment,change_points_arr,dic,g_mapping,tuning_params=None):
    # create dictionary to store interaction counts
    # ( no of interactions between ith and jth group within dth segment)
    key_data = [
        range(0,num_roles),
        range(0,num_roles),
        range(0,num_segments),
    ]
    keys = list(itertools.product(*key_data))
    i_j_d = {key: 0 for key in keys}

    current_g_mapping = g_mapping


    # print(lambda_estimates)
    level_seg_mapping  = {}
    for d in range(num_segments):
        level = current_g_mapping[d]
        if level in level_seg_mapping:
            level_seg_mapping[level].append(d)
        else:
            level_seg_mapping[level] = []
            level_seg_mapping[level].append(d)

    # change-points of (i,j) group pair
    chg_points = change_points_arr[0,0,:]

    ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
    ranges_arr[0][0]=0

    n = len(ranges_arr)


    for key, val in dic.items():

        for item in val:
            d =  _findSegment(ranges_arr, n, int(item))
            # l = current_g_mapping[d]
            g_a = group_assignment[current_g_mapping[d]]

            i=g_a.get(key[0]) #group of v1
            j=g_a.get(key[1]) #group of v2

            # undirected graph
            if i>j:
                i,j=j,i

            i_j_d[(i,j,d)] += 1

    # liklihood_sum = 0
    h = len(set(g_mapping))

    # iterate over levels
    for i in range(0,h):

        g_a = group_assignment[i]

        list_of_groups=  [[] for _ in range(num_roles)]
        for idx, val in g_a.items():
            list_of_groups[val].append(idx)

        for k in range(0, num_roles):
            for g in range(k,num_roles):


                U=list_of_groups[k]
                W=list_of_groups[g]

                size_all_pairs = 0
                if k == g:
                    size_all_pairs = math.comb(len(U), 2)
                if k != g:
                    size_all_pairs = len(U)*len(W)

                inter_count = 0
                delta_t = 0

                for d in level_seg_mapping[i]:
                    # print('i {} d {}'.format(i, d))

                    delta_t += (change_points_arr[k,g,d+1] - change_points_arr[k,g,d])

                    if d == 0:
                        delta_t += 1

                    inter_count += i_j_d[(k,g,d)]

                alpha = inter_count
                beeta = size_all_pairs*delta_t
                lamda = (alpha+ random.random()*tuning_params['theta'])/(beeta+tuning_params['eta'])

                lambda_estimates[k,g,i] = lamda
                lambda_estimates[g,k,i] = lamda

    return lambda_estimates


# Level-dependent : Estimate change points : LINEAR algorithm ( Fast segmentation with SMAWK ) - Ver 2
def mm_linear_seg_ver_2(num_roles,num_segments,group_assignment,lambda_estimates,change_points_arr,dic,g_mapping):


    current_g_mapping = g_mapping
    h = len(set(current_g_mapping))

    time_edges = {}
    time_stamps=[]

    for key, val in dic.items():
    # print('{} {}'.format(key[0],key[1]))

        for item in val:
            if  item not in time_edges:
                time_edges[item] = []
            time_edges[(item)] += {key}
            time_stamps.append(item)

    # find unique time-stamps
    # (there are cases where multiple edges occur at the same timestamp )
    time_stamps = sorted(time_stamps)
    time_stamps_unique =  sorted(list(set(time_stamps)))

    n =  len(time_stamps_unique)

    startsfrom = np.zeros((n,num_segments) , dtype=int)
    lamda_value = np.zeros((n,num_segments) , dtype=int )

    c = np.zeros((n, h) , dtype=float)
    o_ek = np.zeros((n, num_segments) , dtype=float)


    # iterate over levels
    for d in range(0, h):

        # level specific group assignements
        g_a = group_assignment[d]
        list_of_groups=  [[] for _ in range(num_roles)]
        for idx, val in g_a.items():
            list_of_groups[val].append(idx)

        alpha = 0
        cnt = 0

        for k in range(0, num_roles):
            for g in range(k, num_roles):

                U=list_of_groups[k]
                W=list_of_groups[g]

                size_all_pairs = 0
                if k == g:
                    size_all_pairs = math.comb(len(U), 2)
                if k != g:
                    size_all_pairs = len(U)*len(W)

                if  lambda_estimates[k,g,d] != 0:
                    alpha +=   (size_all_pairs * lambda_estimates[k,g,d])
                elif not U or not W:
                    print('list is empty...')

        for  ind, tmsp in enumerate(time_stamps_unique):
            lst = time_edges.get(tmsp)

            for key in lst:

                i=g_a.get(key[0]) #group of v1
                j=g_a.get(key[1]) #group of v2

                cnt += math.log(lambda_estimates[i,j,d])

            c[ind,d] = cnt - alpha*tmsp

    for e in range(0, n):
        # print('{} : {} {}'.format(c[e,:],max(c[e,:]),np.argmax(c[e,:])))
        o_ek[e,0] = max(c[e,:])
        lamda_value[e,0] = np.argmax(c[e,:])


    LARGE_VAL = pow(10, 10)

    for kk in range(1, num_segments):
        max_m = np.zeros((n, h) , dtype=float)
        indices = np.zeros((n, num_segments) , dtype=int)

        for a in range(0, h):
            def lookup(i,j):
                x = -LARGE_VAL

                if not ((j < kk) | (j >= i)):
                    x = o_ek[j,kk-1] + c[i,a] - c[j,a]

                return x

            rows = list(range(0, n))
            cols =  list(range(0, n))

            col_set = utils.smawk(rows,cols,lookup)
            # col_set = utils.Maxcompute(rows,cols,lookup)

            for edg in range(0, n):
                indices[edg,a] = col_set[edg]
                max_m[edg,a] = lookup(edg,indices[edg,a])

        for e in range(0, n):
            o_ek[e,kk] = max(max_m[e,:])
            lamda_value[e,kk] = np.argmax(max_m[e,:])
            startsfrom[e,kk] = indices[e,lamda_value[e,kk]]

    boundary_point_array=[n-1]
    counter = num_segments-1
    boundary_point = n-1

    # new_lambda_estimates = np.zeros((num_roles, num_roles,num_segments) , dtype=float)
    # print(o_ek[boundary_point,counter])
    new_g = np.zeros(  num_segments, dtype=int)
    while counter > -1:
        # print(o_ek[boundary_point,counter])

        d= int(lamda_value[boundary_point,counter])

        new_g[counter] = int(d)


        boundary_point = startsfrom[boundary_point, counter]
        boundary_point_array.append(boundary_point)
        counter -= 1

    b_list = []

    for b in boundary_point_array:
        b_list.append(time_stamps_unique[b])


    change_points_arr_ = list(reversed(b_list))
    change_points_arr[:,:,:] = change_points_arr_
    # print(change_points_arr_)

    # mm_compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic,g_mapping)
    h_current = len(set(list(new_g)))

    current_g_mapping =  list(new_g)
    # mm_compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic,current_g_mapping)
    print('h : {} , h_cur : {}'.format(h, h_current))
    if h > h_current:
        print('No of h-levels does not satisfied... (h:{} , h-new:{})'.format(h, h_current))


        list_all =  [ele for ele in range(0,h)]

        a = set(current_g_mapping)
        b = set(list_all)

        diff = list(b - a)
        n = h - len(a)
        missing_list = random.sample(diff, n)
        # print(missing_list)

        duplicate_entries = [item for item, count in collections.Counter(current_g_mapping).items() if count > 1]
        # print(duplicate_entries)

        new_g_mapping = copy.copy(current_g_mapping)
        # we have duplicate and missing list

        changes= []

        for idx,m_e in  enumerate(missing_list):
            # find index of duplicate entries
            if  idx < len(duplicate_entries) :
                ind = current_g_mapping.index(duplicate_entries[idx])
            else:
                new_duplicate_entries = [item for item, count in collections.Counter(new_g_mapping).items() if count > 1]
                # print(len(new_duplicate_entries))
                duplicate_entries.append(new_duplicate_entries[0])
                ind = new_g_mapping.index(duplicate_entries[idx])
            # update duplicate by a missing entry
            new_g_mapping[ind] = m_e
            # add (index,val) pair
            changes.append([ind,m_e])


        current_g_mapping=new_g_mapping

        print(current_g_mapping)
    return [change_points_arr,current_g_mapping,group_assignment]

# Level-dependent : Compute cost
def  mm_compute_cost(group_assignment,lambda_estimates,change_points_arr,num_roles,num_segments,dic,g_mapping):


    # create dictionary to store interaction counts
    # ( no of interactions between ith and jth group within dth segment)
    key_data = [
        range(0,num_roles),
        range(0,num_roles),
        range(0,num_segments),
    ]
    keys = list(itertools.product(*key_data))
    i_j_d = {key: 0 for key in keys}

    current_g_mapping = g_mapping

    level_seg_mapping  = {}
    for d in range(0,num_segments):

        level = current_g_mapping[d]
        # print(level_seg_mapping)
        if level in level_seg_mapping:
            # print(d)
            level_seg_mapping[level].append(d)
        else:
            # print(d)
            level_seg_mapping[level] = []
            level_seg_mapping[level].append(d)

    # change-points of (i,j) group pair
    chg_points = change_points_arr[0,0,:]

    ranges_arr = [ [chg_points[s]+1,chg_points[s+1]] for s in range(0,len(chg_points)-1)]
    ranges_arr[0][0]=0

    n = len(ranges_arr)


    for key, val in dic.items():

        for item in val:
            d =  _findSegment(ranges_arr, n, int(item))

            # l = current_g_mapping[d]
            g_a = group_assignment[g_mapping[d]]

            i=g_a.get(key[0]) #group of v1
            j=g_a.get(key[1]) #group of v2

            # undirected graph
            if i>j:
                i,j=j,i

            i_j_d[(i,j,d)] += 1

    liklihood_sum = 0
    h = len(set(g_mapping))

    # iterate over levels
    for i in range(0,h):
        g_a = group_assignment[i]

        list_of_groups=  [[] for _ in range(num_roles)]
        for idx, val in g_a.items():
            list_of_groups[val].append(idx)

        for k in range(0, num_roles):
            for g in range(k,num_roles):


                U=list_of_groups[k]
                W=list_of_groups[g]

                size_all_pairs = 0
                if k == g:
                    size_all_pairs = math.comb(len(U), 2)
                if k != g:
                    size_all_pairs = len(U)*len(W)


                inter_count = 0
                delta_t = 0

                for d in level_seg_mapping[i]:
                    # print('i {} d {}'.format(i, d))

                    delta_t += (change_points_arr[k,g,d+1] - change_points_arr[k,g,d])
                    # delta_t += (change_points_arr[0,0,d+1] - change_points_arr[0,0,d])
                    if d == 0:
                        delta_t += 1

                    inter_count += i_j_d[(k,g,d)]

                # jj = level_seg_mapping[i]
                lamda =  lambda_estimates[k,g,i]
                liklihood_sum += (inter_count*math.log(lamda) - size_all_pairs*lamda*delta_t)
    # print(liklihood_sum)            
    return liklihood_sum