Coverage for polar/middlewares.py: 74%
64 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 17:15 +0000
1import functools 1b
2import re 1b
4import dramatiq 1b
5import structlog 1b
6from starlette.datastructures import MutableHeaders 1b
7from starlette.types import ASGIApp, Message, Receive, Scope, Send 1b
9from polar.logging import Logger, generate_correlation_id 1b
10from polar.worker import JobQueueManager 1b
13class LogCorrelationIdMiddleware: 1b
14 def __init__(self, app: ASGIApp) -> None: 1b
15 self.app = app
17 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b
18 if scope["type"] != "http": 1ca
19 return await self.app(scope, receive, send)
21 structlog.contextvars.bind_contextvars( 1ca
22 correlation_id=generate_correlation_id(),
23 method=scope["method"],
24 path=scope["path"],
25 )
27 await self.app(scope, receive, send) 1ca
29 structlog.contextvars.unbind_contextvars("correlation_id", "method", "path") 1ca
32class FlushEnqueuedWorkerJobsMiddleware: 1b
33 def __init__(self, app: ASGIApp) -> None: 1b
34 self.app = app
36 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b
37 if scope["type"] not in ("http", "websocket"): 1ca
38 await self.app(scope, receive, send)
39 return
41 async with JobQueueManager.open(dramatiq.get_broker(), scope["state"]["redis"]): 1ca
42 await self.app(scope, receive, send) 1ca
45class PathRewriteMiddleware: 1b
46 def __init__( 1b
47 self, app: ASGIApp, pattern: str | re.Pattern[str], replacement: str
48 ) -> None:
49 self.app = app
50 self.pattern = pattern
51 self.replacement = replacement
52 self.logger: Logger = structlog.get_logger()
54 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b
55 if scope["type"] not in ("http", "websocket"): 1ca
56 await self.app(scope, receive, send)
57 return
59 scope["path"], replacements = re.subn( 1ca
60 self.pattern, self.replacement, scope["path"]
61 )
63 if replacements > 0: 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true1ca
64 self.logger.warning(
65 "PathRewriteMiddleware",
66 pattern=self.pattern,
67 replacement=self.replacement,
68 path=scope["path"],
69 )
71 send = functools.partial(self.send, send=send, replacements=replacements) 1ca
72 await self.app(scope, receive, send) 1ca
74 async def send(self, message: Message, send: Send, replacements: int) -> None: 1b
75 if message["type"] != "http.response.start": 1ca
76 await send(message) 1ca
77 return 1ca
79 message.setdefault("headers", []) 1ca
80 headers = MutableHeaders(scope=message) 1ca
81 if replacements > 0: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true1ca
82 headers["X-Polar-Deprecation-Notice"] = (
83 "The API root has moved from /api/v1 to /v1. "
84 "Please update your integration."
85 )
87 await send(message) 1ca
90class SandboxResponseHeaderMiddleware: 1b
91 def __init__(self, app: ASGIApp) -> None: 1b
92 self.app = app
94 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b
95 if scope["type"] not in ("http", "websocket"):
96 await self.app(scope, receive, send)
97 return
99 async def send_wrapper(message: Message) -> None:
100 if message["type"] == "http.response.start":
101 message.setdefault("headers", [])
102 headers = MutableHeaders(scope=message)
103 headers["X-Polar-Sandbox"] = "1"
104 await send(message)
106 await self.app(scope, receive, send_wrapper)