# -*- coding: utf-8 -*-
"""
@author: Riku Laine

Script for creating the summary figures used in the paper.

Change 'path' variable below to the folder in your computer containg the
experiment results.
"""

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.transforms import Affine2D
from matplotlib.lines import Line2D

plt.rcParams.update({'font.size': 18})
plt.rcParams.update({'figure.figsize': (10, 4)})
plt.rcParams.update({'savefig.format': 'png'})
plt.rcParams.update({'savefig.dpi': '1200'})

path = "C:/Users/Riku_L/bachelors-thesis/data/result_files/"

### Draw summary figures.

for z_coef in ["1", "5"]:

    plt.rcParams.update({'font.size': 20})
    plt.rcParams.update({'figure.figsize': (12, 6)})

    fig, ax = plt.subplots()

    trans1 = Affine2D().translate(-0.1, 0.0) + ax.transData
    trans2 = Affine2D().translate(+0.1, 0.0) + ax.transData

    legend_elements = [Line2D([0], [0], color='r', label='CFBI'),
                       Line2D([0], [0], color='b', label='Contraction', ls='--')]

    ticks = []
    i = 0

    for deciderM in ["random", "batch", "independent"]:
        for deciderH in ["random", "batch", "independent"]:

            true = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_9coefZ" + z_coef + "_0_true_FRs.npy")
            contraction = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_9coefZ" + z_coef + "_0_contraction_FRs.npy")
            counterfactuals = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_9coefZ" + z_coef + "_0_counterfactuals_FRs.npy")

            # Counterfactuals
            y1 = np.abs(true - counterfactuals)
            ymean1 = np.nanmean(y1)
            yerr1 = np.nanstd(y1, ddof=1)

            # Contraction
            y2 = np.abs(true - contraction)
            ymean2 = np.nanmean(y2)
            yerr2 = np.nanstd(y2, ddof=1)

            # Plot errorbars
            er1 = ax.errorbar(i, ymean1, yerr=yerr1, fmt="o", transform=trans1, c='r')
            er2 = ax.errorbar(i, ymean2, yerr=yerr2, fmt="o", transform=trans2, c='b')

            # Set errorbar linestyle
            er2[-1][0].set_linestyle('--')

            ticks = np.append(ticks, "H: " + deciderH + "\nM: " + deciderM)

            i = i + 1

    # Customize xticks
    plt.xticks(np.arange(0, i), ticks, rotation=45)

    plt.ylabel("MAE w.r.t. True evaluation")
    plt.xlabel("Decision makers")
    plt.grid(axis='y')
    plt.ylim((-0.005, 0.08))
    plt.axhline(0, c="k", linestyle=":", lw=1)
    ax.legend(handles=legend_elements, loc="upper right", title="$\\bf{Evaluators}$")

    # Save manually.
    plt.savefig("summary_z" + z_coef + "_dpi", bbox_inches='tight')

    plt.show()
    plt.close("all")

# Boxplot version

# for z_coef in ["1", "5"]:

#     plt.rcParams.update({'font.size': 16})
#     plt.rcParams.update({'figure.figsize': (12, 6)})

#     fig, ax = plt.subplots()

#     trans1 = Affine2D().translate(-0.1, 0.0) + ax.transData
#     trans2 = Affine2D().translate(+0.1, 0.0) + ax.transData

#     legend_elements = [Line2D([0], [0], color='r', label='CFBI'),
#                        Line2D([0], [0], color='b', label='Contraction')]

#     ticks = []
#     i = 0

#     for deciderM in ["random", "batch", "independent"]:
#         for deciderH in ["random", "batch", "independent"]:

#             true = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_9coefZ" + z_coef + "_0_true_FRs.npy")
#             contraction = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_9coefZ" + z_coef + "_0_contraction_FRs.npy")
#             counterfactuals = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_9coefZ" + z_coef + "_0_counterfactuals_FRs.npy")

