import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd 
from statannot.statannot import add_stat_annotation
from scipy import stats 
from tqdm import tqdm 

# Fixing random state for reproducibility
np.random.seed(2018)

sns.set(rc={'figure.figsize':(11, 8)}, font_scale=3, style="white")

categories = {
    'tumor': "Condition",
    'age': "Age",    
    'status': "Survival Status",
    'stage': "Clinical Stage",
    'histological': "Histological Type"
}
p_df = pd.DataFrame(columns=list(categories.values()))
for category in categories:
    arr_cat = []
    arr_exp = []
    arr_label = []
    fig, ax = plt.subplots()
    lines = [line.rstrip('\n') for line in open('../data/list/adaptive/{}.txt'.format(category))]
    for line in lines:
        res_df = pd.DataFrame(columns=['Category', 'Expression', 'Label'])
        df = pd.read_csv("../data/R3/Level- 3/categorization_miRNA(RPMlog2)/{}/data.csv".format(category), header=0, index_col=0)
        df1 = df.loc[df['label']==1]
        df2 = df.loc[df['label']!=1]
        df1 = df1.loc[:, line]
        df2 = df2.loc[:, line]
        # df1 = df1.apply(lambda x: np.power(x, 2))
        # df2 = df2.apply(lambda x: np.power(x, 2))
        arr_cat = [categories[category]] * (df1.shape[0] + df2.shape[0])
        arr_exp = df1.values.tolist() + df2.values.tolist()
        if category == "tumor":
            arr_label = ["Normal (Condition)"]*df1.shape[0] + ["Tumor (Condition)"]*df2.shape[0]
        elif category == "age":
            arr_label = ["Age < 60"]*df1.shape[0] + ["Age >= 60"]*df2.shape[0]
        elif category == "status":
            arr_label = ["Alive (Srv-Sta)"]*df1.shape[0] + ["Dead (Srv-Sta)"]*df2.shape[0]
        elif category == 'stage':
            arr_label = ["Group I (Cln-Stg)"]*df1.shape[0] + ["Group II (Cln-Stg)"]*df2.shape[0]
        else:
            arr_label = ["Group I (Hist-Typ)"]*df1.shape[0] + ["Group II (Hist-Typ)"]*df2.shape[0]
        res_df['Category'] = arr_cat
        res_df['Expression'] = arr_exp
        res_df['Label'] = arr_label
        # print(res_df) 
        stat, pval = stats.ttest_ind(a=df1.values, b=df2.values)
        p_df.loc[line, categories[category]] = pval
        if pval > 0.05:
            flag = False 
        # print(pval)
        g = sns.boxplot(x="Category", y="Expression", hue="Label",
                palette="Set2",
                data=res_df)
        sns.despine(left=True)
        g.set(xticklabels="")
        # g.set(yticklabels="")
        g.set(xlabel=line)
        g.set(ylabel="")
        plt.legend(title="p-value={:.2e}".format(pval), loc="right", bbox_to_anchor=(1,1))
        if pval < 0.05:
            plt.savefig('../Results/box/significant/{}_{}.pdf'.format(category, line))
        else:
            plt.savefig('../Results/box/insignificant/{}_{}.pdf'.format(category, line))
        plt.clf()
    plt.close(fig)
p_df.to_csv("../Results/box/pvals.csv")