Coverage for polar/event/service.py: 10%
330 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-12-05 16:17 +0000
1import uuid 1a
2from collections import defaultdict 1a
3from collections.abc import Callable, Sequence 1a
4from datetime import date, datetime 1a
5from typing import Any 1a
6from zoneinfo import ZoneInfo 1a
8import structlog 1a
9from sqlalchemy import ( 1a
10 Select,
11 String,
12 UnaryExpression,
13 asc,
14 cast,
15 desc,
16 func,
17 or_,
18 select,
19 text,
20)
21from sqlalchemy.dialects.postgresql import insert 1a
23from polar.auth.models import AuthSubject, is_organization, is_user 1a
24from polar.customer.repository import CustomerRepository 1a
25from polar.event_type.repository import EventTypeRepository 1a
26from polar.exceptions import PolarError, PolarRequestValidationError, ValidationError 1a
27from polar.kit.metadata import MetadataQuery, apply_metadata_clause 1a
28from polar.kit.pagination import PaginationParams, paginate 1a
29from polar.kit.sorting import Sorting 1a
30from polar.kit.time_queries import TimeInterval, get_timestamp_series_cte 1a
31from polar.logging import Logger 1a
32from polar.meter.filter import Filter 1a
33from polar.meter.repository import MeterRepository 1a
34from polar.models import ( 1a
35 Customer,
36 Event,
37 EventClosure,
38 Organization,
39 User,
40 UserOrganization,
41)
42from polar.models.event import EventSource 1a
43from polar.postgres import AsyncSession 1a
44from polar.worker import enqueue_events 1a
46from .repository import EventRepository 1a
47from .schemas import ( 1a
48 EventCreateCustomer,
49 EventName,
50 EventsIngest,
51 EventsIngestResponse,
52 EventStatistics,
53 ListStatisticsTimeseries,
54 StatisticsPeriod,
55)
56from .sorting import EventNamesSortProperty, EventSortProperty 1a
58log: Logger = structlog.get_logger() 1a
61class EventError(PolarError): ... 1a
64class EventIngestValidationError(EventError): 1a
65 def __init__(self, errors: list[ValidationError]) -> None: 1a
66 self.errors = errors
67 super().__init__("Event ingest validation failed.")
70def _topological_sort_events(events: list[dict[str, Any]]) -> list[dict[str, Any]]: 1a
71 """
72 Sort events by dependency order so parents come before children.
73 Events without parents come first, followed by their children in order.
75 Handles parent_id references that can be either Polar IDs or external_id strings.
76 Uses Kahn's algorithm for topological sorting.
77 """
78 if not events:
79 return []
81 id_to_index: dict[uuid.UUID | str, int] = {}
82 for idx, event in enumerate(events):
83 if "id" in event:
84 id_to_index[event["id"]] = idx
85 if "external_id" in event and event["external_id"] is not None:
86 id_to_index[event["external_id"]] = idx
88 graph: dict[int, list[int]] = defaultdict(list)
89 in_degree: dict[int, int] = {}
91 for idx in range(len(events)):
92 in_degree[idx] = 0
94 for idx, event in enumerate(events):
95 parent_id = event.get("parent_id")
96 if parent_id and parent_id in id_to_index:
97 parent_idx = id_to_index[parent_id]
98 graph[parent_idx].append(idx)
99 in_degree[idx] += 1
101 queue = [idx for idx in range(len(events)) if in_degree[idx] == 0]
102 sorted_indices = []
104 while queue:
105 current_idx = queue.pop(0)
106 sorted_indices.append(current_idx)
108 for child_idx in graph[current_idx]:
109 in_degree[child_idx] -= 1
110 if in_degree[child_idx] == 0:
111 queue.append(child_idx)
113 if len(sorted_indices) != len(events):
114 raise EventError("Circular dependency detected in event parent relationships")
116 return [events[idx] for idx in sorted_indices]
119class EventService: 1a
120 async def _build_filtered_statement( 1a
121 self,
122 session: AsyncSession,
123 auth_subject: AuthSubject[User | Organization],
124 repository: EventRepository,
125 *,
126 filter: Filter | None = None,
127 start_timestamp: datetime | None = None,
128 end_timestamp: datetime | None = None,
129 organization_id: Sequence[uuid.UUID] | None = None,
130 customer_id: Sequence[uuid.UUID] | None = None,
131 external_customer_id: Sequence[str] | None = None,
132 meter_id: uuid.UUID | None = None,
133 name: Sequence[str] | None = None,
134 source: Sequence[EventSource] | None = None,
135 event_type_id: uuid.UUID | None = None,
136 metadata: MetadataQuery | None = None,
137 sorting: Sequence[Sorting[EventSortProperty]] = (
138 (EventSortProperty.timestamp, True),
139 ),
140 query: str | None = None,
141 ) -> Select[tuple[Event]]:
142 statement = repository.get_readable_statement(auth_subject).options(
143 *repository.get_eager_options()
144 )
146 if filter is not None:
147 statement = statement.where(filter.get_sql_clause(Event))
149 if start_timestamp is not None:
150 statement = statement.where(Event.timestamp > start_timestamp)
152 if end_timestamp is not None:
153 statement = statement.where(Event.timestamp < end_timestamp)
155 if organization_id is not None:
156 statement = statement.where(Event.organization_id.in_(organization_id))
158 if customer_id is not None:
159 statement = statement.where(
160 repository.get_customer_id_filter_clause(customer_id)
161 )
163 if external_customer_id is not None:
164 statement = statement.where(
165 repository.get_external_customer_id_filter_clause(external_customer_id)
166 )
168 if meter_id is not None:
169 meter_repository = MeterRepository.from_session(session)
170 meter = await meter_repository.get_readable_by_id(meter_id, auth_subject)
171 if meter is None:
172 raise PolarRequestValidationError(
173 [
174 {
175 "type": "meter_id",
176 "msg": "Meter not found.",
177 "loc": ("query", "meter_id"),
178 "input": meter_id,
179 }
180 ]
181 )
182 statement = statement.where(repository.get_meter_clause(meter))
184 if name is not None:
185 statement = statement.where(Event.name.in_(name))
187 if source is not None:
188 statement = statement.where(Event.source.in_(source))
190 if event_type_id is not None:
191 statement = statement.where(Event.event_type_id == event_type_id)
193 if query is not None:
194 statement = statement.where(
195 or_(
196 Event.name.ilike(f"%{query}%"),
197 Event.source.ilike(f"%{query}%"),
198 # Load customers and match against their name/email
199 Event.customer_id.in_(
200 select(Customer.id).where(
201 or_(
202 cast(Customer.id, String).ilike(f"%{query}%"),
203 Customer.external_id.ilike(f"%{query}%"),
204 Customer.name.ilike(f"%{query}%"),
205 Customer.email.ilike(f"%{query}%"),
206 )
207 )
208 ),
209 func.to_tsvector("simple", cast(Event.user_metadata, String)).op(
210 "@@"
211 )(func.plainto_tsquery(query)),
212 )
213 )
215 if metadata is not None:
216 statement = apply_metadata_clause(Event, statement, metadata)
218 order_by_clauses: list[UnaryExpression[Any]] = []
219 for criterion, is_desc in sorting:
220 clause_function = desc if is_desc else asc
221 if criterion == EventSortProperty.timestamp:
222 order_by_clauses.append(clause_function(Event.timestamp))
223 statement = statement.order_by(*order_by_clauses)
225 return statement
227 async def list( 1a
228 self,
229 session: AsyncSession,
230 auth_subject: AuthSubject[User | Organization],
231 *,
232 filter: Filter | None = None,
233 start_timestamp: datetime | None = None,
234 end_timestamp: datetime | None = None,
235 organization_id: Sequence[uuid.UUID] | None = None,
236 customer_id: Sequence[uuid.UUID] | None = None,
237 external_customer_id: Sequence[str] | None = None,
238 meter_id: uuid.UUID | None = None,
239 name: Sequence[str] | None = None,
240 source: Sequence[EventSource] | None = None,
241 event_type_id: uuid.UUID | None = None,
242 metadata: MetadataQuery | None = None,
243 pagination: PaginationParams,
244 sorting: Sequence[Sorting[EventSortProperty]] = (
245 (EventSortProperty.timestamp, True),
246 ),
247 query: str | None = None,
248 parent_id: uuid.UUID | None = None,
249 hierarchical: bool = False,
250 aggregate_fields: Sequence[str] = (),
251 ) -> tuple[Sequence[Event], int]:
252 repository = EventRepository.from_session(session)
253 statement = await self._build_filtered_statement(
254 session,
255 auth_subject,
256 repository,
257 filter=filter,
258 start_timestamp=start_timestamp,
259 end_timestamp=end_timestamp,
260 organization_id=organization_id,
261 customer_id=customer_id,
262 external_customer_id=external_customer_id,
263 meter_id=meter_id,
264 name=name,
265 source=source,
266 event_type_id=event_type_id,
267 metadata=metadata,
268 sorting=sorting,
269 query=query,
270 )
272 if hierarchical:
273 if parent_id is not None:
274 statement = statement.where(Event.parent_id == parent_id)
275 else:
276 statement = statement.where(Event.parent_id.is_(None))
278 return await repository.list_with_closure_table(
279 statement,
280 limit=pagination.limit,
281 page=pagination.page,
282 aggregate_fields=aggregate_fields,
283 )
285 async def get( 1a
286 self,
287 session: AsyncSession,
288 auth_subject: AuthSubject[User | Organization],
289 id: uuid.UUID,
290 ) -> Event | None:
291 repository = EventRepository.from_session(session)
292 statement = (
293 repository.get_readable_statement(auth_subject)
294 .where(Event.id == id)
295 .options(*repository.get_eager_options())
296 )
297 return await repository.get_one_or_none(statement)
299 async def list_statistics_timeseries( 1a
300 self,
301 session: AsyncSession,
302 auth_subject: AuthSubject[User | Organization],
303 *,
304 start_date: date,
305 end_date: date,
306 timezone: ZoneInfo,
307 interval: TimeInterval,
308 filter: Filter | None = None,
309 organization_id: Sequence[uuid.UUID] | None = None,
310 customer_id: Sequence[uuid.UUID] | None = None,
311 external_customer_id: Sequence[str] | None = None,
312 meter_id: uuid.UUID | None = None,
313 name: Sequence[str] | None = None,
314 source: Sequence[EventSource] | None = None,
315 event_type_id: uuid.UUID | None = None,
316 metadata: MetadataQuery | None = None,
317 sorting: Sequence[Sorting[EventSortProperty]] = (
318 (EventSortProperty.timestamp, True),
319 ),
320 query: str | None = None,
321 aggregate_fields: Sequence[str] = ("_cost.amount",),
322 hierarchy_stats_sorting: Sequence[tuple[str, bool]] = (("total", True),),
323 ) -> ListStatisticsTimeseries:
324 start_timestamp = datetime(
325 start_date.year, start_date.month, start_date.day, 0, 0, 0, 0, timezone
326 )
327 end_timestamp = datetime(
328 end_date.year, end_date.month, end_date.day, 23, 59, 59, 999999, timezone
329 )
331 timestamp_series_cte = get_timestamp_series_cte(
332 start_timestamp, end_timestamp, interval
333 )
335 repository = EventRepository.from_session(session)
336 statement = await self._build_filtered_statement(
337 session,
338 auth_subject,
339 repository,
340 filter=filter,
341 start_timestamp=start_timestamp,
342 end_timestamp=end_timestamp,
343 organization_id=organization_id,
344 customer_id=customer_id,
345 external_customer_id=external_customer_id,
346 meter_id=meter_id,
347 name=name,
348 source=source,
349 event_type_id=event_type_id,
350 metadata=metadata,
351 sorting=sorting,
352 query=query,
353 )
355 timeseries_stats = await repository.get_hierarchy_stats(
356 statement,
357 aggregate_fields,
358 hierarchy_stats_sorting,
359 timestamp_series=timestamp_series_cte,
360 )
362 result = await session.execute(select(timestamp_series_cte.c.timestamp))
363 timestamps = [row[0] for row in result.all()]
365 stats_by_timestamp: dict[datetime, list[dict[str, Any]]] = {}
366 all_event_types: dict[tuple[str, str, uuid.UUID], dict[str, Any]] = {}
368 for stat in timeseries_stats:
369 ts = stat.pop("timestamp")
370 if stat["name"] is None:
371 continue
372 if ts not in stats_by_timestamp:
373 stats_by_timestamp[ts] = []
374 stats_by_timestamp[ts].append(stat)
376 # Track all unique event types
377 event_key = (stat["name"], stat["label"], stat["event_type_id"])
378 if event_key not in all_event_types:
379 all_event_types[event_key] = {
380 "name": stat["name"],
381 "label": stat["label"],
382 "event_type_id": stat["event_type_id"],
383 }
385 # Convert field names from dot notation to underscore (e.g., "_cost.amount" -> "_cost_amount")
386 zero_values = {field.replace(".", "_"): "0" for field in aggregate_fields}
388 periods = []
389 for i, period_start in enumerate(timestamps):
390 if i + 1 < len(timestamps):
391 period_end = timestamps[i + 1]
392 else:
393 period_end = end_timestamp
395 period_stats = stats_by_timestamp.get(period_start, [])
397 # Fill in missing event types with zeros
398 stats_by_name = {s["name"]: s for s in period_stats}
399 complete_stats = []
400 for event_type_info in all_event_types.values():
401 if event_type_info["name"] in stats_by_name:
402 complete_stats.append(stats_by_name[event_type_info["name"]])
403 else:
404 complete_stats.append(
405 {
406 **event_type_info,
407 "occurrences": 0,
408 "totals": zero_values,
409 "averages": zero_values,
410 "p50": zero_values,
411 "p95": zero_values,
412 "p99": zero_values,
413 }
414 )
416 periods.append(
417 StatisticsPeriod(
418 timestamp=period_start,
419 period_start=period_start,
420 period_end=period_end,
421 stats=[EventStatistics(**s) for s in complete_stats],
422 )
423 )
425 totals = await repository.get_hierarchy_stats(
426 statement, aggregate_fields, hierarchy_stats_sorting
427 )
429 return ListStatisticsTimeseries(
430 periods=periods,
431 totals=[EventStatistics(**s) for s in totals],
432 )
434 async def list_names( 1a
435 self,
436 session: AsyncSession,
437 auth_subject: AuthSubject[User | Organization],
438 *,
439 organization_id: Sequence[uuid.UUID] | None = None,
440 customer_id: Sequence[uuid.UUID] | None = None,
441 external_customer_id: Sequence[str] | None = None,
442 source: Sequence[EventSource] | None = None,
443 query: str | None = None,
444 pagination: PaginationParams,
445 sorting: Sequence[Sorting[EventNamesSortProperty]] = [
446 (EventNamesSortProperty.last_seen, True)
447 ],
448 ) -> tuple[Sequence[EventName], int]:
449 repository = EventRepository.from_session(session)
450 statement = repository.get_event_names_statement(auth_subject)
452 if organization_id is not None:
453 statement = statement.where(Event.organization_id.in_(organization_id))
455 if customer_id is not None:
456 statement = statement.where(
457 repository.get_customer_id_filter_clause(customer_id)
458 )
460 if external_customer_id is not None:
461 statement = statement.where(
462 repository.get_external_customer_id_filter_clause(external_customer_id)
463 )
465 if source is not None:
466 statement = statement.where(Event.source.in_(source))
468 if query is not None:
469 statement = statement.where(Event.name.ilike(f"%{query}%"))
471 order_by_clauses: list[UnaryExpression[Any]] = []
472 for criterion, is_desc in sorting:
473 clause_function = desc if is_desc else asc
474 if criterion == EventNamesSortProperty.event_name:
475 order_by_clauses.append(clause_function(Event.name))
476 elif criterion == EventNamesSortProperty.first_seen:
477 order_by_clauses.append(clause_function(text("first_seen")))
478 elif criterion == EventNamesSortProperty.last_seen:
479 order_by_clauses.append(clause_function(text("last_seen")))
480 elif criterion == EventNamesSortProperty.occurrences:
481 order_by_clauses.append(clause_function(text("occurrences")))
482 statement = statement.order_by(*order_by_clauses)
484 results, count = await paginate(session, statement, pagination=pagination)
486 event_names: list[EventName] = []
487 for result in results:
488 event_name, event_source, occurrences, first_seen, last_seen = result
489 event_names.append(
490 EventName(
491 name=event_name,
492 source=event_source,
493 occurrences=occurrences,
494 first_seen=first_seen,
495 last_seen=last_seen,
496 )
497 )
499 return event_names, count
501 async def ingest( 1a
502 self,
503 session: AsyncSession,
504 auth_subject: AuthSubject[User | Organization],
505 ingest: EventsIngest,
506 ) -> EventsIngestResponse:
507 validate_organization_id = await self._get_organization_validation_function(
508 session, auth_subject
509 )
510 validate_customer_id = await self._get_customer_validation_function(
511 session, auth_subject
512 )
514 event_type_repository = EventTypeRepository.from_session(session)
515 event_types_cache: dict[tuple[str, uuid.UUID], uuid.UUID] = {}
517 batch_external_id_map: dict[str, uuid.UUID] = {}
518 for event_create in ingest.events:
519 if event_create.external_id is not None:
520 batch_external_id_map[event_create.external_id] = uuid.uuid4()
522 # Build lightweight event metadata for sorting
523 event_metadata: list[dict[str, Any]] = []
524 for index, event_create in enumerate(ingest.events):
525 metadata: dict[str, Any] = {
526 "index": index,
527 "external_id": event_create.external_id,
528 "parent_id": event_create.parent_id,
529 }
530 if event_create.external_id:
531 metadata["id"] = batch_external_id_map[event_create.external_id]
532 event_metadata.append(metadata)
534 sorted_metadata = _topological_sort_events(event_metadata)
536 # Process events in sorted order
537 events: list[dict[str, Any]] = []
538 errors: list[ValidationError] = []
539 processed_events: dict[uuid.UUID, dict[str, Any]] = {}
541 for metadata in sorted_metadata:
542 index = metadata["index"]
543 event_create = ingest.events[index]
545 try:
546 organization_id = validate_organization_id(
547 index, event_create.organization_id
548 )
549 if isinstance(event_create, EventCreateCustomer):
550 validate_customer_id(index, event_create.customer_id)
552 parent_event: Event | None = None
553 parent_id_in_batch: uuid.UUID | None = None
554 if event_create.parent_id is not None:
555 parent_event, parent_id_in_batch = await self._resolve_parent(
556 session,
557 index,
558 event_create.parent_id,
559 organization_id,
560 batch_external_id_map,
561 )
563 event_label_cache_key = (event_create.name, organization_id)
564 if event_label_cache_key not in event_types_cache:
565 event_type = await event_type_repository.get_or_create(
566 event_create.name, organization_id
567 )
568 event_types_cache[event_label_cache_key] = event_type.id
569 event_type_id = event_types_cache[event_label_cache_key]
570 except EventIngestValidationError as e:
571 errors.extend(e.errors)
572 continue
573 else:
574 event_dict = event_create.model_dump(
575 exclude={"organization_id", "parent_id"}, by_alias=True
576 )
577 event_dict["source"] = EventSource.user
578 event_dict["organization_id"] = organization_id
579 event_dict["event_type_id"] = event_type_id
581 if event_create.external_id is not None:
582 event_dict["id"] = batch_external_id_map[event_create.external_id]
584 if parent_event is not None:
585 event_dict["parent_id"] = parent_event.id
586 event_dict["root_id"] = parent_event.root_id or parent_event.id
587 elif parent_id_in_batch is not None:
588 event_dict["parent_id"] = parent_id_in_batch
589 # Parent was already processed, look it up
590 parent_dict = processed_events.get(parent_id_in_batch)
591 if parent_dict:
592 event_dict["root_id"] = parent_dict.get(
593 "root_id", parent_id_in_batch
594 )
596 events.append(event_dict)
597 if event_dict.get("id"):
598 processed_events[event_dict["id"]] = event_dict
600 if len(errors) > 0:
601 raise PolarRequestValidationError(errors)
603 repository = EventRepository.from_session(session)
604 event_ids, duplicates_count = await repository.insert_batch(events)
606 enqueue_events(*event_ids)
608 return EventsIngestResponse(
609 inserted=len(event_ids), duplicates=duplicates_count
610 )
612 async def create_event(self, session: AsyncSession, event: Event) -> Event: 1a
613 repository = EventRepository.from_session(session)
614 event = await repository.create(event, flush=True)
615 enqueue_events(event.id)
617 log.debug(
618 "Event created",
619 id=event.id,
620 name=event.name,
621 source=event.source,
622 metadata=event.user_metadata,
623 )
624 return event
626 async def populate_event_closures_batch( 1a
627 self, session: AsyncSession, event_ids: Sequence[uuid.UUID]
628 ) -> None:
629 if not event_ids:
630 return
632 result = await session.execute(
633 select(Event.id, Event.parent_id).where(Event.id.in_(event_ids))
634 )
635 events_data = result.all()
637 events_list = [
638 {"id": event_id, "parent_id": parent_id}
639 for event_id, parent_id in events_data
640 ]
641 sorted_events = _topological_sort_events(events_list)
643 all_closure_entries = []
644 # Map event_id -> list of its ancestor closures (including self)
645 event_closures: dict[uuid.UUID, list[tuple[uuid.UUID, int]]] = {}
647 for event in sorted_events:
648 event_id = event["id"]
649 parent_id = event.get("parent_id")
651 # Self-reference
652 event_closures[event_id] = [(event_id, 0)]
653 all_closure_entries.append(
654 {
655 "ancestor_id": event_id,
656 "descendant_id": event_id,
657 "depth": 0,
658 }
659 )
661 if parent_id is not None:
662 # Check if parent is in current batch
663 if parent_id in event_closures:
664 # Parent is in current batch, use in-memory closures
665 for ancestor_id, depth in event_closures[parent_id]:
666 event_closures[event_id].append((ancestor_id, depth + 1))
667 all_closure_entries.append(
668 {
669 "ancestor_id": ancestor_id,
670 "descendant_id": event_id,
671 "depth": depth + 1,
672 }
673 )
674 else:
675 # Parent is from previous batch, query database
676 parent_closures_result = await session.execute(
677 select(
678 EventClosure.ancestor_id,
679 EventClosure.depth,
680 ).where(EventClosure.descendant_id == parent_id)
681 )
683 for ancestor_id, depth in parent_closures_result:
684 event_closures[event_id].append((ancestor_id, depth + 1))
685 all_closure_entries.append(
686 {
687 "ancestor_id": ancestor_id,
688 "descendant_id": event_id,
689 "depth": depth + 1,
690 }
691 )
693 # Single bulk insert
694 if all_closure_entries:
695 await session.execute(
696 insert(EventClosure)
697 .values(all_closure_entries)
698 .on_conflict_do_nothing(index_elements=["ancestor_id", "descendant_id"])
699 )
701 async def ingested( 1a
702 self, session: AsyncSession, event_ids: Sequence[uuid.UUID]
703 ) -> None:
704 await self.populate_event_closures_batch(session, event_ids)
705 repository = EventRepository.from_session(session)
706 statement = (
707 repository.get_base_statement()
708 .where(Event.id.in_(event_ids), Event.customer.is_not(None))
709 .options(*repository.get_eager_options())
710 )
711 events = await repository.get_all(statement)
712 customers: set[Customer] = set()
713 for event in events:
714 assert event.customer is not None
715 customers.add(event.customer)
717 customer_repository = CustomerRepository.from_session(session)
718 await customer_repository.touch_meters(customers)
720 async def _get_organization_validation_function( 1a
721 self, session: AsyncSession, auth_subject: AuthSubject[User | Organization]
722 ) -> Callable[[int, uuid.UUID | None], uuid.UUID]:
723 if is_organization(auth_subject):
725 def _validate_organization_id_by_organization(
726 index: int, organization_id: uuid.UUID | None
727 ) -> uuid.UUID:
728 if organization_id is not None:
729 raise EventIngestValidationError(
730 [
731 {
732 "type": "organization_token",
733 "msg": (
734 "Setting organization_id is disallowed "
735 "when using an organization token."
736 ),
737 "loc": ("body", "events", index, "organization_id"),
738 "input": organization_id,
739 }
740 ]
741 )
742 return auth_subject.subject.id
744 return _validate_organization_id_by_organization
746 statement = select(Organization.id).where(
747 Organization.id.in_(
748 select(UserOrganization.organization_id).where(
749 UserOrganization.user_id == auth_subject.subject.id,
750 UserOrganization.deleted_at.is_(None),
751 )
752 ),
753 )
754 result = await session.execute(statement)
755 allowed_organizations = set(result.scalars().all())
757 def _validate_organization_id_by_user(
758 index: int, organization_id: uuid.UUID | None
759 ) -> uuid.UUID:
760 if organization_id is None:
761 raise EventIngestValidationError(
762 [
763 {
764 "type": "missing",
765 "msg": "organization_id is required.",
766 "loc": ("body", "events", index, "organization_id"),
767 "input": None,
768 }
769 ]
770 )
771 if organization_id not in allowed_organizations:
772 raise EventIngestValidationError(
773 [
774 {
775 "type": "organization_id",
776 "msg": "Organization not found.",
777 "loc": ("body", "events", index, "organization_id"),
778 "input": organization_id,
779 }
780 ]
781 )
783 return organization_id
785 return _validate_organization_id_by_user
787 async def _get_customer_validation_function( 1a
788 self, session: AsyncSession, auth_subject: AuthSubject[User | Organization]
789 ) -> Callable[[int, uuid.UUID], uuid.UUID]:
790 statement = select(Customer.id).where(Customer.deleted_at.is_(None))
791 if is_user(auth_subject):
792 statement = statement.where(
793 Customer.organization_id.in_(
794 select(UserOrganization.organization_id).where(
795 UserOrganization.user_id == auth_subject.subject.id,
796 UserOrganization.deleted_at.is_(None),
797 )
798 )
799 )
800 else:
801 statement = statement.where(
802 Customer.organization_id == auth_subject.subject.id
803 )
804 result = await session.execute(statement)
805 allowed_customers = set(result.scalars().all())
807 def _validate_customer_id(index: int, customer_id: uuid.UUID) -> uuid.UUID:
808 if customer_id not in allowed_customers:
809 raise EventIngestValidationError(
810 [
811 {
812 "type": "customer_id",
813 "msg": "Customer not found.",
814 "loc": ("body", "events", index, "customer_id"),
815 "input": customer_id,
816 }
817 ]
818 )
820 return customer_id
822 return _validate_customer_id
824 async def _resolve_parent( 1a
825 self,
826 session: AsyncSession,
827 index: int,
828 parent_id: str,
829 organization_id: uuid.UUID,
830 batch_external_id_map: dict[str, uuid.UUID],
831 ) -> tuple[Event | None, uuid.UUID | None]:
832 """
833 Resolve and return the parent event.
834 Returns a tuple of (parent_event_from_db, parent_id_from_batch).
835 Only one of these will be set - if the parent is in the current batch,
836 parent_id_from_batch will be set. Otherwise, parent_event_from_db will be set.
837 """
838 # Check if parent is in current batch
839 if parent_id in batch_external_id_map:
840 return None, batch_external_id_map[parent_id]
842 # Look up parent in database by ID or external_id
843 try:
844 parent_uuid = uuid.UUID(parent_id)
845 except ValueError:
846 parent_uuid = None
848 if parent_uuid:
849 statement = select(Event).where(
850 Event.organization_id == organization_id,
851 or_(Event.id == parent_uuid, Event.external_id == parent_id),
852 )
853 else:
854 statement = select(Event).where(
855 Event.organization_id == organization_id,
856 Event.external_id == parent_id,
857 )
859 result = await session.execute(statement)
860 parent_event = result.scalar_one_or_none()
862 if parent_event is not None:
863 return parent_event, None
865 raise EventIngestValidationError(
866 [
867 {
868 "type": "parent_id",
869 "msg": "Parent event not found.",
870 "loc": ("body", "events", index, "parent_id"),
871 "input": parent_id,
872 }
873 ]
874 )
877event = EventService() 1a