Coverage for polar/event/service.py: 10%

330 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-12-05 17:15 +0000

1import uuid 1a

2from collections import defaultdict 1a

3from collections.abc import Callable, Sequence 1a

4from datetime import date, datetime 1a

5from typing import Any 1a

6from zoneinfo import ZoneInfo 1a

7 

8import structlog 1a

9from sqlalchemy import ( 1a

10 Select, 

11 String, 

12 UnaryExpression, 

13 asc, 

14 cast, 

15 desc, 

16 func, 

17 or_, 

18 select, 

19 text, 

20) 

21from sqlalchemy.dialects.postgresql import insert 1a

22 

23from polar.auth.models import AuthSubject, is_organization, is_user 1a

24from polar.customer.repository import CustomerRepository 1a

25from polar.event_type.repository import EventTypeRepository 1a

26from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError 1a

27from polar.kit.metadata import MetadataQuery, apply_metadata_clause 1a

28from polar.kit.pagination import PaginationParams, paginate 1a

29from polar.kit.sorting import Sorting 1a

30from polar.kit.time_queries import TimeInterval, get_timestamp_series_cte 1a

31from polar.logging import Logger 1a

32from polar.meter.filter import Filter 1a

33from polar.meter.repository import MeterRepository 1a

34from polar.models import ( 1a

35 Customer, 

36 Event, 

37 EventClosure, 

38 Organization, 

39 User, 

40 UserOrganization, 

41) 

42from polar.models.event import EventSource 1a

43from polar.postgres import AsyncSession 1a

44from polar.worker import enqueue_events 1a

45 

46from .repository import EventRepository 1a

47from .schemas import ( 1a

48 EventCreateCustomer, 

49 EventName, 

50 EventsIngest, 

51 EventsIngestResponse, 

52 EventStatistics, 

53 ListStatisticsTimeseries, 

54 StatisticsPeriod, 

55) 

56from .sorting import EventNamesSortProperty, EventSortProperty 1a

57 

58log: Logger = structlog.get_logger() 1a

59 

60 

61class EventError(PolarError): ... 1a

62 

63 

64class EventIngestValidationError(EventError): 1a

65 def __init__(self, errors: list[ValidationError]) -> None: 1a

66 self.errors = errors 

67 super().__init__("Event ingest validation failed.") 

68 

69 

70def _topological_sort_events(events: list[dict[str, Any]]) -> list[dict[str, Any]]: 1a

71 """ 

72 Sort events by dependency order so parents come before children. 

73 Events without parents come first, followed by their children in order. 

74 

75 Handles parent_id references that can be either Polar IDs or external_id strings. 

76 Uses Kahn's algorithm for topological sorting. 

77 """ 

78 if not events: 

79 return [] 

80 

81 id_to_index: dict[uuid.UUID | str, int] = {} 

82 for idx, event in enumerate(events): 

83 if "id" in event: 

84 id_to_index[event["id"]] = idx 

85 if "external_id" in event and event["external_id"] is not None: 

86 id_to_index[event["external_id"]] = idx 

87 

88 graph: dict[int, list[int]] = defaultdict(list) 

89 in_degree: dict[int, int] = {} 

90 

91 for idx in range(len(events)): 

92 in_degree[idx] = 0 

93 

94 for idx, event in enumerate(events): 

95 parent_id = event.get("parent_id") 

96 if parent_id and parent_id in id_to_index: 

97 parent_idx = id_to_index[parent_id] 

98 graph[parent_idx].append(idx) 

99 in_degree[idx] += 1 

100 

101 queue = [idx for idx in range(len(events)) if in_degree[idx] == 0] 

102 sorted_indices = [] 

103 

104 while queue: 

105 current_idx = queue.pop(0) 

106 sorted_indices.append(current_idx) 

107 

108 for child_idx in graph[current_idx]: 

109 in_degree[child_idx] -= 1 

110 if in_degree[child_idx] == 0: 

111 queue.append(child_idx) 

112 

113 if len(sorted_indices) != len(events): 