#             # Counterfactuals
#             y1 = np.abs(counterfactuals - true)

#             # Contraction
#             y2 = np.abs(contraction - true)

#             # Plot boxplots

#             er1 = ax.boxplot(y1[~np.isnan(y1)].flatten(), positions=[i-.1])

#             # Colors
#             for item in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
#                 plt.setp(er1[item], color="r")
#             plt.setp(er1["fliers"], markeredgecolor="r")

#             er2 = ax.boxplot(y2[~np.isnan(y2)].flatten(), positions=[i+.1])

#             # Colors
#             for item in ['boxes', 'whiskers', 'fliers', 'medians', 'caps']:
#                 plt.setp(er2[item], color="b")
#             plt.setp(er2["fliers"], markeredgecolor="b")

#             ticks = np.append(ticks, "H: " + deciderH + "\nM: " + deciderM)

#             i = i + 1

#     # Customize xticks
#     plt.xticks(np.arange(0, i), ticks, rotation=45)

#     plt.ylabel("MAE w.r.t. True evaluation")
#     plt.xlabel("Decision makers")
#     plt.grid(axis='y')
#     plt.ylim((-0.01, 0.15))
#     plt.axhline(0, c="k", linestyle=":", lw=1)
#     ax.legend(handles=legend_elements, loc="upper left", title="$\\bf{Evaluators}$")

#     # Save manually.

#     plt.show()
#     plt.close("all")


# ### Draw the single result figures for all different configurations.

# plt.close("all")

# for z_coef in ["1", "5"]:
#     for r in ["5", "9"]:
#         for deciderM in ["random", "batch"]:
#             for deciderH in ["random", "batch", "independent", "probabilistic"]:

#                 if z_coef == "5" and r == "5":
#                     continue

#                 true = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_true_FRs.npy")
#                 labeled = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_labeled_FRs.npy")
#                 contraction = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_contraction_FRs.npy")
#                 counterfactuals = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_counterfactuals_FRs.npy")

#                 failure_rates = np.zeros((8, 4))
#                 failure_stds = np.zeros((8, 4))

#                 failure_rates[:, 0] = np.nanmean(true, axis=0)
#                 failure_rates[:, 1] = np.nanmean(labeled, axis=0)
#                 failure_rates[:, 2] = np.nanmean(contraction, axis=0)
#                 failure_rates[:, 3] = np.nanmean(counterfactuals, axis=0)

#                 # Compute sample std
#                 failure_stds[:, 0] = np.nanstd(true, axis=0, ddof=1)
#                 failure_stds[:, 1] = np.nanstd(labeled, axis=0, ddof=1)
#                 failure_stds[:, 2] = np.nanstd(contraction, axis=0, ddof=1)
#                 failure_stds[:, 3] = np.nanstd(counterfactuals, axis=0, ddof=1)

#                 x_ax = np.arange(0.1, 0.9, 0.1)

#                 labels = ['True evaluation', 'Labeled outcomes', 'Contraction',
#                           'CFBI']

#                 colours = ['g', 'magenta', 'b', 'r']

#                 line_styles = ['--', ':', '-.', '-']

#                 # General plot: Failure rate vs. Acceptance rate
#                 for i in range(failure_rates.shape[1]):
#                     plt.errorbar(x_ax,
#                                  failure_rates[:, i],
#                                  label=labels[i],
#                                  c=colours[i],
#                                  linestyle=line_styles[i],
#                                  yerr=failure_stds[:, i])

#                 plt.xlabel('Acceptance rate')
#                 plt.ylabel('Failure rate')
#                 plt.legend(title="$\\bf{Evaluators}$")
#                 plt.grid()

#                 # plt.savefig(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_all")

#                 plt.show()


######

