Coverage for fastblocks / middleware.py: 69%

312 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-26 03:58 -0800

1import sys 

2import typing as t 

3from collections.abc import Mapping, Sequence 

4from contextvars import ContextVar 

5from enum import IntEnum 

6 

7from acb.debug import debug 

8from acb.depends import depends 

9from brotli_asgi import BrotliMiddleware 

10from secure import Secure 

11from starlette.datastructures import URL, Headers, MutableHeaders 

12from starlette.middleware import Middleware 

13from starlette.middleware.sessions import SessionMiddleware 

14from starlette.requests import Request 

15from starlette.types import ASGIApp, Message, Receive, Scope, Send 

16from starlette_csrf.middleware import CSRFMiddleware 

17 

18from .caching import ( 

19 CacheControlResponder, 

20 CacheDirectives, 

21 CacheResponder, 

22 Rule, 

23 delete_from_cache, 

24) 

25from .htmx import HtmxDetails 

26 

27MiddlewareCallable = t.Callable[[ASGIApp], ASGIApp] 

28MiddlewareClass = type[t.Any] 

29MiddlewareOptions = dict[str, t.Any] 

30from .exceptions import MissingCaching 

31 

32 

33class MiddlewarePosition(IntEnum): 

34 CSRF = 0 

35 SESSION = 1 

36 HTMX = 2 

37 CURRENT_REQUEST = 3 

38 COMPRESSION = 4 

39 SECURITY_HEADERS = 5 

40 

41 

42class HtmxMiddleware: 

43 def __init__(self, app: ASGIApp) -> None: 

44 self._app = app 

45 debug("HtmxMiddleware: Initialized FastBlocks native HTMX middleware") 

46 

47 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

48 if scope["type"] in ("http", "websocket"): 

49 await self._process_htmx_request(scope) 

50 await self._app(scope, receive, send) 

51 

52 async def _process_htmx_request(self, scope: Scope) -> None: 

53 """Process HTMX request and add HTMX details to scope.""" 

54 htmx_details = HtmxDetails(scope) 

55 scope["htmx"] = htmx_details 

56 if debug.enabled: 

57 self._log_htmx_details(scope, htmx_details) 

58 

59 def _log_htmx_details(self, scope: Scope, htmx_details: HtmxDetails) -> None: 

60 """Log HTMX details if debugging is enabled.""" 

61 method = scope.get("method", "UNKNOWN") 

62 path = scope.get("path", "unknown") 

63 is_htmx = bool(htmx_details) 

64 debug(f"HtmxMiddleware: {method} {path} - HTMX: {is_htmx}") 

65 if is_htmx: 

66 headers = htmx_details.get_all_headers() 

67 for header_name, header_value in headers.items(): 

68 debug(f"HtmxMiddleware: {header_name}: {header_value}") 

69 

70 

71class HtmxResponseMiddleware: 

72 def __init__(self, app: ASGIApp) -> None: 

73 self._app = app 

74 debug("HtmxResponseMiddleware: Initialized FastBlocks HTMX response middleware") 

75 

76 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

77 if scope["type"] != "http": 

78 await self._app(scope, receive, send) 

79 return 

80 

81 async def send_wrapper(message: Message) -> None: 

82 await self._process_response_message(message, scope, send) 

83 

84 await self._app(scope, receive, send_wrapper) 

85 

86 async def _process_response_message( 

87 self, message: Message, scope: Scope, send: Send 

88 ) -> None: 

89 """Process response message and handle HTMX responses.""" 

90 if message["type"] == "http.response.start": 

91 htmx_details = scope.get("htmx") 

92 if htmx_details and bool(htmx_details): 

93 debug("HtmxResponseMiddleware: Processing HTMX response") 

94 headers = list(message.get("headers", [])) 

95 message["headers"] = headers 

96 await send(message) 

97 

98 

99class MiddlewareUtils: 

100 Cache = t.Any 

101 

102 secure_headers = Secure() 

103 

104 scope_name = "__starlette_caches__" 

105 

106 _request_ctx_var: ContextVar[Scope | None] = ContextVar("request", default=None) 

107 

108 HTTP = sys.intern("http") 

109 WEBSOCKET = sys.intern("websocket") 

110 TYPE = sys.intern("type") 

111 METHOD = sys.intern("method") 

112 PATH = sys.intern("path") 

113 GET = sys.intern("GET") 

114 HEAD = sys.intern("HEAD") 

115 POST = sys.intern("POST") 

