Coverage for /usr/local/lib/python3.12/site-packages/prefect/server/models/task_runs.py: 69%

148 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 10:48 +0000

1""" 

2Functions for interacting with task run ORM objects. 

3Intended for internal use by the Prefect REST API. 

4""" 

5 

6import contextlib 1d

7from typing import ( 1d

8 TYPE_CHECKING, 

9 Any, 

10 Dict, 

11 Optional, 

12 Sequence, 

13 Type, 

14 TypeVar, 

15 Union, 

16 cast, 

17) 

18from uuid import UUID 1d

19 

20import sqlalchemy as sa 1d

21from sqlalchemy import delete, select 1d

22from sqlalchemy.ext.asyncio import AsyncSession 1d

23from sqlalchemy.sql import Select 1d

24 

25import prefect.server.models as models 1d

26import prefect.server.schemas as schemas 1d

27from prefect.logging import get_logger 1d

28from prefect.server.database import PrefectDBInterface, db_injector, orm_models 1d

29from prefect.server.exceptions import ObjectNotFoundError 1d

30from prefect.server.orchestration.core_policy import ( 1d

31 BackgroundTaskPolicy, 

32 MinimalTaskPolicy, 

33) 

34from prefect.server.orchestration.global_policy import GlobalTaskPolicy 1d

35from prefect.server.orchestration.policies import ( 1d

36 TaskRunOrchestrationPolicy, 

37) 

38from prefect.server.orchestration.rules import TaskOrchestrationContext 1d

39from prefect.server.schemas.responses import OrchestrationResult 1d

40from prefect.types._datetime import now 1d

41 

42if TYPE_CHECKING: 42 ↛ 43line 42 didn't jump to line 43 because the condition on line 42 was never true1d

43 import logging 

44 

45T = TypeVar("T", bound=tuple[Any, ...]) 1d

46 

47logger: "logging.Logger" = get_logger("server") 1d

48 

49 

50@db_injector 1d

51async def create_task_run( 1d

52 db: PrefectDBInterface, 

53 session: AsyncSession, 

54 task_run: schemas.core.TaskRun, 

55 orchestration_parameters: Optional[Dict[str, Any]] = None, 

56) -> orm_models.TaskRun: 

57 """ 

58 Creates a new task run. 

59 

60 If a task run with the same flow_run_id, task_key, and dynamic_key already exists, 

61 the existing task run will be returned. If the provided task run has a state 

62 attached, it will also be created. 

63 

64 Args: 

65 session: a database session 

66 task_run: a task run model 

67 

68 Returns: 

69 orm_models.TaskRun: the newly-created or existing task run 

70 """ 

71 

72 right_now = now("UTC") 1bca

73 model: Union[orm_models.TaskRun, None] 

74 

75 task_run.labels = await with_system_labels_for_task_run( 1bca

76 session=session, task_run=task_run 

77 ) 

78 

79 # if a dynamic key exists, we need to guard against conflicts 

80 if task_run.flow_run_id: 80 ↛ 81line 80 didn't jump to line 81 because the condition on line 80 was never true1bca

81 insert_stmt = ( 

82 db.queries.insert(db.TaskRun) 

83 .values( 

84 created=right_now, 

85 **task_run.model_dump_for_orm( 

86 exclude={"state", "created"}, exclude_unset=True 

87 ), 

88 ) 

89 .on_conflict_do_nothing( 

90 index_elements=db.orm.task_run_unique_upsert_columns, 

91 ) 

92 ) 

93 await session.execute(insert_stmt) 

94 

95 query = ( 

96 sa.select(db.TaskRun) 

97 .where( 

98 sa.and_( 

99 db.TaskRun.flow_run_id == task_run.flow_run_id, 

100 db.TaskRun.task_key == task_run.task_key, 

101 db.TaskRun.dynamic_key == task_run.dynamic_key, 

102 ) 

103 ) 

104 .limit(1) 

105 .execution_options(populate_existing=True) 

106 ) 

107 result = await session.execute(query) 

108 model = result.scalar_one() 

109 else: 

110 # Upsert on (task_key, dynamic_key) application logic. 

