Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/task_runs.py: 21%

148 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 13:38 +0000

1""" 

2Functions for interacting with task run ORM objects. 

3Intended for internal use by the Prefect REST API. 

4""" 

5 

6import contextlib 1a

7from typing import ( 1a

8 TYPE_CHECKING, 

9 Any, 

10 Dict, 

11 Optional, 

12 Sequence, 

13 Type, 

14 TypeVar, 

15 Union, 

16 cast, 

17) 

18from uuid import UUID 1a

19 

20import sqlalchemy as sa 1a

21from sqlalchemy import delete, select 1a

22from sqlalchemy.ext.asyncio import AsyncSession 1a

23from sqlalchemy.sql import Select 1a

24 

25import prefect.server.models as models 1a

26import prefect.server.schemas as schemas 1a

27from prefect.logging import get_logger 1a

28from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1a

29from prefect.server.exceptions import ObjectNotFoundError 1a

30from prefect.server.orchestration.core_policy import ( 1a

31 BackgroundTaskPolicy, 

32 MinimalTaskPolicy, 

33) 

34from prefect.server.orchestration.global_policy import GlobalTaskPolicy 1a

35from prefect.server.orchestration.policies import ( 1a

36 TaskRunOrchestrationPolicy, 

37) 

38from prefect.server.orchestration.rules import TaskOrchestrationContext 1a

39from prefect.server.schemas.responses import OrchestrationResult 1a

40from prefect.types._datetime import now 1a

41 

42if TYPE_CHECKING: 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true1a

43 import logging 

44 

45T = TypeVar("T", bound=tuple[Any, ...]) 1a

46 

47logger: "logging.Logger" = get_logger("server") 1a

48 

49 

50@db_injector 1a

51async def create_task_run( 1a

52 db: PrefectDBInterface, 

53 session: AsyncSession, 

54 task_run: schemas.core.TaskRun, 

55 orchestration_parameters: Optional[Dict[str, Any]] = None, 

56) -> orm_models.TaskRun: 

57 """ 

58 Creates a new task run. 

59 

60 If a task run with the same flow_run_id, task_key, and dynamic_key already exists, 

61 the existing task run will be returned. If the provided task run has a state 

62 attached, it will also be created. 

63 

64 Args: 

65 session: a database session 

66 task_run: a task run model 

67 

68 Returns: 

69 orm_models.TaskRun: the newly-created or existing task run 

70 """ 

71 

72 right_now = now("UTC") 

73 model: Union[orm_models.TaskRun, None] 

74 

75 task_run.labels = await with_system_labels_for_task_run( 

76 session=session, task_run=task_run 

77 ) 

78 

79 # if a dynamic key exists, we need to guard against conflicts 

80 if task_run.flow_run_id: 

81 insert_stmt = ( 

82 db.queries.insert(db.TaskRun) 

83 .values( 

84 created=right_now, 

85 **task_run.model_dump_for_orm( 

86 exclude={"state", "created"}, exclude_unset=True 

87 ), 

88 ) 

89 .on_conflict_do_nothing( 

90 index_elements=db.orm.task_run_unique_upsert_columns, 

91 ) 

92 ) 

93 await session.execute(insert_stmt) 

94 

95 query = ( 

96 sa.select(db.TaskRun) 

97 .where( 

98 sa.and_( 

99 db.TaskRun.flow_run_id == task_run.flow_run_id, 

100 db.TaskRun.task_key == task_run.task_key, 

101 db.TaskRun.dynamic_key == task_run.dynamic_key, 

102 ) 

103 ) 

104 .limit(1) 

105 .execution_options(populate_existing=True) 

106 ) 

107 result = await session.execute(query) 

108 model = result.scalar_one() 

109 else: 

110 # Upsert on (task_key, dynamic_key) application logic. 

111 query = ( 

112 sa.select(db.TaskRun) 

113 .where( 

114 sa.and_( 

115 db.TaskRun.flow_run_id.is_(None), 

116 db.TaskRun.task_key == task_run.task_key, 

117 db.TaskRun.dynamic_key == task_run.dynamic_key, 

118 ) 

119 ) 

120 .limit(1) 

121 .execution_options(populate_existing=True) 

122 ) 

123 

124 result = await session.execute(query) 

125 model = result.scalar() 

126 

127 if model is None: 

128 model = db.TaskRun( 

129 created=right_now, 

130 **task_run.model_dump_for_orm( 

131 exclude={"state", "created"}, exclude_unset=True 

132 ), 

133 state=None, 

134 ) 

135 session.add(model) 

136 await session.flush() 

137 

138 if model.created == right_now and task_run.state: 

