Coverage for polar/middlewares.py: 74%

64 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 15:52 +0000

1import functools 1b

2import re 1b

3 

4import dramatiq 1b

5import structlog 1b

6from starlette.datastructures import MutableHeaders 1b

7from starlette.types import ASGIApp, Message, Receive, Scope, Send 1b

8 

9from polar.logging import Logger, generate_correlation_id 1b

10from polar.worker import JobQueueManager 1b

11 

12 

13class LogCorrelationIdMiddleware: 1b

14 def __init__(self, app: ASGIApp) -> None: 1b

15 self.app = app 

16 

17 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b

18 if scope["type"] != "http": 1ca

19 return await self.app(scope, receive, send) 

20 

21 structlog.contextvars.bind_contextvars( 1ca

22 correlation_id=generate_correlation_id(), 

23 method=scope["method"], 

24 path=scope["path"], 

25 ) 

26 

27 await self.app(scope, receive, send) 1ca

28 

29 structlog.contextvars.unbind_contextvars("correlation_id", "method", "path") 1ca

30 

31 

32class FlushEnqueuedWorkerJobsMiddleware: 1b

33 def __init__(self, app: ASGIApp) -> None: 1b

34 self.app = app 

35 

36 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b

37 if scope["type"] not in ("http", "websocket"): 1ca

38 await self.app(scope, receive, send) 

39 return 

40 

41 async with JobQueueManager.open(dramatiq.get_broker(), scope["state"]["redis"]): 1ca

42 await self.app(scope, receive, send) 1ca

43 

44 

45class PathRewriteMiddleware: 1b

46 def __init__( 1b

47 self, app: ASGIApp, pattern: str | re.Pattern[str], replacement: str 

48 ) -> None: 

49 self.app = app 

50 self.pattern = pattern 

51 self.replacement = replacement 

52 self.logger: Logger = structlog.get_logger() 

53 

54 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b

55 if scope["type"] not in ("http", "websocket"): 1ca

56 await self.app(scope, receive, send) 

57 return 

58 

59 scope["path"], replacements = re.subn( 1ca

60 self.pattern, self.replacement, scope["path"] 

61 ) 

62 

63 if replacements > 0: 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true1ca

64 self.logger.warning( 

65 "PathRewriteMiddleware", 

66 pattern=self.pattern, 

67 replacement=self.replacement, 

68 path=scope["path"], 

69 ) 

70 

71 send = functools.partial(self.send, send=send, replacements=replacements) 1ca

72 await self.app(scope, receive, send) 1ca

73 

74 async def send(self, message: Message, send: Send, replacements: int) -> None: 1b

75 if message["type"] != "http.response.start": 1ca

76 await send(message) 1ca

77 return 1ca

78 

79 message.setdefault("headers", []) 1ca

80 headers = MutableHeaders(scope=message) 1ca

81 if replacements > 0: 81 ↛ 82line 81 didn't jump to line 82 because the condition on line 81 was never true1ca

82 headers["X-Polar-Deprecation-Notice"] = ( 

83 "The API root has moved from /api/v1 to /v1. " 

84 "Please update your integration." 

85 ) 

86 

87 await send(message) 1ca

88 

89 

90class SandboxResponseHeaderMiddleware: 1b

91 def __init__(self, app: ASGIApp) -> None: 1b

92 self.app = app 

93 

94 async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: 1b

95 if scope["type"] not in ("http", "websocket"): 

96 await self.app(scope, receive, send) 

97 return 

98 

99 async def send_wrapper(message: Message) -> None: 

100 if message["type"] == "http.response.start": 

101 message.setdefault("headers", []) 

102 headers = MutableHeaders(scope=message) 

103 headers["X-Polar-Sandbox"] = "1" 

104 await send(message) 

105 

106 await self.app(scope, receive, send_wrapper)