114 raise EventError("Circular dependency detected in event parent relationships") 

115 

116 return [events[idx] for idx in sorted_indices] 

117 

118 

119class EventService: 1a

120 async def _build_filtered_statement( 1a

121 self, 

122 session: AsyncSession, 

123 auth_subject: AuthSubject[User | Organization], 

124 repository: EventRepository, 

125 *, 

126 filter: Filter | None = None, 

127 start_timestamp: datetime | None = None, 

128 end_timestamp: datetime | None = None, 

129 organization_id: Sequence[uuid.UUID] | None = None, 

130 customer_id: Sequence[uuid.UUID] | None = None, 

131 external_customer_id: Sequence[str] | None = None, 

132 meter_id: uuid.UUID | None = None, 

133 name: Sequence[str] | None = None, 

134 source: Sequence[EventSource] | None = None, 

135 event_type_id: uuid.UUID | None = None, 

136 metadata: MetadataQuery | None = None, 

137 sorting: Sequence[Sorting[EventSortProperty]] = ( 

138 (EventSortProperty.timestamp, True), 

139 ), 

140 query: str | None = None, 

141 ) -> Select[tuple[Event]]: 

142 statement = repository.get_readable_statement(auth_subject).options( 

143 *repository.get_eager_options() 

144 ) 

145 

146 if filter is not None: 

147 statement = statement.where(filter.get_sql_clause(Event)) 

148 

149 if start_timestamp is not None: 

150 statement = statement.where(Event.timestamp > start_timestamp) 

151 

152 if end_timestamp is not None: 

153 statement = statement.where(Event.timestamp < end_timestamp) 

154 

155 if organization_id is not None: 

156 statement = statement.where(Event.organization_id.in_(organization_id)) 

157 

158 if customer_id is not None: 

159 statement = statement.where( 

160 repository.get_customer_id_filter_clause(customer_id) 

161 ) 

162 

163 if external_customer_id is not None: 

164 statement = statement.where( 

165 repository.get_external_customer_id_filter_clause(external_customer_id) 

166 ) 

167 

168 if meter_id is not None: 

169 meter_repository = MeterRepository.from_session(session) 

170 meter = await meter_repository.get_readable_by_id(meter_id, auth_subject) 

171 if meter is None: 

172 raise PolarRequestValidationError( 

173 [ 

174 { 

175 "type": "meter_id", 

176 "msg": "Meter not found.", 

177 "loc": ("query", "meter_id"), 

178 "input": meter_id, 

179 } 

180 ] 

181 ) 

182 statement = statement.where(repository.get_meter_clause(meter)) 

183 

184 if name is not None: 

185 statement = statement.where(Event.name.in_(name)) 

186 

187 if source is not None: 

188 statement = statement.where(Event.source.in_(source)) 

189 

190 if event_type_id is not None: 

191 statement = statement.where(Event.event_type_id == event_type_id) 

192 

193 if query is not None: 

194 statement = statement.where( 

195 or_( 

196 Event.name.ilike(f"%{query}%"), 

197 Event.source.ilike(f"%{query}%"), 

198 # Load customers and match against their name/email 

199 Event.customer_id.in_( 

200 select(Customer.id).where( 

201 or_( 

202 cast(Customer.id, String).ilike(f"%{query}%"), 

203 Customer.external_id.ilike(f"%{query}%"), 

204 Customer.name.ilike(f"%{query}%"), 

205 Customer.email.ilike(f"%{query}%"), 

206 ) 

207 ) 

208 ), 

209 func.to_tsvector("simple", cast(Event.user_metadata, String)).op( 

210 "@@" 

211 )(func.plainto_tsquery(query)), 

212 ) 

213 ) 

214 

215 if metadata is not None: 

216 statement = apply_metadata_clause(Event, statement, metadata) 

217 

218 order_by_clauses: list[UnaryExpression[Any]] = [] 

219 for criterion, is_desc in sorting: 

220 clause_function = desc if is_desc else asc 