def baseFigure(true, labeled, contraction, counterfactuals, J):
    failure_rates = np.zeros((8, 4))
    failure_stds = np.zeros((8, 4))

    failure_rates[:, 0] = np.nanmean(true, axis=0)
    failure_rates[:, 1] = np.nanmean(labeled, axis=0)
    failure_rates[:, 2] = np.nanmean(contraction, axis=0)
    failure_rates[:, 3] = np.nanmean(counterfactuals, axis=0)

    failure_stds[:, 0] = np.nanstd(true, axis=0, ddof=1)
    failure_stds[:, 1] = np.nanstd(labeled, axis=0, ddof=1)
    failure_stds[:, 2] = np.nanstd(contraction, axis=0, ddof=1)
    failure_stds[:, 3] = np.nanstd(counterfactuals, axis=0, ddof=1)

    x_ax = np.arange(0.1, 0.9, 0.1)

    labels = ['True evaluation', 'Labeled outcomes', 'Contraction',
              'CFBI']

    colours = ['g', 'magenta', 'b', 'r']

    line_styles = ['--', ':', '-.', '-']

    for i in range(failure_rates.shape[1]):
        plt.errorbar(x_ax,
                     failure_rates[:, i],
                     label=labels[i],
                     c=colours[i],
                     linestyle=line_styles[i],
                     yerr=failure_stds[:, i])

    plt.xlabel('Acceptance rate')
    plt.ylabel('Failure rate')
    plt.legend(title="$\\bf{Evaluators}$")
    plt.grid()

    # plt.savefig(path + "sl_compas_nJudges" + J + "_all")

    plt.show()

    plt.close('all')

### Draw COMPAS result figures (FR VS AR).

# # J = 12
# true = np.load(path + "sl_sl_compas_nJudges12_true_FRs.npy")
# labeled = np.load(path + "sl_sl_compas_nJudges12_labeled_FRs.npy")
# contraction = np.load(path + "sl_sl_compas_nJudges12_contraction_FRs.npy")
# counterfactuals = np.load(path + "sl_sl_compas_nJudges12_counterfactuals_FRs.npy")

# baseFigure(true, labeled, contraction, counterfactuals, "12")

# # J = 24
# true = np.load(path + "sl_sl_compas_nJudges24_true_FRs.npy")
# labeled = np.load(path + "sl_sl_compas_nJudges24_labeled_FRs.npy")
# contraction = np.load(path + "sl_sl_compas_nJudges24_contraction_FRs.npy")
# counterfactuals = np.load(path + "sl_sl_compas_nJudges24_counterfactuals_FRs.npy")

# baseFigure(true, labeled, contraction, counterfactuals, "24")

# # J = 48
# true = np.load(path + "sl_sl_compas_nJudges48_true_FRs.npy")
# labeled = np.load(path + "sl_sl_compas_nJudges48_labeled_FRs.npy")
# contraction = np.load(path + "sl_sl_compas_nJudges48_contraction_FRs.npy")
# counterfactuals = np.load(path + "sl_sl_compas_nJudges48_counterfactuals_FRs.npy")

# baseFigure(true, labeled, contraction, counterfactuals, "48")

### Draw COMPAS result figure (FR error vs number of judges).

plt.close("all")

fig, ax = plt.subplots()

trans1 = Affine2D().translate(-0.1, 0.0) + ax.transData
trans2 = Affine2D().translate(+0.1, 0.0) + ax.transData

legend_elements = [Line2D([0], [0], color='r', label='CFBI'),
                   Line2D([0], [0], color='b', label='Contraction', ls='--')]

i=0
ticks = []

# J = 12
true = np.load(path + "sl_sl_compas_nJudges12_true_FRs.npy")
contraction = np.load(path + "sl_sl_compas_nJudges12_contraction_FRs.npy")
counterfactuals = np.load(path + "sl_sl_compas_nJudges12_counterfactuals_FRs.npy")

# Counterfactuals
y1 = np.abs(true - counterfactuals)
ymean1 = np.nanmean(y1)
yerr1 = np.nanstd(y1, ddof=1)

