Coverage for fastblocks / applications.py: 40%

182 statements  

« prev     ^ index     » next       coverage.py v7.12.0, created at 2025-11-26 03:30 -0800

1import typing as t 

2from platform import system 

3 

4from acb.config import AdapterBase 

5from acb.depends import depends 

6from starception import add_link_template, set_editor 

7from starlette.applications import Starlette 

8from starlette.middleware import Middleware 

9from starlette.middleware.errors import ServerErrorMiddleware 

10from starlette.middleware.exceptions import ExceptionMiddleware 

11from starlette.types import ASGIApp, ExceptionHandler, Lifespan 

12 

13from .initializers import ApplicationInitializer 

14from .middleware import MiddlewarePosition 

15 

16 

17class FastBlocksSettings: 

18 def __init_subclass__(cls, **kwargs: t.Any) -> None: 

19 if AdapterBase not in cls.__bases__: 

20 cls.__bases__ = (AdapterBase, *cls.__bases__) 

21 super().__init_subclass__(**kwargs) 

22 

23 

24AppType = t.TypeVar("AppType", bound="FastBlocks") 

25 

26match system(): 

27 case "Windows": 

28 add_link_template("pycharm", "pycharm64.exe --line {lineno} {path}") 

29 case "Darwin": 

30 add_link_template("pycharm", "pycharm --line {lineno} {path}") 

31 case "Linux": 

32 add_link_template("pycharm", "pycharm.sh --line {lineno} {path}") 

33 case _: 

34 ... 

35 

36 

37class MiddlewareManager: 

38 def __init__(self) -> None: 

39 self._system_middleware: dict[MiddlewarePosition, t.Any] = {} 

40 self._middleware_stack_cache: list[Middleware] | None = None 

41 self.user_middleware: list[Middleware] = [] 

42 

43 def add_user_middleware( 

44 self, 

45 middleware_class: t.Any, 

46 *args: t.Any, 

47 **kwargs: t.Any, 

48 ) -> None: 

49 position = kwargs.pop("position", None) 

50 

51 middleware = Middleware(middleware_class, *args, **kwargs) 

52 

53 if not hasattr(self, "user_middleware"): 

54 self.user_middleware = [] 

55 

56 if position is not None and isinstance(position, int): 

57 self.user_middleware.insert(position, middleware) 

58 else: 

59 self.user_middleware.append(middleware) 

60 

61 self._middleware_stack_cache = None 

62 

63 def add_system_middleware( 

64 self, 

65 middleware_class: t.Any, 

66 position: MiddlewarePosition, 

67 **kwargs: t.Any, 

68 ) -> None: 

69 self._system_middleware[position] = (middleware_class, kwargs) 

70 self._middleware_stack_cache = None 

71 

72 def get_middleware_stack(self) -> dict[str, t.Any]: 

73 return { 

74 "user_middleware": [ 

75 self._extract_middleware_info(middleware) 

76 for middleware in self.user_middleware 

77 ], 

78 "system_middleware": { 

79 pos.name: self._extract_middleware_info(middleware) 

80 for pos, middleware in self._system_middleware.items() 

81 }, 

82 } 

83 

84 def _extract_middleware_info(self, middleware: t.Any) -> dict[str, t.Any]: 

85 if isinstance(middleware, Middleware): 

86 return { 

87 "class": getattr(middleware.cls, "__name__", str(middleware.cls)), 

88 "args": middleware.args, 

89 "kwargs": middleware.kwargs, 

90 } 

91 if isinstance(middleware, tuple) and len(middleware) >= 2: 

92 cls, kwargs = middleware[0], middleware[1] 

93 return { 

94 "class": cls.__name__ if hasattr(cls, "__name__") else str(cls), 

95 "kwargs": kwargs, 

96 } 

97 return { 

98 "class": middleware.__class__.__name__, 

99 "raw": str(middleware), 

100 } 

101 

102 

103class FastBlocks(Starlette): 

104 middleware_manager: MiddlewareManager 

105 templates: t.Any 

106 models: t.Any 