221 if criterion == EventSortProperty.timestamp: 

222 order_by_clauses.append(clause_function(Event.timestamp)) 

223 statement = statement.order_by(*order_by_clauses) 

224 

225 return statement 

226 

227 async def list( 1a

228 self, 

229 session: AsyncSession, 

230 auth_subject: AuthSubject[User | Organization], 

231 *, 

232 filter: Filter | None = None, 

233 start_timestamp: datetime | None = None, 

234 end_timestamp: datetime | None = None, 

235 organization_id: Sequence[uuid.UUID] | None = None, 

236 customer_id: Sequence[uuid.UUID] | None = None, 

237 external_customer_id: Sequence[str] | None = None, 

238 meter_id: uuid.UUID | None = None, 

239 name: Sequence[str] | None = None, 

240 source: Sequence[EventSource] | None = None, 

241 event_type_id: uuid.UUID | None = None, 

242 metadata: MetadataQuery | None = None, 

243 pagination: PaginationParams, 

244 sorting: Sequence[Sorting[EventSortProperty]] = ( 

245 (EventSortProperty.timestamp, True), 

246 ), 

247 query: str | None = None, 

248 parent_id: uuid.UUID | None = None, 

249 hierarchical: bool = False, 

250 aggregate_fields: Sequence[str] = (), 

251 ) -> tuple[Sequence[Event], int]: 

252 repository = EventRepository.from_session(session) 

253 statement = await self._build_filtered_statement( 

254 session, 

255 auth_subject, 

256 repository, 

257 filter=filter, 

258 start_timestamp=start_timestamp, 

259 end_timestamp=end_timestamp, 

260 organization_id=organization_id, 

261 customer_id=customer_id, 

262 external_customer_id=external_customer_id, 

263 meter_id=meter_id, 

264 name=name, 

265 source=source, 

266 event_type_id=event_type_id, 

267 metadata=metadata, 

268 sorting=sorting, 

269 query=query, 

270 ) 

271 

272 if hierarchical: 

273 if parent_id is not None: 

274 statement = statement.where(Event.parent_id == parent_id) 

275 else: 

276 statement = statement.where(Event.parent_id.is_(None)) 

277 

278 return await repository.list_with_closure_table( 

279 statement, 

280 limit=pagination.limit, 

281 page=pagination.page, 

282 aggregate_fields=aggregate_fields, 

283 ) 

284 

285 async def get( 1a

286 self, 

287 session: AsyncSession, 

288 auth_subject: AuthSubject[User | Organization], 

289 id: uuid.UUID, 

290 ) -> Event | None: 

291 repository = EventRepository.from_session(session) 

292 statement = ( 

293 repository.get_readable_statement(auth_subject) 

294 .where(Event.id == id) 

295 .options(*repository.get_eager_options()) 

296 ) 

297 return await repository.get_one_or_none(statement) 

298 

299 async def list_statistics_timeseries( 1a

300 self, 

301 session: AsyncSession, 

302 auth_subject: AuthSubject[User | Organization], 

303 *, 

304 start_date: date, 

305 end_date: date, 

306 timezone: ZoneInfo, 

307 interval: TimeInterval, 

308 filter: Filter | None = None, 

309 organization_id: Sequence[uuid.UUID] | None = None, 

310 customer_id: Sequence[uuid.UUID] | None = None, 

311 external_customer_id: Sequence[str] | None = None, 

312 meter_id: uuid.UUID | None = None, 

313 name: Sequence[str] | None = None, 

314 source: Sequence[EventSource] | None = None, 

315 event_type_id: uuid.UUID | None = None, 

316 metadata: MetadataQuery | None = None, 

317 sorting: Sequence[Sorting[EventSortProperty]] = ( 

318 (EventSortProperty.timestamp, True), 

319 ), 

320 query: str | None = None, 

321 aggregate_fields: Sequence[str] = ("_cost.amount",), 

322 hierarchy_stats_sorting: Sequence[tuple[str, bool]] = (("total", True),), 

323 ) -> ListStatisticsTimeseries: 

