#!/usr/bin/env python3 """ BAM-based variation and connectedness summary pipeline. This complementary analysis consumes the per-sample alignments generated by `01_align_and_visualize.py` and derives rapid metrics of genetic variation without requiring joint VCF generation. Core analyses include: 1. `samtools stats` quality reports for each BAM. 2. Random-sampled reference positions (default 500 sites) where base counts, mismatch rates, and putative heterozygosity are estimated directly from BAMs. 3. Pairwise identity-by-state (IBS) similarity and PCA using consensus-based mismatch indicators to visualise connectedness across sampling locations. All outputs, figures, and logs follow the conventions in `INSTRUCTIONS.md`, writing to `output/02_bam_summary/`. The workflow is designed to respect the available compute resources (≤50 CPUs, ≈200 MB RAM) by streaming over BAM files sequentially and sampling a modest number of genomic sites. """ import argparse import json import logging import multiprocessing import os import shutil import subprocess import sys import time from dataclasses import dataclass from pathlib import Path from typing import Dict, Iterable, List, Sequence, Tuple import numpy as np import pandas as pd try: import pysam except ImportError as err: # pragma: no cover - runtime check only raise ImportError( "pysam is required for BAM-based analyses. Install it via `pip install pysam`." ) from err BASES = np.array(["A", "C", "G", "T"]) MISSING_CODE = 4 # sentinel for missing base calls @dataclass(frozen=True) class Position: """Reference genomic position sampled for BAM interrogation.""" contig: str position: int # 0-based coordinate def parse_args() -> argparse.Namespace: """Parse command-line arguments.""" parser = argparse.ArgumentParser( description="Summarise variation directly from aligned BAM files." ) parser.add_argument( "--alignment-dir", type=Path, default=Path("output/01_align_and_visualize/alignments"), help="Directory containing per-sample sorted BAM files.", ) parser.add_argument( "--sample-metadata", type=Path, default=Path("output/01_align_and_visualize/metrics/sample_metadata.tsv"), help="Sample metadata with `sample_id` and `location` columns from step 01.", ) parser.add_argument( "--reference-fai", type=Path, default=Path("output/01_align_and_visualize/reference/Olurida_v081.fa.fai"), help="Reference FASTA index produced during step 01.", ) parser.add_argument( "--output-dir", type=Path, default=Path("output/02_bam_summary"), help="Destination directory for generated reports and figures.", ) parser.add_argument( "--num-sites", type=int, default=500, help="Number of genomic sites to sample for variation estimates.", ) parser.add_argument( "--min-depth", type=int, default=3, help="Minimum depth required at a sampled site to consider it informative.", ) parser.add_argument( "--min-base-quality", type=int, default=20, help="Minimum base quality used when counting bases at sampled sites.", ) parser.add_argument( "--heterozygous-alt-fraction", type=float, default=0.3, help="Minimum alternative allele fraction to classify a site as heterozygous.", ) parser.add_argument( "--seed", type=int, default=42, help="Random seed for reproducible site sampling.", ) parser.add_argument( "--threads", type=int, default=min(50, multiprocessing.cpu_count()), help="Maximum parallel workers for optional concurrent steps (≤50).", ) parser.add_argument( "--force", action="store_true", help="Recompute outputs even if they already exist.", ) return parser.parse_args() def configure_logging(log_path: Path) -> None: """Configure logging to file and stdout.""" log_path.parent.mkdir(parents=True, exist_ok=True) log_format = "%(asctime)s [%(levelname)s] %(message)s" handlers = [logging.FileHandler(log_path, mode="w"), logging.StreamHandler(sys.stdout)] logging.basicConfig(level=logging.INFO, format=log_format, handlers=handlers) def check_dependencies(dependencies: Iterable[str]) -> None: """Ensure required external binaries are available.""" missing = [exe for exe in dependencies if shutil.which(exe) is None] if missing: raise RuntimeError( "Missing required executables: " + ", ".join(missing) + ". Install them before rerunning." ) def to_relative_path(path: Path, base: Path) -> str: """Render `path` relative to `base` when possible.""" try: return str(path.relative_to(base)) except ValueError: return str(path) def load_sample_metadata(sample_metadata_path: Path) -> pd.DataFrame: """Load metadata linking samples to sampling locations.""" if not sample_metadata_path.exists(): raise FileNotFoundError(f"Sample metadata not found: {sample_metadata_path}") metadata = pd.read_csv(sample_metadata_path, sep="\t") expected_cols = {"sample_id", "location"} missing = expected_cols - set(metadata.columns) if missing: raise ValueError( "Sample metadata missing required columns: " + ", ".join(sorted(missing)) ) return metadata def discover_bams(alignment_dir: Path, samples: Sequence[str]) -> List[Tuple[str, Path]]: """Build ordered list of BAM paths for the provided sample IDs.""" bam_paths: List[Tuple[str, Path]] = [] for sample_id in samples: bam_path = alignment_dir / f"{sample_id}.sorted.bam" if not bam_path.exists(): raise FileNotFoundError(f"BAM file not found for sample {sample_id}: {bam_path}") index_candidates = [ bam_path.with_suffix(".bai"), Path(str(bam_path) + ".bai"), ] if not any(candidate.exists() for candidate in index_candidates): raise FileNotFoundError( f"Index (.bai) missing for {bam_path}. Run `samtools index` first." ) bam_paths.append((sample_id, bam_path)) return bam_paths def run_command( command: Iterable[str] | str, *, cwd: Path | None = None, ) -> None: """Execute a command with logging.""" if isinstance(command, str): cmd_display = command else: cmd_display = " ".join(command) logging.info("Running command: %s", cmd_display) subprocess.run(command, cwd=str(cwd) if cwd else None, check=True) def parse_samtools_stats(stats_path: Path, sample_id: str) -> Dict[str, float]: """Extract key summary metrics from a samtools stats file.""" metrics: Dict[str, float] = {"sample_id": sample_id} with stats_path.open() as handle: for line in handle: if not line.startswith("SN"): continue parts = line.strip().split("\t", maxsplit=3) if len(parts) < 3: continue _, key, value = parts[:3] key_clean = key.rstrip(":") try: metrics[key_clean] = float(value) except ValueError: metrics[key_clean] = value return metrics def run_samtools_stats( bam_paths: List[Tuple[str, Path]], stats_dir: Path, force: bool, ) -> pd.DataFrame: """Generate samtools stats reports and return aggregated metrics.""" stats_dir.mkdir(parents=True, exist_ok=True) summary_records: List[Dict[str, float]] = [] for sample_id, bam_path in bam_paths: output_path = stats_dir / f"{sample_id}.stats" if not output_path.exists() or force: logging.info("Running samtools stats for %s", sample_id) with output_path.open("w") as handle: subprocess.run( ["samtools", "stats", str(bam_path)], check=True, stdout=handle, ) summary_records.append(parse_samtools_stats(output_path, sample_id)) summary_df = pd.DataFrame(summary_records) return summary_df def read_reference_index(reference_fai: Path) -> List[Tuple[str, int]]: """Read FASTA index and return (contig, length) tuples.""" if not reference_fai.exists(): raise FileNotFoundError(f"Reference FAI not found: {reference_fai}") contigs: List[Tuple[str, int]] = [] with reference_fai.open() as handle: for line in handle: fields = line.strip().split("\t") if len(fields) < 2: continue contigs.append((fields[0], int(fields[1]))) if not contigs: raise ValueError(f"No contig information available in {reference_fai}") return contigs def sample_positions( contigs: Sequence[Tuple[str, int]], num_sites: int, seed: int, ) -> List[Position]: """Sample genomic positions proportional to contig length.""" rng = np.random.default_rng(seed) lengths = np.array([length for _, length in contigs], dtype=np.float64) probabilities = lengths / lengths.sum() contig_indices = rng.choice(len(contigs), size=num_sites, p=probabilities, replace=True) positions: List[Position] = [] for idx in contig_indices: contig, length = contigs[idx] pos = int(rng.integers(0, length)) positions.append(Position(contig=contig, position=pos)) return positions def collect_sample_base_metrics( bam_paths: List[Tuple[str, Path]], positions: List[Position], min_depth: int, min_base_quality: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ For each BAM and sampled position, collect major base calls, depth, and alternative allele fraction. Returns ------- allele_matrix : ndarray (samples x sites) with base codes (0-3) or 4 for missing depth_matrix : ndarray (samples x sites) with total depth alt_fraction_matrix : ndarray (samples x sites) storing alt/(ref+alt) """ num_samples = len(bam_paths) num_sites = len(positions) allele_matrix = np.full((num_samples, num_sites), MISSING_CODE, dtype=np.int8) depth_matrix = np.zeros((num_samples, num_sites), dtype=np.float32) alt_fraction_matrix = np.zeros((num_samples, num_sites), dtype=np.float32) for sample_idx, (sample_id, bam_path) in enumerate(bam_paths): logging.info("Collecting base counts for %s", sample_id) with pysam.AlignmentFile(bam_path, "rb") as bam: for pos_idx, position in enumerate(positions): try: coverage = bam.count_coverage( position.contig, position.position, position.position + 1, quality_threshold=min_base_quality, ) except ValueError: continue # contig absent or out of range counts = np.array([int(base_counts[0]) for base_counts in coverage], dtype=np.int32) depth = counts.sum() depth_matrix[sample_idx, pos_idx] = depth if depth < min_depth: continue major_idx = int(np.argmax(counts[:4])) major_count = counts[major_idx] alt_count = depth - major_count allele_matrix[sample_idx, pos_idx] = major_idx alt_fraction_matrix[sample_idx, pos_idx] = alt_count / depth if depth else 0.0 return allele_matrix, depth_matrix, alt_fraction_matrix def build_consensus( allele_matrix: np.ndarray, depth_matrix: np.ndarray, min_depth: int, ) -> np.ndarray: """Determine consensus (major) allele per site across all samples.""" num_sites = allele_matrix.shape[1] consensus = np.full(num_sites, MISSING_CODE, dtype=np.int8) for site_idx in range(num_sites): valid_mask = (allele_matrix[:, site_idx] != MISSING_CODE) & ( depth_matrix[:, site_idx] >= min_depth ) if not np.any(valid_mask): continue counts = np.bincount(allele_matrix[valid_mask, site_idx], minlength=5) counts[MISSING_CODE] = 0 consensus[site_idx] = int(np.argmax(counts[:4])) return consensus def compute_sample_metrics( sample_ids: Sequence[str], locations: Sequence[str], allele_matrix: np.ndarray, consensus: np.ndarray, depth_matrix: np.ndarray, alt_fraction_matrix: np.ndarray, min_depth: int, heterozygous_alt_fraction: float, ) -> pd.DataFrame: """Compute per-sample variation metrics.""" num_samples, num_sites = allele_matrix.shape valid = ( (allele_matrix != MISSING_CODE) & (consensus != MISSING_CODE) & (depth_matrix >= min_depth) ) mismatch = (allele_matrix != consensus) & valid heterozygous = (alt_fraction_matrix >= heterozygous_alt_fraction) & valid observed_sites = valid.sum(axis=1) mismatch_rate = np.divide( mismatch.sum(axis=1), observed_sites, out=np.zeros(num_samples, dtype=float), where=observed_sites > 0, ) heterozygosity_rate = np.divide( heterozygous.sum(axis=1), observed_sites, out=np.zeros(num_samples, dtype=float), where=observed_sites > 0, ) mean_depth_selected = np.divide( (depth_matrix * valid).sum(axis=1), observed_sites, out=np.zeros(num_samples, dtype=float), where=observed_sites > 0, ) mean_alt_fraction = np.divide( (alt_fraction_matrix * valid).sum(axis=1), observed_sites, out=np.zeros(num_samples, dtype=float), where=observed_sites > 0, ) records = [] for idx, sample_id in enumerate(sample_ids): records.append( { "sample_id": sample_id, "location": locations[idx], "sites_evaluated": int(observed_sites[idx]), "mismatch_rate": float(mismatch_rate[idx]), "heterozygosity_rate": float(heterozygosity_rate[idx]), "mean_depth_selected": float(mean_depth_selected[idx]), "mean_alt_fraction": float(mean_alt_fraction[idx]), } ) return pd.DataFrame(records) def summarise_by_location(sample_metrics: pd.DataFrame) -> pd.DataFrame: """Aggregate variation metrics at the location level.""" summary = ( sample_metrics.groupby("location") .agg( samples=("sample_id", "count"), mean_mismatch=("mismatch_rate", "mean"), sd_mismatch=("mismatch_rate", "std"), mean_heterozygosity=("heterozygosity_rate", "mean"), sd_heterozygosity=("heterozygosity_rate", "std"), mean_depth=("mean_depth_selected", "mean"), sd_depth=("mean_depth_selected", "std"), ) .reset_index() ) summary.fillna(0.0, inplace=True) return summary def compute_ibs_matrix( allele_matrix: np.ndarray, consensus: np.ndarray, depth_matrix: np.ndarray, min_depth: int, ) -> np.ndarray: """Compute pairwise identity-by-state similarity matrix.""" num_samples = allele_matrix.shape[0] ibs = np.full((num_samples, num_samples), np.nan, dtype=float) for i in range(num_samples): ibs[i, i] = 1.0 for j in range(i + 1, num_samples): valid = ( (allele_matrix[i] != MISSING_CODE) & (allele_matrix[j] != MISSING_CODE) & (consensus != MISSING_CODE) & (depth_matrix[i] >= min_depth) & (depth_matrix[j] >= min_depth) ) valid_sites = np.count_nonzero(valid) if valid_sites == 0: continue matches = np.count_nonzero(allele_matrix[i, valid] == allele_matrix[j, valid]) similarity = matches / valid_sites ibs[i, j] = ibs[j, i] = similarity return ibs def compute_pca( allele_matrix: np.ndarray, consensus: np.ndarray, sample_ids: Sequence[str], locations: Sequence[str], depth_matrix: np.ndarray, min_depth: int, ) -> pd.DataFrame: """Perform PCA on mismatch indicators relative to the consensus.""" valid = ( (allele_matrix != MISSING_CODE) & (consensus != MISSING_CODE) & (depth_matrix >= min_depth) ) mismatch_matrix = np.where( valid, (allele_matrix != consensus).astype(float), np.nan, ) col_means = np.nanmean(mismatch_matrix, axis=0) finite_cols = np.isfinite(col_means) mismatch_matrix = mismatch_matrix[:, finite_cols] col_means = col_means[finite_cols] if mismatch_matrix.shape[1] == 0: raise RuntimeError( "No informative sites detected for PCA. " "Increase `--num-sites` or lower `--min-depth`." ) inds = np.where(np.isnan(mismatch_matrix)) mismatch_matrix[inds] = np.take(col_means, inds[1]) mismatch_matrix -= col_means u, s, _ = np.linalg.svd(mismatch_matrix, full_matrices=False) pcs = u[:, :3] * s[:3] df = pd.DataFrame( { "sample_id": sample_ids, "location": locations, "PC1": pcs[:, 0], "PC2": pcs[:, 1], "PC3": pcs[:, 2] if pcs.shape[1] > 2 else np.zeros(len(sample_ids)), } ) return df def plot_results( pca_df: pd.DataFrame, ibs_matrix: np.ndarray, sample_ids: Sequence[str], locations: Sequence[str], figure_dir: Path, ) -> Path: """Create PCA scatter and IBS heatmap figure.""" figure_dir.mkdir(parents=True, exist_ok=True) try: import matplotlib.pyplot as plt import matplotlib.cm as cm except ImportError as err: # pragma: no cover - runtime dependency raise ImportError("matplotlib is required for plotting.") from err unique_locations = sorted(set(locations)) cmap = cm.get_cmap("tab20", len(unique_locations)) location_colors = {loc: cmap(i) for i, loc in enumerate(unique_locations)} fig, axes = plt.subplots(1, 2, figsize=(14, 6)) ax_pca = axes[0] ax_pca.set_title("BAM-derived PCA (PC1 vs PC2)") for location in unique_locations: subset = pca_df[pca_df["location"] == location] ax_pca.scatter( subset["PC1"], subset["PC2"], label=location, s=35, alpha=0.85, color=location_colors.get(location, "#555555"), ) ax_pca.set_xlabel("PC1") ax_pca.set_ylabel("PC2") ax_pca.legend(frameon=False, fontsize="small") ax_heatmap = axes[1] im = ax_heatmap.imshow(ibs_matrix, cmap="viridis", vmin=0, vmax=1) ax_heatmap.set_title("Pairwise IBS (BAM consensus)") ax_heatmap.set_xticks(range(len(sample_ids))) ax_heatmap.set_yticks(range(len(sample_ids))) ax_heatmap.set_xticklabels(sample_ids, rotation=90, fontsize=6) ax_heatmap.set_yticklabels(sample_ids, fontsize=6) fig.colorbar(im, ax=ax_heatmap, fraction=0.046, pad=0.04, label="IBS similarity") fig.tight_layout() figure_path = figure_dir / "bam_connectedness.png" fig.savefig(figure_path, dpi=300) plt.close(fig) return figure_path def assemble_metadata( script_start: float, params: Dict[str, object], outputs: Dict[str, List[Path]], metadata_path: Path, repo_root: Path, ) -> None: """Write metadata.json describing the run.""" metadata = { "script": Path(__file__).name, "date": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "runtime_seconds": time.time() - script_start, "parameters": params, "outputs": { key: [to_relative_path(path, repo_root) for path in paths] for key, paths in outputs.items() }, } metadata_path.write_text(json.dumps(metadata, indent=2)) def main() -> None: args = parse_args() repo_root = Path(__file__).resolve().parents[1] os.chdir(repo_root) output_dir = args.output_dir if args.output_dir.is_absolute() else repo_root / args.output_dir stats_dir = output_dir / "stats" tables_dir = output_dir / "tables" figures_dir = output_dir / "figures" positions_dir = output_dir / "positions" logs_dir = output_dir / "logs" tmp_dir = output_dir / "tmp" for directory in [output_dir, stats_dir, tables_dir, figures_dir, positions_dir, tmp_dir]: directory.mkdir(parents=True, exist_ok=True) configure_logging(logs_dir / "pipeline.log") script_start = time.time() try: check_dependencies(["samtools"]) metadata_df = load_sample_metadata( args.sample_metadata if args.sample_metadata.is_absolute() else repo_root / args.sample_metadata ) sample_ids = metadata_df["sample_id"].tolist() locations = metadata_df["location"].tolist() bam_paths = discover_bams( args.alignment_dir if args.alignment_dir.is_absolute() else repo_root / args.alignment_dir, sample_ids, ) samtools_stats_df = run_samtools_stats( bam_paths, stats_dir / "samtools_stats", force=args.force, ) samtools_stats_df.to_csv(tables_dir / "samtools_stats_summary.tsv", sep="\t", index=False) contigs = read_reference_index( args.reference_fai if args.reference_fai.is_absolute() else repo_root / args.reference_fai ) positions = sample_positions(contigs, args.num_sites, args.seed) positions_df = pd.DataFrame( {"contig": [pos.contig for pos in positions], "position": [pos.position for pos in positions]} ) positions_path = positions_dir / "sampled_positions.tsv" positions_df.to_csv(positions_path, sep="\t", index=False) allele_matrix, depth_matrix, alt_fraction_matrix = collect_sample_base_metrics( bam_paths, positions, args.min_depth, args.min_base_quality, ) consensus = build_consensus(allele_matrix, depth_matrix, args.min_depth) sample_metrics_df = compute_sample_metrics( sample_ids, locations, allele_matrix, consensus, depth_matrix, alt_fraction_matrix, args.min_depth, args.heterozygous_alt_fraction, ) sample_metrics_path = tables_dir / "bam_variation_metrics_per_sample.tsv" sample_metrics_df.to_csv(sample_metrics_path, sep="\t", index=False) location_summary_df = summarise_by_location(sample_metrics_df) location_summary_path = tables_dir / "bam_variation_metrics_by_location.tsv" location_summary_df.to_csv(location_summary_path, sep="\t", index=False) ibs_matrix = compute_ibs_matrix(allele_matrix, consensus, depth_matrix, args.min_depth) ibs_df = pd.DataFrame(ibs_matrix, index=sample_ids, columns=sample_ids) ibs_matrix_path = tables_dir / "ibs_matrix.tsv" ibs_df.to_csv(ibs_matrix_path, sep="\t") pca_df = compute_pca( allele_matrix, consensus, sample_ids, locations, depth_matrix, args.min_depth, ) pca_path = tables_dir / "bam_pca_components.tsv" pca_df.to_csv(pca_path, sep="\t", index=False) figure_path = plot_results( pca_df, ibs_matrix, sample_ids, locations, figures_dir, ) assemble_metadata( script_start, params={ "alignment_dir": to_relative_path( args.alignment_dir if args.alignment_dir.is_absolute() else repo_root / args.alignment_dir, repo_root, ), "sample_metadata": to_relative_path( args.sample_metadata if args.sample_metadata.is_absolute() else repo_root / args.sample_metadata, repo_root, ), "reference_fai": to_relative_path( args.reference_fai if args.reference_fai.is_absolute() else repo_root / args.reference_fai, repo_root, ), "num_sites": args.num_sites, "min_depth": args.min_depth, "min_base_quality": args.min_base_quality, "heterozygous_alt_fraction": args.heterozygous_alt_fraction, "seed": args.seed, "force": args.force, }, outputs={ "stats": [stats_dir / "samtools_stats"] + list((stats_dir / "samtools_stats").glob("*.stats")), "tables": [ tables_dir / "samtools_stats_summary.tsv", sample_metrics_path, location_summary_path, ibs_matrix_path, pca_path, ], "figures": [figure_path], "positions": [positions_path], "logs": [logs_dir / "pipeline.log"], }, metadata_path=output_dir / "metadata.json", repo_root=repo_root, ) logging.info("BAM-based variation summary complete. Outputs at %s", output_dir) except Exception as exc: # pylint: disable=broad-except logging.exception("BAM summary pipeline failed: %s", exc) raise if __name__ == "__main__": main()