Metadata-Version: 2.4
Name: checkdyn
Version: 0.1.2
Summary: Modeling treatment-induced immune checkpoint dynamics
License: MIT
Requires-Python: >=3.9
Description-Content-Type: text/markdown
Requires-Dist: torch>=2.0
Requires-Dist: pandas>=2.0
Requires-Dist: numpy>=1.24
Requires-Dist: scikit-learn>=1.4
Requires-Dist: scanpy>=1.10
Requires-Dist: anndata>=0.10
Requires-Dist: lifelines>=0.28
Requires-Dist: networkx>=3.0
Requires-Dist: matplotlib>=3.8
Requires-Dist: seaborn>=0.13
Requires-Dist: scipy>=1.12
Requires-Dist: statsmodels>=0.14
Requires-Dist: pycomBat>=0.3
Requires-Dist: pyyaml>=6.0
Requires-Dist: typer>=0.9
Requires-Dist: rich>=13.0
Requires-Dist: tqdm>=4.66
Requires-Dist: GEOparse>=2.0
Provides-Extra: dev
Requires-Dist: pytest>=8.0; extra == "dev"
Requires-Dist: pytest-cov>=5.0; extra == "dev"
Requires-Dist: ruff>=0.4; extra == "dev"

# CheckDyn

**CheckDyn** is a computational biology toolkit for modeling treatment-induced immune checkpoint dynamics and predicting adaptive resistance to immune checkpoint blockade (ICB) therapy.

It implements the full analytical pipeline from raw paired pre/post-treatment transcriptomic data through to a clinical-facing resistance risk score, covering differential expression, compensatory upregulation detection, network rewiring analysis, and a Transformer-based prediction model.

---

## Table of Contents