116 PUT = sys.intern("PUT") 

117 PATCH = sys.intern("PATCH") 

118 DELETE = sys.intern("DELETE") 

119 

120 @classmethod 

121 def get_request(cls) -> Scope | None: 

122 return cls._request_ctx_var.get() 

123 

124 @classmethod 

125 def set_request(cls, scope: Scope | None) -> None: 

126 cls._request_ctx_var.set(scope) 

127 

128 

129Cache = MiddlewareUtils.Cache 

130secure_headers = MiddlewareUtils.secure_headers 

131scope_name = MiddlewareUtils.scope_name 

132_request_ctx_var = MiddlewareUtils._request_ctx_var 

133 

134 

135def get_request() -> Scope | None: 

136 return MiddlewareUtils.get_request() 

137 

138 

139class CurrentRequestMiddleware: 

140 def __init__(self, app: ASGIApp) -> None: 

141 self.app = app 

142 

143 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

144 if scope[MiddlewareUtils.TYPE] not in ( 

145 MiddlewareUtils.HTTP, 

146 MiddlewareUtils.WEBSOCKET, 

147 ): 

148 await self.app(scope, receive, send) 

149 return 

150 local_scope = _request_ctx_var.set(scope) 

151 await self.app(scope, receive, send) 

152 _request_ctx_var.reset(local_scope) 

153 

154 

155class SecureHeadersMiddleware: 

156 def __init__(self, app: ASGIApp) -> None: 

157 self.app = app 

158 try: 

159 self.logger = depends.get("logger") 

160 except Exception: 

161 self.logger = None 

162 

163 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

164 if scope["type"] != "http": 

165 return await self.app(scope, receive, send) 

166 

167 async def send_with_secure_headers(message: Message) -> None: 

168 if message["type"] == "http.response.start": 

169 headers = MutableHeaders(scope=message) 

170 for header_name, header_value in secure_headers.headers.items(): 

171 headers.append(header_name, header_value) 

172 await send(message) 

173 

174 await self.app(scope, receive, send_with_secure_headers) 

175 return None 

176 

177 

178class CacheValidator: 

179 def __init__(self, rules: Sequence[Rule] | None = None) -> None: 

180 self.rules = rules or [Rule()] 

181 

182 def check_for_duplicate_middleware(self, app: ASGIApp) -> None: 

183 if not hasattr(app, "middleware"): 

184 return 

185 

186 middleware_attr = app.middleware # type: ignore[attr-defined] 

187 if callable(middleware_attr): 

188 return 

189 

190 middleware = middleware_attr 

191 self._check_for_cache_middleware_duplicates(middleware) 

192 

193 def _check_for_cache_middleware_duplicates(self, middleware: t.Any) -> None: 

194 """Check if CacheMiddleware is already in the middleware stack.""" 

195 for middleware_item in middleware: 

196 if isinstance(middleware_item, CacheMiddleware): 

197 from .exceptions import DuplicateCaching 

198 

199 msg = "CacheMiddleware detected in middleware stack" 

200 raise DuplicateCaching(msg) 

201 

202 def is_duplicate_in_scope(self, scope: Scope) -> bool: 

203 return scope_name in scope 

204 

205 

206class CacheKeyManager: 

207 def __init__(self, cache: t.Any | None = None) -> None: 

208 self.cache = cache 

209 self._cache_dict: dict[t.Any, t.Any] = {} 

210 

211 def get_cache_instance(self) -> t.Any: 

212 if self.cache is None: 

213 from .exceptions import safe_depends_get 

214 

215 self.cache = safe_depends_get("cache", self._cache_dict) 

216 return self.cache 

217 

218 

219class CacheMiddleware: 

220 def __init__( 

221 self, 

222 app: ASGIApp, 

223 *, 

224 cache: t.Any | None = None, 

225 rules: Sequence[Rule] | None = None, 

226 ) -> None: 

227 self.app = app 

228 

229 self.validator = CacheValidator(rules) 

230 self.key_manager = CacheKeyManager(cache) 

231 

232 self.cache = cache 

233 

234 self.rules = self.validator.rules 

235 

236 self.validator.check_for_duplicate_middleware(app) 

237 

238 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

239 cache = self.key_manager.get_cache_instance() 

240 self.cache = cache 

241 if scope["type"] != "http": 

242 await self.app(scope, receive, send) 

243 return 

244 if self.validator.is_duplicate_in_scope(scope): 