324 start_timestamp = datetime( 

325 start_date.year, start_date.month, start_date.day, 0, 0, 0, 0, timezone 

326 ) 

327 end_timestamp = datetime( 

328 end_date.year, end_date.month, end_date.day, 23, 59, 59, 999999, timezone 

329 ) 

330 

331 timestamp_series_cte = get_timestamp_series_cte( 

332 start_timestamp, end_timestamp, interval 

333 ) 

334 

335 repository = EventRepository.from_session(session) 

336 statement = await self._build_filtered_statement( 

337 session, 

338 auth_subject, 

339 repository, 

340 filter=filter, 

341 start_timestamp=start_timestamp, 

342 end_timestamp=end_timestamp, 

343 organization_id=organization_id, 

344 customer_id=customer_id, 

345 external_customer_id=external_customer_id, 

346 meter_id=meter_id, 

347 name=name, 

348 source=source, 

349 event_type_id=event_type_id, 

350 metadata=metadata, 

351 sorting=sorting, 

352 query=query, 

353 ) 

354 

355 timeseries_stats = await repository.get_hierarchy_stats( 

356 statement, 

357 aggregate_fields, 

358 hierarchy_stats_sorting, 

359 timestamp_series=timestamp_series_cte, 

360 ) 

361 

362 result = await session.execute(select(timestamp_series_cte.c.timestamp)) 

363 timestamps = [row[0] for row in result.all()] 

364 

365 stats_by_timestamp: dict[datetime, list[dict[str, Any]]] = {} 

366 all_event_types: dict[tuple[str, str, uuid.UUID], dict[str, Any]] = {} 

367 

368 for stat in timeseries_stats: 

369 ts = stat.pop("timestamp") 

370 if stat["name"] is None: 

371 continue 

372 if ts not in stats_by_timestamp: 

373 stats_by_timestamp[ts] = [] 

374 stats_by_timestamp[ts].append(stat) 

375 

376 # Track all unique event types 

377 event_key = (stat["name"], stat["label"], stat["event_type_id"]) 

378 if event_key not in all_event_types: 

379 all_event_types[event_key] = { 

380 "name": stat["name"], 

381 "label": stat["label"], 

382 "event_type_id": stat["event_type_id"], 

383 } 

384 

385 # Convert field names from dot notation to underscore (e.g., "_cost.amount" -> "_cost_amount") 

386 zero_values = {field.replace(".", "_"): "0" for field in aggregate_fields} 

387 

388 periods = [] 

389 for i, period_start in enumerate(timestamps): 

390 if i + 1 < len(timestamps): 

391 period_end = timestamps[i + 1] 

392 else: 

393 period_end = end_timestamp 

394 

395 period_stats = stats_by_timestamp.get(period_start, []) 

396 

397 # Fill in missing event types with zeros 

398 stats_by_name = {s["name"]: s for s in period_stats} 

399 complete_stats = [] 

400 for event_type_info in all_event_types.values(): 

401 if event_type_info["name"] in stats_by_name: 

402 complete_stats.append(stats_by_name[event_type_info["name"]]) 

403 else: 

404 complete_stats.append( 

405 { 

406 **event_type_info, 

407 "occurrences": 0, 

408 "totals": zero_values, 

409 "averages": zero_values, 

410 "p50": zero_values, 

411 "p95": zero_values, 

412 "p99": zero_values, 

413 } 

414 ) 

415 

416 periods.append( 

417 StatisticsPeriod( 

418 timestamp=period_start, 

419 period_start=period_start, 

420 period_end=period_end, 

421 stats=[EventStatistics(**s) for s in complete_stats], 

422 ) 

423 ) 

424 

425 totals = await repository.get_hierarchy_stats( 

426 statement, aggregate_fields, hierarchy_stats_sorting 

427 ) 

428 

429 return ListStatisticsTimeseries( 

430 periods=periods, 

431 totals=[EventStatistics(**s) for s in totals], 

432 ) 

433 