- [Installation](#installation)
- [Quick Start](#quick-start)
- [CLI Reference](#cli-reference)
  - [checkdyn demo](#checkdyn-demo)
  - [checkdyn datasets](#checkdyn-datasets)
  - [checkdyn collect](#checkdyn-collect)
  - [checkdyn analyze](#checkdyn-analyze)
  - [checkdyn predict](#checkdyn-predict)
  - [checkdyn train](#checkdyn-train)
- [Python API](#python-api)
  - [Data ingestion](#1-data-ingestion)
  - [Paired dynamics analysis](#2-paired-dynamics-analysis)
  - [Compensatory upregulation detection](#3-compensatory-upregulation-detection)
  - [Network rewiring analysis](#4-network-rewiring-analysis)
  - [Model training](#5-model-training--evaluation)
  - [Resistance risk prediction](#6-resistance-risk-prediction)
  - [Biological validation](#7-biological-validation)
  - [Visualization](#8-visualization)
- [Input Data Format](#input-data-format)
  - [AnnData (.h5ad) format](#anndata-h5ad-format)
  - [Patient CSV format](#patient-csv-format)
- [Sample Data](#sample-data)
- [Checkpoint Gene Panel](#checkpoint-gene-panel)
- [Architecture Overview](#architecture-overview)
- [Supported Datasets](#supported-datasets)
- [Dependencies](#dependencies)
- [License](#license)

---

## Installation

```bash
pip install checkdyn
```

**Requirements:** Python ≥ 3.9, PyTorch ≥ 2.0

---

## Quick Start

The fastest way to explore the full pipeline is the `demo` command, which generates synthetic data and runs everything end-to-end:

```bash
checkdyn demo --output demo_out/
```

This creates:

```
demo_out/
├── demo_data.h5ad          ← synthetic paired AnnData (60 samples × 42 genes)
├── sample_patients.csv     ← pre-treatment CSV for `checkdyn predict`
├── de_results.csv          ← differential expression results
├── compensation.csv        ← compensatory checkpoint candidates
├── validation_report.json  ← biological validation report
└── volcano_all.png         ← volcano plot
```

Then try predict:

```bash
checkdyn predict --input demo_out/sample_patients.csv
```

And analyze with your own h5ad:

```bash
checkdyn analyze --input demo_out/demo_data.h5ad --output demo_out/results/
```

---

## CLI Reference

All commands support `--help` for full option descriptions.

---

### `checkdyn demo`

Generate synthetic paired data and run the full pipeline. Use this to verify your installation and understand expected file formats.

```bash
checkdyn demo [OPTIONS]

Options:
  --output     -o  PATH   Output directory  [default: demo_out/]
  --n-patients     INT    Number of synthetic patients  [default: 30]
  --seed           INT    Random seed  [default: 42]
```

**Example:**

```bash
checkdyn demo --output my_demo/ --n-patients 50
```

**What it does:**

1. Generates `n_patients` synthetic patients (half responders, half non-responders)
2. Simulates compensatory upregulation of HAVCR2, LAG3, TIGIT in non-responders
3. Runs paired differential expression analysis
4. Detects compensatory checkpoints
5. Runs biological validation
6. Saves all outputs + volcano plot

---

### `checkdyn datasets`

List all built-in registered paired ICI cohorts.

```bash
checkdyn datasets
```

**Output:** A table with dataset name, cancer type, treatment, patient counts, platform, and reference.

---

### `checkdyn collect`

Download paired pre/post-treatment transcriptomic datasets from GEO.

```bash
checkdyn collect [OPTIONS]

Options:
  --output  -o  PATH    Output directory  [default: data/]
  --datasets -d STRING  Comma-separated dataset names (default: all 8 datasets)
```

**Examples:**

```bash
# Download all 8 datasets
checkdyn collect --output data/

# Download specific datasets
checkdyn collect --output data/ --datasets Riaz2017,Gide2019,Hanna2018
```

**Downloaded to:** `<output>/raw/<DatasetName>/`

**Available datasets:** Riaz2017, Gide2019, Chen2016, Sade-Feldman2018, Yost2019, Caushi2021, Hanna2018, Miao2018

> **Note:** Raw GEO data is downloaded as SOFT files via GEOparse. Further preprocessing (SOFT → h5ad) is required before `checkdyn analyze`. See [Python API – Data ingestion](#1-data-ingestion).

---

### `checkdyn analyze`

Run the full paired checkpoint dynamics analysis on an AnnData file.

```bash
checkdyn analyze [OPTIONS]

Options:
  --input   -i  PATH     AnnData .h5ad file (required)
  --output  -o  PATH     Output directory  [default: results/]
  --genes   -g  STRING   Comma-separated gene list override (default: 42-gene panel)
  --no-plots             Skip saving figures
```

**Examples:**

```bash
# Basic usage
checkdyn analyze --input data/processed/cohort.h5ad --output results/

# Custom gene list
checkdyn analyze --input cohort.h5ad --genes PDCD1,CD274,HAVCR2,LAG3,TIGIT

# Skip figures (faster, headless environments)
checkdyn analyze --input cohort.h5ad --no-plots
```

**Required `.obs` columns in h5ad:**

| Column | Values | Description |
|--------|--------|-------------|
| `paired_id` | string | Links pre/post samples for the same patient |
| `timepoint` | `"pre"` / `"post"` | Treatment timepoint |
| `response` | `"R"` / `"NR"` | Treatment response label |
| `cancer_type` | string | Cancer type (e.g. `"melanoma"`) |

**Outputs:**

| File | Description |
|------|-------------|
| `de_results.csv` | Paired differential expression for all checkpoint genes |
| `compensation.csv` | Compensatory upregulation candidates |
| `validation_report.json` | Biological plausibility validation |
| `volcano_all.png` | Volcano plot (log2FC vs −log10(padj)) |

**`de_results.csv` columns:**

| Column | Description |
|--------|-------------|
| `gene` | HGNC gene symbol |
| `log2FC_all` | Overall paired log2 fold change (post/pre) |
| `pval_all` | Wilcoxon signed-rank p-value |
| `padj_all` | BH-corrected FDR |
| `log2FC_R` | log2FC in responders |
| `log2FC_NR` | log2FC in non-responders |
| `padj_R` / `padj_NR` | BH-corrected FDR per group |
| `direction` | `"up"` / `"down"` / `"stable"` |
| `R_vs_NR_interaction_p` | R vs NR delta interaction p-value (Mann-Whitney U) |

---

### `checkdyn predict`

Predict adaptive resistance risk for patients from pre-treatment expression data.

```bash
checkdyn predict [OPTIONS]

Options:
  --input        PATH     CSV file with patient expression (required)
  --output  -o   PATH     Output CSV  (default: print table to terminal)
  --model   -m   PATH     Model checkpoint (.pt) or "untrained"  [default: untrained]
  --cancer-type  STRING   Cancer type  [default: other]
                          Options: melanoma, nsclc, bcc, hnscc, rcc, other
  --treatment    STRING   Treatment  [default: anti-pd-1]
                          Options: anti-pd-1, anti-ctla-4, anti-pd-1+anti-ctla-4, other
  --show-format           Print expected CSV column format and exit
```

**Examples:**

```bash
# Check expected format
checkdyn predict --show-format

# Predict (demo mode, untrained model)
checkdyn predict --input demo_out/sample_patients.csv --cancer-type melanoma

# Predict with trained model, save output
checkdyn predict \
  --input patients.csv \
  --model models/checkdyn_best.pt \
  --cancer-type nsclc \
  --treatment anti-pd-1 \
  --output results/risk_predictions.csv
```

**Input CSV format:** See [Patient CSV format](#patient-csv-format).

**Output columns:**

| Column | Description |
|--------|-------------|
| `patient_id` | Patient identifier (from CSV or row index) |
| `risk_score` | Continuous resistance risk score (0–1) |
| `risk_category` | `"low"` / `"medium"` / `"high"` |
| `predicted_compensation_pattern` | One of 5 compensatory patterns (see below) |
| `top_compensatory_checkpoints` | Top 3 genes by predicted delta (in CSV: gene(+delta) strings) |

**Compensation patterns:**

| Pattern | Name | Biology |
|---------|------|---------|
| A | TIM3_LAG3_coUpregulation | Deepened T cell exhaustion |
| B | VISTA_B7H3_escape | Novel checkpoint escape |
| C | CD47_SIRPa_innate_switch | Innate immune checkpoint switch |
| D | IDO1_CD73_metabolic_switch | Metabolic checkpoint switch |
| E | minimal_change | Non-compensatory resistance |

> **Note on untrained model:** Without a trained checkpoint, `risk_score` values are not clinically meaningful. Use `checkdyn demo` + `checkdyn train` first.

---

### `checkdyn train`

Train the CheckpointDynamicsTransformer on your paired expression data.

```bash
checkdyn train [OPTIONS]

Options:
  --input   -i  PATH   AnnData .h5ad file (required)
  --output  -o  PATH   Directory for model checkpoints  [default: models/]
  --config       PATH  YAML model config (default: configs/model_config.yaml)
  --epochs       INT   Override max_epochs in config
```

**Example:**

```bash
# Train on demo data (quick test)
checkdyn train --input demo_out/demo_data.h5ad --output models/ --epochs 20

# Full training
checkdyn train --input data/processed/all_cohorts.h5ad --output models/
```

**Saves:** `models/checkdyn_best.pt` — use with `checkdyn predict --model models/checkdyn_best.pt`.

**Training details:**
- Optimizer: AdamW, lr=1e-3
- Scheduler: CosineAnnealingLR
- Multi-task loss: delta prediction (MSE) + response classification (BCE) + compensation pattern (CE)
- Data augmentation: Gaussian noise, gene dropout, mixup
- Early stopping (patience=30)
- Target: response AUROC ≥ 0.75

---

## Python API

### 1. Data ingestion

```python
from checkdyn.data import PairedICIDataCollector, TranscriptomeHarmonizer, DataSplitter

# ── Dataset registry ──────────────────────────────────────────────────────────
collector = PairedICIDataCollector(output_dir="data/raw")

# View all registered datasets
df = collector.create_dataset_summary()
print(df[["dataset", "cancer_type", "treatment", "n_paired_samples", "platform"]])

# Download a single dataset (requires GEOparse + internet)
collector.download_dataset("Riaz2017")   # → data/raw/Riaz2017/

# Download all (skips already-downloaded)
results = collector.download_all()

# Search GEO for additional paired cohorts
extra = collector.search_additional_datasets()  # returns List[dict]

# ── Harmonization ─────────────────────────────────────────────────────────────
# datasets_dict: {name: pd.DataFrame (genes × samples)}
harmonizer = TranscriptomeHarmonizer(datasets_dict)

datasets_unified = harmonizer.unify_gene_symbols()     # intersect gene sets
datasets_norm    = harmonizer.normalize_expression()    # log1p normalization
batch_corrected  = harmonizer.batch_correction()        # pyComBat
adata            = harmonizer.create_unified_dataset()  # → AnnData

# Validate batch correction preserved biological signal
report = harmonizer.validate_batch_correction(adata)
# → {"pdl1_signal_preserved": True, "batch_silhouette": 0.12, "n_samples": 180}

# ── Splitting ─────────────────────────────────────────────────────────────────
splitter = DataSplitter()

# Dataset-level split
splits = splitter.split_by_dataset(adata,
    train_datasets=["Riaz2017", "Gide2019", "Chen2016"],
    val_datasets=["Hanna2018"],
    test_datasets=["Yost2019"])
# splits = {"train": AnnData, "val": AnnData, "test": AnnData}

# Patient-level stratified split (maintains paired integrity)
splits = splitter.split_by_patient(adata, test_size=0.2, val_size=0.1)

# Leave-one-cancer-out cross-validation
for train_adata, test_adata, cancer_type in splitter.leave_one_cancer_out(adata):
    print(f"Testing on {cancer_type}: {test_adata.n_obs} samples")
```

---

### 2. Paired dynamics analysis

```python
import anndata as ad
from checkdyn.dynamics import PairedCheckpointDynamics
from checkdyn.utils import load_checkpoint_genes

adata = ad.read_h5ad("data/sample/demo_paired.h5ad")
genes = load_checkpoint_genes()   # 42-gene panel

analyzer = PairedCheckpointDynamics(adata, genes)

# ── Paired differential expression ───────────────────────────────────────────
de = analyzer.paired_differential_expression()
# Columns: gene, log2FC_all, pval_all, padj_all,
#          log2FC_R, padj_R, log2FC_NR, padj_NR,
#          direction ("up"/"down"/"stable"), R_vs_NR_interaction_p

# Significant genes
sig = de[de["padj_all"] < 0.05].sort_values("log2FC_all", ascending=False)
print(sig[["gene", "log2FC_all", "padj_all", "direction"]].head(10))

# ── Volcano plot data ─────────────────────────────────────────────────────────
volcano = analyzer.volcano_plot_data()
# Keys: "all", "responders", "non_responders"
# Each DataFrame: gene, log2FC, neg_log10_p, padj, is_significant

print(volcano["non_responders"][volcano["non_responders"]["is_significant"]])

# ── Response-stratified dynamics (LME model) ──────────────────────────────────
lme = analyzer.response_stratified_dynamics()
# Columns: gene, interaction_coef, interaction_p, padj_interaction,
#          interpretation ("ns" / "R_specific_up" / "NR_specific_up")

# ── Treatment-stratified dynamics ─────────────────────────────────────────────
treatment_de = analyzer.treatment_specific_dynamics()
# Columns: gene, treatment, log2FC, pval, padj
```

---

### 3. Compensatory upregulation detection

```python
from checkdyn.dynamics import CompensatoryUpregulationDetector

detector = CompensatoryUpregulationDetector(de, adata)

# ── Identify compensatory checkpoints ────────────────────────────────────────
# Criteria: log2FC_all > 0.5 AND padj_all < 0.05 AND R_vs_NR_interaction_p < 0.05
comp = detector.define_compensation()
# Columns: compensatory_checkpoint, log2FC_NR, log2FC_R, interaction_p, compensation_score

# ── Cluster patients into 5 compensation patterns ────────────────────────────
patterns = detector.compensation_pattern_clustering(n_clusters=5)
# {cluster_id: {pattern_name, patient_list, characteristic_checkpoints, frequency}}

for cid, meta in patterns.items():
    print(f"Pattern {cid}: {meta['pattern_name']}")
    print(f"  Patients: {len(meta['patient_list'])}")
    print(f"  Key checkpoints: {meta['characteristic_checkpoints']}")

# ── Co-occurrence network ─────────────────────────────────────────────────────
import networkx as nx
G = detector.compensation_network()   # nx.DiGraph, edges weighted by Spearman r
print(f"Network: {G.number_of_nodes()} nodes, {G.number_of_edges()} edges")

# ── Clinical impact (requires pfs_months / pfs_event in adata.obs) ───────────
impact = detector.clinical_impact_of_compensation()
# Columns: pattern, n_patients, median_pfs, logrank_p
```

---

### 4. Network rewiring analysis

```python
from checkdyn.dynamics import CheckpointNetworkRewiring

rewirer = CheckpointNetworkRewiring(adata, genes)

# ── Build pre/post co-expression networks ────────────────────────────────────
G_pre, G_post = rewirer.build_pre_post_networks()

# ── Detect rewired edges ──────────────────────────────────────────────────────
rewired = rewirer.detect_rewired_edges()
# Columns: gene1, gene2, r_pre, r_post, delta_r, p_value, edge_type
# edge_type: "new" / "lost" / "strengthened" / "weakened" / "stable"

print(rewired[rewired["edge_type"] == "new"][["gene1", "gene2", "r_post"]].head(10))

# ── Hub gene shift ────────────────────────────────────────────────────────────
hubs = rewirer.hub_gene_shift()
# Columns: gene, degree_pre, degree_post, delta_degree,
#          betweenness_pre, betweenness_post, delta_betweenness

# Top newly central genes post-treatment
print(hubs.sort_values("delta_degree", ascending=False).head(5))

# ── Differential network by response ─────────────────────────────────────────
diff = rewirer.differential_network_by_response()
# Keys: R_pre, R_post, NR_pre, NR_post (nx.Graph), R, NR, differential_edges (DataFrame)
```

---

### 5. Model training & evaluation

```python
import yaml
from checkdyn.model import CheckpointDynamicsTransformer, CheckDynTrainer, CheckDynEvaluator
from checkdyn.data import DataSplitter

# ── Build model ───────────────────────────────────────────────────────────────
model = CheckpointDynamicsTransformer(
    n_checkpoints=42,       # number of checkpoint genes
    n_context_genes=200,    # tumour microenvironment context genes
    d_model=128,            # transformer embedding dimension
    nhead=4,                # attention heads
    num_layers=3,           # transformer encoder layers
    n_cancer_types=8,       # distinct cancer type embeddings
    n_treatments=4,         # distinct treatment embeddings
)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# ── Load config ───────────────────────────────────────────────────────────────
with open("configs/model_config.yaml") as f:
    cfg = yaml.safe_load(f)

# ── Split data ────────────────────────────────────────────────────────────────
splitter = DataSplitter()
splits = splitter.split_by_patient(adata, test_size=0.2, val_size=0.1)

# ── Prepare data loaders ──────────────────────────────────────────────────────
cancer_map    = {"melanoma": 0, "nsclc": 1, "bcc": 2, "hnscc": 3, "rcc": 4}
treatment_map = {"anti-PD-1": 0, "anti-CTLA-4": 1}

trainer = CheckDynTrainer(model, cfg["training"])
train_loader = trainer.prepare_training_data(
    splits["train"], genes, context_genes=[], 
    cancer_type_map=cancer_map, treatment_map=treatment_map
)
val_loader = trainer.prepare_training_data(
    splits["val"], genes, context_genes=[],
    cancer_type_map=cancer_map, treatment_map=treatment_map
)

# ── Train ─────────────────────────────────────────────────────────────────────
history = trainer.train(train_loader, val_loader)
print(f"Best epoch: {history['best_epoch']}")
print(f"Best val metrics: {history['best_val_metrics']}")

# ── Save / load checkpoint ────────────────────────────────────────────────────
trainer.save_checkpoint("models/checkdyn_best.pt", epoch=history["best_epoch"],
                        metrics=history["best_val_metrics"])

meta = trainer.load_checkpoint("models/checkdyn_best.pt")
print(f"Loaded epoch {meta['epoch']}")

# ── Evaluate ──────────────────────────────────────────────────────────────────
evaluator = CheckDynEvaluator(model, device="cpu")
test_loader = trainer.prepare_training_data(splits["test"], genes, [],
                                            cancer_map, treatment_map)

delta_metrics = evaluator.evaluate_delta_prediction(test_loader)
# {"per_gene": [...], "overall": {"mse": ..., "mae": ..., "mean_pearson_r": ...}}

response_metrics = evaluator.evaluate_response_prediction(test_loader)
# {"auc": 0.81, "auprc": 0.74, "accuracy": 0.77, "f1": 0.75}

comp_metrics = evaluator.evaluate_compensation_detection(test_loader)
# {"per_class_f1": [0.82, 0.71, 0.69, 0.74, 0.88], "macro_f1": 0.77}

# ── Benchmark vs baselines ────────────────────────────────────────────────────
bench = evaluator.benchmark_vs_baselines(test_loader)
print(bench[["mse", "mae", "mean_pearson_r"]])
# Methods: NoChange, MeanDelta, LinearRegression, RandomForest, Ridge, CheckDyn

# ── Leave-one-dataset-out evaluation ─────────────────────────────────────────
lodo = evaluator.leave_one_dataset_out(adata, datasets_list=["Riaz2017", "Gide2019"],
                                       checkpoint_genes=genes, cancer_map=cancer_map,
                                       treatment_map=treatment_map)
print(lodo[["dataset", "pearson_r", "response_auc", "compensation_f1"]])
```

---

### 6. Resistance risk prediction

```python
import pandas as pd
from checkdyn.model import AdaptiveResistancePredictor

predictor = AdaptiveResistancePredictor(
    dynamics_model=model,
    compensation_patterns=patterns,   # from detector.compensation_pattern_clustering()
    checkpoint_genes=genes,
)

# ── Batch prediction ──────────────────────────────────────────────────────────
patient_df = pd.read_csv("data/sample/patients_sample.csv")

results = predictor.predict_resistance_risk(
    patient_df,
    cancer_type_id=0,    # 0=melanoma, 1=nsclc, 2=bcc, 3=hnscc, 4=rcc, 5=other
    treatment_id=0,      # 0=anti-PD-1, 1=anti-CTLA-4, 2=combo, 3=other
)

print(results[["patient_id", "risk_score", "risk_category", "predicted_compensation_pattern"]])

# ── Single patient: combination therapy suggestion ────────────────────────────
patient = patient_df.iloc[0].copy()
patient["cancer_type_id"] = 0
patient["treatment_id"] = 0

suggestions = predictor.suggest_combination_strategy(patient)
for s in suggestions:
    print(f"{s['target']}  →  {s['drug_name']}  ({s['evidence_level']})")
    print(f"  Trial: {s['clinical_trial_id']}")
    print(f"  {s['rationale']}\n")

# ── Single patient: resistance timeline ──────────────────────────────────────
timeline = predictor.resistance_timeline(patient)
print(f"Expected progression window: {timeline['predicted_progression_window']}")
print(f"Confidence: {timeline['confidence']}")
print(f"Note: {timeline['note']}")
```

**Built-in combination therapy rules:**

| Gene | Target | Drug | Evidence |
|------|--------|------|----------|
| HAVCR2 / TIM3 | TIM-3 | Sabatolimab (MBG453) | Phase II |
| LAG3 | LAG-3 | Relatlimab | Phase III (RELATIVITY-047) |
| TIGIT | TIGIT | Tiragolumab | Phase II/III (SKYSCRAPER-01) |
| CD47 | CD47 | Magrolimab | Phase III |
| IDO1 | IDO1 | Epacadostat | Phase III (ECHO-301) |
| VISTA | VISTA | CA-170 / HMBD-002 | Phase I |
| CD274 | PD-L1 | Atezolizumab / Durvalumab | FDA-approved |

---

### 7. Biological validation

```python
from checkdyn.validation import BiologicalValidator

validator = BiologicalValidator(genes)

# Individual checks
pdl1 = validator.validate_pdl1_upregulation(de)
# {"passed": True, "pdl1_log2fc": 0.72, "message": "CD274 upregulated post-ICI (expected)"}

comp_val = validator.validate_compensation_known_pairs(comp)
# {"known_pairs_detected": ["TIM3-LAG3", "CD47-SIRPA"], "novel_pairs": [...],
#  "validation_rate": 0.67}

exhaust = validator.validate_exhaustion_markers(de)
# {"passed": True, "exhaustion_fc": {"TOX": 0.41, "NR4A1": 0.58, ...}, "message": ...}

# Combined report
report = validator.generate_validation_report(de, comp)
# {"pdl1_upregulation": {...}, "compensation_known_pairs": {...},
#  "exhaustion_markers": {...}, "summary": {"passed": 3, "total_checks": 3}}
```

---

### 8. Visualization

```python
import matplotlib.pyplot as plt
from checkdyn.visualization import DynamicsPlotter

plotter = DynamicsPlotter(figsize_base=(8, 6), dpi=150)

# ── Volcano plot ──────────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(7, 6))
plotter.volcano(de, log2fc_col="log2FC_all", padj_col="padj_all", label_col="gene", ax=ax)
plotter.save(fig, "figures/volcano.pdf")

# ── Spaghetti plot (per-patient trajectories) ─────────────────────────────────
obs = adata.obs
pre_obs  = obs[obs["timepoint"] == "pre"]
post_obs = obs[obs["timepoint"] == "post"]
common   = pre_obs["paired_id"].values

fig, ax = plt.subplots()
plotter.paired_spaghetti(
    pre=adata[pre_obs.index, "HAVCR2"].X.flatten(),
    post=adata[post_obs.index, "HAVCR2"].X.flatten(),
    response=pre_obs["response"],
    gene_name="HAVCR2",
    ax=ax,
)

# ── Checkpoint heatmap ────────────────────────────────────────────────────────
fig, ax = plt.subplots(figsize=(10, 8))
plotter.checkpoint_heatmap(de, value_col="log2FC_all", ax=ax)

# ── Compensation network ──────────────────────────────────────────────────────
G = detector.compensation_network()
fig, ax = plt.subplots(figsize=(8, 8))
plotter.compensation_network_plot(G, ax=ax)

# ── Kaplan-Meier curves ───────────────────────────────────────────────────────
# Requires pfs_months and pfs_event in adata.obs
fig, ax = plt.subplots()
plotter.kaplan_meier(
    time=adata.obs["pfs_months"],
    event=adata.obs["pfs_event"],
    group=adata.obs["response"],
    ax=ax,
)
```

---

## Input Data Format

### AnnData (.h5ad) format

Use for `checkdyn analyze`, `checkdyn train`, and direct Python API usage.

**Create from scratch:**

```python
import anndata as ad
import pandas as pd
import numpy as np

# Expression matrix: rows = samples, columns = genes
X = np.random.lognormal(1.5, 0.8, size=(60, 42)).astype("float32")

# obs: one row per sample
obs = pd.DataFrame({
    "patient_id":  ["P001", "P001", "P002", "P002", ...],  # patient identifier
    "paired_id":   ["P001", "P001", "P002", "P002", ...],  # links pre/post (same as patient_id usually)
    "timepoint":   ["pre",  "post", "pre",  "post", ...],  # MUST be "pre" or "post"
    "response":    ["R",    "R",    "NR",   "NR",   ...],  # "R" = responder, "NR" = non-responder
    "cancer_type": ["melanoma", "melanoma", "nsclc", "nsclc", ...],
    "treatment":   ["anti-PD-1", ...],   # optional but recommended
    "dataset":     ["MyStudy2024", ...], # optional, for multi-cohort analysis
})

# var: one row per gene
var = pd.DataFrame({
    "gene_symbol": ["PDCD1", "CD274", ...],
    "is_checkpoint": [True, True, ...],
}, index=["PDCD1", "CD274", ...])

adata = ad.AnnData(X=X, obs=obs, var=var)
adata.write_h5ad("my_cohort.h5ad")
```

**Required `obs` columns:**

| Column | Type | Required | Values | Notes |
|--------|------|----------|--------|-------|
| `paired_id` | str | Yes | patient ID | Used to link pre/post samples |
| `timepoint` | str | Yes | `"pre"`, `"post"` | Exact strings required |
| `response` | str | Yes | `"R"`, `"NR"` | R=responder, NR=non-responder |
| `cancer_type` | str | Yes | free text | e.g. `"melanoma"` |
| `treatment` | str | No | free text | e.g. `"anti-PD-1"` |
| `dataset` | str | No | free text | Cohort name for multi-dataset analysis |
| `pfs_months` | float | No | ≥0 | For Kaplan-Meier analysis |
| `pfs_event` | int | No | 0/1 | PFS event indicator |

**Gene name conventions:**

Columns (var index) must use HGNC standard symbols. CheckDyn resolves common aliases automatically:

| Alias | Canonical |
|-------|-----------|
| PD-1 | PDCD1 |
| PD-L1 | CD274 |
| TIM-3, TIM3 | HAVCR2 |
| LAG-3 | LAG3 |
| B7-H3 | CD276 |
| VISTA | VSIR |
| SIRPa, SIRP-alpha | SIRPA |

---

### Patient CSV format

Use for `checkdyn predict`.

**Structure:** One patient per row. Columns are HGNC gene symbols plus optional metadata.

```
patient_id,cancer_type,treatment,PDCD1,CD274,PDCD1LG2,CTLA4,...
P001,melanoma,anti-PD-1,5.7188,1.9503,8.1691,...
P002,melanoma,anti-PD-1,2.6316,5.3963,4.9201,...
```

**Column rules:**

| Column | Required | Description |
|--------|----------|-------------|
| `patient_id` | No | Row identifier in output. If absent, row index is used. |
| `cancer_type_id` | No | Integer cancer type (0–5). Overrides `--cancer-type` CLI flag per-patient. |
| `treatment_id` | No | Integer treatment (0–3). Overrides `--treatment` CLI flag per-patient. |
| `<GENE>` | Yes (≥5) | Log-normalized expression. Column name = HGNC symbol. |

**Expression values:** Should be log-normalized (e.g. `log1p(CPM)` or `log1p(TPM)`), the same normalization used during training.

**Get the full column list:**

```bash
checkdyn predict --show-format
```

```python
from checkdyn.utils import load_checkpoint_genes
print(load_checkpoint_genes())
# ['PDCD1', 'CD274', 'PDCD1LG2', 'CTLA4', 'CD80', 'CD86', 'HAVCR2', 'LGALS9',
#  'LAG3', 'FGL1', 'TIGIT', 'CD155', 'CD112', 'VSIR', 'CD276', 'CD47', 'SIRPA',
#  'BTLA', 'TNFRSF14', 'CD28', 'ICOS', 'ICOSLG', 'CD27', 'CD70', 'TNFRSF4',
#  'TNFSF4', 'TNFRSF9', 'TNFSF9', 'IDO1', 'IDO2', 'TDO2', 'NT5E', 'ENTPD1',
#  'PTGS2', 'TOX', 'TOX2', 'EOMES', 'TBX21', 'NR4A1', 'NR4A2', 'NR4A3', 'PRDM1']
```

Missing gene columns are filled with zeros (with a warning). Partial panels work but may reduce prediction quality.

---

## Sample Data

The `data/sample/` directory contains ready-to-use synthetic data:

| File | Description | Use with |
|------|-------------|---------|
| `data/sample/patients_sample.csv` | 15 patients × 42 checkpoint genes (pre-treatment) | `checkdyn predict` |
| `data/sample/demo_paired.h5ad` | 60 samples: 30 patients × 2 timepoints (pre/post) | `checkdyn analyze`, `checkdyn train` |

**Try them:**

```bash
# Analyze paired cohort
checkdyn analyze --input data/sample/demo_paired.h5ad --output results/

# Predict resistance risk
checkdyn predict --input data/sample/patients_sample.csv --cancer-type melanoma

# Generate fresh synthetic data with custom parameters
checkdyn demo --output my_demo/ --n-patients 100 --seed 123
```

**Generate programmatically:**

```python
from checkdyn.utils import load_checkpoint_genes
import numpy as np, pandas as pd

genes = load_checkpoint_genes()
rng = np.random.default_rng(42)

df = pd.DataFrame(
    rng.lognormal(1.5, 0.8, size=(10, len(genes))),
    columns=genes
)
df.insert(0, "patient_id", [f"P{i+1:03d}" for i in range(10)])
df.to_csv("my_patients.csv", index=False)
```

---

## Checkpoint Gene Panel

CheckDyn uses a curated panel of **42 immune checkpoint genes** across 5 functional categories:

| Category | Genes |
|----------|-------|
| **Inhibitory checkpoints** | PDCD1, CD274, PDCD1LG2, CTLA4, CD80, CD86, HAVCR2, LGALS9, LAG3, FGL1, TIGIT, CD155, CD112, VSIR, CD276, CD47, SIRPA, BTLA, TNFRSF14 |
| **Co-stimulatory** | CD28, ICOS, ICOSLG, CD27, CD70, TNFRSF4, TNFSF4, TNFRSF9, TNFSF9 |
| **Metabolic checkpoints** | IDO1, IDO2, TDO2, NT5E, ENTPD1, PTGS2 |
| **Exhaustion markers** | TOX, TOX2, EOMES, TBX21, NR4A1, NR4A2, NR4A3, PRDM1 |

Load by category:

```python
from checkdyn.utils import load_checkpoint_genes

inhibitory = load_checkpoint_genes(categories=["inhibitory_checkpoints"])
exhaustion = load_checkpoint_genes(categories=["exhaustion_markers"])
all_genes  = load_checkpoint_genes()   # all 42

# Resolve aliases
from checkdyn.utils import resolve_gene_alias
resolve_gene_alias("TIM-3")   # → "HAVCR2"
resolve_gene_alias("PD-1")    # → "PDCD1"
```

---

## Architecture Overview

```
checkdyn/
├── data/
│   ├── collector.py        ← GEO download, 8 registered cohorts
│   ├── harmonizer.py       ← batch correction (pyComBat), normalization
│   └── splitter.py         ← patient/dataset/LOCO CV splits
├── dynamics/
│   ├── paired_analysis.py  ← Wilcoxon signed-rank DE, LME interaction model
│   ├── compensation.py     ← compensatory upregulation detection, k-means patterns
│   ├── trajectory.py       ← temporal dynamics, slope clustering
│   └── network_rewiring.py ← Spearman co-expression networks, edge rewiring
├── model/
│   ├── checkpoint_transformer.py  ← Transformer encoder (<5M params)
│   ├── trainer.py                 ← multi-task training loop
│   ├── evaluator.py               ← metrics, baselines, LODO evaluation
│   └── resistance_predictor.py   ← clinical risk score + therapy suggestions
├── validation/
│   ├── biological_validation.py  ← PD-L1 signal, exhaustion marker checks
│   └── external_validation.py   ← external cohort validation
├── visualization/
│   ├── dynamics_plots.py  ← volcano, spaghetti, heatmap, network, KM
│   └── paper_figures.py   ← publication-ready panels
└── utils/
    ├── gene_sets.py   ← 42-gene checkpoint panel, alias resolution
    └── statistics.py  ← Wilcoxon, BH-FDR, log2FC, LME, Spearman
```

**Model architecture:**

```
pre-treatment checkpoint expression (42 genes)
        ↓  linear projection
Transformer Encoder (3 layers, d=128, 4 heads)
        ↓
   ┌────┴────────────────┐
   │                     │
   ↓                     ↓
Context features    Cancer/treatment
  (200 TME genes)     embeddings
        └──────┬─────────┘
               ↓
         MLP fusion
               ↓
    ┌─────────────────────┐
    │          │          │
    ↓          ↓          ↓
 delta      response   compensation
prediction  (binary)  pattern (5-class)
(42 values)
```

**Training objective (multi-task):**

```
L = 1.0 × MSE(delta) + 0.5 × BCE(response) + 0.3 × CE(compensation_pattern)
```

---

## Supported Datasets

| Dataset | Cancer | Treatment | Platform | n (paired) | GEO |
|---------|--------|-----------|----------|------------|-----|
| Riaz2017 | Melanoma | anti-PD-1 (nivolumab) | RNA-seq | 23 | GSE91061 |
| Gide2019 | Melanoma | anti-PD-1 ± anti-CTLA-4 | RNA-seq | 32 | GSE91061 |
| Chen2016 | Melanoma | anti-CTLA-4 (ipilimumab) | RNA-seq | 18 | GSE79691 |
| Sade-Feldman2018 | Melanoma | anti-PD-1 ± anti-CTLA-4 | scRNA-seq | 48 | GSE120575 |
| Yost2019 | BCC | anti-PD-1 | scRNA-seq | — | GSE123813 |
| Caushi2021 | NSCLC | neoadjuvant anti-PD-1 | scRNA-seq + TCR | — | GSE176021 |
| Hanna2018 | HNSCC | anti-PD-1 (pembrolizumab) | NanoString | — | GSE112927 |
| Miao2018 | RCC | anti-PD-1/PD-L1 | WES + RNA-seq | — | Supp. |

Target: ≥ 150 paired samples across ≥ 4 cancer types.

---

## Dependencies

| Package | Version | Purpose |
|---------|---------|---------|
| `torch` | ≥ 2.0 | Transformer model |
| `scanpy` | ≥ 1.10 | AnnData I/O, preprocessing |
| `anndata` | ≥ 0.10 | Data structures |
| `statsmodels` | ≥ 0.14 | Linear mixed-effects models |
| `scipy` | ≥ 1.12 | Wilcoxon signed-rank test |
| `pycomBat` | ≥ 0.3 | Batch correction |
| `networkx` | ≥ 3.0 | Co-expression networks |
| `lifelines` | ≥ 0.28 | Kaplan-Meier / survival analysis |
| `GEOparse` | ≥ 2.0 | GEO dataset download |
| `scikit-learn` | ≥ 1.4 | K-means, baselines |
| `matplotlib` | ≥ 3.8 | Figures |
| `seaborn` | ≥ 0.13 | Figure styling |
| `typer` + `rich` | ≥ 0.9 / 13.0 | CLI |

---

## License

MIT
