#!/usr/bin/env python import argparse import pandas as pd import numpy as np import os from sklearn.feature_selection import mutual_info_classif from sklearn.preprocessing import LabelEncoder from sklearn.decomposition import PCA from sklearn.preprocessing import StandardScaler from sklearn.metrics import accuracy_score, confusion_matrix, classification_report import matplotlib.pyplot as plt import matplotlib import seaborn as sns matplotlib.use('Agg') import io import boto3 from botocore.exceptions import NoCredentialsError def save_figure(fig, path, format='png', **kwargs): """ Save a matplotlib figure to a local file or an S3 bucket using in-memory storage. Parameters: fig (matplotlib.figure.Figure): The matplotlib figure to save. path (str): The file path or S3 bucket path (e.g., 's3://bucket-name/key-name'). format (str): The format of the file (e.g., 'png', 'jpg', 'pdf'). **kwargs: Additional keyword arguments passed to plt.savefig(). """ if path.startswith("s3://"): # Parse the S3 bucket and key s3_path = path.replace("s3://", "").split("/", 1) if len(s3_path) != 2: raise ValueError("Invalid S3 path. Format should be 's3://bucket-name/key-name'.") bucket_name, key_name = s3_path # Save the figure to an in-memory buffer buffer = io.BytesIO() try: fig.savefig(buffer, format=format, **kwargs) buffer.seek(0) # Reset the buffer's pointer to the beginning # Upload the buffer content to S3 s3_client = boto3.client('s3') s3_client.upload_fileobj(buffer, bucket_name, key_name) print(f"Figure saved to S3 at {path}") except NoCredentialsError: raise NoCredentialsError("AWS credentials not found. Please configure your AWS environment.") except Exception as e: raise Exception(f"Failed to save figure to S3: {e}") finally: buffer.close() else: # Save to the local filesystem try: fig.savefig(path, format=format, **kwargs) print(f"Figure saved locally at {path}") except Exception as e: raise Exception(f"Failed to save figure locally: {e}") # Parse arguments parser = argparse.ArgumentParser(description="Process gene counts and metadata for PCA and LDA analysis.") parser.add_argument('--gene_counts', required=True, help="Path to the gene count matrix file.") parser.add_argument('--metadata', required=True, help="Path to the metadata file.") parser.add_argument('--output_dir', required=True, help="Directory to save output files.") parser.add_argument('--mutual_information_threshold', type=int, default=10, help="Top N highest mutual information genes. Default is 10.") parser.add_argument('--low_expression_threshold', type=int, default=10, help="Threshold for filtering low-expression genes. Default is 10.") parser.add_argument('--n_repeats', type=int, default=10, help="Number of repetitions for mutual information calculation. Default is 10.") args = parser.parse_args() gene_counts_path = args.gene_counts metadata_path = args.metadata output_dir = args.output_dir mutual_information_threshold = args.mutual_information_threshold low_expression_threshold = args.low_expression_threshold n_repeats = args.n_repeats # Read the gene counts and metadata files gene_counts = pd.read_csv(gene_counts_path, sep='\t', index_col=0) metadata = pd.read_csv(metadata_path) # Filter gene counts to include only samples present in the metadata samples_in_metadata = metadata['sample'].tolist() filtered_gene_counts = \ gene_counts.loc[:, gene_counts.columns.intersection(samples_in_metadata)] # Filter low-expression genes gene_sums = filtered_gene_counts.sum(axis=1) filtered_gene_counts_high_expression = \ filtered_gene_counts[gene_sums > low_expression_threshold] # Normalize and log-transform data cpm = (filtered_gene_counts_high_expression.T / filtered_gene_counts_high_expression.sum(axis=1)).T * 1e6 log_cpm = cpm.apply(lambda x: np.log2(x + 1), axis=1) # Calculate variance and prepare for mutual information analysis gene_variances = log_cpm.var(axis=1) # Select the top 50 most variable genes top_genes = gene_variances.nlargest(50).index # Subset the log-transformed CPM data for these top genes top_genes_data = log_cpm.loc[top_genes] metadata_sorted = metadata.sort_values(by=['thermal.tolerance', 'day']) sorted_samples = metadata_sorted['sample'] # Subset and reorder the heatmap data to match the sorted samples top_genes_data_sorted = top_genes_data[sorted_samples] # Create heatmap with samples grouped by thermal resilience and day # Create custom labels for the x-axis showing both thermal resilience and day x_labels = [ f"{metadata_sorted.loc[metadata_sorted['sample'] == sample, 'thermal.tolerance'].values[0]}-Day{metadata_sorted.loc[metadata_sorted['sample'] == sample, 'day'].values[0]}" for sample in sorted_samples ] # Create heatmap with labeled groupings plt.figure(figsize=(16, 10)) sns.heatmap( top_genes_data_sorted, cmap="viridis", yticklabels=False, xticklabels=x_labels, cbar_kws={"label": "Log2(CPM + 1)"} ) plt.title("Heatmap of Top 50 Most Variable Genes (Grouped by Thermal Resilience and Day)") plt.xlabel("Samples (Resilience-Day)") plt.ylabel("Genes") plt.xticks(rotation=90, fontsize=8) save_figure(plt, os.path.join(output_dir, "heatmap_top_50_variable_genes.png")) # Calculate mutual information sample_metadata = metadata.set_index('sample') expression_data = top_genes_data_sorted.T.values labels_resilience_day = metadata_sorted[['thermal.tolerance', 'day']] labels_combined = labels_resilience_day['thermal.tolerance'] + "-Day" + \ labels_resilience_day['day'].astype(str) label_encoder = LabelEncoder() encoded_labels = label_encoder.fit_transform(labels_combined) mi_scores_list = [] for _ in range(n_repeats): mi_scores = mutual_info_classif(expression_data, encoded_labels, discrete_features=False, random_state=None) mi_scores_list.append(mi_scores) avg_mi_scores = np.mean(mi_scores_list, axis=0) mutual_info_results = pd.DataFrame({ 'Gene': top_genes_data_sorted.index, 'Mutual_Information': avg_mi_scores }).sort_values(by='Mutual_Information', ascending=False) mutual_info_results.to_csv(os.path.join(output_dir,"mutual_information_results.csv"), index=False) top_genes_mi = mutual_info_results.iloc[:mutual_information_threshold]['Gene'] top_data_mi = log_cpm.loc[top_genes_mi] # Reorder samples based on metadata (optional) top_data_mi_sorted = top_data_mi[sorted_samples] # PCA Plot for Top 10 Genes # Select the top genes based on averaged mutual information scores scaled_top_mi_data = StandardScaler().fit_transform(top_data_mi.T) # Transpose pca_top_mi = PCA(n_components=2).fit_transform(scaled_top_mi_data) plt.figure(figsize=(12, 10)) color_map = {'resistant': 'blue', 'susceptible': 'red'} colors = [color_map[sample_metadata.loc[sample, 'thermal.tolerance']] for sample in filtered_gene_counts_high_expression.columns] days = sample_metadata.loc[filtered_gene_counts_high_expression.columns, 'day'].tolist() for i, sample in enumerate(filtered_gene_counts_high_expression.columns): plt.scatter(pca_top_mi[i, 0], pca_top_mi[i, 1], color=colors[i], s=50, alpha=0.7) plt.text(pca_top_mi[i, 0], pca_top_mi[i, 1], f"Day {days[i]}", fontsize=8, ha='right') plt.title(f"PCA of Samples Based on Top {mutual_information_threshold} Genes") plt.xlabel("PC1") plt.ylabel("PC2") plt.grid(True) legend_handles = [ plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='blue', markersize=10, label='Resistant'), plt.Line2D([0], [0], marker='o', color='w', markerfacecolor='red', markersize=10, label='Susceptible') ] plt.legend(handles=legend_handles, title="Thermal Tolerance") save_figure(plt, os.path.join(output_dir, f"PCA_top_{mutual_information_threshold}_mutual_info_genes.png")) # Create heatmap with grouped samples plt.figure(figsize=(12, 8)) sns.heatmap( top_data_mi_sorted, cmap="viridis", xticklabels=sorted_samples, yticklabels=top_genes_mi, cbar_kws={"label": "Log2(CPM + 1)"} ) plt.title(f"Heatmap of Top {mutual_information_threshold} Genes with Highest Mutual Information (Grouped by Thermal Tolerance and Day)") plt.xlabel("Samples (Grouped by Resilience and Day)") plt.ylabel("Genes") plt.xticks(rotation=90, fontsize=8) save_figure(plt, os.path.join(output_dir, f"heatmap_top_{mutual_information_threshold}_mutual_info_genes.png")) #--- # Create heatmap with hierarchical clustering plt.figure(figsize=(12, 10)) sns.clustermap( top_data_mi_sorted, cmap="viridis", metric="euclidean", method="average", col_cluster=True, # Cluster samples row_cluster=True, # Cluster genes cbar_kws={"label": "Log2(CPM + 1)"} ) plt.title(f"Hierarchical Clustering Heatmap of Top {mutual_information_threshold} Mutual Information Genes") save_figure(plt, os.path.join(output_dir, f"heatmap_hierarchical_top_{mutual_information_threshold}_mutual_info_genes.png")) # Additional plots and outputs can follow a similar pattern.