139 await models.task_runs.set_task_run_state( 

140 session=session, 

141 task_run_id=model.id, 

142 state=task_run.state, 

143 force=True, 

144 orchestration_parameters=orchestration_parameters, 

145 ) 

146 

147 return model 

148 

149 

150@db_injector 1a

151async def update_task_run( 1a

152 db: PrefectDBInterface, 

153 session: AsyncSession, 

154 task_run_id: UUID, 

155 task_run: schemas.actions.TaskRunUpdate, 

156) -> bool: 

157 """ 

158 Updates a task run. 

159 

160 Args: 

161 session: a database session 

162 task_run_id: the task run id to update 

163 task_run: a task run model 

164 

165 Returns: 

166 bool: whether or not matching rows were found to update 

167 """ 

168 update_stmt = ( 

169 sa.update(db.TaskRun) 

170 .where(db.TaskRun.id == task_run_id) 

171 # exclude_unset=True allows us to only update values provided by 

172 # the user, ignoring any defaults on the model 

173 .values(**task_run.model_dump_for_orm(exclude_unset=True)) 

174 ) 

175 result = await session.execute(update_stmt) 

176 return result.rowcount > 0 

177 

178 

179@db_injector 1a

180async def read_task_run( 1a

181 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID 

182) -> Union[orm_models.TaskRun, None]: 

183 """ 

184 Read a task run by id. 

185 

186 Args: 

187 session: a database session 

188 task_run_id: the task run id 

189 

190 Returns: 

191 orm_models.TaskRun: the task run 

192 """ 

193 

194 model = await session.get(db.TaskRun, task_run_id) 

195 return model 

196 

197 

198@db_injector 1a

199async def read_task_run_with_flow_run_name( 1a

200 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID 

201) -> Union[orm_models.TaskRun, None]: 

202 """ 

203 Read a task run by id. 

204 

205 Args: 

206 session: a database session 

207 task_run_id: the task run id 

208 

209 Returns: 

210 orm_models.TaskRun: the task run with the flow run name 

211 """ 

212 

213 result = await session.execute( 

214 select(orm_models.TaskRun, orm_models.FlowRun.name.label("flow_run_name")) 

215 .outerjoin( 

216 orm_models.FlowRun, orm_models.TaskRun.flow_run_id == orm_models.FlowRun.id 

217 ) 

218 .where(orm_models.TaskRun.id == task_run_id) 

219 ) 

220 row = result.first() 

221 if not row: 

222 return None 

223 

224 task_run = row[0] 

225 flow_run_name = row[1] 

226 if flow_run_name: 

227 setattr(task_run, "flow_run_name", flow_run_name) 

228 return task_run 

229 

230 

231async def _apply_task_run_filters( 1a

232 db: PrefectDBInterface, 

233 query: Select[T], 

234 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

235 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

236 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

237 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

238 work_pool_filter: Optional[schemas.filters.WorkPoolFilter] = None, 

239 work_queue_filter: Optional[schemas.filters.WorkQueueFilter] = None, 

240) -> Select[T]: 

241 """ 

242 Applies filters to a task run query as a combination of EXISTS subqueries. 

243 """ 

244 

245 if task_run_filter: 

246 query = query.where(task_run_filter.as_sql_filter()) 

247 

248 # Return a simplified query in the case that the request is ONLY asking to filter on flow_run_id (and task_run_filter) 

249 # In this case there's no need to generate the complex EXISTS subqueries; the generated query here is much more efficient 

250 if ( 

251 flow_run_filter 

252 and flow_run_filter.only_filters_on_id() 

253 and flow_run_filter.id 

254 and flow_run_filter.id.any_ 

255 and not any( 

256 [flow_filter, deployment_filter, work_pool_filter, work_queue_filter] 

257 ) 

258 ): 

259 query = query.where(db.TaskRun.flow_run_id.in_(flow_run_filter.id.any_)) 

260 

261 return query 

262 

263 if ( 

264 flow_filter 

265 or flow_run_filter 

266 or deployment_filter 

267 or work_pool_filter 

268 or work_queue_filter 

269 ): 

270 exists_clause = select(db.FlowRun).where( 

271 db.FlowRun.id == db.TaskRun.flow_run_id 

272 ) 

273 

274 if flow_run_filter: 

275 exists_clause = exists_clause.where(flow_run_filter.as_sql_filter()) 

276 

277 if flow_filter: 

278 exists_clause = exists_clause.join( 

279 db.Flow, 

280 db.Flow.id == db.FlowRun.flow_id, 

281 ).where(flow_filter.as_sql_filter()) 

282 

283 if deployment_filter: 

284 exists_clause = exists_clause.join( 

285 db.Deployment, 

286 db.Deployment.id == db.FlowRun.deployment_id, 

287 ).where(deployment_filter.as_sql_filter()) 