245 from .exceptions import DuplicateCaching 

246 

247 msg = ( 

248 "Another `CacheMiddleware` was detected in the middleware stack.\n" 

249 "HINT: this exception probably occurred because:\n" 

250 "- You wrapped an application around `CacheMiddleware` multiple times.\n" 

251 "- You tried to apply `@cached()` onto an endpoint, but the application " 

252 "is already wrapped around a `CacheMiddleware`." 

253 ) 

254 raise DuplicateCaching( 

255 msg, 

256 ) 

257 scope[scope_name] = self 

258 responder = CacheResponder(self.app, rules=self.rules) 

259 await responder(scope, receive, send) 

260 

261 

262class _BaseCacheMiddlewareHelper: 

263 def __init__(self, request: Request) -> None: 

264 self.request = request 

265 if scope_name not in request.scope: 

266 msg = "No CacheMiddleware instance found in the ASGI scope. Did you forget to wrap the ASGI application with `CacheMiddleware`?" 

267 raise MissingCaching( 

268 msg, 

269 ) 

270 middleware = request.scope[scope_name] 

271 if not isinstance(middleware, CacheMiddleware): 

272 msg = f"A scope variable named {scope_name!r} was found, but it does not contain a `CacheMiddleware` instance. It is likely that an incompatible middleware was added to the middleware stack." 

273 raise MissingCaching( 

274 msg, 

275 ) 

276 self.middleware = middleware 

277 

278 

279class CacheHelper(_BaseCacheMiddlewareHelper): 

280 async def invalidate_cache_for( 

281 self, 

282 url: str | URL, 

283 *, 

284 headers: Mapping[str, str] | None = None, 

285 ) -> None: 

286 if not isinstance(url, URL): 

287 url = self.request.url_for(url) 

288 if not isinstance(headers, Headers): 

289 headers = Headers(headers) 

290 await delete_from_cache(url, vary=headers, cache=self.middleware.cache) 

291 

292 

293class CacheControlMiddleware: 

294 app: ASGIApp 

295 kwargs: CacheDirectives 

296 max_age: int | None 

297 s_maxage: int | None 

298 no_cache: bool 

299 no_store: bool 

300 no_transform: bool 

301 must_revalidate: bool 

302 proxy_revalidate: bool 

303 must_understand: bool 

304 private: bool 

305 public: bool 

306 immutable: bool 

307 stale_while_revalidate: int | None 

308 stale_if_error: int | None 

309 

310 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None: 

311 self.app = app 

312 self.kwargs = kwargs 

313 self.max_age = None 

314 self.s_maxage = None 

315 self.no_cache = False 

316 self.no_store = False 

317 self.no_transform = False 

318 self.must_revalidate = False 

319 self.proxy_revalidate = False 

320 self.must_understand = False 

321 self.private = False 

322 self.public = False 

323 self.immutable = False 

324 self.stale_while_revalidate = None 

325 self.stale_if_error = None 

326 for key, value in kwargs.items(): 

327 setattr(self, key, value) 

328 

329 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 

330 if scope["type"] != "http": 

331 await self.app(scope, receive, send) 

332 return 

333 responder = CacheControlResponder(self.app, **self.kwargs) 

334 await responder(scope, receive, send) 

335 

336 def process_response(self, response: t.Any) -> None: 

337 cache_control_parts: list[str] = [] 

338 if getattr(self, "public", False): 

339 cache_control_parts.append("public") 

340 elif getattr(self, "private", False): 

341 cache_control_parts.append("private") 

342 if getattr(self, "no_cache", False): 

343 cache_control_parts.append("no-cache") 

344 if getattr(self, "no_store", False): 

345 cache_control_parts.append("no-store") 

346 if getattr(self, "must_revalidate", False): 

347 cache_control_parts.append("must-revalidate") 

348 max_age = getattr(self, "max_age", None) 

349 if max_age is not None: 

350 cache_control_parts.append(f"max-age={max_age}") 

351 if cache_control_parts: 

352 response.headers["Cache-Control"] = ", ".join(cache_control_parts) 

353 

354 

355def get_middleware_positions() -> dict[str, int]: 

356 return {position.name: position.value for position in MiddlewarePosition} 

357 

358 

359class MiddlewareStackManager: 

360 def __init__( 

361 self, 

362 config: t.Any | None = None, 

363 logger: t.Any | None = None, 

364 ) -> None: 

365 self.config = config 

366 self.logger = logger 