111 query = ( 1bca

112 sa.select(db.TaskRun) 

113 .where( 

114 sa.and_( 

115 db.TaskRun.flow_run_id.is_(None), 

116 db.TaskRun.task_key == task_run.task_key, 

117 db.TaskRun.dynamic_key == task_run.dynamic_key, 

118 ) 

119 ) 

120 .limit(1) 

121 .execution_options(populate_existing=True) 

122 ) 

123 

124 result = await session.execute(query) 1bca

125 model = result.scalar() 

126 

127 if model is None: 

128 model = db.TaskRun( 

129 created=right_now, 

130 **task_run.model_dump_for_orm( 

131 exclude={"state", "created"}, exclude_unset=True 

132 ), 

133 state=None, 

134 ) 

135 session.add(model) 

136 await session.flush() 1bca

137 

138 if model.created == right_now and task_run.state: 

139 await models.task_runs.set_task_run_state( 1bca

140 session=session, 

141 task_run_id=model.id, 

142 state=task_run.state, 

143 force=True, 

144 orchestration_parameters=orchestration_parameters, 

145 ) 

146 

147 return model 

148 

149 

150@db_injector 1d

151async def update_task_run( 1d

152 db: PrefectDBInterface, 

153 session: AsyncSession, 

154 task_run_id: UUID, 

155 task_run: schemas.actions.TaskRunUpdate, 

156) -> bool: 

157 """ 

158 Updates a task run. 

159 

160 Args: 

161 session: a database session 

162 task_run_id: the task run id to update 

163 task_run: a task run model 

164 

165 Returns: 

166 bool: whether or not matching rows were found to update 

167 """ 

168 update_stmt = ( 1ba

169 sa.update(db.TaskRun) 

170 .where(db.TaskRun.id == task_run_id) 

171 # exclude_unset=True allows us to only update values provided by 

172 # the user, ignoring any defaults on the model 

173 .values(**task_run.model_dump_for_orm(exclude_unset=True)) 

174 ) 

175 result = await session.execute(update_stmt) 1ba

176 return result.rowcount > 0 

177 

178 

179@db_injector 1d

180async def read_task_run( 1d

181 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID 

182) -> Union[orm_models.TaskRun, None]: 

183 """ 

184 Read a task run by id. 

185 

186 Args: 

187 session: a database session 

188 task_run_id: the task run id 

189 

190 Returns: 

191 orm_models.TaskRun: the task run 

192 """ 

193 

194 model = await session.get(db.TaskRun, task_run_id) 1bca

195 return model 1bca

196 

197 

198@db_injector 1d

199async def read_task_run_with_flow_run_name( 1d

200 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID 

201) -> Union[orm_models.TaskRun, None]: 

202 """ 

203 Read a task run by id. 

204 

205 Args: 

206 session: a database session 

207 task_run_id: the task run id 

208 

209 Returns: 

210 orm_models.TaskRun: the task run with the flow run name 

211 """ 

212 

213 result = await session.execute( 1ba

214 select(orm_models.TaskRun, orm_models.FlowRun.name.label("flow_run_name")) 

215 .outerjoin( 

216 orm_models.FlowRun, orm_models.TaskRun.flow_run_id == orm_models.FlowRun.id 

217 ) 

218 .where(orm_models.TaskRun.id == task_run_id) 

219 ) 

220 row = result.first() 

221 if not row: 

222 return None 

223 

224 task_run = row[0] 

225 flow_run_name = row[1] 

226 if flow_run_name: 

227 setattr(task_run, "flow_run_name", flow_run_name) 

228 return task_run 

229 

230 

231async def _apply_task_run_filters( 1d

232 db: PrefectDBInterface, 

233 query: Select[T], 

234 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

235 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

236 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

237 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

238 work_pool_filter: Optional[schemas.filters.WorkPoolFilter] = None, 

239 work_queue_filter: Optional[schemas.filters.WorkQueueFilter] = None, 

240) -> Select[T]: 

241 """ 

242 Applies filters to a task run query as a combination of EXISTS subqueries. 

243 """ 

244 

245 if task_run_filter: 1bca

246 query = query.where(task_run_filter.as_sql_filter()) 1ba