434 async def list_names( 1a

435 self, 

436 session: AsyncSession, 

437 auth_subject: AuthSubject[User | Organization], 

438 *, 

439 organization_id: Sequence[uuid.UUID] | None = None, 

440 customer_id: Sequence[uuid.UUID] | None = None, 

441 external_customer_id: Sequence[str] | None = None, 

442 source: Sequence[EventSource] | None = None, 

443 query: str | None = None, 

444 pagination: PaginationParams, 

445 sorting: Sequence[Sorting[EventNamesSortProperty]] = [ 

446 (EventNamesSortProperty.last_seen, True) 

447 ], 

448 ) -> tuple[Sequence[EventName], int]: 

449 repository = EventRepository.from_session(session) 

450 statement = repository.get_event_names_statement(auth_subject) 

451 

452 if organization_id is not None: 

453 statement = statement.where(Event.organization_id.in_(organization_id)) 

454 

455 if customer_id is not None: 

456 statement = statement.where( 

457 repository.get_customer_id_filter_clause(customer_id) 

458 ) 

459 

460 if external_customer_id is not None: 

461 statement = statement.where( 

462 repository.get_external_customer_id_filter_clause(external_customer_id) 

463 ) 

464 

465 if source is not None: 

466 statement = statement.where(Event.source.in_(source)) 

467 

468 if query is not None: 

469 statement = statement.where(Event.name.ilike(f"%{query}%")) 

470 

471 order_by_clauses: list[UnaryExpression[Any]] = [] 

472 for criterion, is_desc in sorting: 

473 clause_function = desc if is_desc else asc 

474 if criterion == EventNamesSortProperty.event_name: 

475 order_by_clauses.append(clause_function(Event.name)) 

476 elif criterion == EventNamesSortProperty.first_seen: 

477 order_by_clauses.append(clause_function(text("first_seen"))) 

478 elif criterion == EventNamesSortProperty.last_seen: 

479 order_by_clauses.append(clause_function(text("last_seen"))) 

480 elif criterion == EventNamesSortProperty.occurrences: 

481 order_by_clauses.append(clause_function(text("occurrences"))) 

482 statement = statement.order_by(*order_by_clauses) 

483 

484 results, count = await paginate(session, statement, pagination=pagination) 

485 

486 event_names: list[EventName] = [] 

487 for result in results: 

488 event_name, event_source, occurrences, first_seen, last_seen = result 

489 event_names.append( 

490 EventName( 

491 name=event_name, 

492 source=event_source, 

493 occurrences=occurrences, 

494 first_seen=first_seen, 

495 last_seen=last_seen, 

496 ) 

497 ) 

498 

499 return event_names, count 

500 

501 async def ingest( 1a

502 self, 

503 session: AsyncSession, 

504 auth_subject: AuthSubject[User | Organization], 

505 ingest: EventsIngest, 

506 ) -> EventsIngestResponse: 

507 validate_organization_id = await self._get_organization_validation_function( 

508 session, auth_subject 

509 ) 

510 validate_customer_id = await self._get_customer_validation_function( 

511 session, auth_subject 

512 ) 

513 

514 event_type_repository = EventTypeRepository.from_session(session) 

515 event_types_cache: dict[tuple[str, uuid.UUID], uuid.UUID] = {} 

516 

517 batch_external_id_map: dict[str, uuid.UUID] = {} 

518 for event_create in ingest.events: 

519 if event_create.external_id is not None: 

520 batch_external_id_map[event_create.external_id] = uuid.uuid4() 

521 

522 # Build lightweight event metadata for sorting 

523 event_metadata: list[dict[str, Any]] = [] 

524 for index, event_create in enumerate(ingest.events): 

525 metadata: dict[str, Any] = { 

526 "index": index, 

527 "external_id": event_create.external_id, 

528 "parent_id": event_create.parent_id, 

529 } 

530 if event_create.external_id: 

531 metadata["id"] = batch_external_id_map[event_create.external_id] 

532 event_metadata.append(metadata) 