288 

289 if work_queue_filter: 

290 exists_clause = exists_clause.join( 

291 db.WorkQueue, 

292 db.WorkQueue.id == db.FlowRun.work_queue_id, 

293 ).where(work_queue_filter.as_sql_filter()) 

294 

295 if work_pool_filter: 

296 exists_clause = exists_clause.join( 

297 db.WorkPool, 

298 sa.and_( 

299 db.WorkPool.id == db.WorkQueue.work_pool_id, 

300 db.WorkQueue.id == db.FlowRun.work_queue_id, 

301 ), 

302 ).where(work_pool_filter.as_sql_filter()) 

303 

304 query = query.where(exists_clause.exists()) 

305 

306 return query 

307 

308 

309@db_injector 1a

310async def read_task_runs( 1a

311 db: PrefectDBInterface, 

312 session: AsyncSession, 

313 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

314 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

315 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

316 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

317 offset: Optional[int] = None, 

318 limit: Optional[int] = None, 

319 sort: schemas.sorting.TaskRunSort = schemas.sorting.TaskRunSort.ID_DESC, 

320) -> Sequence[orm_models.TaskRun]: 

321 """ 

322 Read task runs. 

323 

324 Args: 

325 session: a database session 

326 flow_filter: only select task runs whose flows match these filters 

327 flow_run_filter: only select task runs whose flow runs match these filters 

328 task_run_filter: only select task runs that match these filters 

329 deployment_filter: only select task runs whose deployments match these filters 

330 offset: Query offset 

331 limit: Query limit 

332 sort: Query sort 

333 

334 Returns: 

335 List[orm_models.TaskRun]: the task runs 

336 """ 

337 

338 query = select(db.TaskRun).order_by(*sort.as_sql_sort()) 

339 

340 query = await _apply_task_run_filters( 

341 db, 

342 query, 

343 flow_filter=flow_filter, 

344 flow_run_filter=flow_run_filter, 

345 task_run_filter=task_run_filter, 

346 deployment_filter=deployment_filter, 

347 ) 

348 

349 if offset is not None: 

350 query = query.offset(offset) 

351 

352 if limit is not None: 

353 query = query.limit(limit) 

354 

355 logger.debug(f"In read_task_runs, query generated is:\n{query}") 

356 result = await session.execute(query) 

357 return result.scalars().unique().all() 

358 

359 

360@db_injector 1a

361async def count_task_runs( 1a

362 db: PrefectDBInterface, 

363 session: AsyncSession, 

364 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

365 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

366 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

367 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

368) -> int: 

369 """ 

370 Count task runs. 

371 

372 Args: 

373 session: a database session 

374 flow_filter: only count task runs whose flows match these filters 

375 flow_run_filter: only count task runs whose flow runs match these filters 

376 task_run_filter: only count task runs that match these filters 

377 deployment_filter: only count task runs whose deployments match these filters 

378 Returns: 

379 int: count of task runs 

380 """ 

381 

382 query = select(sa.func.count(None)).select_from(db.TaskRun) 

383 

384 query = await _apply_task_run_filters( 

385 db, 

386 query, 

387 flow_filter=flow_filter, 

388 flow_run_filter=flow_run_filter, 

389 task_run_filter=task_run_filter, 

390 deployment_filter=deployment_filter, 

391 ) 

392 

393 result = await session.execute(query) 

394 return result.scalar_one() 

395 

396 

397@db_injector 1a

398async def count_task_runs_by_state( 1a

399 db: PrefectDBInterface, 

400 session: AsyncSession, 

401 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

402 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

403 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

404 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

405) -> schemas.states.CountByState: 

406 """ 

407 Count task runs by state. 

408 

409 Args: 

410 session: a database session 

411 flow_filter: only count task runs whose flows match these filters 

412 flow_run_filter: only count task runs whose flow runs match these filters 

413 task_run_filter: only count task runs that match these filters 

414 deployment_filter: only count task runs whose deployments match these filters 

415 Returns: 

416 schemas.states.CountByState: count of task runs by state 

417 """ 

418 

419 base_query = ( 

420 select(db.TaskRun.state_type, sa.func.count(None).label("count")) 

421 .select_from(db.TaskRun) 

422 .group_by(db.TaskRun.state_type) 

423 ) 

424 

425 query = await _apply_task_run_filters( 

426 db, 

427 base_query, 

428 flow_filter=flow_filter, 

429 flow_run_filter=flow_run_filter, 

430 task_run_filter=task_run_filter, 

431 deployment_filter=deployment_filter, 

432 ) 

433 

434 result = await session.execute(query) 

