Skip to content
Snippets Groups Projects
utils.py 5.6 KiB
Newer Older
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from datetime import datetime


# SMAWK utilities
def reduce(r,c,lookup):    
    
    n = len(r)
    m = len(c)
    
    if n == m :
        return c
    
    t = np.zeros((n) , dtype=int)
    a = 0
    b = 1
    t[0] = c[0]
        
    while True :
      
        if ( a < (n-1)) & (lookup(r[a], t[a]) >= lookup(r[a], c[b])) :
            a +=1
            t[a] = c[b]
            b +=1
           
        elif (a == (n-1)) & (lookup(r[a], t[a]) >= lookup(r[a], c[b])) :
            b +=1
            
        elif lookup(r[a], t[a]) < lookup(r[a], c[b]):
            
            if a>0:
                a = a - 1
            else:
                t[0] = c[b]
                b +=1
                       
        if  ((a+1+(m-b+1))<= n) or (b == m) : 
            break

                    
    while a < (n-1):
        a += 1 
        print('%d %d %d %d'%(a,b,m,n))
        t[a] = c[b]
        b += 1
        
        if (b == m):
            break

    return t

def Maxcompute(rows,cols,lookup):
    
    if not rows: return {}

    cols = reduce(rows,cols,lookup)

    result = Maxcompute([rows[i] for i in range(1,len(rows),2)],cols,lookup)
    
    # go back and fill in the even rows
    c = 0
    for r in range(0,len(rows),2):
        row = rows[r]
        
            
        if r == len(rows) - 1:
            cc = len(cols)-1  # if r is last row, search through last col
        else:
            cc = c            # otherwise only until pos of max in row r+1
            target = result[rows[r+1]]
            while cols[cc] != target:
                cc += 1
        result[row] = max([ (lookup(row,cols[x]),-x,cols[x]) \
                            for x in range(c,cc+1) ]) [2]
        c = cc

    return result

# Ref: David Eppstein's SMAWK Python code
def smawk(rows,cols,lookup):   
    # base case of recursion
    if not rows: return {}
 	
    # reduce phase: make number of columns at most equal to number of rows
    stack = []
    for c in cols:
        while len(stack) >= 1 and \
          lookup(rows[len(stack)-1],stack[-1]) < lookup(rows[len(stack)-1],c):
            stack.pop()
        if len(stack) != len(rows):
            stack.append(c)

    cols = stack

    # recursive call to search for every odd row
    result = smawk([rows[i] for i in range(1,len(rows),2)],cols,lookup)

    # go back and fill in the even rows
    c = 0
    for r in range(0,len(rows),2):
        row = rows[r]
        if r == len(rows) - 1:
            cc = len(cols)-1  # if r is last row, search through last col
        else:
            cc = c            # otherwise only until pos of max in row r+1
            target = result[rows[r+1]]
            while cols[cc] != target:
                cc += 1
        result[row] = max([ (lookup(row,cols[x]),-x,cols[x]) \
                            for x in range(c,cc+1) ]) [2]
        c = cc

    return result


# Graph processing
def getGraph(nodes_arr, temporal_graph):
    
    G=nx.Graph()
    
    for _vertex in nodes_arr:
        G.add_node(_vertex)
            
    for _edge in temporal_graph:
        G.add_edge(_edge[0], _edge[2])
    return G

# Generate Plots
def generate_plots(G,group_assignment,lambda_estimates,num_roles,num_segments,dest_folder,nodes,change_points_arr,t_df,refValue):
    
    color_map = []
    
    print('Number of nodes: {}'.format( len(nodes)))
    
    # supports for 6 clusters
    for node in G:
        if group_assignment[node] ==  0:
            color_map.append('blue')
        elif group_assignment[node] ==  1:
            color_map.append('red')    
        elif group_assignment[node] == 2:
            color_map.append('green')  
        elif group_assignment[node] == 3:
            color_map.append('yellow')  
        elif group_assignment[node] == 4:
            color_map.append('black') 
        else:
            color_map.append('white') 
    
    plt.figure(figsize=(10,10))
    nx.draw_spring(G, node_color=color_map, with_labels=True)
    _file_name = dest_folder+'spring.png'
    plt.savefig(_file_name)
    
    plt.figure(figsize=(10,10))
    nx.draw_random(G, node_color=color_map, with_labels=True)
    _file_name = dest_folder+'random.png'
    plt.savefig(_file_name)
    
    plt.figure(figsize=(10,10))
    nx.draw_circular(G, node_color=color_map, with_labels=True)
    _file_name = dest_folder+'circular.png'
    plt.savefig(_file_name)
       
    list_of_groups=  [[] for _ in range(num_roles)]
    for idx, val in group_assignment.items():
        list_of_groups[val].append(idx)
    print('group assignments: {}'.format(list_of_groups))

    print('lambdas....')
    for k in range(0, num_roles):
            for g in range(k,num_roles):
                print(lambda_estimates[k,g,:])
                _y_limit =  np.max(lambda_estimates[k,g,:])
                          
                fig, ax = plt.subplots()
                for d in range(0, num_segments): 
                                        
                    p = change_points_arr[k,g,d]                
                    q = change_points_arr[k,g,d+1]
                    plt.hlines(y=lambda_estimates[k,g,d], xmin=datetime.utcfromtimestamp(p+refValue), xmax=datetime.utcfromtimestamp(q+refValue))
                
                plt.xlabel('t')
                plt.title(r'$\lambda_{ %d%d}(t)$'%(k+1,g+1))
                         
                plt.ylim(0,_y_limit)
                plt.setp(ax.get_xticklabels(), rotation=30, horizontalalignment='right')  
                plt.autoscale()
                                
                _file_name = dest_folder+'/lamda'+str(k+1)+str(g+1)+'.pdf'
                plt.savefig(_file_name)