# Contraction
y2 = np.abs(true - contraction)
ymean2 = np.nanmean(y2)
yerr2 = np.nanstd(y2, ddof=1)

# Plot errorbars
er1 = ax.errorbar(i, ymean1, yerr=yerr1, fmt="o", transform=trans1, c='r')
er2 = ax.errorbar(i, ymean2, yerr=yerr2, fmt="o", transform=trans2, c='b')

# Set errorbar linestyle
er2[-1][0].set_linestyle('--')

ticks = np.append(ticks, "12")

i = i + 1

# J = 24
true = np.load(path + "sl_sl_compas_nJudges24_true_FRs.npy")
contraction = np.load(path + "sl_sl_compas_nJudges24_contraction_FRs.npy")
counterfactuals = np.load(path + "sl_sl_compas_nJudges24_counterfactuals_FRs.npy")

# Counterfactuals
y1 = np.abs(true - counterfactuals)
ymean1 = np.nanmean(y1)
yerr1 = np.nanstd(y1, ddof=1)

# Contraction
y2 = np.abs(true - contraction)
ymean2 = np.nanmean(y2)
yerr2 = np.nanstd(y2, ddof=1)

# Plot errorbars
er1 = ax.errorbar(i, ymean1, yerr=yerr1, fmt="o", transform=trans1, c='r')
er2 = ax.errorbar(i, ymean2, yerr=yerr2, fmt="o", transform=trans2, c='b')

# Set errorbar linestyle
er2[-1][0].set_linestyle('--')

ticks = np.append(ticks, "24")

i = i + 1

# J = 48
true = np.load(path + "sl_sl_compas_nJudges48_true_FRs.npy")
contraction = np.load(path + "sl_sl_compas_nJudges48_contraction_FRs.npy")
counterfactuals = np.load(path + "sl_sl_compas_nJudges48_counterfactuals_FRs.npy")

# Counterfactuals
y1 = np.abs(true - counterfactuals)
ymean1 = np.nanmean(y1)
yerr1 = np.nanstd(y1, ddof=1)

# Contraction
y2 = np.abs(true - contraction)
ymean2 = np.nanmean(y2)
yerr2 = np.nanstd(y2, ddof=1)

# Plot errorbars
er1 = ax.errorbar(i, ymean1, yerr=yerr1, fmt="o", transform=trans1, c='r')
er2 = ax.errorbar(i, ymean2, yerr=yerr2, fmt="o", transform=trans2, c='b')

# Set errorbar linestyle
er2[-1][0].set_linestyle('--')

ticks = np.append(ticks, "48")

i = i + 1

# Customize xticks
plt.xticks(np.arange(0, i), ticks)

plt.ylabel("Error w.r.t. True evaluation")
plt.xlabel("Number of judges")
plt.grid(axis='y')
plt.axhline(0, c="k", linestyle=":", lw=1)
ax.legend(handles=legend_elements, loc='upper left', title="$\\bf{Evaluators}$")

plt.savefig("sl_errors_compas_squeezed_dpi", bbox_inches='tight')

plt.show()

### Redraw fig 6 without erraneous tail.

plt.rcParams.update({'font.size': 18})
plt.figure(figsize=(10, 5.5))

deciderH = "independent"
deciderM = "batch"
r = "5"
z_coef = "1"

true = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_true_FRs.npy")
labeled = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_labeled_FRs.npy")
imputed = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_imputed_FRs.npy")
contraction = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_contraction_FRs.npy")
counterfactuals = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_counterfactuals_FRs.npy")

failure_rates = np.zeros((8, 5))
failure_stds = np.zeros((8, 5))

failure_rates[:, 0] = np.nanmean(true, axis=0)
failure_rates[:, 1] = np.nanmean(labeled, axis=0)
failure_rates[:, 2] = np.nanmean(imputed, axis=0)
failure_rates[:, 3] = np.nanmean(contraction, axis=0)
failure_rates[:, 4] = np.nanmean(counterfactuals, axis=0)

