import pandas as pd 
import numpy as np 

def explode(df, columns):
    idx = np.repeat(df.index, df[columns[0]].str.len())
    a = df.T.reindex(columns).values
    concat = np.concatenate([np.concatenate(a[i]) for i in range(a.shape[0])])
    p = pd.DataFrame(concat.reshape(a.shape[0], -1).T, idx, columns)
    return pd.concat([df.drop(columns, axis=1), p], axis=1).reset_index(drop=True)

subtypes = ["44", "LA", "LB", "H2", "BL"]
hsa_tf_df = pd.read_excel("../data/hsa.xlsx", header=0)
trrust_df = pd.read_excel("../data/TRRUST_DB_Gene_TF.xlsx", header=0)
for subtype in subtypes:
        print('\n\n{}'.format(subtype))

        mirnas_44 = [line.rstrip('\n') for line in open('../data/miRNA_list/1star{}.txt'.format(subtype))]
        tarbase_df = pd.read_excel("../Results/tarbase_genes_KEGG/{}.xlsx".format(subtype), sheet_name="Top_10", header=0)
        final_df = pd.DataFrame(columns=['miRNA', 'Gene'])

        tfs = []
        genes = []
        mirnas = []

        mirnas_44_precursor = mirnas_44.copy()
        precursor_dict = {}
        for i in range(len(mirnas_44_precursor)):
                if mirnas_44_precursor[i].count('-') > 2:
                        mirnas_44_precursor[i] = mirnas_44_precursor[i].rsplit('-', 1)[0]
                        mirnas_44_precursor[i] = mirnas_44_precursor[i].lower()
                precursor_dict[mirnas_44_precursor[i]] = mirnas_44[i]

        sort_df = tarbase_df.query('(miRNA in @mirnas_44)')
        sort_df = sort_df.loc[:, ['miRNA', 'Gene']]        
        # sort_df.to_csv("../Results/Network/mirna-gene/{}_mirna_gene.csv".format(subtype), index=False)
        mirna_genes = sort_df['Gene']
        print("Genes from miRNA = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['Gene']).shape))

        sort_df = hsa_tf_df.query('(miRNA in @mirnas_44_precursor)')
        sort_df = sort_df.loc[:, ['miRNA', 'TF']]  
        sort_df['miRNA (Mature)'] = sort_df['miRNA'].map(precursor_dict)
        # sort_df.to_csv("../Results/Network/mirna-tf/{}mirna_tf.csv".format(subtype), index=False)
        all_tfs = sort_df['TF'].tolist()
        sort_df = sort_df.drop_duplicates()
        print("TFs from miRNA = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['TF']).shape))

        sort_df = trrust_df.query('(TF in @all_tfs)')
        sort_df = sort_df.loc[:, ['TF', 'Gene']] 
        sort_df = sort_df.drop_duplicates()
        print("Genes from TF = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['Gene']).shape))
        # sort_df.to_csv("../Results/Network/tf-gene/{}.csv".format(subtype), index=False)

        overlap_genes = np.intersect1d(mirna_genes, sort_df['Gene'])
        sort_df = tarbase_df.query('(miRNA in @mirnas_44) & (Gene in @overlap_genes)')
        sort_df = sort_df.loc[:, ['Gene', 'miRNA']]
        sort_df = sort_df.drop_duplicates()
        print("Overlap genes = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['Gene']).shape))
        print("Common miRNAs = {}".format(sort_df.drop_duplicates(subset=['miRNA']).shape))
        sort_df.to_csv("../Results/Network/Intersection/gene-mirna_top10/{}.csv".format(subtype), index=False)
        common_mirnas = sort_df["miRNA"].unique().tolist()

        common_mirnas_44_precursor = common_mirnas.copy()
        common_precursor_dict = {}
        for i in range(len(common_mirnas_44_precursor)):
                if common_mirnas_44_precursor[i].count('-') > 2:
                        common_mirnas_44_precursor[i] = common_mirnas_44_precursor[i].rsplit('-', 1)[0]
                        common_mirnas_44_precursor[i] = common_mirnas_44_precursor[i].lower()
                common_precursor_dict[common_mirnas_44_precursor[i]] = common_mirnas[i]

        sort_df = hsa_tf_df.query('(miRNA in @common_mirnas_44_precursor)')
        sort_df = sort_df.loc[:, ['miRNA', 'TF']]
        sort_df['miRNA (Mature)'] = sort_df['miRNA'].map(common_precursor_dict)
        sort_df = sort_df.drop_duplicates()
        print("TFs by miRNAs from overlap genes = {}, unique = {}".format(sort_df.shape, sort_df.drop_duplicates(subset=['TF']).shape))
        sort_df.to_csv("../Results/Network/Intersection/mirna-tf_top10/{}.csv".format(subtype), index=False)

        overlap_gene_tf = np.intersect1d(mirna_genes, sort_df['TF'])
        print("Overlap genes = {}, unique = {}".format(overlap_gene_tf.shape, np.unique(overlap_gene_tf).shape))
        
        # break 