533 

534 sorted_metadata = _topological_sort_events(event_metadata) 

535 

536 # Process events in sorted order 

537 events: list[dict[str, Any]] = [] 

538 errors: list[ValidationError] = [] 

539 processed_events: dict[uuid.UUID, dict[str, Any]] = {} 

540 

541 for metadata in sorted_metadata: 

542 index = metadata["index"] 

543 event_create = ingest.events[index] 

544 

545 try: 

546 organization_id = validate_organization_id( 

547 index, event_create.organization_id 

548 ) 

549 if isinstance(event_create, EventCreateCustomer): 

550 validate_customer_id(index, event_create.customer_id) 

551 

552 parent_event: Event | None = None 

553 parent_id_in_batch: uuid.UUID | None = None 

554 if event_create.parent_id is not None: 

555 parent_event, parent_id_in_batch = await self._resolve_parent( 

556 session, 

557 index, 

558 event_create.parent_id, 

559 organization_id, 

560 batch_external_id_map, 

561 ) 

562 

563 event_label_cache_key = (event_create.name, organization_id) 

564 if event_label_cache_key not in event_types_cache: 

565 event_type = await event_type_repository.get_or_create( 

566 event_create.name, organization_id 

567 ) 

568 event_types_cache[event_label_cache_key] = event_type.id 

569 event_type_id = event_types_cache[event_label_cache_key] 

570 except EventIngestValidationError as e: 

571 errors.extend(e.errors) 

572 continue 

573 else: 

574 event_dict = event_create.model_dump( 

575 exclude={"organization_id", "parent_id"}, by_alias=True 

576 ) 

577 event_dict["source"] = EventSource.user 

578 event_dict["organization_id"] = organization_id 

579 event_dict["event_type_id"] = event_type_id 

580 

581 if event_create.external_id is not None: 

582 event_dict["id"] = batch_external_id_map[event_create.external_id] 

583 

584 if parent_event is not None: 

585 event_dict["parent_id"] = parent_event.id 

586 event_dict["root_id"] = parent_event.root_id or parent_event.id 

587 elif parent_id_in_batch is not None: 

588 event_dict["parent_id"] = parent_id_in_batch 

589 # Parent was already processed, look it up 

590 parent_dict = processed_events.get(parent_id_in_batch) 

591 if parent_dict: 

592 event_dict["root_id"] = parent_dict.get( 

593 "root_id", parent_id_in_batch 

594 ) 

595 

596 events.append(event_dict) 

597 if event_dict.get("id"): 

598 processed_events[event_dict["id"]] = event_dict 

599 

600 if len(errors) > 0: 

601 raise PolarRequestValidationError(errors) 

602 

603 repository = EventRepository.from_session(session) 

604 event_ids, duplicates_count = await repository.insert_batch(events) 

605 

606 enqueue_events(*event_ids) 

607 

608 return EventsIngestResponse( 

609 inserted=len(event_ids), duplicates=duplicates_count 

610 ) 

611 

612 async def create_event(self, session: AsyncSession, event: Event) -> Event: 1a

613 repository = EventRepository.from_session(session) 

614 event = await repository.create(event, flush=True) 

615 enqueue_events(event.id) 

616 

617 log.debug( 

618 "Event created", 

619 id=event.id, 

620 name=event.name, 

621 source=event.source, 

622 metadata=event.user_metadata, 

623 ) 

624 return event 

625 

626 async def populate_event_closures_batch( 1a

627 self, session: AsyncSession, event_ids: Sequence[uuid.UUID] 

628 ) -> None: 

629 if not event_ids: 

630 return 

631 

632 result = await session.execute( 

633 select(Event.id, Event.parent_id).where(Event.id.in_(event_ids)) 

634 ) 

635 events_data = result.all() 

636 

637 events_list = [ 

638 {"id": event_id, "parent_id": parent_id} 

639 for event_id, parent_id in events_data 

640 ] 

641 sorted_events = _topological_sort_events(events_list) 

642 

643 all_closure_entries = [] 