# Correction
failure_rates[5, 3] = np.nan

# Compute sample std
failure_stds[:, 0] = np.nanstd(true, axis=0, ddof=1)
failure_stds[:, 1] = np.nanstd(labeled, axis=0, ddof=1)
failure_stds[:, 2] = np.nanstd(imputed, axis=0, ddof=1)
failure_stds[:, 3] = np.nanstd(contraction, axis=0, ddof=1)
failure_stds[:, 4] = np.nanstd(counterfactuals, axis=0, ddof=1)

x_ax = np.arange(0.1, 0.9, 0.1)

labels = ['True evaluation', 'Labeled outcomes', 'Logistic regression',
          'Contraction', 'CFBI']

colours = ['g', 'magenta', 'darkmagenta', 'b', 'r']

l_styles = [':', '-.', '-.', '--', '-']

# General plot: Failure rate vs. Acceptance rate
for i in range(failure_rates.shape[1]):
    plt.errorbar(x_ax,
                 failure_rates[:, i],
                 label=labels[i],
                 c=colours[i],
                 yerr=failure_stds[:, i],
                 linestyle = l_styles[i])

plt.xlabel('Acceptance rate')
plt.ylabel('Failure rate')
plt.legend(title="$\\bf{Evaluators}$")
plt.grid()

plt.savefig("fig4_right_styled_dpi", bbox_inches='tight')

# Save manually.

plt.show()
plt.close('all')

##############

plt.rcParams.update({'font.size': 18})
plt.figure(figsize=(10, 5.5))

deciderH = "independent"
deciderM = "batch"
r = "9"
z_coef = "1"

true = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_true_FRs.npy")
labeled = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_labeled_FRs.npy")
imputed = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_imputed_FRs.npy")
contraction = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_contraction_FRs.npy")
counterfactuals = np.load(path + "_deciderH_" + deciderH + "_deciderM_" + deciderM + "_maxR_0_" + r + "coefZ" + z_coef + "_0_counterfactuals_FRs.npy")

failure_rates = np.zeros((8, 5))
failure_stds = np.zeros((8, 5))

failure_rates[:, 0] = np.nanmean(true, axis=0)
failure_rates[:, 1] = np.nanmean(labeled, axis=0)
failure_rates[:, 2] = np.nanmean(imputed, axis=0)
failure_rates[:, 3] = np.nanmean(contraction, axis=0)
failure_rates[:, 4] = np.nanmean(counterfactuals, axis=0)

# Compute sample std
failure_stds[:, 0] = np.nanstd(true, axis=0, ddof=1)
failure_stds[:, 1] = np.nanstd(labeled, axis=0, ddof=1)
failure_stds[:, 2] = np.nanstd(imputed, axis=0, ddof=1)
failure_stds[:, 3] = np.nanstd(contraction, axis=0, ddof=1)
failure_stds[:, 4] = np.nanstd(counterfactuals, axis=0, ddof=1)

x_ax = np.arange(0.1, 0.9, 0.1)

labels = ['True evaluation', 'Labeled outcomes', 'Logistic regression',
          'Contraction', 'CFBI']

colours = ['g', 'magenta', 'darkmagenta', 'b', 'r']

l_styles = [':', '-.', '-.', '--', '-']

# General plot: Failure rate vs. Acceptance rate
for i in range(failure_rates.shape[1]):
    plt.errorbar(x_ax,
                 failure_rates[:, i],
                 label=labels[i],
                 c=colours[i],
                 yerr=failure_stds[:, i],
                 linestyle = l_styles[i])

plt.xlabel('Acceptance rate')
plt.ylabel('Failure rate')
plt.legend(title="$\\bf{Evaluators}$")
plt.grid()

plt.savefig("fig4_left_styled_dpi", bbox_inches='tight')

# Save manually.

plt.show()
plt.close('all')