247 

248 # Return a simplified query in the case that the request is ONLY asking to filter on flow_run_id (and task_run_filter) 

249 # In this case there's no need to generate the complex EXISTS subqueries; the generated query here is much more efficient 

250 if ( 250 ↛ 259line 250 didn't jump to line 259 because the condition on line 250 was never true

251 flow_run_filter 

252 and flow_run_filter.only_filters_on_id() 

253 and flow_run_filter.id 

254 and flow_run_filter.id.any_ 

255 and not any( 

256 [flow_filter, deployment_filter, work_pool_filter, work_queue_filter] 

257 ) 

258 ): 

259 query = query.where(db.TaskRun.flow_run_id.in_(flow_run_filter.id.any_)) 

260 

261 return query 

262 

263 if ( 

264 flow_filter 

265 or flow_run_filter 

266 or deployment_filter 

267 or work_pool_filter 

268 or work_queue_filter 

269 ): 

270 exists_clause = select(db.FlowRun).where( 1ba

271 db.FlowRun.id == db.TaskRun.flow_run_id 

272 ) 

273 

274 if flow_run_filter: 1ba

275 exists_clause = exists_clause.where(flow_run_filter.as_sql_filter()) 1ba

276 

277 if flow_filter: 1ba

278 exists_clause = exists_clause.join( 1ba

279 db.Flow, 

280 db.Flow.id == db.FlowRun.flow_id, 

281 ).where(flow_filter.as_sql_filter()) 

282 

283 if deployment_filter: 1ba

284 exists_clause = exists_clause.join( 1ba

285 db.Deployment, 

286 db.Deployment.id == db.FlowRun.deployment_id, 

287 ).where(deployment_filter.as_sql_filter()) 

288 

289 if work_queue_filter: 289 ↛ 290line 289 didn't jump to line 290 because the condition on line 289 was never true1ba

290 exists_clause = exists_clause.join( 

291 db.WorkQueue, 

292 db.WorkQueue.id == db.FlowRun.work_queue_id, 

293 ).where(work_queue_filter.as_sql_filter()) 

294 

295 if work_pool_filter: 295 ↛ 296line 295 didn't jump to line 296 because the condition on line 295 was never true1ba

296 exists_clause = exists_clause.join( 

297 db.WorkPool, 

298 sa.and_( 

299 db.WorkPool.id == db.WorkQueue.work_pool_id, 

300 db.WorkQueue.id == db.FlowRun.work_queue_id, 

301 ), 

302 ).where(work_pool_filter.as_sql_filter()) 

303 

304 query = query.where(exists_clause.exists()) 1ba

305 

306 return query 1bca

307 

308 

309@db_injector 1d

310async def read_task_runs( 1d

311 db: PrefectDBInterface, 

312 session: AsyncSession, 

313 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

314 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

315 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

316 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

317 offset: Optional[int] = None, 

318 limit: Optional[int] = None, 

319 sort: schemas.sorting.TaskRunSort = schemas.sorting.TaskRunSort.ID_DESC, 

320) -> Sequence[orm_models.TaskRun]: 

321 """ 

322 Read task runs. 

323 

324 Args: 

325 session: a database session 

326 flow_filter: only select task runs whose flows match these filters 

327 flow_run_filter: only select task runs whose flow runs match these filters 

328 task_run_filter: only select task runs that match these filters 

329 deployment_filter: only select task runs whose deployments match these filters 

330 offset: Query offset 

331 limit: Query limit 

332 sort: Query sort 

333 

334 Returns: 

335 List[orm_models.TaskRun]: the task runs 

336 """ 

337 

338 query = select(db.TaskRun).order_by(*sort.as_sql_sort()) 1bca

339 

340 query = await _apply_task_run_filters( 1bca

341 db, 

342 query, 

343 flow_filter=flow_filter, 

344 flow_run_filter=flow_run_filter, 

345 task_run_filter=task_run_filter, 

346 deployment_filter=deployment_filter, 

347 ) 

348 

349 if offset is not None: 349 ↛ 352line 349 didn't jump to line 352 because the condition on line 349 was always true1bca