644 # Map event_id -> list of its ancestor closures (including self) 

645 event_closures: dict[uuid.UUID, list[tuple[uuid.UUID, int]]] = {} 

646 

647 for event in sorted_events: 

648 event_id = event["id"] 

649 parent_id = event.get("parent_id") 

650 

651 # Self-reference 

652 event_closures[event_id] = [(event_id, 0)] 

653 all_closure_entries.append( 

654 { 

655 "ancestor_id": event_id, 

656 "descendant_id": event_id, 

657 "depth": 0, 

658 } 

659 ) 

660 

661 if parent_id is not None: 

662 # Check if parent is in current batch 

663 if parent_id in event_closures: 

664 # Parent is in current batch, use in-memory closures 

665 for ancestor_id, depth in event_closures[parent_id]: 

666 event_closures[event_id].append((ancestor_id, depth + 1)) 

667 all_closure_entries.append( 

668 { 

669 "ancestor_id": ancestor_id, 

670 "descendant_id": event_id, 

671 "depth": depth + 1, 

672 } 

673 ) 

674 else: 

675 # Parent is from previous batch, query database 

676 parent_closures_result = await session.execute( 

677 select( 

678 EventClosure.ancestor_id, 

679 EventClosure.depth, 

680 ).where(EventClosure.descendant_id == parent_id) 

681 ) 

682 

683 for ancestor_id, depth in parent_closures_result: 

684 event_closures[event_id].append((ancestor_id, depth + 1)) 

685 all_closure_entries.append( 

686 { 

687 "ancestor_id": ancestor_id, 

688 "descendant_id": event_id, 

689 "depth": depth + 1, 

690 } 

691 ) 

692 

693 # Single bulk insert 

694 if all_closure_entries: 

695 await session.execute( 

696 insert(EventClosure) 

697 .values(all_closure_entries) 

698 .on_conflict_do_nothing(index_elements=["ancestor_id", "descendant_id"]) 

699 ) 

700 

701 async def ingested( 1a

702 self, session: AsyncSession, event_ids: Sequence[uuid.UUID] 

703 ) -> None: 

704 await self.populate_event_closures_batch(session, event_ids) 

705 repository = EventRepository.from_session(session) 

706 statement = ( 

707 repository.get_base_statement() 

708 .where(Event.id.in_(event_ids), Event.customer.is_not(None)) 

709 .options(*repository.get_eager_options()) 

710 ) 

711 events = await repository.get_all(statement) 

712 customers: set[Customer] = set() 

713 for event in events: 

714 assert event.customer is not None 

715 customers.add(event.customer) 

716 

717 customer_repository = CustomerRepository.from_session(session) 

718 await customer_repository.touch_meters(customers) 

719 

720 async def _get_organization_validation_function( 1a

721 self, session: AsyncSession, auth_subject: AuthSubject[User | Organization] 

722 ) -> Callable[[int, uuid.UUID | None], uuid.UUID]: 

723 if is_organization(auth_subject): 

724 

725 def _validate_organization_id_by_organization( 

726 index: int, organization_id: uuid.UUID | None 

727 ) -> uuid.UUID: 

728 if organization_id is not None: 

729 raise EventIngestValidationError( 

730 [ 

731 { 

732 "type": "organization_token", 

733 "msg": ( 

734 "Setting organization_id is disallowed " 

735 "when using an organization token." 

736 ), 

737 "loc": ("body", "events", index, "organization_id"), 

738 "input": organization_id, 

739 } 

740 ] 

741 ) 

742 return auth_subject.subject.id 

743 

744 return _validate_organization_id_by_organization 

745 

746 statement = select(Organization.id).where( 

747 Organization.id.in_( 

748 select(UserOrganization.organization_id).where( 

749 UserOrganization.user_id == auth_subject.subject.id, 

750 UserOrganization.deleted_at.is_(None), 

751 ) 

752 ), 

753 ) 

754 result = await session.execute(statement) 

755 allowed_organizations = set(result.scalars().all()) 