107 _middleware_position_map: dict[MiddlewarePosition, int] 

108 

109 def __init__( 

110 self, 

111 middleware: t.Sequence[Middleware] | None = None, 

112 exception_handlers: t.Mapping[t.Any, ExceptionHandler] | None = None, 

113 lifespan: Lifespan["t.Self"] | None = None, 

114 config: t.Any | None = None, 

115 logger: t.Any | None = None, 

116 ) -> None: 

117 initializer = ApplicationInitializer( 

118 self, 

119 middleware=middleware, 

120 exception_handlers=exception_handlers, 

121 lifespan=lifespan, 

122 config=config, 

123 logger=logger, 

124 ) 

125 

126 object.__setattr__(self, "middleware_manager", MiddlewareManager()) 

127 

128 self._middleware_position_map = {pos: pos.value for pos in MiddlewarePosition} 

129 self.templates = None 

130 self.models = None 

131 

132 initializer.initialize() 

133 

134 set_editor("pycharm") 

135 

136 def add_middleware( 

137 self, 

138 middleware_class: t.Any, 

139 *args: t.Any, 

140 **kwargs: t.Any, 

141 ) -> None: 

142 self.middleware_manager.add_user_middleware(middleware_class, *args, **kwargs) 

143 

144 @property 

145 def user_middleware(self) -> list[Middleware]: 

146 return self.middleware_manager.user_middleware 

147 

148 @user_middleware.setter 

149 def user_middleware(self, value: list[Middleware]) -> None: 

150 self.middleware_manager.user_middleware = value 

151 

152 @property 

153 def _system_middleware(self) -> dict[MiddlewarePosition, t.Any]: 

154 return self.middleware_manager._system_middleware 

155 

156 @_system_middleware.setter 

157 def _system_middleware(self, value: dict[MiddlewarePosition, t.Any]) -> None: 

158 self.middleware_manager._system_middleware = value 

159 

160 @property 

161 def _middleware_stack_cache(self) -> list[Middleware] | None: 

162 return self.middleware_manager._middleware_stack_cache 

163 

164 @_middleware_stack_cache.setter 

165 def _middleware_stack_cache(self, value: list[Middleware] | None) -> None: 

166 self.middleware_manager._middleware_stack_cache = value 

167 

168 def add_system_middleware( 

169 self, 

170 middleware_class: type, 

171 *, 

172 position: MiddlewarePosition, 

173 **options: t.Any, 

174 ) -> None: 

175 self.middleware_manager.add_system_middleware( 

176 middleware_class, 

177 position, 

178 **options, 

179 ) 

180 

181 def _extract_middleware_info(self, middleware: t.Any) -> tuple[str, type] | None: 

182 try: 

183 if hasattr(middleware, "cls"): 

184 cls = middleware.cls 

185 elif isinstance(middleware, tuple) and len(middleware) > 0: 

186 cls = middleware[0] 

187 else: 

188 return None 

189 cls_name = str(getattr(cls, "__name__", cls)) 

190 return cls_name, cls 

191 except (AttributeError, IndexError, TypeError): 

192 return None 

193 

194 def _get_system_middleware_with_overrides(self) -> list[t.Any]: 

195 from .middleware import middlewares 

196 

197 modified_system_middleware = middlewares().copy() 

198 for position, middleware in self._system_middleware.items(): 

199 position_index = position.value 

200 if 0 <= position_index < len(modified_system_middleware): 

201 modified_system_middleware[position_index] = middleware 

202 else: 

203 modified_system_middleware.append(middleware) 

204 

205 return modified_system_middleware 

206 

207 def get_middleware_stack(self) -> list[tuple[str, type]]: 

208 middleware_list = [("ExceptionMiddleware", ExceptionMiddleware)] 

209 system_middleware = self._get_system_middleware_with_overrides() 

210 for middleware in system_middleware: 

211 info = self._extract_middleware_info(middleware) 

212 if info: 

213 middleware_list.append(info) 

214 for middleware in self.user_middleware: 

215 info = self._extract_middleware_info(middleware) 

216 if info: 

