Skip to content
Snippets Groups Projects
MultinomialNBXGBoost.ipynb 419 KiB
Newer Older
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>C&gt;A</td>\n",
       "      <td>ACA</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>C&gt;A</td>\n",
       "      <td>ACC</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 9693 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "  Mutation type Trinucleotide  ALL::TARGET-10-PAIXPH-03A-01D  \\\n",
       "0           C>A           ACA                              0   \n",
       "1           C>A           ACC                              0   \n",
       "\n",
       "   ALL::TARGET-10-PAKHZT-03A-01R  ALL::TARGET-10-PAKMVD-09A-01D  \\\n",
       "0                              0                              0   \n",
       "1                              0                              0   \n",
       "\n",
       "   ALL::TARGET-10-PAKSWW-03A-01D  ALL::TARGET-10-PALETF-03A-01D  \\\n",
       "0                              1                              0   \n",
       "1                              1                              0   \n",
       "\n",
       "   ALL::TARGET-10-PALLSD-09A-01D  ALL::TARGET-10-PAMDKS-03A-01D  \\\n",
       "0                              0                              0   \n",
       "1                              0                              0   \n",
       "\n",
       "   ALL::TARGET-10-PAPJIB-04A-01D  ...  Head-SCC::V-109  Head-SCC::V-112  \\\n",
       "0                              2  ...                0                0   \n",
       "1                              0  ...                1                0   \n",
       "\n",
       "   Head-SCC::V-116  Head-SCC::V-119  Head-SCC::V-123  Head-SCC::V-124  \\\n",
       "0                0                0                0                0   \n",
       "1                0                0                0                0   \n",
       "\n",
       "   Head-SCC::V-125  Head-SCC::V-14  Head-SCC::V-29  Head-SCC::V-98  \n",
       "0                0               0               0               1  \n",
       "1                0               1               0               0  \n",
       "\n",
       "[2 rows x 9693 columns]"
      ]
     },
     "execution_count": 418,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "other_wes_mut = pd.read_csv(\"./project_data/catalogs/WES/WES_Other.96.csv\")\n",
    "other_wes_mut.head(2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 419,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Cancer Types</th>\n",
       "      <th>Sample Names</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>SBS1</th>\n",
       "      <th>SBS2</th>\n",
       "      <th>SBS3</th>\n",
       "      <th>SBS4</th>\n",
       "      <th>SBS5</th>\n",
       "      <th>SBS6</th>\n",
       "      <th>SBS7a</th>\n",
       "      <th>...</th>\n",
       "      <th>SBS51</th>\n",
       "      <th>SBS52</th>\n",
       "      <th>SBS53</th>\n",
       "      <th>SBS54</th>\n",
       "      <th>SBS55</th>\n",
       "      <th>SBS56</th>\n",
       "      <th>SBS57</th>\n",
       "      <th>SBS58</th>\n",
       "      <th>SBS59</th>\n",
       "      <th>SBS60</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>ALL</td>\n",
       "      <td>TARGET-10-PAIXPH-03A-01D</td>\n",
       "      <td>0.529</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>ALL</td>\n",
       "      <td>TARGET-10-PAKHZT-03A-01R</td>\n",
       "      <td>0.696</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>2 rows × 68 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "  Cancer Types              Sample Names  Accuracy  SBS1  SBS2  SBS3  SBS4  \\\n",
       "0          ALL  TARGET-10-PAIXPH-03A-01D     0.529     0     0     0     0   \n",
       "1          ALL  TARGET-10-PAKHZT-03A-01R     0.696     0     0     0     0   \n",
       "\n",
       "   SBS5  SBS6  SBS7a  ...  SBS51  SBS52  SBS53  SBS54  SBS55  SBS56  SBS57  \\\n",
       "0     0     0      0  ...      0      0      0      1      0      0      0   \n",
       "1     0     0      0  ...      0      0      0      1      0      0      0   \n",
       "\n",
       "   SBS58  SBS59  SBS60  \n",
       "0      0      0      0  \n",
       "1      0      0      0  \n",
       "\n",
       "[2 rows x 68 columns]"
      ]
     },
     "execution_count": 419,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "other_wes_act = pd.read_csv(\"./project_data/activities/WES/WES_Other.activities.csv\")\n",
    "other_wes_act.head(2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports and helpers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 420,
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sklearn\n",
    "from sklearn.decomposition import PCA\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "#import torch \n",
    "\n",
    "from sklearn.ensemble import RandomForestClassifier\n",
    "\n",
    "from sklearn.metrics import accuracy_score\n",
    "from sklearn.metrics import roc_auc_score\n",
    "from sklearn.metrics import roc_curve\n",
    "from sklearn.metrics import classification_report\n",
    "\n",
    "from sklearn.model_selection import cross_val_score, train_test_split, KFold\n",
    "from sklearn.model_selection import StratifiedShuffleSplit\n",
    "from sklearn.model_selection import StratifiedKFold, GridSearchCV\n",
    "from sklearn.model_selection import learning_curve\n",
    "\n",
    "from sklearn.preprocessing import LabelEncoder\n",
    "\n",
    "# These ones are work in progress\n",
    "def plot_roc_auc(X_tst, y_test, model, is_multi_class=False):\n",
    "    probs = model.predict_proba(X_tst)\n",
    "    probs = probs[:, 1]\n",
    "    if is_multi_class:\n",
    "        auc = roc_auc_score(y_test, probs, multi_class='ovo')\n",
    "    else:\n",
    "        auc = roc_auc_score(y_test, probs, multi_class='ovo')\n",
    "    \n",
    "    fp_rate, tp_rate, thresholds = roc_curve(y_test, probs)\n",
    "    \n",
    "    plt.figure(figsize=(7,6))\n",
    "    plt.axis('scaled')\n",
    "    plt.xlim([0,1])\n",
    "    plt.ylim([0,1])\n",
    "    plt.title(\"AUC & ROC\")\n",
    "    plt.plot(fp_rate, tp_rate, 'g')\n",
    "    plt.fill_between(fp_rate, tp_rate, facecolor = \"green\", alpha = 0.7)\n",
    "    plt.text(0.95, 0.05, f'AUC = {auc}', ha='right', fontsize=12, weight='bold', color='blue')\n",
    "    plt.xlabel(\"False Positive Rate\")\n",
    "    plt.ylabel(\"True Positive Rate\")\n",
    "\n",
    "def plot_confusion_mat(y_test, y_pred, labs=None, size=None):\n",
    "    cm = sklearn.metrics.confusion_matrix(y_test, y_pred)\n",
    "    if size is None:\n",
    "        plt.figure(figsize=(12,10))\n",
    "    else:\n",
    "        plt.figure(figsize=size)\n",
    "    if labs is None:\n",
    "        sns.heatmap(cm, square=False, annot=True, fmt='d', cmap='viridis', cbar=True)\n",
    "    else:\n",
    "        sns.heatmap(cm, square=False, annot=True, fmt='d', cmap='viridis', xticklabels=labs, yticklabels=labs, cbar=True)\n",
    "    plt.xlabel('Predicted label')\n",
    "    plt.ylabel('True label')\n",
    "    #plt.ylim(0, 2)\n",
    "\n",
    "def plot_learning_curve(model, X, y):\n",
    "    N, train_lc, val_lc = learning_curve(model, X, y, cv=7, train_sizes=np.linspace(0.3, 1, 25))\n",
    "    plt.figure(figsize=(7,6))\n",
    "    plt.title(\"Learning curve\")\n",
    "    plt.plot(N, np.mean(train_lc, 1), color='blue', label='training score')\n",
    "    plt.plot(N, np.mean(val_lc, 1), color='red', label='validation score')\n",
    "    #plt.hlines(N, np.mean([train_lc[-1],  val_lc[-1]]), N[0], N[-1], color='gray', label='mean', linestyle='dashed')\n",
    "\n",
    "def plot_trn_tst_dist(y_all, y_train, y_test, y_pred, in_cols=False):\n",
    "    #fig = None\n",
    "    #ax = None\n",
    "    if in_cols:\n",
    "        fig, ax = plt.subplots(2,2)\n",
    "    else:\n",
    "        fig, ax = plt.subplots(4,1)\n",
    "\n",
    "    fig.set_size_inches(15,8)\n",
    "\n",
    "    plt_sets = [y_all, y_train, y_test, y_pred]\n",
    "    plt_labels = [\"All\", \"Train\", \"Test\", \"Pred\"]\n",
    "    plt_set_df = pd.DataFrame()\n",
    "    for i in range(len(plt_sets)):\n",
    "        s = pd.Series(plt_sets[i]).value_counts().sort_index()\n",
    "        plt_set_df[plt_labels[i]] = s\n",
    "    \n",
    "        pd.DataFrame({plt_labels[i]: s}).plot(ax=ax.flat[i], kind=\"bar\")\n",
    "        #sns.countplot(x=s, \n",
    "        #            palette=sns.hls_palette(2),\n",
    "        #            ax=ax[i])\n",
    "        ax.flat[i].tick_params(axis=\"x\", rotation=90)\n",
    "\n",
    "    fig.tight_layout()\n",
    "    with pd.option_context('display.max_rows', None,\n",
    "                       'display.max_columns', None,\n",
    "                       'display.precision', 2,\n",
    "                       ):\n",
    "        print(plt_set_df)\n",
    "\n",
    "\n",
    "   \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset preprocess, combine profile data to a single data frame\n",
    "\n",
    "From all profile sets, a combined data frame is made, which has samples in the rows and features in the columns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 421,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Profile data:\n",
      "\n",
      "---Data set diagnostics print---\n",
      "\n",
      "Missing entries in mutations: 0\n",
      "The shape of the mutations data frame (20343, 97)\n",
      "Checking normalization: sum of some rows:\n",
      " cancer::TK74_LCIS2    1.0\n",
      "RCC::TCGA             1.0\n",
      "AdenoCA::TCGA         1.0\n",
      "AdenoCA::TCGA         1.0\n",
      "Melanoma::TCGA        1.0\n",
      "Tumor counts:\n",
      " AdenoCA      7712\n",
      "SCC          2188\n",
      "cancer       1639\n",
      "HCC          1318\n",
      "Melanoma     1231\n",
      "BNHL          822\n",
      "RCC           775\n",
      "GBM           605\n",
      "Medullo       557\n",
      "CA            462\n",
      "cell          389\n",
      "CMDI          357\n",
      "LGG           326\n",
      "CLL           302\n",
      "Papillary     297\n",
      "neoplasm      247\n",
      "ALL           240\n",
      "Ewings        231\n",
      "AML           230\n",
      "30            212\n",
      "bone          203\n",
      "Name: tumor_types, dtype: int64\n",
      "\n",
      "\n",
      "Tumor types with smallish counts: 7\n",
      "Papillary    297\n",
      "neoplasm     247\n",
      "ALL          240\n",
      "Ewings       231\n",
      "AML          230\n",
      "30           212\n",
      "bone         203\n",
      "Name: tumor_types, dtype: int64\n",
      "Unique tumor types:  21\n",
      "['30', 'ALL', 'AML', 'AdenoCA', 'BNHL', 'CA', 'CLL', 'CMDI', 'Ewings', 'GBM', 'HCC', 'LGG', 'Medullo', 'Melanoma', 'Papillary', 'RCC', 'SCC', 'bone', 'cancer', 'cell', 'neoplasm']\n"
     ]
    }
   ],
   "source": [
    "\n",
    "def prepare_mut_df(raw_mutation_dfs, is_profile, small_sample_limit=None):\n",
    "\n",
    "    mutations_all = pd.DataFrame()\n",
    "\n",
    "    for df in raw_mutation_dfs:\n",
    "        # Make a copy of the original data frame and start processing from there\n",
    "        mutations  = df.copy()\n",
    "    \n",
    "        if is_profile:\n",
    "            mutations['mut_tri'] = mutations.apply(lambda a: '{}_{}'.format(a['Mutation type'], a['Trinucleotide']), axis=1)\n",
    "            mutations = mutations.set_index('mut_tri').drop(['Mutation type', 'Trinucleotide'], axis=1)\n",
    "            mutations = mutations.T\n",
    "        else:\n",
    "            mutations['mut_tri'] = mutations.apply(lambda a: '{}::{}'.format(a['Cancer Types'], a['Sample Names']), axis=1)\n",
    "            mutations = mutations.set_index('mut_tri').drop(['Cancer Types', 'Sample Names', 'Accuracy'], axis=1)\n",
    "     \n",
    "        # Rename some index names\n",
    "        renamed_items = list(mutations.index)\n",
    "        index_items = list(mutations.index)\n",
    "\n",
    "        # Combine rows for low count labels\n",
    "        for i in range(len(index_items)):\n",
    "            result = index_items[i]\n",
    "            parts = result.split('-')\n",
    "            if len(parts) > 1:\n",
    "                result = parts[1]\n",
    "            else:\n",
    "                result = parts[0]\n",
    "            \n",
    "            #result = result.split('-')[0]\n",
    "            #for to_sub in ['Adeno', 'Bone', 'Breast', 'Cervix', 'CNS', 'Eye', 'Liver', 'Lymph', 'Lung', 'Kidney', 'Myeloid', 'Panc' ]:\n",
    "            #    result = re.sub( to_sub + r'(-\\w*)', to_sub, result)\n",
    "                \n",
    "            renamed_items[i] = result.replace('Ca', 'CA')\n",
    "       \n",
    "        mutations.rename(index=dict(zip(index_items, renamed_items)), inplace = True)\n",
    "   \n",
    "        # Normalize \n",
    "        row_sums = mutations.sum(axis=1)\n",
    "        mutations = mutations.divide(row_sums, axis = 0)\n",
    "\n",
    "        mutations_all = pd.concat([mutations_all, mutations])\n",
    "\n",
    "    mutations_all.sort_index(inplace=True)\n",
    "\n",
    "    # Do we need to renormalize after obtaining the full dataframe?\n",
    "  \n",
    "    # Figure out tumor types based on the first part of the index\n",
    "    tumor_types = [a.split(':')[0] for a in mutations_all.index]\n",
    "    \n",
    "    #print(\"ttt\", tumor_types)\n",
    "    mutations_all[\"tumor_types\"] = tumor_types\n",
    "\n",
    "    # Get rid of types with very few samples if the limit is specified\n",
    "    if small_sample_limit is not None:\n",
    "        mutations_all = cull_small_sample_counts(mutations_all, small_sample_limit)\n",
    "\n",
    "    tumor_types = mutations_all[\"tumor_types\"] \n",
    "    # Prepare a list with all the types appearing only once\n",
    "    unique_tumor_types = sorted(list(set(tumor_types)))\n",
    "    # Attach this back to the frame\n",
    "    \n",
    "    return (mutations_all, unique_tumor_types)\n",
    "\n",
    "def cull_small_sample_counts(mutations, small_sample_limit):\n",
    "    \n",
    "    counts = mutations[\"tumor_types\"].value_counts()\n",
    "    big_counts = counts[list(counts > small_sample_limit)]\n",
    "    big_index = mutations[\"tumor_types\"].isin(list(big_counts.index))\n",
    "    mutations = mutations[big_index]\n",
    "\n",
    "    return mutations\n",
    "\n",
    "def print_dset_diag(mut_df, unique_tumor_types, small_sample_limit):\n",
    "    # Check if the data frame is ok\n",
    "    print(\"\\n---Data set diagnostics print---\\n\")\n",
    "    print(\"Missing entries in mutations:\", mut_df.isnull().sum().sum())\n",
    "    print(\"The shape of the mutations data frame\", mut_df.shape)\n",
    "\n",
    "    # Check to see if the rows are normalized to one, take a sample from the data frame\n",
    "    norm_df = mut_df.sample(n=5, random_state=5)\n",
    "    print(\"Checking normalization: sum of some rows:\\n\", norm_df.iloc[:,0:-1].sum(axis=1))\n",
    "    print(\"\\n\")\n",
    "\n",
    "    # Check some counts of tumor types\n",
    "    tumor_counts = mut_df[\"tumor_types\"].value_counts() #.sort_values(ascending=True)\n",
    "    print(\"Tumor counts:\\n\", tumor_counts)\n",
    "    print(\"\\n\")\n",
    "\n",
    "    small_counts = tumor_counts < 1.5*small_sample_limit\n",
    "    print(\"Tumor types with smallish counts:\",  sum(small_counts))\n",
    "\n",
    "    print(tumor_counts[small_counts])\n",
    "    print(\"\\n\")\n",
    "\n",
    "    # Tumor types\n",
    "    print(\"Unique tumor types: \", len(unique_tumor_types))\n",
    "    print(unique_tumor_types)\n",
    "\n",
    "\n",
    "small_sample_limit = 250\n",
    "\n",
    "profile_raw_data_sets = [PCAWG_wgs_mut, TCGA_wes_mut, nonPCAWG_wgs_mut, other_wes_mut]\n",
    "profile_mut_all, prf_unique_tumor_types = prepare_mut_df(profile_raw_data_sets, True, small_sample_limit)\n",
    "\n",
    "# Print some diagnostics from the prepared data set\n",
    "print(\"Profile data:\")\n",
    "print_dset_diag(profile_mut_all, prf_unique_tumor_types, small_sample_limit)\n",
    "\n",
    "# Data matrix X for fitting, omit the tumor labeling from there, use that information in constructing true y\n",
    "# Note: this contains profile data only\n",
    "#X_prf = profile_mut_all.drop(\"tumor_types\", axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Dataset preprocess for activites data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 422,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Activities data:\n",
      "\n",
      "---Data set diagnostics print---\n",
      "\n",
      "Missing entries in mutations: 0\n",
      "The shape of the mutations data frame (20343, 66)\n",
      "Checking normalization: sum of some rows:\n",
      " mut_tri\n",
      "cancer::TK74_LCIS2    1.0\n",
      "RCC::TCGA             1.0\n",
      "AdenoCA::TCGA         1.0\n",
      "AdenoCA::TCGA         1.0\n",
      "Melanoma::TCGA        1.0\n",
      "Tumor counts:\n",
      " AdenoCA      7712\n",
      "SCC          2188\n",
      "cancer       1639\n",
      "HCC          1318\n",
      "Melanoma     1231\n",
      "BNHL          822\n",
      "RCC           775\n",
      "GBM           605\n",
      "Medullo       557\n",
      "CA            462\n",
      "cell          389\n",
      "CMDI          357\n",
      "LGG           326\n",
      "CLL           302\n",
      "Papillary     297\n",
      "neoplasm      247\n",
      "ALL           240\n",
      "Ewings        231\n",
      "AML           230\n",
      "30            212\n",
      "bone          203\n",
      "Name: tumor_types, dtype: int64\n",
      "\n",
      "\n",
      "Tumor types with smallish counts: 7\n",
      "Papillary    297\n",
      "neoplasm     247\n",
      "ALL          240\n",
      "Ewings       231\n",
      "AML          230\n",
      "30           212\n",
      "bone         203\n",
      "Name: tumor_types, dtype: int64\n",
      "Unique tumor types:  21\n",
      "['30', 'ALL', 'AML', 'AdenoCA', 'BNHL', 'CA', 'CLL', 'CMDI', 'Ewings', 'GBM', 'HCC', 'LGG', 'Medullo', 'Melanoma', 'Papillary', 'RCC', 'SCC', 'bone', 'cancer', 'cell', 'neoplasm']\n"
     ]
    }
   ],
   "source": [
    "act_raw_data_sets = [PCAWG_wgs_act, TCGA_wes_act, nonPCAWG_wgs_act, other_wes_act]\n",
    "act_mut_all, act_unique_tumor_types = prepare_mut_df(act_raw_data_sets, is_profile=False, small_sample_limit=small_sample_limit)\n",
    "\n",
    "# Print some diagnostics from the prepared data set\n",
    "print(\"Activities data:\")\n",
    "print_dset_diag(act_mut_all, act_unique_tumor_types, small_sample_limit)\n",
    "\n",
    "# Data matrix X for fitting, omit the tumor labeling from there, use that information in constructing true y\n",
    "# Note: this contains profile data only\n",
    "X_act = act_mut_all.drop(\"tumor_types\", axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check profile data content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 423,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Some content from the full profile set:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th>mut_tri</th>\n",
       "      <th>C&gt;A_ACA</th>\n",
       "      <th>C&gt;A_ACC</th>\n",
       "      <th>C&gt;A_ACG</th>\n",
       "      <th>C&gt;A_ACT</th>\n",
       "      <th>C&gt;A_CCA</th>\n",
       "      <th>C&gt;A_CCC</th>\n",
       "      <th>C&gt;A_CCG</th>\n",
       "      <th>C&gt;A_CCT</th>\n",
       "      <th>C&gt;A_GCA</th>\n",
       "      <th>C&gt;A_GCC</th>\n",
       "      <th>...</th>\n",
       "      <th>T&gt;G_CTT</th>\n",
       "      <th>T&gt;G_GTA</th>\n",
       "      <th>T&gt;G_GTC</th>\n",
       "      <th>T&gt;G_GTG</th>\n",
       "      <th>T&gt;G_GTT</th>\n",
       "      <th>T&gt;G_TTA</th>\n",
       "      <th>T&gt;G_TTC</th>\n",
       "      <th>T&gt;G_TTG</th>\n",
       "      <th>T&gt;G_TTT</th>\n",
       "      <th>tumor_types</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.1</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>30</td>\n",
       "      <th>30</th>\n",
       "      <td>0.040000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.08</td>\n",
       "      <td>0.020000</td>\n",
       "      <td>0.020000</td>\n",
       "      <td>0.040000</td>\n",
       "      <td>0.140000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.04</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.02</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30</td>\n",
       "      <th>30</th>\n",
       "      <td>0.153846</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.076923</td>\n",
       "      <td>0.076923</td>\n",
       "      <td>0.076923</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>30</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>0.032258</td>\n",
       "      <td>0.032258</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.032258</td>\n",
       "      <td>0.032258</td>\n",
       "      <td>0.032258</td>\n",
       "      <td>0.032258</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30</td>\n",
       "      <th>30</th>\n",
       "      <td>0.100000</td>\n",
       "      <td>0.100000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.00</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 97 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "mut_tri   C>A_ACA   C>A_ACC   C>A_ACG  C>A_ACT  C>A_CCA   C>A_CCC   C>A_CCG  \\\n",
       "30       0.000000  0.000000  0.000000     0.00     0.00  0.000000  0.000000   \n",
       "30       0.040000  0.000000  0.000000     0.02     0.08  0.020000  0.020000   \n",
       "30       0.153846  0.000000  0.000000     0.00     0.00  0.076923  0.000000   \n",
       "30       0.000000  0.032258  0.032258     0.00     0.00  0.000000  0.032258   \n",
       "30       0.100000  0.100000  0.000000     0.00     0.00  0.000000  0.000000   \n",
       "\n",
       "mut_tri   C>A_CCT   C>A_GCA   C>A_GCC  ...  T>G_CTT  T>G_GTA  T>G_GTC  \\\n",
       "30       0.100000  0.100000  0.000000  ...     0.00      0.0      0.0   \n",
       "30       0.040000  0.140000  0.000000  ...     0.04      0.0      0.0   \n",
       "30       0.076923  0.076923  0.000000  ...     0.00      0.0      0.0   \n",
       "30       0.032258  0.032258  0.032258  ...     0.00      0.0      0.0   \n",
       "30       0.000000  0.000000  0.000000  ...     0.00      0.0      0.0   \n",
       "\n",
       "mut_tri  T>G_GTG  T>G_GTT  T>G_TTA  T>G_TTC  T>G_TTG  T>G_TTT  tumor_types  \n",
       "30           0.1      0.0      0.0      0.0     0.00      0.0           30  \n",
       "30           0.0      0.0      0.0      0.0     0.02      0.0           30  \n",
       "30           0.0      0.0      0.0      0.0     0.00      0.0           30  \n",
       "30           0.0      0.0      0.0      0.0     0.00      0.0           30  \n",
       "30           0.0      0.0      0.0      0.0     0.00      0.0           30  \n",
     "execution_count": 423,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Some content from the full profile set:\")\n",
    "profile_mut_all.head(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 424,
      "image/png": "",
      "text/plain": [
       "<Figure size 1800x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure(figsize=(25, 5))\n",
    "sns.set_theme()\n",
    "profile_mut_all[\"tumor_types\"].value_counts().sort_index().plot(kind=\"bar\")\n",
    "#sns.countplot(x=profile_mut_all[\"tumor_types\"], palette=sns.hls_palette(2))\n",
    "plt.xticks(rotation=90);\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check activites data content"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 425,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Some content from the full act set:\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>SBS1</th>\n",
       "      <th>SBS2</th>\n",
       "      <th>SBS3</th>\n",
       "      <th>SBS4</th>\n",
       "      <th>SBS5</th>\n",
       "      <th>SBS6</th>\n",
       "      <th>SBS7a</th>\n",
       "      <th>SBS7b</th>\n",
       "      <th>SBS7c</th>\n",
       "      <th>SBS7d</th>\n",
       "      <th>...</th>\n",
       "      <th>SBS52</th>\n",
       "      <th>SBS53</th>\n",
       "      <th>SBS54</th>\n",
       "      <th>SBS55</th>\n",
       "      <th>SBS56</th>\n",
       "      <th>SBS57</th>\n",
       "      <th>SBS58</th>\n",
       "      <th>SBS59</th>\n",
       "      <th>SBS60</th>\n",
       "      <th>tumor_types</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mut_tri</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>30</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30</td>\n",
       "      <th>30</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>...</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>30</td>\n",