import pandas as pd 
import numpy as np 

def explode(df, columns):
    idx = np.repeat(df.index, df[columns[0]].str.len())
    a = df.T.reindex(columns).values
    concat = np.concatenate([np.concatenate(a[i]) for i in range(a.shape[0])])
    p = pd.DataFrame(concat.reshape(a.shape[0], -1).T, idx, columns)
    return pd.concat([df.drop(columns, axis=1), p], axis=1).reset_index(drop=True)

top_num = 5
subtypes = [
        'common', 
        'age', 
        'tumor', 
        'status', 
        'stage', 
        'histological']
hsa_tf_df = pd.read_excel("../data/hsa.xlsx", header=0)
tarbase_df = pd.read_csv("../data/tarbase.csv", header=0)
trrust_df = pd.read_excel("../data/TRRUST_DB_Gene_TF.xlsx", header=0)
for subtype in subtypes:
        print('\n\n{}'.format(subtype))

        mirnas_all = [line.rstrip('\n') for line in open('../data/list/adaptive/{}.txt'.format(subtype))]
        tarbase_df = pd.read_excel("../Results/KEGG/adaptive/KEGG_{}.xlsx".format(subtype), sheet_name="Top_{}".format(top_num), header=0)

        tfs = []
        genes = []
        mirnas = []

        res_source = []
        res_target = []
        res_source_type = []
        res_target_type = []

        mirnas_precursor = mirnas_all.copy()
        print("Total miRNA: {}".format(len(mirnas_all)))
        precursor_dict = {}
        for i in range(len(mirnas_precursor)):
                if mirnas_precursor[i].count('-') > 2:
                        mirnas_precursor[i] = mirnas_precursor[i].rsplit('-', 1)[0]
                        mirnas_precursor[i] = mirnas_precursor[i].lower()
                precursor_dict[mirnas_precursor[i]] = mirnas_all[i]
        print("HSA Common miRNAs: {}".format(hsa_tf_df.query("miRNA in @mirnas_precursor").drop_duplicates(subset=['miRNA']).shape[0]))
        sort_df = tarbase_df.query('(miRNA in @mirnas_all)')
        sort_df = sort_df.loc[:, ['miRNA', 'Gene']]  
        sort_df = sort_df.drop_duplicates()      
        sort_df.to_csv("../Results/Network/mirna-gene/{}_mirna_gene.csv".format(subtype), index=False)
        mirna_genes = sort_df['Gene']
        print("Genes from miRNA = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['Gene']).shape))

        res_source += sort_df.iloc[:, 0].values.tolist()
        res_target += sort_df.iloc[:, 1].values.tolist()
        res_source_type += [1]*sort_df.shape[0]
        res_target_type += [2]*sort_df.shape[0]

        sort_df = hsa_tf_df.query('(miRNA in @mirnas_precursor)')
        sort_df = sort_df.loc[:, ['miRNA', 'TF']]  
        sort_df['miRNA (Mature)'] = sort_df['miRNA'].map(precursor_dict)
        sort_df = sort_df.drop_duplicates()
        sort_df.to_csv("../Results/Network/mirna-tf/{}mirna_tf.csv".format(subtype), index=False)
        all_tfs = sort_df['TF'].tolist()
        print("TFs from miRNA = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['TF']).shape))

        sort_df = trrust_df.query('(TF in @all_tfs)')
        sort_df = sort_df.loc[:, ['TF', 'Gene']] 
        sort_df = trrust_df.query('(Gene in @mirna_genes)')
        sort_df = sort_df.drop_duplicates()
        print("Genes from TF = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['Gene']).shape))
        sort_df.to_csv("../Results/Network/tf-gene/{}.csv".format(subtype), index=False)

        res_source += sort_df.iloc[:, 0].values.tolist()
        res_target += sort_df.iloc[:, 1].values.tolist()
        res_source_type += [3]*sort_df.shape[0]
        res_target_type += [2]*sort_df.shape[0]

        overlap_genes = np.intersect1d(mirna_genes, sort_df['Gene'])
        sort_df = tarbase_df.query('(miRNA in @mirnas_all) & (Gene in @overlap_genes)')
        sort_df = sort_df.loc[:, ['Gene', 'miRNA']]
        sort_df = sort_df.drop_duplicates()
        print("Overlap genes = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['Gene']).shape))
        print("Common miRNAs = {}".format(sort_df.drop_duplicates(subset=['miRNA']).shape))
        sort_df.to_csv("../Results/Network/Intersection/gene-mirna/{}.csv".format(subtype), index=False)
        common_mirnas = sort_df["miRNA"].unique().tolist()

        common_mirnas_44_precursor = common_mirnas.copy()
        common_precursor_dict = {}
        for i in range(len(common_mirnas_44_precursor)):
                if common_mirnas_44_precursor[i].count('-') > 2:
                        common_mirnas_44_precursor[i] = common_mirnas_44_precursor[i].rsplit('-', 1)[0]
                        common_mirnas_44_precursor[i] = common_mirnas_44_precursor[i].lower()
                common_precursor_dict[common_mirnas_44_precursor[i]] = common_mirnas[i]

        sort_df = hsa_tf_df.query('(miRNA in @common_mirnas_44_precursor)')
        sort_df = sort_df.loc[:, ['miRNA', 'TF']]
        sort_df['miRNA (Mature)'] = sort_df['miRNA'].map(common_precursor_dict)
        sort_df = sort_df.drop_duplicates()
        print("TFs by miRNAs from overlap genes = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['TF']).shape))
        sort_df.to_csv("../Results/Network/Intersection/mirna-tf/{}.csv".format(subtype), index=False)

        overlap_trrust_df = trrust_df.query('(Gene in @overlap_genes)')
        overlap_tf = overlap_trrust_df.drop_duplicates("TF")
        overlap_tf = overlap_tf['TF'].values  
        sort_df = hsa_tf_df.query('(miRNA in @common_mirnas_44_precursor)')
        sort_df = sort_df.loc[:, ['miRNA', 'TF']]
        sort_df['miRNA (Mature)'] = sort_df['miRNA'].map(common_precursor_dict)        
        sort_df = sort_df.query('(TF in @overlap_tf)')
        sort_df = sort_df.drop_duplicates()
        print("Reduced TFs by miRNAs from overlap genes = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['TF']).shape))
        sort_df.to_csv("../Results/Network/Intersection/mirna-tf-overlap/{}.csv".format(subtype), index=False)
        res_source += sort_df.iloc[:, 1].values.tolist()
        res_target += sort_df.iloc[:, 2].values.tolist()
        res_source_type += [3]*sort_df.shape[0]
        res_target_type += [1]*sort_df.shape[0]

        res_df = pd.DataFrame({
                "Source": res_source,
                "Target": res_target,
                "Source Type": res_source_type,
                "Target Type": res_target_type
        }, columns=["Source", "Target", "Source Type", "Target Type"])
        res_df.to_csv("../Results/Network/cytoscape_{}.csv".format(subtype), index=False)