756 

757 def _validate_organization_id_by_user( 

758 index: int, organization_id: uuid.UUID | None 

759 ) -> uuid.UUID: 

760 if organization_id is None: 

761 raise EventIngestValidationError( 

762 [ 

763 { 

764 "type": "missing", 

765 "msg": "organization_id is required.", 

766 "loc": ("body", "events", index, "organization_id"), 

767 "input": None, 

768 } 

769 ] 

770 ) 

771 if organization_id not in allowed_organizations: 

772 raise EventIngestValidationError( 

773 [ 

774 { 

775 "type": "organization_id", 

776 "msg": "Organization not found.", 

777 "loc": ("body", "events", index, "organization_id"), 

778 "input": organization_id, 

779 } 

780 ] 

781 ) 

782 

783 return organization_id 

784 

785 return _validate_organization_id_by_user 

786 

787 async def _get_customer_validation_function( 1a

788 self, session: AsyncSession, auth_subject: AuthSubject[User | Organization] 

789 ) -> Callable[[int, uuid.UUID], uuid.UUID]: 

790 statement = select(Customer.id).where(Customer.deleted_at.is_(None)) 

791 if is_user(auth_subject): 

792 statement = statement.where( 

793 Customer.organization_id.in_( 

794 select(UserOrganization.organization_id).where( 

795 UserOrganization.user_id == auth_subject.subject.id, 

796 UserOrganization.deleted_at.is_(None), 

797 ) 

798 ) 

799 ) 

800 else: 

801 statement = statement.where( 

802 Customer.organization_id == auth_subject.subject.id 

803 ) 

804 result = await session.execute(statement) 

805 allowed_customers = set(result.scalars().all()) 

806 

807 def _validate_customer_id(index: int, customer_id: uuid.UUID) -> uuid.UUID: 

808 if customer_id not in allowed_customers: 

809 raise EventIngestValidationError( 

810 [ 

811 { 

812 "type": "customer_id", 

813 "msg": "Customer not found.", 

814 "loc": ("body", "events", index, "customer_id"), 

815 "input": customer_id, 

816 } 

817 ] 

818 ) 

819 

820 return customer_id 

821 

822 return _validate_customer_id 

823 

824 async def _resolve_parent( 1a

825 self, 

826 session: AsyncSession, 

827 index: int, 

828 parent_id: str, 

829 organization_id: uuid.UUID, 

830 batch_external_id_map: dict[str, uuid.UUID], 

831 ) -> tuple[Event | None, uuid.UUID | None]: 

832 """ 

833 Resolve and return the parent event. 

834 Returns a tuple of (parent_event_from_db, parent_id_from_batch). 

835 Only one of these will be set - if the parent is in the current batch, 

836 parent_id_from_batch will be set. Otherwise, parent_event_from_db will be set. 

837 """ 

838 # Check if parent is in current batch 

839 if parent_id in batch_external_id_map: 

840 return None, batch_external_id_map[parent_id] 

841 

842 # Look up parent in database by ID or external_id 

843 try: 

844 parent_uuid = uuid.UUID(parent_id) 

845 except ValueError: 

846 parent_uuid = None 

847 

848 if parent_uuid: 

849 statement = select(Event).where( 

850 Event.organization_id == organization_id, 

851 or_(Event.id == parent_uuid, Event.external_id == parent_id), 

852 ) 

853 else: 

854 statement = select(Event).where( 

855 Event.organization_id == organization_id, 

856 Event.external_id == parent_id, 

857 ) 

858 

859 result = await session.execute(statement) 

860 parent_event = result.scalar_one_or_none() 

861 

862 if parent_event is not None: 

863 return parent_event, None 

864 

865 raise EventIngestValidationError( 

866 [ 

867 { 

868 "type": "parent_id", 

869 "msg": "Parent event not found.", 

870 "loc": ("body", "events", index, "parent_id"), 

871 "input": parent_id, 

872 } 

873 ] 

874 ) 

875 

876 

877event = EventService() 1a