350 query = query.offset(offset) 1bca

351 

352 if limit is not None: 352 ↛ 355line 352 didn't jump to line 355 because the condition on line 352 was always true1bca

353 query = query.limit(limit) 1bca

354 

355 logger.debug(f"In read_task_runs, query generated is:\n{query}") 1bca

356 result = await session.execute(query) 1bca

357 return result.scalars().unique().all() 

358 

359 

360@db_injector 1d

361async def count_task_runs( 1d

362 db: PrefectDBInterface, 

363 session: AsyncSession, 

364 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

365 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

366 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

367 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

368) -> int: 

369 """ 

370 Count task runs. 

371 

372 Args: 

373 session: a database session 

374 flow_filter: only count task runs whose flows match these filters 

375 flow_run_filter: only count task runs whose flow runs match these filters 

376 task_run_filter: only count task runs that match these filters 

377 deployment_filter: only count task runs whose deployments match these filters 

378 Returns: 

379 int: count of task runs 

380 """ 

381 

382 query = select(sa.func.count(None)).select_from(db.TaskRun) 1ba

383 

384 query = await _apply_task_run_filters( 1ba

385 db, 

386 query, 

387 flow_filter=flow_filter, 

388 flow_run_filter=flow_run_filter, 

389 task_run_filter=task_run_filter, 

390 deployment_filter=deployment_filter, 

391 ) 

392 

393 result = await session.execute(query) 1ba

394 return result.scalar_one() 

395 

396 

397@db_injector 1d

398async def count_task_runs_by_state( 1d

399 db: PrefectDBInterface, 

400 session: AsyncSession, 

401 flow_filter: Optional[schemas.filters.FlowFilter] = None, 

402 flow_run_filter: Optional[schemas.filters.FlowRunFilter] = None, 

403 task_run_filter: Optional[schemas.filters.TaskRunFilter] = None, 

404 deployment_filter: Optional[schemas.filters.DeploymentFilter] = None, 

405) -> schemas.states.CountByState: 

406 """ 

407 Count task runs by state. 

408 

409 Args: 

410 session: a database session 

411 flow_filter: only count task runs whose flows match these filters 

412 flow_run_filter: only count task runs whose flow runs match these filters 

413 task_run_filter: only count task runs that match these filters 

414 deployment_filter: only count task runs whose deployments match these filters 

415 Returns: 

416 schemas.states.CountByState: count of task runs by state 

417 """ 

418 

419 base_query = ( 1ba

420 select(db.TaskRun.state_type, sa.func.count(None).label("count")) 

421 .select_from(db.TaskRun) 

422 .group_by(db.TaskRun.state_type) 

423 ) 

424 

425 query = await _apply_task_run_filters( 1ba

426 db, 

427 base_query, 

428 flow_filter=flow_filter, 

429 flow_run_filter=flow_run_filter, 

430 task_run_filter=task_run_filter, 

431 deployment_filter=deployment_filter, 

432 ) 

433 

434 result = await session.execute(query) 1ba

435 

436 counts = schemas.states.CountByState() 

437 

438 for row in result: 

439 setattr(counts, row.state_type, row.count) 

440 

441 return counts 

442 

443 

444@db_injector 1d

445async def delete_task_run( 1d

446 db: PrefectDBInterface, session: AsyncSession, task_run_id: UUID 

447) -> bool: 

448 """ 

449 Delete a task run by id. 

450 

451 Args: 

452 session: a database session 

453 task_run_id: the task run id to delete 

454 

455 Returns: 

456 bool: whether or not the task run was deleted 

457 """ 

458 

459 result = await session.execute( 1ba

460 delete(db.TaskRun).where(db.TaskRun.id == task_run_id) 

461 ) 

462 return result.rowcount > 0 

463 

464 

465async def set_task_run_state( 1d

466 session: AsyncSession, 

467 task_run_id: UUID, 

468 state: schemas.states.State, 

469 force: bool = False, 

470 task_policy: Optional[Type[TaskRunOrchestrationPolicy]] = None, 

471 orchestration_parameters: Optional[Dict[str, Any]] = None, 

472) -> OrchestrationResult: 

