Coverage for fastblocks / middleware.py: 69%
312 statements
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-26 03:58 -0800
« prev ^ index » next coverage.py v7.12.0, created at 2025-11-26 03:58 -0800
1import sys
2import typing as t
3from collections.abc import Mapping, Sequence
4from contextvars import ContextVar
5from enum import IntEnum
7from acb.debug import debug
8from acb.depends import depends
9from brotli_asgi import BrotliMiddleware
10from secure import Secure
11from starlette.datastructures import URL, Headers, MutableHeaders
12from starlette.middleware import Middleware
13from starlette.middleware.sessions import SessionMiddleware
14from starlette.requests import Request
15from starlette.types import ASGIApp, Message, Receive, Scope, Send
16from starlette_csrf.middleware import CSRFMiddleware
18from .caching import (
19 CacheControlResponder,
20 CacheDirectives,
21 CacheResponder,
22 Rule,
23 delete_from_cache,
24)
25from .htmx import HtmxDetails
27MiddlewareCallable = t.Callable[[ASGIApp], ASGIApp]
28MiddlewareClass = type[t.Any]
29MiddlewareOptions = dict[str, t.Any]
30from .exceptions import MissingCaching
33class MiddlewarePosition(IntEnum):
34 CSRF = 0
35 SESSION = 1
36 HTMX = 2
37 CURRENT_REQUEST = 3
38 COMPRESSION = 4
39 SECURITY_HEADERS = 5
42class HtmxMiddleware:
43 def __init__(self, app: ASGIApp) -> None:
44 self._app = app
45 debug("HtmxMiddleware: Initialized FastBlocks native HTMX middleware")
47 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
48 if scope["type"] in ("http", "websocket"):
49 await self._process_htmx_request(scope)
50 await self._app(scope, receive, send)
52 async def _process_htmx_request(self, scope: Scope) -> None:
53 """Process HTMX request and add HTMX details to scope."""
54 htmx_details = HtmxDetails(scope)
55 scope["htmx"] = htmx_details
56 if debug.enabled:
57 self._log_htmx_details(scope, htmx_details)
59 def _log_htmx_details(self, scope: Scope, htmx_details: HtmxDetails) -> None:
60 """Log HTMX details if debugging is enabled."""
61 method = scope.get("method", "UNKNOWN")
62 path = scope.get("path", "unknown")
63 is_htmx = bool(htmx_details)
64 debug(f"HtmxMiddleware: {method} {path} - HTMX: {is_htmx}")
65 if is_htmx:
66 headers = htmx_details.get_all_headers()
67 for header_name, header_value in headers.items():
68 debug(f"HtmxMiddleware: {header_name}: {header_value}")
71class HtmxResponseMiddleware:
72 def __init__(self, app: ASGIApp) -> None:
73 self._app = app
74 debug("HtmxResponseMiddleware: Initialized FastBlocks HTMX response middleware")
76 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
77 if scope["type"] != "http":
78 await self._app(scope, receive, send)
79 return
81 async def send_wrapper(message: Message) -> None:
82 await self._process_response_message(message, scope, send)
84 await self._app(scope, receive, send_wrapper)
86 async def _process_response_message(
87 self, message: Message, scope: Scope, send: Send
88 ) -> None:
89 """Process response message and handle HTMX responses."""
90 if message["type"] == "http.response.start":
91 htmx_details = scope.get("htmx")
92 if htmx_details and bool(htmx_details):
93 debug("HtmxResponseMiddleware: Processing HTMX response")
94 headers = list(message.get("headers", []))
95 message["headers"] = headers
96 await send(message)
99class MiddlewareUtils:
100 Cache = t.Any
102 secure_headers = Secure()
104 scope_name = "__starlette_caches__"
106 _request_ctx_var: ContextVar[Scope | None] = ContextVar("request", default=None)
108 HTTP = sys.intern("http")
109 WEBSOCKET = sys.intern("websocket")
110 TYPE = sys.intern("type")
111 METHOD = sys.intern("method")
112 PATH = sys.intern("path")
113 GET = sys.intern("GET")
114 HEAD = sys.intern("HEAD")
115 POST = sys.intern("POST")
116 PUT = sys.intern("PUT")
117 PATCH = sys.intern("PATCH")
118 DELETE = sys.intern("DELETE")
120 @classmethod
121 def get_request(cls) -> Scope | None:
122 return cls._request_ctx_var.get()
124 @classmethod
125 def set_request(cls, scope: Scope | None) -> None:
126 cls._request_ctx_var.set(scope)
129Cache = MiddlewareUtils.Cache
130secure_headers = MiddlewareUtils.secure_headers
131scope_name = MiddlewareUtils.scope_name
132_request_ctx_var = MiddlewareUtils._request_ctx_var
135def get_request() -> Scope | None:
136 return MiddlewareUtils.get_request()
139class CurrentRequestMiddleware:
140 def __init__(self, app: ASGIApp) -> None:
141 self.app = app
143 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
144 if scope[MiddlewareUtils.TYPE] not in (
145 MiddlewareUtils.HTTP,
146 MiddlewareUtils.WEBSOCKET,
147 ):
148 await self.app(scope, receive, send)
149 return
150 local_scope = _request_ctx_var.set(scope)
151 await self.app(scope, receive, send)
152 _request_ctx_var.reset(local_scope)
155class SecureHeadersMiddleware:
156 def __init__(self, app: ASGIApp) -> None:
157 self.app = app
158 try:
159 self.logger = depends.get("logger")
160 except Exception:
161 self.logger = None
163 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
164 if scope["type"] != "http":
165 return await self.app(scope, receive, send)
167 async def send_with_secure_headers(message: Message) -> None:
168 if message["type"] == "http.response.start":
169 headers = MutableHeaders(scope=message)
170 for header_name, header_value in secure_headers.headers.items():
171 headers.append(header_name, header_value)
172 await send(message)
174 await self.app(scope, receive, send_with_secure_headers)
175 return None
178class CacheValidator:
179 def __init__(self, rules: Sequence[Rule] | None = None) -> None:
180 self.rules = rules or [Rule()]
182 def check_for_duplicate_middleware(self, app: ASGIApp) -> None:
183 if not hasattr(app, "middleware"):
184 return
186 middleware_attr = app.middleware # type: ignore[attr-defined]
187 if callable(middleware_attr):
188 return
190 middleware = middleware_attr
191 self._check_for_cache_middleware_duplicates(middleware)
193 def _check_for_cache_middleware_duplicates(self, middleware: t.Any) -> None:
194 """Check if CacheMiddleware is already in the middleware stack."""
195 for middleware_item in middleware:
196 if isinstance(middleware_item, CacheMiddleware):
197 from .exceptions import DuplicateCaching
199 msg = "CacheMiddleware detected in middleware stack"
200 raise DuplicateCaching(msg)
202 def is_duplicate_in_scope(self, scope: Scope) -> bool:
203 return scope_name in scope
206class CacheKeyManager:
207 def __init__(self, cache: t.Any | None = None) -> None:
208 self.cache = cache
209 self._cache_dict: dict[t.Any, t.Any] = {}
211 def get_cache_instance(self) -> t.Any:
212 if self.cache is None:
213 from .exceptions import safe_depends_get
215 self.cache = safe_depends_get("cache", self._cache_dict)
216 return self.cache
219class CacheMiddleware:
220 def __init__(
221 self,
222 app: ASGIApp,
223 *,
224 cache: t.Any | None = None,
225 rules: Sequence[Rule] | None = None,
226 ) -> None:
227 self.app = app
229 self.validator = CacheValidator(rules)
230 self.key_manager = CacheKeyManager(cache)
232 self.cache = cache
234 self.rules = self.validator.rules
236 self.validator.check_for_duplicate_middleware(app)
238 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
239 cache = self.key_manager.get_cache_instance()
240 self.cache = cache
241 if scope["type"] != "http":
242 await self.app(scope, receive, send)
243 return
244 if self.validator.is_duplicate_in_scope(scope):
245 from .exceptions import DuplicateCaching
247 msg = (
248 "Another `CacheMiddleware` was detected in the middleware stack.\n"
249 "HINT: this exception probably occurred because:\n"
250 "- You wrapped an application around `CacheMiddleware` multiple times.\n"
251 "- You tried to apply `@cached()` onto an endpoint, but the application "
252 "is already wrapped around a `CacheMiddleware`."
253 )
254 raise DuplicateCaching(
255 msg,
256 )
257 scope[scope_name] = self
258 responder = CacheResponder(self.app, rules=self.rules)
259 await responder(scope, receive, send)
262class _BaseCacheMiddlewareHelper:
263 def __init__(self, request: Request) -> None:
264 self.request = request
265 if scope_name not in request.scope:
266 msg = "No CacheMiddleware instance found in the ASGI scope. Did you forget to wrap the ASGI application with `CacheMiddleware`?"
267 raise MissingCaching(
268 msg,
269 )
270 middleware = request.scope[scope_name]
271 if not isinstance(middleware, CacheMiddleware):
272 msg = f"A scope variable named {scope_name!r} was found, but it does not contain a `CacheMiddleware` instance. It is likely that an incompatible middleware was added to the middleware stack."
273 raise MissingCaching(
274 msg,
275 )
276 self.middleware = middleware
279class CacheHelper(_BaseCacheMiddlewareHelper):
280 async def invalidate_cache_for(
281 self,
282 url: str | URL,
283 *,
284 headers: Mapping[str, str] | None = None,
285 ) -> None:
286 if not isinstance(url, URL):
287 url = self.request.url_for(url)
288 if not isinstance(headers, Headers):
289 headers = Headers(headers)
290 await delete_from_cache(url, vary=headers, cache=self.middleware.cache)
293class CacheControlMiddleware:
294 app: ASGIApp
295 kwargs: CacheDirectives
296 max_age: int | None
297 s_maxage: int | None
298 no_cache: bool
299 no_store: bool
300 no_transform: bool
301 must_revalidate: bool
302 proxy_revalidate: bool
303 must_understand: bool
304 private: bool
305 public: bool
306 immutable: bool
307 stale_while_revalidate: int | None
308 stale_if_error: int | None
310 def __init__(self, app: ASGIApp, **kwargs: t.Unpack[CacheDirectives]) -> None:
311 self.app = app
312 self.kwargs = kwargs
313 self.max_age = None
314 self.s_maxage = None
315 self.no_cache = False
316 self.no_store = False
317 self.no_transform = False
318 self.must_revalidate = False
319 self.proxy_revalidate = False
320 self.must_understand = False
321 self.private = False
322 self.public = False
323 self.immutable = False
324 self.stale_while_revalidate = None
325 self.stale_if_error = None
326 for key, value in kwargs.items():
327 setattr(self, key, value)
329 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
330 if scope["type"] != "http":
331 await self.app(scope, receive, send)
332 return
333 responder = CacheControlResponder(self.app, **self.kwargs)
334 await responder(scope, receive, send)
336 def process_response(self, response: t.Any) -> None:
337 cache_control_parts: list[str] = []
338 if getattr(self, "public", False):
339 cache_control_parts.append("public")
340 elif getattr(self, "private", False):
341 cache_control_parts.append("private")
342 if getattr(self, "no_cache", False):
343 cache_control_parts.append("no-cache")
344 if getattr(self, "no_store", False):
345 cache_control_parts.append("no-store")
346 if getattr(self, "must_revalidate", False):
347 cache_control_parts.append("must-revalidate")
348 max_age = getattr(self, "max_age", None)
349 if max_age is not None:
350 cache_control_parts.append(f"max-age={max_age}")
351 if cache_control_parts:
352 response.headers["Cache-Control"] = ", ".join(cache_control_parts)
355def get_middleware_positions() -> dict[str, int]:
356 return {position.name: position.value for position in MiddlewarePosition}
359class MiddlewareStackManager:
360 def __init__(
361 self,
362 config: t.Any | None = None,
363 logger: t.Any | None = None,
364 ) -> None:
365 self.config = config
366 self.logger = logger
367 self._middleware_registry: dict[MiddlewarePosition, MiddlewareClass] = {}
368 self._middleware_options: dict[MiddlewarePosition, MiddlewareOptions] = {}
369 self._custom_middleware: dict[MiddlewarePosition, Middleware] = {}
370 self._initialized = False
372 def _ensure_dependencies(self) -> None:
373 if self.config is None or self.logger is None:
374 if self.config is None:
375 self.config = depends.get("config")
376 if self.logger is None:
377 try:
378 self.logger = depends.get("logger")
379 except Exception:
380 self.logger = None
382 def _register_default_middleware(self) -> None:
383 self._middleware_registry.update(
384 {
385 MiddlewarePosition.HTMX: HtmxMiddleware,
386 MiddlewarePosition.CURRENT_REQUEST: CurrentRequestMiddleware,
387 MiddlewarePosition.COMPRESSION: BrotliMiddleware,
388 },
389 )
390 self._middleware_options[MiddlewarePosition.COMPRESSION] = {"quality": 3}
392 def _register_conditional_middleware(self) -> None:
393 self._ensure_dependencies()
394 if not self.config:
395 return
396 from acb.adapters import get_adapter
398 self._middleware_registry[MiddlewarePosition.CSRF] = CSRFMiddleware
399 self._middleware_options[MiddlewarePosition.CSRF] = {
400 "secret": self.config.app.secret_key.get_secret_value(),
401 "cookie_name": f"{getattr(self.config.app, 'token_id', '_fb_')}_csrf",
402 "cookie_secure": self.config.deployed,
403 }
404 if get_adapter("auth"):
405 self._middleware_registry[MiddlewarePosition.SESSION] = SessionMiddleware
406 self._middleware_options[MiddlewarePosition.SESSION] = {
407 "secret_key": self.config.app.secret_key.get_secret_value(),
408 "session_cookie": f"{getattr(self.config.app, 'token_id', '_fb_')}_app",
409 "https_only": self.config.deployed,
410 }
411 if self.config.deployed or getattr(self.config.debug, "production", False):
412 self._middleware_registry[MiddlewarePosition.SECURITY_HEADERS] = (
413 SecureHeadersMiddleware
414 )
416 def initialize(self) -> None:
417 if self._initialized:
418 return
419 self._register_default_middleware()
420 self._register_conditional_middleware()
421 self._initialized = True
423 def register_middleware(
424 self,
425 middleware_class: MiddlewareClass,
426 position: MiddlewarePosition,
427 **options: t.Any,
428 ) -> None:
429 self._middleware_registry[position] = middleware_class
430 if options:
431 self._middleware_options[position] = options
433 def add_custom_middleware(
434 self,
435 middleware: Middleware,
436 position: MiddlewarePosition,
437 ) -> None:
438 self._custom_middleware[position] = middleware
440 def build_stack(self) -> list[Middleware]:
441 if not self._initialized:
442 self.initialize()
444 middleware_stack: dict[MiddlewarePosition, Middleware] = {}
445 self._build_middleware_stack(middleware_stack)
446 middleware_stack.update(self._custom_middleware)
448 return [
449 middleware_stack[position] for position in sorted(middleware_stack.keys())
450 ]
452 def _build_middleware_stack(
453 self, middleware_stack: dict[MiddlewarePosition, Middleware]
454 ) -> None:
455 """Build the middleware stack from registered middleware."""
456 for position, middleware_class in self._middleware_registry.items():
457 options = self._middleware_options.get(position, {})
458 middleware_stack[position] = Middleware(middleware_class, **options)
460 def get_middleware_info(self) -> dict[str, t.Any]:
461 if not self._initialized:
462 self.initialize()
464 return {
465 "registered": {
466 pos.name: cls.__name__ for pos, cls in self._middleware_registry.items()
467 },
468 "custom": {
469 pos.name: str(middleware)
470 for pos, middleware in self._custom_middleware.items()
471 },
472 "positions": get_middleware_positions(),
473 }
476def middlewares() -> list[Middleware]:
477 return MiddlewareStackManager().build_stack()