367 self._middleware_registry: dict[MiddlewarePosition, MiddlewareClass] = {} 

368 self._middleware_options: dict[MiddlewarePosition, MiddlewareOptions] = {} 

369 self._custom_middleware: dict[MiddlewarePosition, Middleware] = {} 

370 self._initialized = False 

371 

372 def _ensure_dependencies(self) -> None: 

373 if self.config is None or self.logger is None: 

374 if self.config is None: 

375 self.config = depends.get("config") 

376 if self.logger is None: 

377 try: 

378 self.logger = depends.get("logger") 

379 except Exception: 

380 self.logger = None 

381 

382 def _register_default_middleware(self) -> None: 

383 self._middleware_registry.update( 

384 { 

385 MiddlewarePosition.HTMX: HtmxMiddleware, 

386 MiddlewarePosition.CURRENT_REQUEST: CurrentRequestMiddleware, 

387 MiddlewarePosition.COMPRESSION: BrotliMiddleware, 

388 }, 

389 ) 

390 self._middleware_options[MiddlewarePosition.COMPRESSION] = {"quality": 3} 

391 

392 def _register_conditional_middleware(self) -> None: 

393 self._ensure_dependencies() 

394 if not self.config: 

395 return 

396 from acb.adapters import get_adapter 

397 

398 self._middleware_registry[MiddlewarePosition.CSRF] = CSRFMiddleware 

399 self._middleware_options[MiddlewarePosition.CSRF] = { 

400 "secret": self.config.app.secret_key.get_secret_value(), 

401 "cookie_name": f"{getattr(self.config.app, 'token_id', '_fb_')}_csrf", 

402 "cookie_secure": self.config.deployed, 

403 } 

404 if get_adapter("auth"): 

405 self._middleware_registry[MiddlewarePosition.SESSION] = SessionMiddleware 

406 self._middleware_options[MiddlewarePosition.SESSION] = { 

407 "secret_key": self.config.app.secret_key.get_secret_value(), 

408 "session_cookie": f"{getattr(self.config.app, 'token_id', '_fb_')}_app", 

409 "https_only": self.config.deployed, 

410 } 

411 if self.config.deployed or getattr(self.config.debug, "production", False): 

412 self._middleware_registry[MiddlewarePosition.SECURITY_HEADERS] = ( 

413 SecureHeadersMiddleware 

414 ) 

415 

416 def initialize(self) -> None: 

417 if self._initialized: 

418 return 

419 self._register_default_middleware() 

420 self._register_conditional_middleware() 

421 self._initialized = True 

422 

423 def register_middleware( 

424 self, 

425 middleware_class: MiddlewareClass, 

426 position: MiddlewarePosition, 

427 **options: t.Any, 

428 ) -> None: 

429 self._middleware_registry[position] = middleware_class 

430 if options: 

431 self._middleware_options[position] = options 

432 

433 def add_custom_middleware( 

434 self, 

435 middleware: Middleware, 

436 position: MiddlewarePosition, 

437 ) -> None: 

438 self._custom_middleware[position] = middleware 

439 

440 def build_stack(self) -> list[Middleware]: 

441 if not self._initialized: 

442 self.initialize() 

443 

444 middleware_stack: dict[MiddlewarePosition, Middleware] = {} 

445 self._build_middleware_stack(middleware_stack) 

446 middleware_stack.update(self._custom_middleware) 

447 

448 return [ 

449 middleware_stack[position] for position in sorted(middleware_stack.keys()) 

450 ] 

451 

452 def _build_middleware_stack( 

453 self, middleware_stack: dict[MiddlewarePosition, Middleware] 

454 ) -> None: 

455 """Build the middleware stack from registered middleware.""" 

456 for position, middleware_class in self._middleware_registry.items(): 

457 options = self._middleware_options.get(position, {}) 

458 middleware_stack[position] = Middleware(middleware_class, **options) 

459 

460 def get_middleware_info(self) -> dict[str, t.Any]: 

461 if not self._initialized: 

462 self.initialize() 

463 

464 return { 

465 "registered": { 

466 pos.name: cls.__name__ for pos, cls in self._middleware_registry.items() 

467 }, 

468 "custom": { 

469 pos.name: str(middleware) 

470 for pos, middleware in self._custom_middleware.items() 

471 }, 

472 "positions": get_middleware_positions(), 

473 } 

474 

475 

476def middlewares() -> list[Middleware]: 

477 return MiddlewareStackManager().build_stack()