"""Batch dose-response fitting for high-throughput screening (HTS).
This is the primary GPU showcase: fit thousands of 4PL curves simultaneously.
Each compound's curve fit is independent — perfect for GPU batching.
**CPU path**: loops over compounds calling :func:`fit_drm` for each.
**GPU path**: batched Levenberg-Marquardt in PyTorch. All K compounds are
fit simultaneously using vectorised forward passes, finite-difference
Jacobians, and batched linear solves. EC50 is parameterised on the log
scale so the optimisation is unconstrained.
"""
from __future__ import annotations
import numpy as np
from numpy.typing import NDArray
from pystatsbio.doseresponse._common import BatchDoseResponseResult
# ---------------------------------------------------------------------------
# CPU fallback
# ---------------------------------------------------------------------------
def _batch_cpu(
dose_matrix: NDArray,
response_matrix: NDArray,
model: str,
max_iter: int,
tol: float,
) -> BatchDoseResponseResult:
"""Fit each compound sequentially on CPU via :func:`fit_drm`."""
from pystatsbio.doseresponse._fit import fit_drm
K = dose_matrix.shape[0]
ec50 = np.empty(K)
hill = np.empty(K)
top = np.empty(K)
bottom = np.empty(K)
converged = np.empty(K, dtype=bool)
rss = np.empty(K)
for i in range(K):
try:
result = fit_drm(dose_matrix[i], response_matrix[i], model=model)
ec50[i] = result.params.ec50
hill[i] = result.params.hill
top[i] = result.params.top
bottom[i] = result.params.bottom
converged[i] = result.converged
rss[i] = result.rss
except Exception:
ec50[i] = hill[i] = top[i] = bottom[i] = rss[i] = np.nan
converged[i] = False
return BatchDoseResponseResult(
ec50=ec50, hill=hill, top=top, bottom=bottom,
converged=converged, rss=rss, n_compounds=K,
)
# ---------------------------------------------------------------------------
# GPU batched Levenberg-Marquardt
# ---------------------------------------------------------------------------
def _batch_gpu(
dose_matrix: NDArray,
response_matrix: NDArray,
max_iter: int,
tol: float,
) -> BatchDoseResponseResult:
"""Batched Levenberg-Marquardt for LL.4 on GPU (CUDA / MPS / CPU-torch).
Parameters are ``[bottom, top, log_ec50, hill]`` — log-scale EC50
keeps the optimisation unconstrained.
"""
import torch
# Select device
if torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
# MPS (Apple Silicon) does not support float64 — use float32 there
if device.type == "mps":
dtype = torch.float32
else:
dtype = torch.float64
# Adapt numerical constants for precision of chosen dtype
is_f32 = (dtype == torch.float32)
eps_fd = 1e-3 if is_f32 else 1e-5 # finite-difference step
tol_eff = max(tol, 1e-5) if is_f32 else tol # convergence tolerance
diag_clamp = 1e-6 if is_f32 else 1e-12 # minimum diagonal damping
lam_lo = 1e-10 if is_f32 else 1e-15 # lambda lower bound
lam_hi = 1e10 if is_f32 else 1e15 # lambda upper bound
K, N = dose_matrix.shape
dose_t = torch.from_numpy(dose_matrix).to(device=device, dtype=dtype)
resp_t = torch.from_numpy(response_matrix).to(device=device, dtype=dtype)
# Pre-compute log-dose (handle dose=0 → -inf)
log_dose = torch.where(
dose_t > 0,
torch.log(dose_t),
torch.tensor(float("-inf"), device=device, dtype=dtype),
)
# ---- vectorised self-start -----------------------------------------
theta = _batch_init_gpu(dose_t, resp_t, log_dose, device, dtype) # (K, 4)
lam = torch.full((K,), 1.0, device=device, dtype=dtype) # start conservatively
conv = torch.zeros(K, dtype=torch.bool, device=device)
def _ll4_fwd(ld: torch.Tensor, th: torch.Tensor) -> torch.Tensor:
"""LL.4 forward pass for all K compounds.
th[:, 0]=bottom, th[:, 1]=top, th[:, 2]=log_ec50, th[:, 3]=hill
"""
b = th[:, 0].unsqueeze(1)
t = th[:, 1].unsqueeze(1)
le = th[:, 2].unsqueeze(1)
h = th[:, 3].unsqueeze(1)
exponent = -h * (ld - le)
return b + (t - b) / (1.0 + torch.exp(exponent))
for _ in range(max_iter):
pred = _ll4_fwd(log_dose, theta)
r = resp_t - pred
rss_old = (r**2).sum(dim=1)
# Batched model Jacobian via finite differences: J[k, n, p] = ∂f/∂θ_p
J = torch.zeros(K, N, 4, device=device, dtype=dtype)
for p in range(4):
th_p = theta.clone()
th_p[:, p] += eps_fd
pred_p = _ll4_fwd(log_dose, th_p)
J[:, :, p] = (pred_p - pred) / eps_fd
Jt = J.transpose(1, 2) # (K, 4, N)
JtJ = Jt @ J # (K, 4, 4)
Jtr = (Jt @ r.unsqueeze(2)) # (K, 4, 1)
# LM damping: A = JtJ + λ * (diag(JtJ) + μ*I)
# The identity term prevents zero damping when JtJ diagonal is tiny
diag_JtJ = torch.diagonal(JtJ, dim1=-2, dim2=-1) # (K, 4)
diag_JtJ = torch.clamp(diag_JtJ, min=diag_clamp)
# Add identity floor to ensure damping works even for near-zero columns
mu = diag_JtJ.mean(dim=1, keepdim=True).clamp(min=1.0) # (K, 1)
damping = lam.unsqueeze(1) * (diag_JtJ + mu)
A = JtJ + torch.diag_embed(damping)
# Solve for step
try:
delta = torch.linalg.solve(A, Jtr).squeeze(2) # (K, 4)
except RuntimeError:
lam *= 10.0
continue
# Trial step
theta_new = theta + delta
pred_new = _ll4_fwd(log_dose, theta_new)
rss_new = ((resp_t - pred_new) ** 2).sum(dim=1)
# Accept / reject per compound
improved = rss_new < rss_old
active = ~conv
accept = improved & active
theta = torch.where(accept.unsqueeze(1), theta_new, theta)
# Update damping
lam = torch.where(improved, lam * 0.1, lam * 10.0)
lam = torch.clamp(lam, lam_lo, lam_hi)
# Convergence
rel_change = torch.abs(rss_new - rss_old) / (rss_old + diag_clamp)
newly = (rel_change < tol_eff) & improved
conv = conv | newly
if conv.all():
break
# Final RSS
pred_final = _ll4_fwd(log_dose, theta)
rss_final = ((resp_t - pred_final) ** 2).sum(dim=1)
ec50_out = torch.exp(theta[:, 2])
return BatchDoseResponseResult(
ec50=ec50_out.cpu().numpy(),
hill=theta[:, 3].cpu().numpy(),
top=theta[:, 1].cpu().numpy(),
bottom=theta[:, 0].cpu().numpy(),
converged=conv.cpu().numpy(),
rss=rss_final.cpu().numpy(),
n_compounds=K,
)
def _batch_init_gpu(
dose_t: "torch.Tensor",
resp_t: "torch.Tensor",
log_dose: "torch.Tensor",
device: "torch.device",
dtype: "torch.dtype",
) -> "torch.Tensor":
"""Vectorised self-starting for all K compounds.
Returns ``(K, 4)`` tensor: ``[bottom, top, log_ec50, hill]``.
"""
import torch
K, N = dose_t.shape
# Sort responses by dose for each compound
dose_order = dose_t.argsort(dim=1)
resp_sorted = resp_t.gather(1, dose_order)
n_edge = max(1, N // 4)
low_resp = resp_sorted[:, :n_edge].mean(dim=1)
high_resp = resp_sorted[:, -n_edge:].mean(dim=1)
# Direction
increasing = high_resp > low_resp
bottom = torch.where(increasing, low_resp, high_resp)
top = torch.where(increasing, high_resp, low_resp)
# EC50 ≈ geometric mean of positive doses
pos_mask = dose_t > 0
# Replace zero/neg doses with 1 for log (won't affect mean much)
dose_safe = torch.where(pos_mask, dose_t, torch.ones_like(dose_t))
# Mean of log over positive doses per compound
pos_float = pos_mask.to(dtype=dtype)
n_pos = pos_float.sum(dim=1).clamp(min=1)
log_ec50 = (torch.log(dose_safe) * pos_float).sum(dim=1) / n_pos
# Hill: +1 for increasing, -1 for decreasing
hill = torch.where(
increasing,
torch.ones(K, device=device, dtype=dtype),
-torch.ones(K, device=device, dtype=dtype),
)
return torch.stack([bottom, top, log_ec50, hill], dim=1)
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
def fit_drm_batch(
dose_matrix: NDArray[np.floating],
response_matrix: NDArray[np.floating],
*,
model: str = "LL.4",
backend: str = "auto",
max_iter: int = 100,
tol: float = 1e-8,
) -> BatchDoseResponseResult:
"""Batch-fit dose-response curves across many compounds.
Parameters
----------
dose_matrix : array, shape ``(n_compounds, n_doses)``
Dose values for each compound.
response_matrix : array, shape ``(n_compounds, n_doses)``
Response values for each compound.
model : str
Model name (currently only ``'LL.4'`` for batch fitting).
backend : str
``'cpu'``, ``'gpu'``, or ``'auto'``. GPU uses batched
Levenberg-Marquardt via PyTorch for massive parallelism.
max_iter : int
Maximum LM iterations per compound (default 100).
tol : float
Convergence tolerance on relative RSS change (default 1e-8).
Returns
-------
BatchDoseResponseResult
Notes
-----
GPU backend requires ``pip install pystatsbio[gpu]`` (PyTorch).
On CPU, curves are fit sequentially using ``scipy.optimize``.
On GPU, all curves are fit simultaneously using batched Jacobian
computation and batched normal equations.
"""
dose_matrix = np.asarray(dose_matrix, dtype=np.float64)
response_matrix = np.asarray(response_matrix, dtype=np.float64)
if dose_matrix.ndim != 2:
raise ValueError(
f"dose_matrix must be 2-D (n_compounds, n_doses), got shape {dose_matrix.shape}"
)
if dose_matrix.shape != response_matrix.shape:
raise ValueError(
f"dose_matrix and response_matrix must have same shape, "
f"got {dose_matrix.shape} and {response_matrix.shape}"
)
if model != "LL.4":
raise ValueError(f"Batch fitting currently supports only 'LL.4', got {model!r}")
if backend == "cpu":
return _batch_cpu(dose_matrix, response_matrix, model, max_iter, tol)
if backend == "gpu":
return _batch_gpu(dose_matrix, response_matrix, max_iter, tol)
# auto — try GPU, fall back to CPU
try:
import torch
has_gpu = torch.cuda.is_available() or (
hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
)
if has_gpu:
return _batch_gpu(dose_matrix, response_matrix, max_iter, tol)
except ImportError:
pass
return _batch_cpu(dose_matrix, response_matrix, model, max_iter, tol)