435 

436 counts = schemas.states.CountByState() 

437 

438 for row in result: 

439 setattr(counts, row.state_type, row.count) 

440 

441 return counts 

442 

443 

444@db_injector 1a

445async def delete_task_run( 1a

446 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID 

447) -> bool: 

448 """ 

449 Delete a task run by id. 

450 

451 Args: 

452 session: a database session 

453 task_run_id: the task run id to delete 

454 

455 Returns: 

456 bool: whether or not the task run was deleted 

457 """ 

458 

459 result = await session.execute( 

460 delete(db.TaskRun).where(db.TaskRun.id == task_run_id) 

461 ) 

462 return result.rowcount > 0 

463 

464 

465async def set_task_run_state( 1a

466 session: AsyncSession, 

467 task_run_id: UUID, 

468 state: schemas.states.State, 

469 force: bool = False, 

470 task_policy: Optional[Type[TaskRunOrchestrationPolicy]] = None, 

471 orchestration_parameters: Optional[Dict[str, Any]] = None, 

472) -> OrchestrationResult: 

473 """ 

474 Creates a new orchestrated task run state. 

475 

476 Setting a new state on a run is the one of the principal actions that is governed by 

477 Prefect's orchestration logic. Setting a new run state will not guarantee creation, 

478 but instead trigger orchestration rules to govern the proposed `state` input. If 

479 the state is considered valid, it will be written to the database. Otherwise, a 

480 it's possible a different state, or no state, will be created. A `force` flag is 

481 supplied to bypass a subset of orchestration logic. 

482 

483 Args: 

484 session: a database session 

485 task_run_id: the task run id 

486 state: a task run state model 

487 force: if False, orchestration rules will be applied that may alter or prevent 

488 the state transition. If True, orchestration rules are not applied. 

489 

490 Returns: 

491 OrchestrationResult object 

492 """ 

493 

494 # load the task run 

495 run = await models.task_runs.read_task_run(session=session, task_run_id=task_run_id) 

496 

497 if not run: 

498 raise ObjectNotFoundError(f"Task run with id {task_run_id} not found") 

499 

500 initial_state = run.state.as_state() if run.state else None 

501 initial_state_type = initial_state.type if initial_state else None 

502 proposed_state_type = state.type if state else None 

503 intended_transition = (initial_state_type, proposed_state_type) 

504 

505 if state.state_details.deferred: 

506 task_policy = BackgroundTaskPolicy # CoreTaskPolicy + prevent `Running` -> `Running` transition 

507 elif force or task_policy is None: 

508 task_policy = MinimalTaskPolicy 

509 

510 orchestration_rules = task_policy.compile_transition_rules(*intended_transition) # type: ignore 

511 global_rules = GlobalTaskPolicy.compile_transition_rules(*intended_transition) 

512 

513 context = TaskOrchestrationContext( 

514 session=session, 

515 run=run, 

516 initial_state=initial_state, 

517 proposed_state=state, 

518 ) 

519 

520 if orchestration_parameters is not None: 

521 context.parameters = orchestration_parameters 

522 

523 # apply orchestration rules and create the new task run state 

524 async with contextlib.AsyncExitStack() as stack: 

525 for rule in orchestration_rules: 

526 context = await stack.enter_async_context( 

527 rule(context, *intended_transition) 

528 ) 

529 

530 for rule in global_rules: 

531 context = await stack.enter_async_context( 

532 rule(context, *intended_transition) 

533 ) 

534 

535 await context.validate_proposed_state() 

536 

537 if context.orchestration_error is not None: 

538 raise context.orchestration_error 

539 

540 result = OrchestrationResult( 

541 state=context.validated_state, 

542 status=context.response_status, 

543 details=context.response_details, 

544 ) 

545 

546 return result 

547 

548 

549async def with_system_labels_for_task_run( 1a

550 session: AsyncSession, 

551 task_run: schemas.core.TaskRun, 

552) -> schemas.core.KeyValueLabels: 

553 """Augment user supplied labels with system default labels for a task 

554 run.""" 

555 

556 client_supplied_labels = task_run.labels or {} 

557 default_labels = cast(schemas.core.KeyValueLabels, {}) 

558 parent_labels: schemas.core.KeyValueLabels = {} 

559 

560 if task_run.flow_run_id: 

561 default_labels["prefect.flow-run.id"] = str(task_run.flow_run_id) 

562 flow_run = await models.flow_runs.read_flow_run( 

563 session=session, flow_run_id=task_run.flow_run_id 

564 ) 

565 parent_labels = flow_run.labels if flow_run and flow_run.labels else {} 

566 

567 return parent_labels | default_labels | client_supplied_labels