252 lines
9.4 KiB
Python
252 lines
9.4 KiB
Python
|
|
from datetime import datetime, timezone
|
||
|
|
from decimal import Decimal
|
||
|
|
|
||
|
|
from sqlalchemy import case, func, select
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
|
from sqlalchemy.orm import selectinload
|
||
|
|
|
||
|
|
from app.experiments.registry import EXPERIMENT_REGISTRY
|
||
|
|
from app.models.experiment import Experiment, ExperimentEvent, ExperimentVariant
|
||
|
|
from app.schemas.experiments import (
|
||
|
|
ExperimentConversionCreate,
|
||
|
|
ExperimentDefinitionResponse,
|
||
|
|
ExperimentDefinitionUpdate,
|
||
|
|
ExperimentEventCreate,
|
||
|
|
ExperimentImpressionCreate,
|
||
|
|
ExperimentResult,
|
||
|
|
ExperimentVariantResult,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
def experiment_exists(experiment_key: str, variant_key: str) -> bool:
|
||
|
|
definition = EXPERIMENT_REGISTRY.get(experiment_key)
|
||
|
|
if not definition:
|
||
|
|
return False
|
||
|
|
return any(variant["variant_key"] == variant_key for variant in definition["variants"])
|
||
|
|
|
||
|
|
|
||
|
|
async def sync_experiment_registry(db: AsyncSession) -> None:
|
||
|
|
result = await db.execute(
|
||
|
|
select(Experiment).options(selectinload(Experiment.variants))
|
||
|
|
)
|
||
|
|
existing = {experiment.experiment_key: experiment for experiment in result.scalars().all()}
|
||
|
|
|
||
|
|
for definition in EXPERIMENT_REGISTRY.values():
|
||
|
|
experiment = existing.get(definition["experiment_key"])
|
||
|
|
existing_variants: dict[str, ExperimentVariant] = {}
|
||
|
|
|
||
|
|
if experiment is None:
|
||
|
|
experiment = Experiment(
|
||
|
|
experiment_key=definition["experiment_key"],
|
||
|
|
cookie_name=definition["cookie_name"],
|
||
|
|
name=definition["name"],
|
||
|
|
description=definition.get("description"),
|
||
|
|
enabled=definition["enabled"],
|
||
|
|
eligible_routes=definition["eligible_routes"],
|
||
|
|
)
|
||
|
|
db.add(experiment)
|
||
|
|
await db.flush()
|
||
|
|
else:
|
||
|
|
existing_variants = {variant.variant_key: variant for variant in experiment.variants}
|
||
|
|
|
||
|
|
for variant_definition in definition["variants"]:
|
||
|
|
variant = existing_variants.get(variant_definition["variant_key"])
|
||
|
|
if variant is None:
|
||
|
|
db.add(
|
||
|
|
ExperimentVariant(
|
||
|
|
experiment_id=experiment.id,
|
||
|
|
variant_key=variant_definition["variant_key"],
|
||
|
|
label=variant_definition["label"],
|
||
|
|
allocation=variant_definition["allocation"],
|
||
|
|
is_control=variant_definition["is_control"],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
|
||
|
|
variant.label = variant_definition["label"]
|
||
|
|
variant.allocation = variant_definition["allocation"]
|
||
|
|
variant.is_control = variant_definition["is_control"]
|
||
|
|
|
||
|
|
await db.flush()
|
||
|
|
|
||
|
|
|
||
|
|
async def list_experiment_definitions(db: AsyncSession) -> list[ExperimentDefinitionResponse]:
|
||
|
|
result = await db.execute(
|
||
|
|
select(Experiment).options(selectinload(Experiment.variants)).order_by(Experiment.experiment_key)
|
||
|
|
)
|
||
|
|
experiments = result.scalars().all()
|
||
|
|
|
||
|
|
return [
|
||
|
|
ExperimentDefinitionResponse(
|
||
|
|
experiment_key=experiment.experiment_key,
|
||
|
|
cookie_name=experiment.cookie_name,
|
||
|
|
name=experiment.name,
|
||
|
|
description=experiment.description,
|
||
|
|
enabled=experiment.enabled,
|
||
|
|
eligible_routes=experiment.eligible_routes,
|
||
|
|
variants=[
|
||
|
|
{
|
||
|
|
"variant_key": variant.variant_key,
|
||
|
|
"label": variant.label,
|
||
|
|
"allocation": variant.allocation,
|
||
|
|
"is_control": variant.is_control,
|
||
|
|
}
|
||
|
|
for variant in experiment.variants
|
||
|
|
],
|
||
|
|
)
|
||
|
|
for experiment in experiments
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
async def record_experiment_event(
|
||
|
|
db: AsyncSession,
|
||
|
|
payload: ExperimentImpressionCreate | ExperimentEventCreate | ExperimentConversionCreate,
|
||
|
|
) -> ExperimentEvent:
|
||
|
|
conversion_value = getattr(payload, "conversion_value", None)
|
||
|
|
timestamp = payload.timestamp
|
||
|
|
if timestamp.tzinfo is None:
|
||
|
|
timestamp = timestamp.replace(tzinfo=timezone.utc)
|
||
|
|
|
||
|
|
event = ExperimentEvent(
|
||
|
|
experiment_key=payload.experiment_key,
|
||
|
|
variant_key=payload.variant_key,
|
||
|
|
session_id=payload.session_id,
|
||
|
|
user_id=payload.user_id,
|
||
|
|
path=payload.path,
|
||
|
|
event_type=payload.event_name,
|
||
|
|
conversion_value=conversion_value,
|
||
|
|
metadata_=payload.metadata,
|
||
|
|
created_at=timestamp.astimezone(timezone.utc).replace(tzinfo=None),
|
||
|
|
)
|
||
|
|
db.add(event)
|
||
|
|
await db.flush()
|
||
|
|
await db.refresh(event)
|
||
|
|
return event
|
||
|
|
|
||
|
|
|
||
|
|
async def get_experiment_results(db: AsyncSession, experiment_key: str | None = None) -> list[ExperimentResult]:
|
||
|
|
stmt = select(
|
||
|
|
ExperimentEvent.experiment_key,
|
||
|
|
ExperimentEvent.variant_key,
|
||
|
|
func.sum(case((ExperimentEvent.event_type == "impression", 1), else_=0)).label("impressions"),
|
||
|
|
func.sum(case((ExperimentEvent.event_type == "cta_click", 1), else_=0)).label("cta_clicks"),
|
||
|
|
func.sum(case((ExperimentEvent.event_type == "form_start", 1), else_=0)).label("form_starts"),
|
||
|
|
func.sum(case((ExperimentEvent.event_type == "form_submit", 1), else_=0)).label("form_submits"),
|
||
|
|
func.sum(case((ExperimentEvent.event_type == "conversion", 1), else_=0)).label("conversions"),
|
||
|
|
func.count(func.distinct(ExperimentEvent.session_id)).label("unique_sessions"),
|
||
|
|
func.coalesce(func.sum(ExperimentEvent.conversion_value), Decimal("0")).label("conversion_value_total"),
|
||
|
|
).group_by(ExperimentEvent.experiment_key, ExperimentEvent.variant_key).order_by(
|
||
|
|
ExperimentEvent.experiment_key,
|
||
|
|
ExperimentEvent.variant_key,
|
||
|
|
)
|
||
|
|
|
||
|
|
if experiment_key:
|
||
|
|
stmt = stmt.where(ExperimentEvent.experiment_key == experiment_key)
|
||
|
|
|
||
|
|
result = await db.execute(stmt)
|
||
|
|
rows = result.all()
|
||
|
|
|
||
|
|
grouped: dict[str, list[ExperimentVariantResult]] = {}
|
||
|
|
for row in rows:
|
||
|
|
impressions = int(row.impressions or 0)
|
||
|
|
conversions = int(row.conversions or 0)
|
||
|
|
conversion_rate = conversions / impressions if impressions else 0.0
|
||
|
|
|
||
|
|
grouped.setdefault(row.experiment_key, []).append(
|
||
|
|
ExperimentVariantResult(
|
||
|
|
variant_key=row.variant_key,
|
||
|
|
impressions=impressions,
|
||
|
|
cta_clicks=int(row.cta_clicks or 0),
|
||
|
|
form_starts=int(row.form_starts or 0),
|
||
|
|
form_submits=int(row.form_submits or 0),
|
||
|
|
conversions=conversions,
|
||
|
|
unique_sessions=int(row.unique_sessions or 0),
|
||
|
|
conversion_rate=round(conversion_rate, 4),
|
||
|
|
conversion_value_total=float(row.conversion_value_total or 0),
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
generated_at = datetime.now(timezone.utc)
|
||
|
|
return [
|
||
|
|
ExperimentResult(
|
||
|
|
experiment_key=key,
|
||
|
|
generated_at=generated_at,
|
||
|
|
variants=variants,
|
||
|
|
)
|
||
|
|
for key, variants in grouped.items()
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
async def get_experiment_definition(db: AsyncSession, experiment_key: str) -> Experiment | None:
|
||
|
|
result = await db.execute(
|
||
|
|
select(Experiment)
|
||
|
|
.options(selectinload(Experiment.variants))
|
||
|
|
.where(Experiment.experiment_key == experiment_key)
|
||
|
|
)
|
||
|
|
return result.scalars().first()
|
||
|
|
|
||
|
|
|
||
|
|
async def upsert_experiment_definition(
|
||
|
|
db: AsyncSession,
|
||
|
|
experiment_key: str,
|
||
|
|
payload: ExperimentDefinitionUpdate,
|
||
|
|
) -> Experiment:
|
||
|
|
experiment = await get_experiment_definition(db, experiment_key)
|
||
|
|
|
||
|
|
duplicate_cookie = await db.execute(
|
||
|
|
select(Experiment).where(
|
||
|
|
Experiment.cookie_name == payload.cookie_name,
|
||
|
|
Experiment.experiment_key != experiment_key,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
if duplicate_cookie.scalars().first():
|
||
|
|
raise ValueError("cookie_name is already used by another experiment")
|
||
|
|
|
||
|
|
if experiment is None:
|
||
|
|
experiment = Experiment(
|
||
|
|
experiment_key=experiment_key,
|
||
|
|
cookie_name=payload.cookie_name,
|
||
|
|
name=payload.name,
|
||
|
|
description=payload.description,
|
||
|
|
enabled=payload.enabled,
|
||
|
|
eligible_routes=payload.eligible_routes,
|
||
|
|
)
|
||
|
|
db.add(experiment)
|
||
|
|
await db.flush()
|
||
|
|
existing_variants: dict[str, ExperimentVariant] = {}
|
||
|
|
else:
|
||
|
|
experiment.cookie_name = payload.cookie_name
|
||
|
|
experiment.name = payload.name
|
||
|
|
experiment.description = payload.description
|
||
|
|
experiment.enabled = payload.enabled
|
||
|
|
experiment.eligible_routes = payload.eligible_routes
|
||
|
|
existing_variants = {variant.variant_key: variant for variant in experiment.variants}
|
||
|
|
|
||
|
|
incoming_keys = {variant.variant_key for variant in payload.variants}
|
||
|
|
for variant in list(existing_variants.values()):
|
||
|
|
if variant.variant_key not in incoming_keys:
|
||
|
|
await db.delete(variant)
|
||
|
|
|
||
|
|
for variant_payload in payload.variants:
|
||
|
|
variant = existing_variants.get(variant_payload.variant_key)
|
||
|
|
if variant is None:
|
||
|
|
db.add(
|
||
|
|
ExperimentVariant(
|
||
|
|
experiment_id=experiment.id,
|
||
|
|
variant_key=variant_payload.variant_key,
|
||
|
|
label=variant_payload.label,
|
||
|
|
allocation=variant_payload.allocation,
|
||
|
|
is_control=variant_payload.is_control,
|
||
|
|
)
|
||
|
|
)
|
||
|
|
continue
|
||
|
|
|
||
|
|
variant.label = variant_payload.label
|
||
|
|
variant.allocation = variant_payload.allocation
|
||
|
|
variant.is_control = variant_payload.is_control
|
||
|
|
|
||
|
|
await db.flush()
|
||
|
|
refreshed = await get_experiment_definition(db, experiment_key)
|
||
|
|
assert refreshed is not None
|
||
|
|
return refreshed
|