217 middleware_list.extend( 

218 ( 

219 info, 

220 ( 

221 "ServerErrorMiddleware", 

222 t.cast("type", ServerErrorMiddleware), 

223 ), 

224 ), 

225 ) 

226 return middleware_list 

227 

228 def _get_dependencies(self, config: t.Any, logger: t.Any) -> tuple[t.Any, t.Any]: 

229 if config is None: 

230 try: 

231 config = depends.get_sync("config") 

232 except Exception: 

233 config = None 

234 if logger is None: 

235 try: 

236 logger = depends.get_sync("logger") 

237 except Exception: 

238 logger = None 

239 if logger is not None and not hasattr(logger, "debug"): 

240 logger = None 

241 return config, logger 

242 

243 def _separate_exception_handlers( 

244 self, 

245 ) -> tuple[t.Any, dict[t.Any, ExceptionHandler]]: 

246 error_handler = None 

247 exception_handlers: dict[t.Any, ExceptionHandler] = {} 

248 for key, value in self.exception_handlers.items(): 

249 if key in (500, Exception): 

250 error_handler = value 

251 else: 

252 exception_handlers[key] = value 

253 return error_handler, exception_handlers 

254 

255 def _build_base_middleware_list(self, error_handler: t.Any) -> list[Middleware]: 

256 middleware_list = [ 

257 Middleware( 

258 ServerErrorMiddleware, 

259 handler=error_handler, 

260 debug=self.debug, 

261 ), 

262 ] 

263 middleware_list.extend(self.user_middleware) 

264 return middleware_list 

265 

266 def _apply_system_middleware_overrides( 

267 self, 

268 system_middleware: list[t.Any], 

269 logger: t.Any, 

270 ) -> list[t.Any]: 

271 if not (hasattr(self, "_system_middleware") and self._system_middleware): 

272 return system_middleware 

273 

274 modified_system_middleware = system_middleware.copy() 

275 

276 for position, middleware in self._system_middleware.items(): 

277 position_index = self._middleware_position_map[position] 

278 

279 if 0 <= position_index < len(modified_system_middleware): 

280 if logger: 

281 logger.debug(f"Replacing middleware at position {position.name}") 

282 modified_system_middleware[position_index] = middleware 

283 else: 

284 if logger: 

285 logger.debug(f"Adding middleware at position {position.name}") 

286 modified_system_middleware.append(middleware) 

287 

288 return modified_system_middleware 

289 

290 def _apply_middleware_to_app( 

291 self, 

292 middleware_list: list[t.Any], 

293 logger: t.Any, 

294 ) -> ASGIApp: 

295 app = self.router 

296 for cls, args, kwargs in reversed(middleware_list): 

297 if logger: 

298 logger.debug(f"Adding middleware: {cls.__name__}") 

299 app = cls(*args, app=app, **kwargs) 

300 return app 

301 

302 def build_middleware_stack( 

303 self, 

304 config: t.Any | None = None, 

305 logger: t.Any | None = None, 

306 ) -> ASGIApp: 

307 if self._middleware_stack_cache is not None: 

308 return self._middleware_stack_cache # type: ignore[return-value] # Cached middleware stack 

309 

310 config, logger = self._get_dependencies(config, logger) 

311 error_handler, exception_handlers = self._separate_exception_handlers() 

312 

313 from .middleware import middlewares 

314 

315 middleware_list = self._build_base_middleware_list(error_handler) 

316 system_middleware = middlewares() 

317 system_middleware = self._apply_system_middleware_overrides( 

318 system_middleware, 

319 logger, 

320 ) 

321 

322 middleware_list.extend(system_middleware) 

323 middleware_list.append( 

324 Middleware( 

325 ExceptionMiddleware, 

326 handlers=exception_handlers, 

327 debug=self.debug, 

328 ), 

329 ) 

330 

331 app = self._apply_middleware_to_app(middleware_list, logger) 

332 

333 if logger: 

334 logger.info("Middleware stack built") 

335 

336 object.__setattr__(self, "_middleware_stack_cache", app) 

337 return app