473 """ 

474 Creates a new orchestrated task run state. 

475 

476 Setting a new state on a run is the one of the principal actions that is governed by 

477 Prefect's orchestration logic. Setting a new run state will not guarantee creation, 

478 but instead trigger orchestration rules to govern the proposed `state` input. If 

479 the state is considered valid, it will be written to the database. Otherwise, a 

480 it's possible a different state, or no state, will be created. A `force` flag is 

481 supplied to bypass a subset of orchestration logic. 

482 

483 Args: 

484 session: a database session 

485 task_run_id: the task run id 

486 state: a task run state model 

487 force: if False, orchestration rules will be applied that may alter or prevent 

488 the state transition. If True, orchestration rules are not applied. 

489 

490 Returns: 

491 OrchestrationResult object 

492 """ 

493 

494 # load the task run 

495 run = await models.task_runs.read_task_run(session=session, task_run_id=task_run_id) 1bca

496 

497 if not run: 497 ↛ 498line 497 didn't jump to line 498 because the condition on line 497 was never true1bca

498 raise ObjectNotFoundError(f"Task run with id {task_run_id} not found") 

499 

500 initial_state = run.state.as_state() if run.state else None 1bca

501 initial_state_type = initial_state.type if initial_state else None 1bca

502 proposed_state_type = state.type if state else None 1bca

503 intended_transition = (initial_state_type, proposed_state_type) 1bca

504 

505 if state.state_details.deferred: 505 ↛ 506line 505 didn't jump to line 506 because the condition on line 505 was never true1bca

506 task_policy = BackgroundTaskPolicy # CoreTaskPolicy + prevent `Running` -> `Running` transition 

507 elif force or task_policy is None: 507 ↛ 510line 507 didn't jump to line 510 because the condition on line 507 was always true1bca

508 task_policy = MinimalTaskPolicy 1bca

509 

510 orchestration_rules = task_policy.compile_transition_rules(*intended_transition) # type: ignore 1bca

511 global_rules = GlobalTaskPolicy.compile_transition_rules(*intended_transition) 1bca

512 

513 context = TaskOrchestrationContext( 1bca

514 session=session, 

515 run=run, 

516 initial_state=initial_state, 

517 proposed_state=state, 

518 ) 

519 

520 if orchestration_parameters is not None: 520 ↛ 524line 520 didn't jump to line 524 because the condition on line 520 was always true1bca

521 context.parameters = orchestration_parameters 1bca

522 

523 # apply orchestration rules and create the new task run state 

524 async with contextlib.AsyncExitStack() as stack: 1bca

525 for rule in orchestration_rules: 1bca

526 context = await stack.enter_async_context( 1bca

527 rule(context, *intended_transition) 

528 ) 

529 

530 for rule in global_rules: 1bca

531 context = await stack.enter_async_context( 1bca

532 rule(context, *intended_transition) 

533 ) 

534 

535 await context.validate_proposed_state() 1bca

536 

537 if context.orchestration_error is not None: 

538 raise context.orchestration_error 

539 

540 result = OrchestrationResult( 

541 state=context.validated_state, 

542 status=context.response_status, 

543 details=context.response_details, 

544 ) 

545 

546 return result 

547 

548 

549async def with_system_labels_for_task_run( 1d

550 session: AsyncSession, 

551 task_run: schemas.core.TaskRun, 

552) -> schemas.core.KeyValueLabels: 

553 """Augment user supplied labels with system default labels for a task 

554 run.""" 

555 

556 client_supplied_labels = task_run.labels or {} 1bca

557 default_labels = cast(schemas.core.KeyValueLabels, {}) 1bca

558 parent_labels: schemas.core.KeyValueLabels = {} 1bca

559 

560 if task_run.flow_run_id: 1bca

561 default_labels["prefect.flow-run.id"] = str(task_run.flow_run_id) 

562 flow_run = await models.flow_runs.read_flow_run( 

563 session=session, flow_run_id=task_run.flow_run_id 

564 ) 

565 parent_labels = flow_run.labels if flow_run and flow_run.labels else {} 

566 

567 return parent_labels | default_labels | client_supplied_labels 1bca