Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/api/middleware.py: 33%
25 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 10:48 +0000
1from typing import Awaitable, Callable 1a
3from fastapi import status 1a
4from starlette.middleware.base import BaseHTTPMiddleware 1a
5from starlette.requests import Request 1a
6from starlette.responses import JSONResponse, Response 1a
8from prefect import settings 1a
9from prefect.server import models 1a
10from prefect.server.database import provide_database_interface 1a
12NextMiddlewareFunction = Callable[[Request], Awaitable[Response]] 1a
15class CsrfMiddleware(BaseHTTPMiddleware): 1a
16 """
17 Middleware for CSRF protection. This middleware will check for a CSRF token
18 in the headers of any POST, PUT, PATCH, or DELETE request. If the token is
19 not present or does not match the token stored in the database for the
20 client, the request will be rejected with a 403 status code.
21 """
23 async def dispatch( 1a
24 self, request: Request, call_next: NextMiddlewareFunction
25 ) -> Response:
26 """
27 Dispatch method for the middleware. This method will check for the
28 presence of a CSRF token in the headers of the request and compare it
29 to the token stored in the database for the client. If the token is not
30 present or does not match, the request will be rejected with a 403
31 status code.
32 """
34 request_needs_csrf_protection = request.method in {
35 "POST",
36 "PUT",
37 "PATCH",
38 "DELETE",
39 }
41 if (
42 settings.PREFECT_SERVER_CSRF_PROTECTION_ENABLED.value()
43 and request_needs_csrf_protection
44 ):
45 incoming_token = request.headers.get("Prefect-Csrf-Token")
46 incoming_client = request.headers.get("Prefect-Csrf-Client")
48 if incoming_token is None:
49 return JSONResponse(
50 {"detail": "Missing CSRF token."},
51 status_code=status.HTTP_403_FORBIDDEN,
52 )
54 if incoming_client is None:
55 return JSONResponse(
56 {"detail": "Missing client identifier."},
57 status_code=status.HTTP_403_FORBIDDEN,
58 )
60 db = provide_database_interface()
61 async with db.session_context() as session:
62 token = await models.csrf_token.read_token_for_client(
63 session=session, client=incoming_client
64 )
66 if token is None or token.token != incoming_token:
67 return JSONResponse(
68 {"detail": "Invalid CSRF token or client identifier."},
69 status_code=status.HTTP_403_FORBIDDEN,
70 headers={"Access-Control-Allow-Origin": "*"},
71 )
73 return await call_next(request)