Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/api/middleware.py: 33%

25 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1from typing import Awaitable, Callable 1a

2 

3from fastapi import status 1a

4from starlette.middleware.base import BaseHTTPMiddleware 1a

5from starlette.requests import Request 1a

6from starlette.responses import JSONResponse, Response 1a

7 

8from prefect import settings 1a

9from prefect.server import models 1a

10from prefect.server.database import provide_database_interface 1a

11 

12NextMiddlewareFunction = Callable[[Request], Awaitable[Response]] 1a

13 

14 

15class CsrfMiddleware(BaseHTTPMiddleware): 1a

16 """ 

17 Middleware for CSRF protection. This middleware will check for a CSRF token 

18 in the headers of any POST, PUT, PATCH, or DELETE request. If the token is 

19 not present or does not match the token stored in the database for the 

20 client, the request will be rejected with a 403 status code. 

21 """ 

22 

23 async def dispatch( 1a

24 self, request: Request, call_next: NextMiddlewareFunction 

25 ) -> Response: 

26 """ 

27 Dispatch method for the middleware. This method will check for the 

28 presence of a CSRF token in the headers of the request and compare it 

29 to the token stored in the database for the client. If the token is not 

30 present or does not match, the request will be rejected with a 403 

31 status code. 

32 """ 

33 

34 request_needs_csrf_protection = request.method in { 

35 "POST", 

36 "PUT", 

37 "PATCH", 

38 "DELETE", 

39 } 

40 

41 if ( 

42 settings.PREFECT_SERVER_CSRF_PROTECTION_ENABLED.value() 

43 and request_needs_csrf_protection 

44 ): 

45 incoming_token = request.headers.get("Prefect-Csrf-Token") 

46 incoming_client = request.headers.get("Prefect-Csrf-Client") 

47 

48 if incoming_token is None: 

49 return JSONResponse( 

50 {"detail": "Missing CSRF token."}, 

51 status_code=status.HTTP_403_FORBIDDEN, 

52 ) 

53 

54 if incoming_client is None: 

55 return JSONResponse( 

56 {"detail": "Missing client identifier."}, 

57 status_code=status.HTTP_403_FORBIDDEN, 

58 ) 

59 

60 db = provide_database_interface() 

61 async with db.session_context() as session: 

62 token = await models.csrf_token.read_token_for_client( 

63 session=session, client=incoming_client 

64 ) 

65 

66 if token is None or token.token != incoming_token: 

67 return JSONResponse( 

68 {"detail": "Invalid CSRF token or client identifier."}, 

69 status_code=status.HTTP_403_FORBIDDEN, 

70 headers={"Access-Control-Allow-Origin": "*"}, 

71 ) 

72 

73 return await call_next(request)