In [36]:
from pprint import pprint

import pandas as pd
import polars as pl

from continuousvi.continuousVI import ContinuousVI
from continuousvi.simulation import BatchSimulation
In [40]:
def evaluate(result: pd.DataFrame, target_genes: list[str], use_col_name: str, up: bool = True) -> pl.DataFrame:
    """Evaluate scores of the given result."""
    gt_included = (
        pl.DataFrame(result)
        .sort(use_col_name, descending=up)
        .with_columns(pl.col("gene").is_in(target_genes).alias("gt"))
        .with_row_index("index")
        .with_columns((pl.col("index") < len(target_genes)).alias("pred"))
        .drop("index")
    )
    TP = gt_included.filter(pl.col("gt") & pl.col("pred")).height
    TN = gt_included.filter(~pl.col("gt") & ~pl.col("pred")).height
    FP = gt_included.filter(~pl.col("gt") & pl.col("pred")).height
    FN = gt_included.filter(pl.col("gt") & ~pl.col("pred")).height
    accuracy = (TP + TN) / (TP + TN + FP + FN) if (TP + TN + FP + FN) else 0
    precision = TP / (TP + FP) if (TP + FP) else 0
    recall = TP / (TP + FN) if (TP + FN) else 0

    return pl.DataFrame({
        "Accuracy[%]": [accuracy * 100],
        "Precision[%]": [precision * 100],
        "Recall[%]": [recall * 100],
    })
In [104]:
sim = BatchSimulation(library_size_sigma=0.5)
Initializing BatchSimulation with parameters:
  cell_types=['Tcell', 'Bcell', 'Macrophage'], projects=['A', 'B']
  library_size_logmean={'A': 0.3, 'B': -0.3}, dropout_rates={'A': 0.05, 'B': 0.1}
=== Genes Count Summary ===
  Tcell: Up=50, Down=50
  Bcell: Up=50, Down=50
  Macrophage: Up=50, Down=50
  Neutral=700
Project loop:   0%|          | 0/2 [00:00<?, ?it/s]  cell_types=['Tcell', 'Bcell', 'Macrophage'], projects=['A', 'B']
  library_size_logmean={'A': 0.3, 'B': -0.3}, dropout_rates={'A': 0.05, 'B': 0.1}
=== Genes Count Summary ===
  Tcell: Up=50, Down=50
  Bcell: Up=50, Down=50
  Macrophage: Up=50, Down=50
  Neutral=700
Project loop: 100%|██████████| 2/2 [00:00<00:00, 14.29it/s]
/home/yuyasato/work3/libs/ContinuousVI/.venv/lib/python3.10/site-packages/anndata/_core/aligned_df.py:68: ImplicitModificationWarning: Transforming to str index.
  warnings.warn("Transforming to str index.", ImplicitModificationWarning)
=== Final AnnData ===
AnnData object with n_obs × n_vars = 600 × 1000
    obs: 'project_id', 'cell_type', 'age'
In [110]:
sim.plot_genes_by_age(genes=next(iter(sim.ct_up_genes["Tcell"])))
In [113]:
sim.plot_boxplot_by_batch(genes=next(iter(sim.ct_up_genes["Tcell"])))
In [106]:
cont = ContinuousVI(sim.adata, batch_key="project_id", label_key="cell_type", continuous_key="age")
In [107]:
import warnings

warnings.filterwarnings("ignore", message="Can't initialize NVML")
trained = cont.train(n_train=1)
Training multiple scVI models:   0%|          | 0/1 [00:00<?, ?it/s]GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/yuyasato/work3/libs/ContinuousVI/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
/home/yuyasato/work3/libs/ContinuousVI/.venv/lib/python3.10/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (5) is smaller than the logging interval Trainer(log_every_n_steps=10). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
/home/yuyasato/work3/libs/ContinuousVI/.venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=23` in the `DataLoader` to improve performance.
Epoch 329/800:  41%|████      | 329/800 [00:26<00:38, 12.19it/s, v_num=1, train_loss_step=3.81e+3, train_loss_epoch=3.81e+3]
                                                                            
Monitored metric elbo_validation did not improve in the last 45 records. Best score: 4026.364. Signaling Trainer to stop.

In [6]:
trained.save(
    "./data/continuous/simulation/models",
    overwrite=True,
)
Saving trained model.: 100%|██████████| 5/5 [00:01<00:00,  3.79it/s]
Out[6]:
<continuousvi.continuousVI.TrainedContinuousVI at 0x7387f0a8ee90>
In [108]:
# Use first cell type
trained.adata = trained.adata[trained.adata.obs["cell_type"] == sim.cell_types[0]]
In [109]:
result = trained.regress_advanced(n_samples=25, n_threads=1)
pprint("@Up (mode=ols)")  # noqa: T203
pprint(evaluate(result, list(sim.ct_up_genes["Tcell"]), use_col_name="Slope_97.5pct", up=True))  # noqa: T203
pprint("@Down (mode=ols)")  # noqa: T203
pprint(evaluate(result, list(sim.ct_down_genes["Tcell"]), use_col_name="Slope_97.5pct", up=False))  # noqa: T203
Sampling px: 100%|██████████| 1/1 [00:00<00:00, 24.71it/s]
Fitting regressions (threads): 100%|██████████| 25000/25000 [00:04<00:00, 5553.07it/s]
'@Up (mode=ols)'
shape: (1, 3)
┌─────────────┬──────────────┬───────────┐
│ Accuracy[%] ┆ Precision[%] ┆ Recall[%] │
│ ---         ┆ ---          ┆ ---       │
│ f64         ┆ f64          ┆ f64       │
╞═════════════╪══════════════╪═══════════╡
│ 100.0       ┆ 100.0        ┆ 100.0     │
└─────────────┴──────────────┴───────────┘
'@Down (mode=ols)'
shape: (1, 3)
┌─────────────┬──────────────┬───────────┐
│ Accuracy[%] ┆ Precision[%] ┆ Recall[%] │
│ ---         ┆ ---          ┆ ---       │
│ f64         ┆ f64          ┆ f64       │
╞═════════════╪══════════════╪═══════════╡
│ 100.0       ┆ 100.0        ┆ 100.0     │
└─────────────┴──────────────┴───────────┘
In [ ]: